当前位置: 首页 > news >正文

万州做网站多少钱扬州网站建设 开元

万州做网站多少钱,扬州网站建设 开元,华企立方网站,微信生活门户网站源码前言 ChatGPT出来后的两年多#xff0c;也是我疯狂写博的两年多(年初deepseek更引爆了下)#xff0c;比如从创业起步时的15年到后来22年之间 每年2-6篇的#xff0c;干到了23年30篇、24年65篇、25年前两月18篇#xff0c;成了我在大模型和具身的原始技术积累 如今一转眼…前言  ChatGPT出来后的两年多也是我疯狂写博的两年多(年初deepseek更引爆了下)比如从创业起步时的15年到后来22年之间 每年2-6篇的干到了23年30篇、24年65篇、25年前两月18篇成了我在大模型和具身的原始技术积累 如今一转眼已到25年3月初时光走得太快近期和团队接了好几个大客户订单使得3月起 不得不全力加速落地自己也得每天抠paper、搞代码 虽然今年可能没法像去年24年那样干65篇不过我还是争取保持月月更新 一方面有些文章是之前既定计划中的比如如此文《π0开源了且推出自回归版π0-FAST——打造机器人动作专用的高效Tokenizer比扩散π0的训练速度快5倍但效果相当》最后所说的对π0源码的解读 「至于什么是π0详见此文《π0——用于通用机器人控制的VLA模型一套框架控制7种机械臂(基于PaliGemma和流匹配的3B模型)》」二方面我司「七月在线」在做一系列工厂落地场景的过程中我们也希望团结到可以和我们一块做的朋友而若想团结便需要对外分享我们每个季度在重点做的业务场景 比如过去一周我把lerobot、reflect vlm、π0的仿真环境都在我自己本地电脑上跑了下(过程中GitHub copilot这种AI编程工具在环境的安装上帮了我很大的忙——各种环境 只要几句命令直接帮我装好真心不错) 如此硬着头皮冥思苦想、摸索了好几天随后使得我自己知道怎么带队完成『太多工厂希望实现的一个生产线任务』了3月初先仿真训练2-3个月内部署到真机 当然了也不单纯只是「这几天的想」就能想出来的​这几天之前 有把过去一年当三年用的具身技术积累有一年多来和同事们 如姚博士以及朋友们许多的讨论有去年十几个工厂对我们的支持与信任 我们正在不断壮大队伍 有我司内部同事亦有我带的北理、中南等985的具身研究生及一块合作开发的朋友很快会把多个生产线任务并行开发起来且无论哪个项目都是不断长期迭代的故过程中少不了科研层面的突破欢迎更多伙伴加入我们(全、兼、实习皆可有意者敬请私我)和我们一块开发 话休絮烦本文便按照如下图所示的源码结构重点解读一下π的整个源码 π0的源码结构非常清晰、可读性高不愧是成熟的商业化公司是我司七月的学习榜样之一我身边的很多朋友目前都在做π0的微调及二次开发相信本文无论对我身边的朋友还是对更多人的学习与工作都会起到比较大的提升 目录 前言  第一部分 examples、packages、scripts等结构的分析 1.1 examples 各种机器人平台的示例实现 1.2 packages 1.3 scripts包含数据处理、模型训练/推理的多个脚本 1.3.1 __init__.py 1.3.2 compute_norm_stats.py计算数据的归一化统计信息 1.3.3 serve_policy.py启动策略服务用于模型推理 1.3.4 train_test.py训练和测试模型 1.3.5 train.py训练模型 1.3.6 scripts/docker 第二部分 核心模块src下models的全面分析与解读 2.1 models/pi0.py的实现 2.1.1 make_attn_mask注意力掩码生成函数 2.1.2 posemb_sincos位置编码函数 2.1.3 class Pi0Config含inputs_spec、get_freeze_filter 2.1.3.1 模型配置参数的定义 2.1.3.2 inputs_spec定义了π0模型本身接收的输入数据格式​编辑 2.1.3.3 get_freeze_filter针对是否LoRA的处理 2.1.4 class Pi0初始化、特征嵌入、损失函数、推理(去噪生成动作) 2.1.4.1 初始化方法 __init__ 2.1.4.2 特征嵌入方法embed_prefix(图像和文本输入)、embed_suffix(状态和动作信息)​编辑 2.1.4.3 损失函数 compute_loss 2.1.4.4 推理函数 sample_actions基于扩散模型逆向采样生成机器人动作序列 第一部分 examples、packages、scripts等结构的分析 1.1 examples 各种机器人平台的示例实现 根据π0对应examples模块的结构 其涉及以下模块 aloha_real/真实机器人ALOHA的示例aloha_sim/ALOHA模拟器的示例droid/DROID机器人的示例libero/LIBERO基准测试的示例simple_client/简单客户端的示例ur5/UR5机器人的示例inference.ipynb推理示例的Jupyter Notebookpolicy_records.ipynb策略记录示例的Jupyter Notebook 1.2 packages 该模块的目录结构如下 1.3 scripts包含数据处理、模型训练/推理的多个脚本 根据下图 可知scripts 目录包含多个 Python 脚本这些脚本用于数据处理、模型训练和服务部署等任务每个脚本通常对应一个特定的功能或任务 __init__.pycompute_norm_stats.py: 计算数据的归一化统计信息serve_policy.py: 启动策略服务提供模型推理接口train_test.py: 训练和测试模型train.py: 训练模型 1.3.1 __init__.py 1.3.2 compute_norm_stats.py计算数据的归一化统计信息 1.3.3 serve_policy.py启动策略服务用于模型推理 在这个代码片段中首先导入了一些必要的模块和库包括 policy、policy_config、websocket_policy_server 和 config这些模块来自 openpi 项目 from openpi.policies import policy as _policy # 导入 openpi.policies.policy 模块并重命名为 _policy from openpi.policies import policy_config as _policy_config # 导入 openpi.policies.policy_config 模块并重命名为 _policy_config from openpi.serving import websocket_policy_server # 导入 openpi.serving.websocket_policy_server 模块 from openpi.training import config as _config # 导入 openpi.training.config 模块并重命名为 _config 接下来定义了一个枚举类 EnvMode它表示支持的环境类型包括 ALOHA、ALOHA_SIM、DROID 和 LIBERO class EnvMode(enum.Enum):支持的环境。ALOHA aloha # ALOHA 环境ALOHA_SIM aloha_sim # ALOHA 模拟环境DROID droid # DROID 环境LIBERO libero # LIBERO 环境 然后定义了几个数据类 Checkpoint 类用于从训练好的检查点加载策略包含两个字段config训练配置名称和 dir检查点目录 Default 类表示使用默认策略 Args 类定义了脚本的参数包括环境类型、默认提示、端口、是否记录策略行为以及如何加载策略接下来定义了一个字典 DEFAULT_CHECKPOINT它为每个环境类型指定了默认的检查点配置 # 每个环境应使用的默认检查点 DEFAULT_CHECKPOINT: dict[EnvMode, Checkpoint] {EnvMode.ALOHA: Checkpoint(configpi0_aloha,dirs3://openpi-assets/checkpoints/pi0_base,),EnvMode.ALOHA_SIM: Checkpoint(configpi0_aloha_sim,dirs3://openpi-assets/checkpoints/pi0_aloha_sim,),EnvMode.DROID: Checkpoint(configpi0_fast_droid,dirs3://openpi-assets/checkpoints/pi0_fast_droid,),EnvMode.LIBERO: Checkpoint(configpi0_fast_libero,dirs3://openpi-assets/checkpoints/pi0_fast_libero,), } create_default_policy 函数根据环境类型创建默认策略如果环境类型不支持则抛出异常 def create_default_policy(env: EnvMode, *, default_prompt: str | None None) - _policy.Policy:为给定环境创建默认策略 if checkpoint : DEFAULT_CHECKPOINT.get(env): # 获取环境对应的默认检查点return _policy_config.create_trained_policy(_config.get_config(checkpoint.config), checkpoint.dir, default_promptdefault_prompt) # 创建训练好的策略raise ValueError(fUnsupported environment mode: {env}) # 如果环境不支持抛出异常 create_policy 函数根据传入的参数创建策略如果参数中指定了检查点则从检查点加载策略否则使用默认策略 def create_policy(args: Args) - _policy.Policy:根据给定的参数创建策略 match args.policy: # 匹配策略类型case Checkpoint(): # 如果是 Checkpoint 类型return _policy_config.create_trained_policy(_config.get_config(args.policy.config), args.policy.dir, default_promptargs.default_prompt) # 创建训练好的策略case Default(): # 如果是 Default 类型return create_default_policy(args.env, default_promptargs.default_prompt) # 创建默认策略 main 函数是脚本的入口点它首先调用 create_policy 函数创建策略然后记录策略的元数据 def main(args: Args) - None:policy create_policy(args) # 创建策略policy_metadata policy.metadata # 获取策略的元数据 如果参数中指定了记录策略行为则使用 PolicyRecorder 包装策略 # 记录策略的行为if args.record:# 使用 PolicyRecorder 记录策略行为policy _policy.PolicyRecorder(policy, policy_records) 接着获取主机名和本地 IP 地址 hostname socket.gethostname() # 获取主机名local_ip socket.gethostbyname(hostname) # 获取本地 IP 地址logging.info(Creating server (host: %s, ip: %s), hostname, local_ip) # 记录服务器创建信息 并创建一个 WebSocket 服务器来提供策略服务最后调用 serve_forever 方法启动服务器 server websocket_policy_server.WebsocketPolicyServer(policypolicy,host0.0.0.0,portargs.port,metadatapolicy_metadata,) # 创建 WebSocket 策略服务器server.serve_forever() # 启动服务器永远运行 在脚本的最后使用 logging 模块配置日志记录并调用 main 函数启动脚本参数通过 tyro.cli 解析 1.3.4 train_test.py训练和测试模型 1.3.5 train.py训练模型 1.3.6 scripts/docker 好的下面是对 openpi-main/scripts/docker 目录的详细分析。这个目录通包含与 Docker 相关的脚本和配置文件用于构建和管理 Docker 容器具体而言包含以下文件和子目录 主要文件和功能如下所示 docker/compose.ymldocker/install_docker_ubuntu22.shdocker/install_nvidia_container_toolkit.shdocker/serve_policy.Dockerfile // 待更 第二部分 核心模块src下models的全面分析与解读 接下来我们来看核心src下的各个模块 首先是其中的src/openpi/models 2.1 models/pi0.py的实现 它结合了多模态输入图像和文本来生成机器人动作序列。下面是对代码的详细解析 2.1.1 make_attn_mask注意力掩码生成函数 这个函数生成transformer中使用的注意力掩码控制 token 之间的注意力流动方式 def make_attn_mask(input_mask, mask_ar):从big_vision项目改编的注意力掩码生成函数Token可以关注那些累积mask_ar小于等于自己的有效输入token。这样mask_ar bool[?B, N]可用于设置几种类型的注意力例如[[1 1 1 1 1 1]]: 纯因果注意力。[[0 0 0 1 1 1]]: 前缀语言模型注意力。前3个token之间可以互相关注后3个token有因果注意力。第一个条目也可以是1不改变行为。[[1 0 1 0 1 0 0 1 0 0]]: 4个块之间的因果注意力。一个块的token可以关注所有之前的块和同一块内的所有token。参数:input_mask: bool[B, N] 如果是输入的一部分则为true如果是填充则为falsemask_ar: bool[?B, N] 如果前面的token不能依赖于它则为true如果它共享与前一个token相同的注意力掩码则为false# 将mask_ar广播到与input_mask相同的形状mask_ar jnp.broadcast_to(mask_ar, input_mask.shape) # 计算mask_ar在序列维度上的累积和cumsum jnp.cumsum(mask_ar, axis1) # 创建注意力掩码当目标位置的累积值查询位置的累积值时允许注意力流动attn_mask cumsum[:, None, :] cumsum[:, :, None] # 创建有效掩码只有有效的输入位置之间才能有注意力valid_mask input_mask[:, None, :] * input_mask[:, :, None] # 结合注意力掩码和有效掩码return jnp.logical_and(attn_mask, valid_mask) 它支持多种注意力模式 纯因果注意力每个 token 只能关注自己和之前的 token前缀语言模型注意力允许前缀内部自由注意后缀部分使用因果注意力块状因果注意力在块内自由注意块之间是因果的 2.1.2 posemb_sincos位置编码函数 使用正弦余弦函数实现位置编码 def posemb_sincos(pos: at.Real[at.Array, Any], embedding_dim: int, min_period: float, max_period: float ) - at.Float[at.Array, fb {embedding_dim}]:计算标量位置的正弦余弦位置嵌入向量if embedding_dim % 2 ! 0: # 检查嵌入维度是否为偶数raise ValueError(fembedding_dim ({embedding_dim}) must be divisible by 2)fraction jnp.linspace(0.0, 1.0, embedding_dim // 2) # 创建均匀分布的分数值period min_period * (max_period / min_period) ** fraction # 计算周期值对数空间中均匀分布sinusoid_input jnp.einsum(i,j-ij,pos,1.0 / period * 2 * jnp.pi, # 计算角频率precisionjax.lax.Precision.HIGHEST, # 使用最高精度进行计算)# 连接sin和cos值形成完整的位置编码return jnp.concatenate([jnp.sin(sinusoid_input), jnp.cos(sinusoid_input)], axis-1) 2.1.3 class Pi0Config含inputs_spec、get_freeze_filter 2.1.3.1 模型配置参数的定义 首先这个类定义了模型的配置参数比如PaLI-Gemma 变体gemma_2b class Pi0Config(_model.BaseModelConfig):dtype: str bfloat16 # 设置数据类型为bfloat16paligemma_variant: _gemma.Variant gemma_2b # 设置PaLI-Gemma变体为2B参数版本action_expert_variant: _gemma.Variant gemma_300m # 设置动作专家变体为300M参数版本# 设置模型特定的默认值action_dim: int 32 # 设置动作维度为32action_horizon: int 50 # 设置动作序列长度为50步max_token_len: int 48 # 设置最大token长度为48 2.1.3.2 inputs_spec定义了π0模型本身接收的输入数据格式 其次通过inputs_spec函数定义了π0模型本身接收的输入数据格式函数采用关键字参数 batch_size默认为1返回一个包含观察规格和动作规格的元组 def inputs_spec(self, *, batch_size: int 1) - Tuple[Type[_model.Observation], Type[_model.Actions]] 其支持多种输入比如 视觉输入(三个不同视角的RGB图像)、语言输入(分词后的文本prompt)、状态输入(当前机器人状态)输出上 则是一个时序动作序列(包含50个连续的动作向量每个动作向量有32个维度可能对应关节角度或其他控制信号) 具体而言该函数先 创建图像规格 image_spec jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32) 其中的 [batch_size, *_model.IMAGE_RESOLUTION, 3] 定义了图像张量的形状比如   批次大小   图像分辨率从 _model.IMAGE_RESOLUTION 获取可能是如 [224, 224] 这样的值   3 个颜色通道 (RGB)jnp.float32 指定了数据类型为 32 位浮点数 创建图像掩码规格 image_mask_spec jax.ShapeDtypeStruct([batch_size], jnp.bool_) 其定义了图像掩码规格每个批次中的每个图像都有一个布尔值这个掩码用于指示哪些图像是有效的True或无效的False 创建观察规格包含视觉输入、机器人状态、指令输入 at.disable_typechecking() 临时禁用类型检查可能是因为这里创建的是类型规格而不是实际的数据且观察规格包含多个组件 多视角图像 base_0_rgb: 机器人底座/身体视角的RGB图像 left_wrist_0_rgb: 左手腕视角的RGB图像 right_wrist_0_rgb: 右手腕视角的RGB图像 with at.disable_typechecking():observation_spec _model.Observation(images{base_0_rgb: image_spec,left_wrist_0_rgb: image_spec,right_wrist_0_rgb: image_spec,}, 图像掩码 对应每个视角图像的有效性掩码机器人状态 形状为 [batch_size, self.action_dim] 的浮点数张量 self.action_dim 默认为32表示状态向量的维度 statejax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32), 分词后的文本prompt 形状为 [batch_size, self.max_token_len] 的整数张量 self.max_token_len 默认为48表示最大token数量 数据类型为 jnp.int32表示token ID提示掩码 与分词提示相同形状的布尔张量用于指示哪些位置有有效的token statejax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32),tokenized_promptjax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),tokenized_prompt_maskjax.ShapeDtypeStruct([batch_size, self.max_token_len], bool),) 创建动作规格 action_spec jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32) 其定义了动作数据的形状和类型 batch_size: 批次大小self.action_horizon: 动作序列长度默认为50 self.action_dim: 每个动作的维度默认为32jnp.float32 指定了数据类型为32位浮点数 然后返回 return observation_spec, action_spec 2.1.3.3 get_freeze_filter针对是否LoRA的处理 此外该配置类还实现了get_freeze_filter这个函数作用是如果选择LoRA微调(冻结原始预训练模型的参数只更新新添加的低秩适应层参数)则需要对模型中的某些参数做冻结 三种可能的情况 只对 PaLI-Gemma 使用 LoRA冻结 Gemma 参数但排除动作专家参数只对动作专家使用 LoRA冻结动作专家参数对两者都使用 LoRA冻结两者的基础参数 如此可以选择性地微调模型的特定部分(语言部分或动作预测部分 具体而言 首先定义函数 def get_freeze_filter(self) - nnx.filterlib.Filter:返回基于模型配置的冻结过滤器 其次初始化变量 filters [] # 初始化过滤器列表has_lora False # 初始化LoRA标志 接着创建参数过滤器 # 匹配所有LLM参数的正则表达式用于选择 Gemma 语言模型的参数gemma_params_filter nnx_utils.PathRegex(.*llm.*) # 匹配动作专家参数的正则表达式action_expert_params_filter nnx_utils.PathRegex(.*llm.*_1.*) 接下来是对PaLI-Gemma变体的处理 # 如果PaLI-Gemma使用LoRAif lora in self.paligemma_variant:filters.append(gemma_params_filter, # 添加Gemma参数过滤器)if lora not in self.action_expert_variant:# 如果只冻结Gemma参数排除动作专家参数filters.append(nnx.Not(action_expert_params_filter),)has_lora True 再下来是对动作专家变体的处理 elif lora in self.action_expert_variant:# 如果动作专家使用LoRAfilters.append(action_expert_params_filter,)has_lora True 2.1.4 class Pi0初始化、特征嵌入、损失函数、推理(去噪生成动作) 核心模型类继承自 _model.BaseModel实现了 多模态输入处理 处理多视角图像基础视角、左手腕视角、右手腕视角 处理文本提示如指令 处理机器人当前状态扩散过程 训练时将干净动作添加噪声让模型学习去噪 推理时从纯噪声开始逐步降噪生成动作序列注意力机制 使用精心设计的注意力掩码控制信息流动 前缀图像和文本内部使用全注意力 后缀状态和动作使用特殊的注意力模式 2.1.4.1 初始化方法 __init__ class Pi0(_model.BaseModel):def __init__(self, config: Pi0Config, rngs: nnx.Rngs):# 初始化基类super().__init__(config.action_dim, config.action_horizon, config.max_token_len)# 获取PaLI-Gemma和动作专家配置paligemma_config _gemma.get_config(config.paligemma_variant)action_expert_config _gemma.get_config(config.action_expert_variant) 其组合了多个核心组件 一个是PaLI-Gemma 模型结合了 Gemma 语言模型和 SigLIP 视觉模型 先是对语言模型的初始化 # 创建并初始化语言模型# TODO: 用NNX重写Gemma目前使用桥接llm nnx_bridge.ToNNX(_gemma.Module(configs[paligemma_config, action_expert_config], # 配置两个Gemma模型embed_dtypeconfig.dtype, # 设置嵌入数据类型))llm.lazy_init(rngsrngs, methodinit) # 延迟初始化LLM 然后是对视觉模型的初始化 # 创建并初始化图像模型img nnx_bridge.ToNNX(_siglip.Module(num_classespaligemma_config.width, # 设置图像特征维度与语言模型宽度相匹配variantSo400m/14, # 使用400M参数SigLIP模型pool_typenone, # 不使用池化保留所有图像标记scanTrue, # 启用扫描优化dtype_mmconfig.dtype, # 设置矩阵乘法数据类型))# 使用假观察中的图像初始化图像模型img.lazy_init(next(iter(config.fake_obs().images.values())), trainFalse, rngsrngs) 最后把语言模型和视觉模型组合成PaLI-Gemma多模态模型 # 组合LLM和图像模型为PaLI-Gemma多模态模型self.PaliGemma nnx.Dict(llmllm, imgimg) 另一个是线性投影层用于 状态投影 # 状态投影层将机器人状态投影到模型维度self.state_proj nnx.Linear(config.action_dim, action_expert_config.width, rngsrngs) 动作投影 # 动作输入投影层将动作投影到模型维度self.action_in_proj nnx.Linear(config.action_dim, action_expert_config.width, rngsrngs) 时间-动作混合等 # 动作-时间MLP输入层将连接的动作和时间特征投影到模型维度self.action_time_mlp_in nnx.Linear(2 * action_expert_config.width, action_expert_config.width, rngsrngs)# 动作-时间MLP输出层self.action_time_mlp_out nnx.Linear(action_expert_config.width, action_expert_config.width, rngsrngs)# 动作输出投影层将模型输出投影回动作维度self.action_out_proj nnx.Linear(action_expert_config.width, config.action_dim, rngsrngs) 2.1.4.2 特征嵌入方法embed_prefix(图像和文本输入)、embed_suffix(状态和动作信息) embed_prefix处理图像和文本输入(图像通过SigLip模型编码文本通过Gemma LLM编码)创建前缀 token皆为双向注意力用ar_mask false表示embed_suffix处理机器人状态信息、噪声化的动作信息(状态和噪声动作经过线性投影和MLP处理)创建后缀 token 其中 状态为单个token和第一个动作token均设置为单向注意力用ar_mask true表示 其余动作tokens之间设置为双向注意力用ar_mask false表示 对于前者embed_prefix def embed_prefix(self, obs: _model.Observation) - Tuple[at.Float[at.Array, Any], at.Bool[at.Array, Any], at.Bool[at.Array, Any]]:嵌入前缀部分图像和文本input_mask [] # 初始化输入掩码列表ar_mask [] # 初始化自回归掩码列表tokens [] # 初始化token列表 其工作流程为 图像处理说白了就是把图像token化 使用SigLip视觉模型处理每个图像生成图像tokens序列 # 嵌入图像for name in obs.images:# 通过图像模型获取图像tokenimage_tokens, _ self.PaliGemma.img(obs.images[name], trainFalse)tokens.append(image_tokens) # 添加图像token 图像掩码扩展 将图像掩码扩展到与图像tokens相同的序列长度使用einops.repeat进行形状变换这些掩码会指示哪些图像是有效的而哪些是填充的 # 重复图像掩码以匹配token维度input_mask.append(einops.repeat(obs.image_masks[name],b - b s, # 调整形状批次维度保持不变添加序列维度simage_tokens.shape[1], # 序列长度等于图像token数)) 自回归掩码设置 设置图像tokens之间的注意力为双向(False表示双向注意力)原因在于图像内容通常是非时序性的数据 # 图像token之间互相关注非自回归ar_mask [False] * image_tokens.shape[1] 文本处理 使用LLM模型对文本输入tokenized_inputs进行嵌入 # 添加语言即分词后的输入if obs.tokenized_prompt is not None:# 通过语言模型嵌入分词后的提示tokenized_inputs self.PaliGemma.llm(obs.tokenized_prompt, methodembed)tokens.append(tokenized_inputs) # 添加文本tokeninput_mask.append(obs.tokenized_prompt_mask) # 添加提示掩码 且同样设置为双向注意力相当于语言token可以关注图像token图像token反过来亦可关注语言token最终实现多模态融合 # 图像和语言输入之间完全关注非自回归ar_mask [False] * tokenized_inputs.shape[1] 最后连接所有token和掩码其中包含了   多模态信息的融合表示tokens——图像token和语言token   以及指示哪些token是有效信息的input_mask   和如何在这些token之间进行注意力计算规则的ar_mask——相当于控制信息流动的方向 # 连接所有token和掩码tokens jnp.concatenate(tokens, axis1) # 在序列维度上连接tokeninput_mask jnp.concatenate(input_mask, axis1) # 在序列维度上连接输入掩码ar_mask jnp.array(ar_mask) # 转换自回归掩码为数组return tokens, input_mask, ar_mask # 返回token、输入掩码和自回归掩码 顺便再回顾下此图 对于后者embed_suffix def embed_suffix(self, obs: _model.Observation, noisy_actions: _model.Actions, timestep: at.Float[at.Array, Any]) - Tuple[at.Float[at.Array, Any], at.Bool[at.Array, Any], at.Bool[at.Array, Any]]:嵌入后缀部分状态和动作input_mask [] # 初始化输入掩码列表ar_mask [] # 初始化自回归掩码列表tokens [] # 初始化token列表 其工作流程为 状态处理 将状态信息投影到embedding空间 # 添加单个状态tokenstate_token self.state_proj(obs.state)[:, None, :] # 投影状态并添加序列维度tokens.append(state_token) # 添加状态token# 添加状态掩码全为1表示这个状态token是有效的input_mask.append(jnp.ones((obs.state.shape[0], 1), dtypejnp.bool_)) 并设置为单向注意力(True)表明图像和语言输入不能关注状态信息因为image/language do not attend to state or actions # 图像/语言输入不关注状态或动作自回归ar_mask [True] 时间步嵌入使用正弦-余弦位置编码生成时间步嵌入 # 使用正弦余弦位置编码嵌入时间步敏感度范围为[0, 1]time_emb posemb_sincos(timestep, self.action_in_proj.out_features, min_period4e-3, max_period4.0) 动作和时间信息融合 # 混合时间步动作信息使用MLPaction_tokens self.action_in_proj(noisy_actions) # 投影带噪声的动作# 重复时间嵌入以匹配动作序列长度time_tokens einops.repeat(time_emb, b emb - b s emb, sself.action_horizon)# 连接动作和时间tokenaction_time_tokens jnp.concatenate([action_tokens, time_tokens], axis-1) MLP处理 使用两层MLP和swish激活函数对「动作和时间的组合表示」进行非线性变换以进一步融合动作和时间信息 # 通过MLP处理action_time_tokens self.action_time_mlp_in(action_time_tokens) # 输入层action_time_tokens nnx.swish(action_time_tokens) # Swish激活函数action_time_tokens self.action_time_mlp_out(action_time_tokens) # 输出层 注意力掩码设置 第一个动作token设置为单向注意力「上面说过了的单向注意力用ar_mask true表示」其余动作tokens之间设置为双向注意力 # 添加动作时间tokentokens.append(action_time_tokens)# 添加掩码全为1表示所有动作token都是有效的input_mask.append(jnp.ones(action_time_tokens.shape[:2], dtypejnp.bool_)) # 图像/语言/状态输入不关注动作token动作第一个是自回归的——单向其余不是——双向ar_mask [True] ([False] * (self.action_horizon - 1)) 最后连接所有token和掩码 # 连接所有token和掩码tokens jnp.concatenate(tokens, axis1) # 在序列维度上连接tokeninput_mask jnp.concatenate(input_mask, axis1) # 在序列维度上连接输入掩码ar_mask jnp.array(ar_mask) # 转换自回归掩码为数组return tokens, input_mask, ar_mask # 返回token、输入掩码和自回归掩码 2.1.4.3 损失函数 compute_loss 实现了扩散模型的训练损失计算 对输入观察进行预处理其中 preprocess_rng用于观察预处理(比如图像增强等) noise_rng用于生成噪声 time_rng用于从beta分布采样时间步 def compute_loss(self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool False) - at.Float[at.Array, Any]:计算扩散模型的损失函数# 分割随机数生成器为三部分用于不同的随机操作preprocess_rng, noise_rng, time_rng jax.random.split(rng, 3) 生成随机噪声并采样时间点 t # 获取动作的批次形状batch_shape actions.shape[:-2]# 生成与动作相同形状的高斯噪声noise jax.random.normal(noise_rng, actions.shape)# 从Beta分布采样时间点范围为[0.001, 1]Beta(1.5, 1)偏向较低的值time jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 0.001# 扩展时间维度以匹配动作形状time_expanded time[..., None, None] 创建带噪动作序列 x_t相当于x_t是噪声化的动作随着时间从0到1原始动作逐渐加噪变为纯噪声 而u_t代表所加的真实噪声而咱们就是要预测所添加的噪声(而所添加的噪声即等于加满噪声的动作 - 原始动作) # 创建带噪声的动作t*noise (1-t)*actionsx_t time_expanded * noise (1 - time_expanded) * actions# 计算真实噪声减去动作的差异这是模型需要预测的目标u_t noise - actions 扩散策略diffusion policy的灵感来源于图像生成中的扩散模型DDPM通过逐步去除噪声来生成目标数据(比如机器人的动作序列)如果对DDPM原理不太明白的详见此文《图像生成发展起源从VAE、扩散模型DDPM、DDIM到DETR、ViT、Swin transformer》嵌入前缀和后缀 # 一次性前向传递前缀后缀# 嵌入前缀图像和文本prefix_tokens, prefix_mask, prefix_ar_mask self.embed_prefix(observation)# 嵌入后缀状态和带噪声的动作suffix_tokens, suffix_mask, suffix_ar_mask self.embed_suffix(observation, x_t, time) 构建注意力掩码和位置编码 根据下图 可得 # 连接掩码通过链接前缀和后缀的掩码从而创建完整的输入掩码input_mask jnp.concatenate([prefix_mask, suffix_mask], axis1)ar_mask jnp.concatenate([prefix_ar_mask, suffix_ar_mask], axis0)# 创建注意力掩码make_attn_mask从而控制不同token之间的可见性attn_mask make_attn_mask(input_mask, ar_mask)# 计算位置编码positions jnp.cumsum(input_mask, axis1) - 1 模型前向传播即使用PaliGemma进行推理处理前缀和后缀token 当然了输出中我们只关注与后缀相关的部分因为其中包含了我们想要的动作预测的部分 # 通过PaLI-Gemma模型处理token_, suffix_out self.PaliGemma.llm([prefix_tokens, suffix_tokens], maskattn_mask, positionspositions) 预测噪声v_t # 将模型输出投影回动作空间v_t self.action_out_proj(suffix_out[:, -self.action_horizon :]) 计算预测噪声与实际噪声间的均方误差 # 返回预测噪声和真实噪声之间的均方误差return jnp.mean(jnp.square(v_t - u_t), axis-1) 2.1.4.4 推理函数 sample_actions基于扩散模型逆向采样生成机器人动作序列 sample_actions函数是Pi0模型的核心推理方法实现了基于扩散模型的逆向采样过程——说白了 就是去噪它从纯噪声开始通过多步骤逐渐去噪最终生成符合条件分布的机器人动作序列 函数的核心是一个基于while循环的迭代过程每一步都使用训练好的神经网络预测从当前噪声化动作到目标动作的方向——从噪声到目标的方向 代表速度场毕竟咱们去噪的方向得对 不然就去歪了 总之这个函数将观察数据图像和可选的文本提示转换为具体的动作轨迹是模型部署时的主要接口简言之其包含以下流程 首先从纯噪声开始 (t1)通过重复迭代降噪步骤逐步将噪声转化为有意义的动作序列使用KV缓存优化推理速度实现了一个迭代降噪过程最终返回完全降噪后的动作序列 x_0 具体而言包含如下步骤 第一初始化 首先函数对输入观察数据进行预处理包括标准化图像大小等操作 def sample_actions(self,rng: at.KeyArrayLike, # 随机数生成器observation: _model.Observation, # 观察输入包含图像和文本等*,num_steps: int 10, # 扩散过程的步数默认为10步 ) - _model.Actions: # 返回生成的动作序列# 对观察数据进行预处理不进行训练时的数据增强observation _model.preprocess_observation(None, observation, trainFalse) 然后设置时间步长dt为负值因为我们是从t1向t0方向演化生成初始随机噪声作为起点且时间上约定t1是噪声t0是目标分布这是扩散文献中常见的约定不过与Pi0论文相反 # 注意这里使用扩散模型文献中更常见的约定t1是噪声t0是目标分布# 这与pi0论文相反dt -1.0 / num_steps # 计算时间步长从1到0batch_size observation.state.shape[0] # 获取批次大小# 生成初始噪声形状为[批次大小, 动作序列长度, 动作维度]noise jax.random.normal(rng, (batch_size, self.action_horizon, self.action_dim)) 第二Key-Value缓存初始化(预计算并存储前缀表示减少冗余计算) 处理观察数据得到前缀表示和相关掩码 # 首先通过前缀的前向传递填充KV缓存# 获取前缀的token表示和掩码prefix_tokens, prefix_mask, prefix_ar_mask self.embed_prefix(observation)# 创建前缀的注意力掩码prefix_attn_mask make_attn_mask(prefix_mask, prefix_ar_mask)# 计算位置编码positions jnp.cumsum(prefix_mask, axis1) - 1 然后使用PaliGemma语言模型进行一次前向传递生成Key-Value缓存kv_cache——这是一个性能优化因为前缀部分在整个采样过程中保持不变预先计算并缓存它们的表示可以避免重复计算 # 进行前向传递获取KV缓存_, kv_cache self.PaliGemma.llm([prefix_tokens, None], maskprefix_attn_mask, positionspositions) 第三通过step函数构建注意力掩码系统并让PaliGemma做推理 核心迭代通过 jax.lax.while_loop 实现 根据源码 可知该class Pi0(_model.BaseModel)类的最后两行是 # 使用while循环进行迭代采样从t1噪声开始x_0, _ jax.lax.while_loop(cond, step, (noise, 1.0))# 返回最终的去噪结果生成的动作序列return x_0 具体而言包含 step 函数和 cond 函数其中step 函数是每次迭代的核心 首先step函数通过 embed_suffix 处理当前状态包括状态信息嵌入、噪声化动作、时间步编码 def step(carry):定义单步去噪函数x_t, time carry # carry数组包含当前状态和时间# 将时间广播到批次维度并嵌入后缀状态和动作suffix_tokens, suffix_mask, suffix_ar_mask self.embed_suffix(observation, x_t, jnp.broadcast_to(time, batch_size)) 其次构建复杂的注意力掩码系统处理前缀-后缀之间的注意力关系——这个复杂的掩码系统允许后缀token包括状态和动作有选择地关注前缀token图像和文本实现了条件生成具体而言其构建了三层注意力掩码 后缀内部注意力掩码控制后缀token状态和动作之间的注意力关系 # 创建后缀内部的注意力掩码形状为(批次, 后缀长度, 后缀长度)suffix_attn_mask make_attn_mask(suffix_mask, suffix_ar_mask) 前缀-后缀注意力掩码控制后缀token如何关注前缀token图像和文本输入 # 创建后缀对前缀的注意力掩码形状为(批次, 后缀长度, 前缀长度)prefix_attn_mask einops.repeat(prefix_mask, b p - b s p, ssuffix_tokens.shape[1]) 完整注意力掩码将前两个掩码组合形成完整的注意力控制机制 # 组合掩码形状为(批次, 后缀长度, 前缀长度后缀长度)# 控制后缀token生成查询如何关注完整序列生成键和值full_attn_mask jnp.concatenate([prefix_attn_mask, suffix_attn_mask], axis-1) 当然了过程中还做了形状检查确保张量维度正确 # 验证掩码形状正确assert full_attn_mask.shape (batch_size,suffix_tokens.shape[1],prefix_tokens.shape[1] suffix_tokens.shape[1],) 接着计算位置编码为后缀token计算其在完整序列中的位置这对于Transformer模型理解序列顺序很重要 # 计算后缀token的位置编码positions jnp.sum(prefix_mask, axis-1)[:, None] jnp.cumsum(suffix_mask, axis-1) - 1 之后模型推理使用PaliGemma语言模型进行推理利用缓存的前缀信息kv_cache提高效率 # 使用KV缓存进行高效的前向传递(prefix_out, suffix_out), _ self.PaliGemma.llm([None, suffix_tokens], maskfull_attn_mask, positionspositions, kv_cachekv_cache)# 且确保前缀输出为None因为使用了KV缓存assert prefix_out is None 第四step函数中做最后的速度预测与动作更新(去噪) 在每一步中模型预测速度场 v_t从噪声到目标的方向并通过类欧拉法更新动作表示——使用简单而有效的欧拉方法进行轨迹采样 具体而言 一方面提取模型输出并预测速度场v_t——相当于本质是通过PaliGemma模型预测去噪方向 v_t # 预测噪声v_t self.action_out_proj(suffix_out[:, -self.action_horizon :]) 二方面使用欧拉法更新动作状态和时间步 # 使用欧拉方法更新状态和时间return x_t dt * v_t, time dt 至于cond函数确定何时停止迭代通过检查时间是否接近零(当然要考虑浮点精读可能存在的误差) def cond(carry):定义循环终止条件x_t, time carry# 考虑浮点误差当时间接近0时停止return time -dt / 2 // 待更
http://www.yingshimen.cn/news/129375/

相关文章:

  • 前端做项目有哪些网站百度智能云windows系统服务器建站
  • 网站运营难吗一条龙网站建设
  • 做一元云购网站电商网站的建设的主要目的
  • 长宁怎么做网站优化好网站管理程序
  • 阳江网站建设 公司价格建设网站都需要准备什么
  • 长沙长沙建设网站桂林北站到象鼻山景区怎么坐车
  • 好动词做的网站能行吗土木工程网官网登录
  • 正规网站建设空间哪个好建设 网站协议范本
  • 上海专业网站建设服ps网站切图教程
  • php网站文件夹结构关键的近义词
  • 网站咋建立pc做网站服务器
  • 网站规划书市场分析怎么推销自己的网站
  • 网站建设 统一标准体系小程序推广计划怎么做
  • 企业网站推广方案手机wordpress加载图片慢
  • 集团网站建设计划表百度wordpress 抓取
  • 嘉兴房地产网站建设建e网app
  • 有找专业做淘宝网站的美工郑州电力高等专科学校怎么样
  • 有没有和小孩做的网站公司网站建设中心
  • 福州cms模板建站德赞网站建设网站制作
  • 企业网络营销分析seo伪原创工具
  • 徐州营销网站建设报价wordpress邮箱头像
  • 网站 php .netword和the wordpress
  • ps为什么做不了视频网站蚌埠市网站建设公司
  • 百度推广必须做手机网站吗制作微信小程序怎么赚钱
  • 北京大龙建设集团有限公司网站首页wordpress国外主题加载慢
  • 建设学校网站天津网络推广网站建设公司
  • dede网站地图怎么做网站域名是啥
  • 做商业地产常用的网站logo设计网站排行榜
  • 免费网站制作成品网站开发工程师的职务
  • 批量 网站标题北京国税局网站做票种核定时