DiT4DiT 昇腾NPU适配简记

发表于 16 小时前  10 次阅读


文章目录

DiT4DiT 是港科广团队在 StarVLA 基础上提出的视觉-动作模型(VAM),将视频生成 Diffusion Transformer 与 Flow Matching 动作预测结合,支持机械臂灵巧操作与人形机器人全身控制。原始代码基于 CUDA 生态,本文记录将其适配到华为昇腾 Atlas 800T A3 NPU 的过程。

一、项目关联:DiT4DiT 与 StarVLA

两个开源仓库的关系如下:

StarVLADiT4DiT
定位VLA 基础框架基于 StarVLA 的扩展
基座模型Qwen3-VL-4B-InstructCosmos-Predict2.5-2B
动作预测标准 diffusion headFlow-matching action head
视频模型视频 diffusion + VAE 编解码
华为适配✅ 已有 NPU 版(DrivingSDK)❌ 暂无
仓库地址gitcode.com/Ascend/DrivingSDKgithub.com/Mondo-Robotics/DiT4DiT

DiT4DiT 复用了 StarVLA 的 dataloader(LeRobot 格式)、accelerate + DeepSpeed 训练框架以及 DiT-B 动作头,但将视觉编码器从 Qwen3-VL 替换为 Cosmos-Predict2.5-2B(NVIDIA 的视频生成扩散模型),并增加了 Flow Matching 动作预测 + 视频辅助损失。

适配思路:参照华为 DrivingSDK 中已跑通的 StarVLA NPU 版(starvla.patch),对 DiT4DiT 做相同的 NPU 移植操作。

二、环境准备

2.1 软件版本

软件版本
CANN9.0.0
PyTorch2.7.1
torch_npu随 CANN 附赠
DeepSpeed0.16.9
accelerate1.12.0
transformers4.57.6
diffusersgit main(Cosmos2.5 支持)

2.2 模型权重下载

huggingface-cli download nvidia/Cosmos-Predict2.5-2B \
  --revision diffusers/base/post-trained \
  --local-dir playground/Pretrained_models/Cosmos-Predict2.5-2B

2.3 LD_LIBRARY_PATH

conda 环境下 torch_npu 的 .so 文件不会自动加入动态库搜索路径,需在启动脚本中显式设置:

export LD_LIBRARY_PATH=/opt/conda/envs/torch2.7.1/lib/python3.10/site-packages/torch/lib:/opt/conda/envs/torch2.7.1/lib/python3.10/site-packages/torch_npu/lib:$LD_LIBRARY_PATH

也可写入 conda activate 钩子实现自动生效:

mkdir -p /opt/conda/envs/torch2.7.1/etc/conda/activate.d
cat > /opt/conda/envs/torch2.7.1/etc/conda/activate.d/ld_path.sh << 'EOF'
export LD_LIBRARY_PATH=/opt/conda/envs/torch2.7.1/lib/python3.10/site-packages/torch/lib:/opt/conda/envs/torch2.7.1/lib/python3.10/site-packages/torch_npu/lib:$LD_LIBRARY_PATH
EOF

三、代码适配

适配只需修改两个文件,核心改动点如下:

3.1 train.py(训练入口)

  • torch_npu + patcher 导入:在 import torch.distributed 后插入
  • autocast 替换torch.autocast("cuda")torch.autocast("npu")
  • AdamW 融合优化:加 fused=True
  • patcher 初始化__main__ 入口处调用 default_patcher_builder.build().__enter__()
# === 改动 1: import 区域 ===
import torch
import torch.distributed as dist

import warnings

try:
    import torch_npu
    from torch_npu.contrib import transfer_to_npu
except ImportError as e:
    warnings.warn(f"Failed to import torch_npu or its submodule: {e}", ImportWarning)

try:
    from mx_driving.patcher import default_patcher_builder
except ImportError as e:
    warnings.warn(f"Failed to import from mx_driving.patcher: {e}", ImportWarning)

warnings.filterwarnings("ignore", category=DeprecationWarning, module="pandas")

# === 改动 2: AdamW fused=True ===
optimizer = torch.optim.AdamW(
    param_groups,
    ...
    fused=True,
)

# === 改动 3: autocast cuda → npu ===
# torch.autocast("cuda", dtype=torch.bfloat16)
torch.autocast("npu", dtype=torch.bfloat16)

# === 改动 4: __main__ 入口 ===
if __name__ == "__main__":
    patcher = default_patcher_builder.build()
    patcher.__enter__()

3.2 DiT4DiT.py(模型框架)

该文件有 4 处 torch.autocast("cuda"),全部替换为 torch.autocast("npu")

  • forward() 中 backbone 前向(bf16)
  • forward() 中 action 前向(fp32)
  • predict_action() 中 backbone 前向(bf16)
  • predict_action() 中 action 解码(fp32)

3.3 适配总结

改动项原因
torch_npu import + transfer_to_npu替换 torch.cuda.* 为 torch.npu.*
default_patcher_buildermx_driving 的 NPU monkey patch(mmcv/torch/numpy)
autocast("npu")NPU 混合精度(对应 CUDA 的 amp)
AdamW fused=TrueNPU 融合优化器(减少 kernel launch 开销)
LD_LIBRARY_PATHlibc10.so、libtorch_npu.so 动态库路径

注意:DiT4DiT 使用的 default_patcher_builder API 与 StarVLA 原始 starvla.patch 中的 Patcher().add(TransformersNPU) 不同,原因是新版本 mx_driving 将 Patcher 从无参构造改为了 PatcherBuilder 模式,TransformersNPU 类也已移除(功能整合进了 default_patcher_builder)。

四、8卡启动脚本

完整的 train_8p.sh:

#!/bin/bash
export WANDB_MODE=offline
export WANDB_OFFLINE=true
export TASK_QUEUE_ENABLE=2
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
export CPU_AFFINITY_CONF=1
#export TORCH_HCCL_ZERO_COPY=1  # A2 需注释
export LD_LIBRARY_PATH=/opt/conda/envs/torch2.7.1/lib/python3.10/site-packages/torch/lib:/opt/conda/envs/torch2.7.1/lib/python3.10/site-packages/torch_npu/lib:$LD_LIBRARY_PATH

num_processes=8
max_train_steps=150000
per_device_batch_size=4
data_mix=robotwin
base_model=./playground/Pretrained_models/Cosmos-Predict2.5-2B
config_yaml=./DiT4DiT/config/robotwin/dit4dit_robotwin.yaml
Framework_name=DiT4DiT
freeze_module_list=''
run_root_dir=./results/Checkpoints
action_video_freq_ratio=1

# ... 参数解析(略)

accelerate launch \
  --config_file DiT4DiT/config/deepseeds/deepspeed_zero2.yaml \
  --num_processes ${num_processes} \
  DiT4DiT/training/train.py \
  --config_yaml ${config_yaml} \
  --framework.name ${Framework_name} \
  --framework.cosmos25.base_model ${base_model} \
  --datasets.vla_data.per_device_batch_size ${per_device_batch_size} \
  --datasets.vla_data.data_mix ${data_mix} \
  --trainer.max_train_steps ${max_train_steps} \
  --run_root_dir ${run_root_dir} \
  --run_id ${run_id} \
  2>&1 | tee ${LOG_FILE}

使用方式:

bash train_8p.sh --num_processes=8 --max_train_steps=100 --per_device_batch_size=1

五、训练配置(Robotwin)

新建配置文件 DiT4DiT/config/robotwin/dit4dit_robotwin.yaml,关键参数:

  • base_model: Cosmos-Predict2.5-2B(视频生成扩散模型)
  • data_mix: robotwin(复用 StarVLA 数据集)
  • action_dim/state_dim: 14(双臂 7 关节 × 2)
  • video_delta_indices: [0-7](8 帧视频序列)
  • training: joint(动作 + 视频联合训练)
  • future_loss_type: flow_matching(潜空间流匹配损失)
  • future_action_window_size: 7

六、踩坑记录

6.1 num_processes 必须匹配物理卡数

train_8p.sh 名带 8p,但实际使用 --num_processes=16 会导致 rank 0 和 rank 8 分配到同一张物理卡,HCCL 报错:same physical device ID0 as the rank8。必须保持 --num_processes=8

6.2 mx_driving.patcher API 变更

旧版 StarVLA starvla.patch 使用 Patcher().add(TransformersNPU).apply(),新版 mx_driving 改为了 default_patcher_builder.build().__enter__()TransformersNPU 类已移除,功能整合进了 default_patcher_builder 预设的 mmcv/torch/numpy/mmdet 全套 patch。

6.3 libc10.so / libtorch_npu.so 未自动发现

交互式 python shell 中 import mx_driving 会报 libc10.so: cannot open shared object file,需手动设置 LD_LIBRARY_PATH。训练脚本中 torch_npu 通过 transfer_to_npu 在运行时会正确设置路径,但排查问题时需注意。

6.4 Cosmos-Predict2.5-2B 显存占用

Cosmos2.5 包含 VAE + Transformer + Text Encoder 三部分,完整加载约需 10GB+ 显存。配合 DeepSpeed ZeRO-2 可将优化器状态分片到 8 张卡,但首卡仍需承载完整模型权重(约 4.5B 参数)。实际训练需根据卡显存(A3 单卡约 64GB HBM)调整 per_device_batch_size。

简记。

本站文章基于国际协议BY-NA-SA 4.0协议共享;
如未特殊说明,本站文章皆为原创文章,请规范转载。

0

scanz个人博客