- 一、项目关联:DiT4DiT 与 StarVLA
- 二、环境准备
- 2.1 软件版本
- 2.2 基础环境搭建
- 2.3 安装 DiT4DiT 依赖
- 2.4 模型权重下载
- 2.5 LD_LIBRARY_PATH
- 三、代码适配
- 3.1 train.py(训练入口)
- 3.2 DiT4DiT.py(模型框架)
- 3.3 Cosmos25.py(backbone)
- 3.4 train.py eval_action_model
- 四、8卡启动脚本
- 五、训练配置(Robotwin)
- 六、踩坑记录
- 6.2 mx_driving.patcher API 变更
- 6.3 libc10.so / libtorch_npu.so 未自动发现
- 6.4 Cosmos-Predict2.5-2B 显存占用
- 6.5 VAE / Text Encoder 混合 dtype → ZeRO-3 defragment 报错
- 6.6 eval_action_model 维度不匹配 + bf16 → numpy
- 6.7 @torch.inference_mode() 与 ZeRO-3 冲突
- 6.8 ZeRO-3 模型保存卡死
- 七、ZeRO-2 vs ZeRO-3 配置与性能对比
- 八、适配 Patch
DiT4DiT 是港科广团队在 StarVLA 基础上提出的视觉-动作模型(VAM),将视频生成 Diffusion Transformer 与 Flow Matching 动作预测结合,支持机械臂灵巧操作与人形机器人全身控制。原始代码基于 CUDA 生态,本文记录将其适配到华为昇腾 Atlas 800T A3 NPU 的全过程,包括 ZeRO-2/ZeRO-3 两种分布式策略的配置与踩坑,最终在 RoboTwin 双臂数据集上成功跑通训练。
Git 仓库版本:commit 1ae6efd(github.com/Mondo-Robotics/DiT4DiT),适配 patch 已上传至服务器,文末附下载链接。
一、项目关联:DiT4DiT 与 StarVLA
两个开源仓库的关系如下:
| StarVLA | DiT4DiT | |
|---|---|---|
| 定位 | VLA 基础框架 | 基于 StarVLA 的扩展 |
| 基座模型 | Qwen3-VL-4B-Instruct | Cosmos-Predict2.5-2B |
| 动作预测 | 标准 diffusion head | Flow-matching action head |
| 视频模型 | 无 | 视频 diffusion + VAE 编解码 |
| 华为适配 | ✅ 已有 NPU 版(DrivingSDK) | ❌ 暂无 |
| 仓库地址 | gitcode.com/Ascend/DrivingSDK | github.com/Mondo-Robotics/DiT4DiT |
DiT4DiT 复用了 StarVLA 的 dataloader(LeRobot 格式)、accelerate + DeepSpeed 训练框架以及 DiT-B 动作头,但将视觉编码器从 Qwen3-VL 替换为 Cosmos-Predict2.5-2B(NVIDIA 的视频生成扩散模型),并增加了 Flow Matching 动作预测 + 视频辅助损失。
适配思路:参照华为 DrivingSDK 中已跑通的 StarVLA NPU 版(starvla.patch),对 DiT4DiT 做相同的 NPU 移植操作。本环境直接复用 StarVLA 的 conda 环境,无需额外安装基础依赖,只需下载 DiT4DiT 源码 + 权重即可。
二、环境准备
2.1 软件版本
以下版本已在 Atlas 800T A3 上验证通过:
| 软件 | 版本 |
|---|---|
| Python | 3.10 |
| CANN | 9.0.0 |
| PyTorch | 2.7.1 |
| torch_npu | 2.7.1.post2 |
| DeepSpeed | 0.18.4 |
| accelerate | 1.12.0 |
| diffusers | 0.38.0 |
| transformers | 4.57.0 |
| decord | 0.6.0 |
| mx_driving | 1.0.20260421 |
2.2 基础环境搭建
参考 StarVLA for PyTorch(DrivingSDK) 的环境准备步骤。
安装依赖:
apt install -y libgl1-mesa-glx libglib2.0-0
创建 conda 环境:
conda create -n dit4dit python=3.10
conda activate dit4dit
安装 PyTorch + torch_npu:
pip install torch==2.7.1
pip install torch_npu==2.7.1.post2
pip install torchvision==0.22.1
安装 ffmpeg:
推荐 conda 安装(会自动处理依赖):
conda install -c conda-forge ffmpeg=4.4.2
或源码安装:
wget https://ffmpeg.org/releases/ffmpeg-4.4.2.tar.bz2
tar -xvf ffmpeg-4.4.2.tar.bz2
cd ffmpeg-4.4.2
./configure --enable-shared --prefix=/usr/local/ffmpeg
make -j 64
make install
cd ..
echo 'export PATH="/usr/local/ffmpeg/bin:$PATH"' >> /etc/profile.d/ffmpeg.sh
echo 'export LD_LIBRARY_PATH="/usr/local/ffmpeg/lib:$LD_LIBRARY_PATH"' >> /etc/profile.d/ffmpeg.sh
source /etc/profile
安装 decord(视频解码库):
git clone --recursive https://github.com/dmlc/decord --depth 1
cd decord
mkdir build && cd build
cmake .. -DCMAKE_BUILD_TYPE=Release -DFFMPEG_DIR:PATH=$CONDA_PREFIX
make
cd ../python
python setup.py sdist bdist_wheel
cd ../..
pip install decord/python/dist/decord-0.6.0-cp310-cp310-linux_aarch64.whl
安装 mx_driving(NPU patcher,关键!):
# mx_driving 从 DrivingSDK 获取
# 参考 https://gitcode.com/Ascend/DrivingSDK
2.3 安装 DiT4DiT 依赖
下载 DiT4DiT 源码后(commit 1ae6efd),安装 requirements:
cd DiT4DiT
pip install -r requirements.txt
pip install -e .
完整 requirements.txt 如下(基于 NPU 环境实际 pip list 整理):
absl-py==2.3.1
accelerate==1.12.0
albucore==0.0.17
albumentations==1.4.18
av==12.3.0
certifi==2026.1.4
charset-normalizer==3.4.4
click==8.3.1
contourpy==1.3.2
cramjam==2.11.0
cycler==0.12.1
deepspeed==0.18.4
diffusers==0.38.0
docstring_parser==0.17.0
einops==0.8.1
einx==0.3.0
eval_type_backport==0.3.1
fastparquet==2024.11.0
filelock==3.20.3
fonttools==4.61.1
frozendict==2.4.7
fsspec==2026.1.0
fvcore==0.1.5.post20221221
gitdb==4.0.12
GitPython==3.1.46
greenlet==3.3.0
grpcio==1.76.0
hf-xet==1.2.0
hjson==3.1.0
huggingface-hub==0.36.0
hyper-connections==0.4.6
idna==3.11
imageio==2.37.0
imageio-ffmpeg==0.6.0
importlib_metadata==8.7.1
iopath==0.1.10
Jinja2==3.1.6
joblib==1.5.3
kiwisolver==1.4.9
lazy_loader==0.4
Markdown==3.10
markdown-it-py==4.0.0
MarkupSafe==3.0.3
matplotlib==3.10.8
mdurl==0.1.2
mpmath==1.3.0
msgpack==1.1.2
networkx==3.4.2
ninja==1.13.0
nltk==3.9.1
numpy==1.26.4
numpydantic==1.6.9
omegaconf==2.3.0
opencv-python==4.10.0.84
opencv-python-headless==4.11.0.86
packaging==25.0
pandas==2.3.3
peft==0.18.1
pillow==11.1.0
pipablepytorch3d==0.7.6
portalocker==3.2.0
protobuf==6.33.4
psutil==7.2.1
pyarrow==14.0.1
pydantic==2.10.6
Pygments==2.19.2
pyparsing==3.3.1
python-dateutil==2.9.0.post0
pytz==2025.2
PyYAML==6.0.3
qwen-vl-utils==0.0.14
regex==2026.1.15
requests==2.32.5
rich==14.2.0
safetensors==0.5.3
scikit-image==0.25.2
scipy==1.15.3
sentencepiece==0.2.0
sentry-sdk==2.49.0
six==1.17.0
smmap==5.0.2
sympy==1.13.3
tabulate==0.9.0
tensorboard==2.20.0
tensorboard-data-server==0.7.2
tifffile==2025.5.10
tiktoken==0.12.0
timm==1.0.24
tokenizers==0.22.2
torch==2.7.1
torch-npu==2.7.1.post2
torchvision==0.22.1
tqdm==4.66.5
transformers==4.57.0
typing_extensions==4.15.0
tyro==1.0.5
urllib3==2.6.3
wandb==0.24.0
websockets==16.0
Werkzeug==3.1.5
yacs==0.1.8
zipp==3.23.0
注意:decord 需要源码编译安装(非 pip),ffmpeg 需提前安装。mx_driving 需从 DrivingSDK 获取。
2.4 模型权重下载
huggingface-cli download nvidia/Cosmos-Predict2.5-2B \
--revision diffusers/base/post-trained \
--local-dir playground/Pretrained_models/Cosmos-Predict2.5-2B
2.5 LD_LIBRARY_PATH
huggingface-cli download nvidia/Cosmos-Predict2.5-2B \
--revision diffusers/base/post-trained \
--local-dir playground/Pretrained_models/Cosmos-Predict2.5-2Bconda 环境下 torch_npu 的 .so 文件不会自动加入动态库搜索路径,需在启动脚本中显式设置:
export LD_LIBRARY_PATH=/opt/conda/envs/torch2.7.1/lib/python3.10/site-packages/torch/lib:/opt/conda/envs/torch2.7.1/lib/python3.10/site-packages/torch_npu/lib:$LD_LIBRARY_PATH
也可写入 conda activate 钩子实现自动生效:
mkdir -p /opt/conda/envs/torch2.7.1/etc/conda/activate.d
cat > /opt/conda/envs/torch2.7.1/etc/conda/activate.d/ld_path.sh << 'EOF'
export LD_LIBRARY_PATH=/opt/conda/envs/torch2.7.1/lib/python3.10/site-packages/torch/lib:/opt/conda/envs/torch2.7.1/lib/python3.10/site-packages/torch_npu/lib:$LD_LIBRARY_PATH
EOF
三、代码适配
适配只需修改两个文件,核心改动点如下:
3.1 train.py(训练入口)
- torch_npu + patcher 导入:在
import torch.distributed 后插入
- autocast 替换:
torch.autocast("cuda") → torch.autocast("npu")
- AdamW 融合优化:加
fused=True
- patcher 初始化:
__main__ 入口处调用 default_patcher_builder.build().__enter__()
- DeepSpeedPlugin 去硬编码:删除
hf_ds_config 参数,改为空构造
3.2 DiT4DiT.py(模型框架)
- 4 处
torch.autocast("cuda") → torch.autocast("npu")
action_mask 截断对齐(修复 dim 8 vs 16 不匹配)
predict_action 中 bf16 → float() → numpy()(NPU bf16 不支持 numpy)
@torch.inference_mode() → @torch.no_grad()(ZeRO-3 兼容性)
3.3 Cosmos25.py(backbone)
- 冻结 VAE + Text Encoder 参数(
requires_grad=False,解决 ZeRO-3 混合 dtype 报错 + 省显存)
3.4 train.py eval_action_model
actions 和 action_mask 截断对齐 action_horizon(eval 时同样有 dim 不匹配问题)
四、8卡启动脚本
import torch.distributed 后插入torch.autocast("cuda") → torch.autocast("npu")fused=True__main__ 入口处调用 default_patcher_builder.build().__enter__()hf_ds_config 参数,改为空构造- 4 处
torch.autocast("cuda")→torch.autocast("npu") action_mask截断对齐(修复 dim 8 vs 16 不匹配)predict_action中bf16 → float() → numpy()(NPU bf16 不支持 numpy)@torch.inference_mode()→@torch.no_grad()(ZeRO-3 兼容性)
3.3 Cosmos25.py(backbone)
- 冻结 VAE + Text Encoder 参数(
requires_grad=False,解决 ZeRO-3 混合 dtype 报错 + 省显存)
3.4 train.py eval_action_model
actions 和 action_mask 截断对齐 action_horizon(eval 时同样有 dim 不匹配问题)
四、8卡启动脚本
requires_grad=False,解决 ZeRO-3 混合 dtype 报错 + 省显存)actions和action_mask截断对齐 action_horizon(eval 时同样有 dim 不匹配问题)
四、8卡启动脚本
完整的 train_8p.sh:
#!/bin/bash
export WANDB_MODE=offline
export WANDB_OFFLINE=true
export TASK_QUEUE_ENABLE=2
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
export CPU_AFFINITY_CONF=1
export LD_LIBRARY_PATH=/opt/conda/envs/torch2.7.1/lib/python3.10/site-packages/torch/lib:/opt/conda/envs/torch2.7.1/lib/python3.10/site-packages/torch_npu/lib:$LD_LIBRARY_PATH
num_processes=8
max_train_steps=150000
per_device_batch_size=2
data_mix=robotwin
base_model=./playground/Pretrained_models/Cosmos-Predict2.5-2B
config_yaml=./DiT4DiT/config/robotwin/dit4dit_robotwin.yaml
Framework_name=DiT4DiT
run_root_dir=./results/Checkpoints
accelerate launch \
--config_file DiT4DiT/config/deepseeds/deepspeed_zero2.yaml \
--num_processes ${num_processes} \
DiT4DiT/training/train.py \
--config_yaml ${config_yaml} \
--framework.name ${Framework_name} \
--framework.cosmos25.base_model ${base_model} \
--datasets.vla_data.per_device_batch_size ${per_device_batch_size} \
--datasets.vla_data.data_mix ${data_mix} \
--trainer.max_train_steps ${max_train_steps} \
--run_root_dir ${run_root_dir} \
--run_id ${run_id} \
2>&1 | tee ${LOG_FILE}
使用方式:
bash train_8p.sh --num_processes=8 --max_train_steps=100 --per_device_batch_size=2
五、训练配置(Robotwin)
新建配置文件 DiT4DiT/config/robotwin/dit4dit_robotwin.yaml,关键参数:
- base_model: Cosmos-Predict2.5-2B
- data_mix: robotwin
- action_dim/state_dim: 14(双臂 7 关节 × 2)
- training: action(仅动作训练,不含视频 loss)
- future_loss_type: flow_matching
- future_action_window_size: 7
六、踩坑记录
6.2 mx_driving.patcher API 变更
旧版 StarVLA 的 starvla.patch 使用 Patcher().add(TransformersNPU).apply(),新版 mx_driving 改为了 default_patcher_builder.build().__enter__()。TransformersNPU 类已移除,功能整合进 default_patcher_builder 预设的 mmcv/torch/numpy/mmdet 全套 patch。
6.3 libc10.so / libtorch_npu.so 未自动发现
需手动设置 LD_LIBRARY_PATH 或写入 conda activate 钩子(见 2.5 节)。
6.4 Cosmos-Predict2.5-2B 显存占用
Cosmos2.5 包含 VAE + Transformer + Text Encoder 三部分。VAE 和 Text Encoder 冻结后(requires_grad=False),实际训练参数约 2.2B(Transformer + Action Head)。
6.5 VAE / Text Encoder 混合 dtype → ZeRO-3 defragment 报错
VAE 和 Text Encoder 权重为 fp32,Transformer 为 bf16。ZeRO-3 初始化时要求所有 trainable 参数 dtype 一致。冻结后解决。
6.6 eval_action_model 维度不匹配 + bf16 → numpy
eval 时 actions 和 action_mask 未截断到 action_horizon(8 步),需要 slice。同时 predict_action 的输出是 bf16,numpy 不支持,需 .float().numpy()。
6.7 @torch.inference_mode() 与 ZeRO-3 冲突
ZeRO-3 的 LinearFunctionForZeroStage3 内部保存 tensor 用于反向,inference_mode 创建的 tensor 无法被保存。改用 @torch.no_grad()。
6.8 ZeRO-3 模型保存卡死
accelerator.get_state_dict() 把 8 卡分片参数 gather 到 rank 0,2.3B 参数耗时数分钟。改用 model.save_checkpoint()(DeepSpeed 原生保存,每卡各自写 shard)。
七、ZeRO-2 vs ZeRO-3 配置与性能对比
| ZeRO-3 | ZeRO-2 | |
|---|---|---|
| 训练速度 | ~3.14s/it | ~0.74s/it(实际训练不含保存) |
| per_device_batch_size | 16 | 2 |
| gradient_accumulation | 1 | 8 |
| 等效 batch | 128 | 128 |
| 数据吞吐量 | 更高(大batch x 大步时间) | 较低 |
| 显存压力 | 较高(参数分片,可开大batch) | 较低(参数完整保留) |
| 模型保存 | save_checkpoint(分片) | 正常 |
| 配置复杂度 | 较高(多个冲突点) | 较低 |
| 推荐场景 | 正式训练(吞吐优先) | 调试/显存不足 |
八、适配 Patch
基于 commit 1ae6efd 生成的两个 patch 文件,已覆盖全部修改:
- dit4dit_zero3.patch:ZeRO-3 适配版(含 NPU 移植 + 全部 Bug 修复)
- dit4dit_zero2.patch:ZeRO-2 适配版(同上 + train.py 去硬编码 + ZeRO-2 配置)
使用方法:
git clone https://github.com/Mondo-Robotics/DiT4DiT.git
cd DiT4DiT
git checkout 1ae6efd
git apply dit4dit_zero3.patch # 或 dit4dit_zero2.patch
pip install -r requirements.txt
pip install -e .
两个 patch 下载:







