DiT4DiT 是港科广团队在 StarVLA 基础上提出的视觉-动作模型(VAM),将视频生成 Diffusion Transformer 与 Flow Matching 动作预测结合,支持机械臂灵巧操作与人形机器人全身控制。原始代码基于 CUDA 生态,本文记录将其适配到华为昇腾 Atlas 800T A3 NPU 的过程。
一、项目关联:DiT4DiT 与 StarVLA
两个开源仓库的关系如下:
| StarVLA | DiT4DiT | |
|---|---|---|
| 定位 | VLA 基础框架 | 基于 StarVLA 的扩展 |
| 基座模型 | Qwen3-VL-4B-Instruct | Cosmos-Predict2.5-2B |
| 动作预测 | 标准 diffusion head | Flow-matching action head |
| 视频模型 | 无 | 视频 diffusion + VAE 编解码 |
| 华为适配 | ✅ 已有 NPU 版(DrivingSDK) | ❌ 暂无 |
| 仓库地址 | gitcode.com/Ascend/DrivingSDK | github.com/Mondo-Robotics/DiT4DiT |
DiT4DiT 复用了 StarVLA 的 dataloader(LeRobot 格式)、accelerate + DeepSpeed 训练框架以及 DiT-B 动作头,但将视觉编码器从 Qwen3-VL 替换为 Cosmos-Predict2.5-2B(NVIDIA 的视频生成扩散模型),并增加了 Flow Matching 动作预测 + 视频辅助损失。
适配思路:参照华为 DrivingSDK 中已跑通的 StarVLA NPU 版(starvla.patch),对 DiT4DiT 做相同的 NPU 移植操作。
二、环境准备
2.1 软件版本
软件 版本
CANN 9.0.0
PyTorch 2.7.1
torch_npu 随 CANN 附赠
DeepSpeed 0.16.9
accelerate 1.12.0
transformers 4.57.6
diffusers git main(Cosmos2.5 支持)
2.2 模型权重下载
huggingface-cli download nvidia/Cosmos-Predict2.5-2B \
--revision diffusers/base/post-trained \
--local-dir playground/Pretrained_models/Cosmos-Predict2.5-2B
2.3 LD_LIBRARY_PATH
| 软件 | 版本 |
|---|---|
| CANN | 9.0.0 |
| PyTorch | 2.7.1 |
| torch_npu | 随 CANN 附赠 |
| DeepSpeed | 0.16.9 |
| accelerate | 1.12.0 |
| transformers | 4.57.6 |
| diffusers | git main(Cosmos2.5 支持) |
huggingface-cli download nvidia/Cosmos-Predict2.5-2B \
--revision diffusers/base/post-trained \
--local-dir playground/Pretrained_models/Cosmos-Predict2.5-2B
2.3 LD_LIBRARY_PATH
conda 环境下 torch_npu 的 .so 文件不会自动加入动态库搜索路径,需在启动脚本中显式设置:
export LD_LIBRARY_PATH=/opt/conda/envs/torch2.7.1/lib/python3.10/site-packages/torch/lib:/opt/conda/envs/torch2.7.1/lib/python3.10/site-packages/torch_npu/lib:$LD_LIBRARY_PATH
也可写入 conda activate 钩子实现自动生效:
mkdir -p /opt/conda/envs/torch2.7.1/etc/conda/activate.d
cat > /opt/conda/envs/torch2.7.1/etc/conda/activate.d/ld_path.sh << 'EOF'
export LD_LIBRARY_PATH=/opt/conda/envs/torch2.7.1/lib/python3.10/site-packages/torch/lib:/opt/conda/envs/torch2.7.1/lib/python3.10/site-packages/torch_npu/lib:$LD_LIBRARY_PATH
EOF
三、代码适配
适配只需修改两个文件,核心改动点如下:
3.1 train.py(训练入口)
- torch_npu + patcher 导入:在
import torch.distributed 后插入
- autocast 替换:
torch.autocast("cuda") → torch.autocast("npu")
- AdamW 融合优化:加
fused=True
- patcher 初始化:
__main__ 入口处调用 default_patcher_builder.build().__enter__()
# === 改动 1: import 区域 ===
import torch
import torch.distributed as dist
import warnings
try:
import torch_npu
from torch_npu.contrib import transfer_to_npu
except ImportError as e:
warnings.warn(f"Failed to import torch_npu or its submodule: {e}", ImportWarning)
try:
from mx_driving.patcher import default_patcher_builder
except ImportError as e:
warnings.warn(f"Failed to import from mx_driving.patcher: {e}", ImportWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pandas")
# === 改动 2: AdamW fused=True ===
optimizer = torch.optim.AdamW(
param_groups,
...
fused=True,
)
# === 改动 3: autocast cuda → npu ===
# torch.autocast("cuda", dtype=torch.bfloat16)
torch.autocast("npu", dtype=torch.bfloat16)
# === 改动 4: __main__ 入口 ===
if __name__ == "__main__":
patcher = default_patcher_builder.build()
patcher.__enter__()
3.2 DiT4DiT.py(模型框架)
import torch.distributed 后插入torch.autocast("cuda") → torch.autocast("npu")fused=True__main__ 入口处调用 default_patcher_builder.build().__enter__()# === 改动 1: import 区域 ===
import torch
import torch.distributed as dist
import warnings
try:
import torch_npu
from torch_npu.contrib import transfer_to_npu
except ImportError as e:
warnings.warn(f"Failed to import torch_npu or its submodule: {e}", ImportWarning)
try:
from mx_driving.patcher import default_patcher_builder
except ImportError as e:
warnings.warn(f"Failed to import from mx_driving.patcher: {e}", ImportWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning, module="pandas")
# === 改动 2: AdamW fused=True ===
optimizer = torch.optim.AdamW(
param_groups,
...
fused=True,
)
# === 改动 3: autocast cuda → npu ===
# torch.autocast("cuda", dtype=torch.bfloat16)
torch.autocast("npu", dtype=torch.bfloat16)
# === 改动 4: __main__ 入口 ===
if __name__ == "__main__":
patcher = default_patcher_builder.build()
patcher.__enter__()该文件有 4 处 torch.autocast("cuda"),全部替换为 torch.autocast("npu"):
- forward() 中 backbone 前向(bf16)
- forward() 中 action 前向(fp32)
- predict_action() 中 backbone 前向(bf16)
- predict_action() 中 action 解码(fp32)
3.3 适配总结
改动项 原因
torch_npu import + transfer_to_npu 替换 torch.cuda.* 为 torch.npu.*
default_patcher_builder mx_driving 的 NPU monkey patch(mmcv/torch/numpy)
autocast("npu") NPU 混合精度(对应 CUDA 的 amp)
AdamW fused=True NPU 融合优化器(减少 kernel launch 开销)
LD_LIBRARY_PATH libc10.so、libtorch_npu.so 动态库路径
| 改动项 | 原因 |
|---|---|
| torch_npu import + transfer_to_npu | 替换 torch.cuda.* 为 torch.npu.* |
| default_patcher_builder | mx_driving 的 NPU monkey patch(mmcv/torch/numpy) |
| autocast("npu") | NPU 混合精度(对应 CUDA 的 amp) |
| AdamW fused=True | NPU 融合优化器(减少 kernel launch 开销) |
| LD_LIBRARY_PATH | libc10.so、libtorch_npu.so 动态库路径 |
注意:DiT4DiT 使用的 default_patcher_builder API 与 StarVLA 原始 starvla.patch 中的 Patcher().add(TransformersNPU) 不同,原因是新版本 mx_driving 将 Patcher 从无参构造改为了 PatcherBuilder 模式,TransformersNPU 类也已移除(功能整合进了 default_patcher_builder)。
四、8卡启动脚本
完整的 train_8p.sh:
#!/bin/bash
export WANDB_MODE=offline
export WANDB_OFFLINE=true
export TASK_QUEUE_ENABLE=2
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
export CPU_AFFINITY_CONF=1
#export TORCH_HCCL_ZERO_COPY=1 # A2 需注释
export LD_LIBRARY_PATH=/opt/conda/envs/torch2.7.1/lib/python3.10/site-packages/torch/lib:/opt/conda/envs/torch2.7.1/lib/python3.10/site-packages/torch_npu/lib:$LD_LIBRARY_PATH
num_processes=8
max_train_steps=150000
per_device_batch_size=4
data_mix=robotwin
base_model=./playground/Pretrained_models/Cosmos-Predict2.5-2B
config_yaml=./DiT4DiT/config/robotwin/dit4dit_robotwin.yaml
Framework_name=DiT4DiT
freeze_module_list=''
run_root_dir=./results/Checkpoints
action_video_freq_ratio=1
# ... 参数解析(略)
accelerate launch \
--config_file DiT4DiT/config/deepseeds/deepspeed_zero2.yaml \
--num_processes ${num_processes} \
DiT4DiT/training/train.py \
--config_yaml ${config_yaml} \
--framework.name ${Framework_name} \
--framework.cosmos25.base_model ${base_model} \
--datasets.vla_data.per_device_batch_size ${per_device_batch_size} \
--datasets.vla_data.data_mix ${data_mix} \
--trainer.max_train_steps ${max_train_steps} \
--run_root_dir ${run_root_dir} \
--run_id ${run_id} \
2>&1 | tee ${LOG_FILE}
使用方式:
bash train_8p.sh --num_processes=8 --max_train_steps=100 --per_device_batch_size=1
五、训练配置(Robotwin)
新建配置文件 DiT4DiT/config/robotwin/dit4dit_robotwin.yaml,关键参数:
- base_model: Cosmos-Predict2.5-2B(视频生成扩散模型)
- data_mix: robotwin(复用 StarVLA 数据集)
- action_dim/state_dim: 14(双臂 7 关节 × 2)
- video_delta_indices: [0-7](8 帧视频序列)
- training: joint(动作 + 视频联合训练)
- future_loss_type: flow_matching(潜空间流匹配损失)
- future_action_window_size: 7
六、踩坑记录
6.1 num_processes 必须匹配物理卡数
train_8p.sh 名带 8p,但实际使用 --num_processes=16 会导致 rank 0 和 rank 8 分配到同一张物理卡,HCCL 报错:same physical device ID0 as the rank8。必须保持 --num_processes=8。
6.2 mx_driving.patcher API 变更
旧版 StarVLA starvla.patch 使用 Patcher().add(TransformersNPU).apply(),新版 mx_driving 改为了 default_patcher_builder.build().__enter__()。TransformersNPU 类已移除,功能整合进了 default_patcher_builder 预设的 mmcv/torch/numpy/mmdet 全套 patch。
6.3 libc10.so / libtorch_npu.so 未自动发现
交互式 python shell 中 import mx_driving 会报 libc10.so: cannot open shared object file,需手动设置 LD_LIBRARY_PATH。训练脚本中 torch_npu 通过 transfer_to_npu 在运行时会正确设置路径,但排查问题时需注意。
6.4 Cosmos-Predict2.5-2B 显存占用
Cosmos2.5 包含 VAE + Transformer + Text Encoder 三部分,完整加载约需 10GB+ 显存。配合 DeepSpeed ZeRO-2 可将优化器状态分片到 8 张卡,但首卡仍需承载完整模型权重(约 4.5B 参数)。实际训练需根据卡显存(A3 单卡约 64GB HBM)调整 per_device_batch_size。
简记。







