DiT4DiT 昇腾NPU适配简记

发表于 12 小时前  14 次阅读


文章目录

DiT4DiT 是港科广团队在 StarVLA 基础上提出的视觉-动作模型(VAM),将视频生成 Diffusion Transformer 与 Flow Matching 动作预测结合,支持机械臂灵巧操作与人形机器人全身控制。原始代码基于 CUDA 生态,本文记录将其适配到华为昇腾 Atlas 800T A3 NPU 的全过程,包括 ZeRO-2/ZeRO-3 两种分布式策略的配置与踩坑,最终在 RoboTwin 双臂数据集上成功跑通训练。

Git 仓库版本:commit 1ae6efdgithub.com/Mondo-Robotics/DiT4DiT),适配 patch 已上传至服务器,文末附下载链接。

一、项目关联:DiT4DiT 与 StarVLA

两个开源仓库的关系如下:

StarVLADiT4DiT
定位VLA 基础框架基于 StarVLA 的扩展
基座模型Qwen3-VL-4B-InstructCosmos-Predict2.5-2B
动作预测标准 diffusion headFlow-matching action head
视频模型视频 diffusion + VAE 编解码
华为适配✅ 已有 NPU 版(DrivingSDK)❌ 暂无
仓库地址gitcode.com/Ascend/DrivingSDKgithub.com/Mondo-Robotics/DiT4DiT

DiT4DiT 复用了 StarVLA 的 dataloader(LeRobot 格式)、accelerate + DeepSpeed 训练框架以及 DiT-B 动作头,但将视觉编码器从 Qwen3-VL 替换为 Cosmos-Predict2.5-2B(NVIDIA 的视频生成扩散模型),并增加了 Flow Matching 动作预测 + 视频辅助损失。

适配思路:参照华为 DrivingSDK 中已跑通的 StarVLA NPU 版(starvla.patch),对 DiT4DiT 做相同的 NPU 移植操作。本环境直接复用 StarVLA 的 conda 环境,无需额外安装基础依赖,只需下载 DiT4DiT 源码 + 权重即可。

二、环境准备

2.1 软件版本

以下版本已在 Atlas 800T A3 上验证通过:

软件版本
Python3.10
CANN9.0.0
PyTorch2.7.1
torch_npu2.7.1.post2
DeepSpeed0.18.4
accelerate1.12.0
diffusers0.38.0
transformers4.57.0
decord0.6.0
mx_driving1.0.20260421

2.2 基础环境搭建

参考 StarVLA for PyTorch(DrivingSDK) 的环境准备步骤。

安装依赖:

 apt install -y libgl1-mesa-glx libglib2.0-0

创建 conda 环境:

conda create -n dit4dit python=3.10
conda activate dit4dit

安装 PyTorch + torch_npu:

pip install torch==2.7.1
pip install torch_npu==2.7.1.post2
pip install torchvision==0.22.1

安装 ffmpeg:

推荐 conda 安装(会自动处理依赖):

conda install -c conda-forge ffmpeg=4.4.2

或源码安装:

wget https://ffmpeg.org/releases/ffmpeg-4.4.2.tar.bz2
tar -xvf ffmpeg-4.4.2.tar.bz2
cd ffmpeg-4.4.2
./configure --enable-shared --prefix=/usr/local/ffmpeg
make -j 64
make install
cd ..
echo 'export PATH="/usr/local/ffmpeg/bin:$PATH"' >> /etc/profile.d/ffmpeg.sh
echo 'export LD_LIBRARY_PATH="/usr/local/ffmpeg/lib:$LD_LIBRARY_PATH"' >> /etc/profile.d/ffmpeg.sh
source /etc/profile

安装 decord(视频解码库):

git clone --recursive https://github.com/dmlc/decord --depth 1
cd decord
mkdir build && cd build
cmake .. -DCMAKE_BUILD_TYPE=Release -DFFMPEG_DIR:PATH=$CONDA_PREFIX
make
cd ../python
python setup.py sdist bdist_wheel
cd ../..
pip install decord/python/dist/decord-0.6.0-cp310-cp310-linux_aarch64.whl

安装 mx_driving(NPU patcher,关键!):

# mx_driving 从 DrivingSDK 获取
# 参考 https://gitcode.com/Ascend/DrivingSDK

2.3 安装 DiT4DiT 依赖

下载 DiT4DiT 源码后(commit 1ae6efd),安装 requirements:

cd DiT4DiT
pip install -r requirements.txt
pip install -e .

完整 requirements.txt 如下(基于 NPU 环境实际 pip list 整理):

absl-py==2.3.1
accelerate==1.12.0
albucore==0.0.17
albumentations==1.4.18
av==12.3.0
certifi==2026.1.4
charset-normalizer==3.4.4
click==8.3.1
contourpy==1.3.2
cramjam==2.11.0
cycler==0.12.1
deepspeed==0.18.4
diffusers==0.38.0
docstring_parser==0.17.0
einops==0.8.1
einx==0.3.0
eval_type_backport==0.3.1
fastparquet==2024.11.0
filelock==3.20.3
fonttools==4.61.1
frozendict==2.4.7
fsspec==2026.1.0
fvcore==0.1.5.post20221221
gitdb==4.0.12
GitPython==3.1.46
greenlet==3.3.0
grpcio==1.76.0
hf-xet==1.2.0
hjson==3.1.0
huggingface-hub==0.36.0
hyper-connections==0.4.6
idna==3.11
imageio==2.37.0
imageio-ffmpeg==0.6.0
importlib_metadata==8.7.1
iopath==0.1.10
Jinja2==3.1.6
joblib==1.5.3
kiwisolver==1.4.9
lazy_loader==0.4
Markdown==3.10
markdown-it-py==4.0.0
MarkupSafe==3.0.3
matplotlib==3.10.8
mdurl==0.1.2
mpmath==1.3.0
msgpack==1.1.2
networkx==3.4.2
ninja==1.13.0
nltk==3.9.1
numpy==1.26.4
numpydantic==1.6.9
omegaconf==2.3.0
opencv-python==4.10.0.84
opencv-python-headless==4.11.0.86
packaging==25.0
pandas==2.3.3
peft==0.18.1
pillow==11.1.0
pipablepytorch3d==0.7.6
portalocker==3.2.0
protobuf==6.33.4
psutil==7.2.1
pyarrow==14.0.1
pydantic==2.10.6
Pygments==2.19.2
pyparsing==3.3.1
python-dateutil==2.9.0.post0
pytz==2025.2
PyYAML==6.0.3
qwen-vl-utils==0.0.14
regex==2026.1.15
requests==2.32.5
rich==14.2.0
safetensors==0.5.3
scikit-image==0.25.2
scipy==1.15.3
sentencepiece==0.2.0
sentry-sdk==2.49.0
six==1.17.0
smmap==5.0.2
sympy==1.13.3
tabulate==0.9.0
tensorboard==2.20.0
tensorboard-data-server==0.7.2
tifffile==2025.5.10
tiktoken==0.12.0
timm==1.0.24
tokenizers==0.22.2
torch==2.7.1
torch-npu==2.7.1.post2
torchvision==0.22.1
tqdm==4.66.5
transformers==4.57.0
typing_extensions==4.15.0
tyro==1.0.5
urllib3==2.6.3
wandb==0.24.0
websockets==16.0
Werkzeug==3.1.5
yacs==0.1.8
zipp==3.23.0

注意:decord 需要源码编译安装(非 pip),ffmpeg 需提前安装。mx_driving 需从 DrivingSDK 获取。

2.4 模型权重下载

huggingface-cli download nvidia/Cosmos-Predict2.5-2B \
  --revision diffusers/base/post-trained \
  --local-dir playground/Pretrained_models/Cosmos-Predict2.5-2B

2.5 LD_LIBRARY_PATH

conda 环境下 torch_npu 的 .so 文件不会自动加入动态库搜索路径,需在启动脚本中显式设置:

export LD_LIBRARY_PATH=/opt/conda/envs/torch2.7.1/lib/python3.10/site-packages/torch/lib:/opt/conda/envs/torch2.7.1/lib/python3.10/site-packages/torch_npu/lib:$LD_LIBRARY_PATH

也可写入 conda activate 钩子实现自动生效:

mkdir -p /opt/conda/envs/torch2.7.1/etc/conda/activate.d
cat > /opt/conda/envs/torch2.7.1/etc/conda/activate.d/ld_path.sh << 'EOF'
export LD_LIBRARY_PATH=/opt/conda/envs/torch2.7.1/lib/python3.10/site-packages/torch/lib:/opt/conda/envs/torch2.7.1/lib/python3.10/site-packages/torch_npu/lib:$LD_LIBRARY_PATH
EOF

三、代码适配

适配只需修改两个文件,核心改动点如下:

3.1 train.py(训练入口)

  • torch_npu + patcher 导入:在 import torch.distributed 后插入
  • autocast 替换torch.autocast("cuda")torch.autocast("npu")
  • AdamW 融合优化:加 fused=True
  • patcher 初始化__main__ 入口处调用 default_patcher_builder.build().__enter__()
  • DeepSpeedPlugin 去硬编码:删除 hf_ds_config 参数,改为空构造

3.2 DiT4DiT.py(模型框架)

  • 4 处 torch.autocast("cuda")torch.autocast("npu")
  • action_mask 截断对齐(修复 dim 8 vs 16 不匹配)
  • predict_actionbf16 → float() → numpy()(NPU bf16 不支持 numpy)
  • @torch.inference_mode()@torch.no_grad()(ZeRO-3 兼容性)

3.3 Cosmos25.py(backbone)

  • 冻结 VAE + Text Encoder 参数(requires_grad=False,解决 ZeRO-3 混合 dtype 报错 + 省显存)

3.4 train.py eval_action_model

  • actionsaction_mask 截断对齐 action_horizon(eval 时同样有 dim 不匹配问题)

四、8卡启动脚本

完整的 train_8p.sh:

#!/bin/bash
export WANDB_MODE=offline
export WANDB_OFFLINE=true
export TASK_QUEUE_ENABLE=2
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
export CPU_AFFINITY_CONF=1
export LD_LIBRARY_PATH=/opt/conda/envs/torch2.7.1/lib/python3.10/site-packages/torch/lib:/opt/conda/envs/torch2.7.1/lib/python3.10/site-packages/torch_npu/lib:$LD_LIBRARY_PATH

num_processes=8
max_train_steps=150000
per_device_batch_size=2
data_mix=robotwin
base_model=./playground/Pretrained_models/Cosmos-Predict2.5-2B
config_yaml=./DiT4DiT/config/robotwin/dit4dit_robotwin.yaml
Framework_name=DiT4DiT
run_root_dir=./results/Checkpoints

accelerate launch \
  --config_file DiT4DiT/config/deepseeds/deepspeed_zero2.yaml \
  --num_processes ${num_processes} \
  DiT4DiT/training/train.py \
  --config_yaml ${config_yaml} \
  --framework.name ${Framework_name} \
  --framework.cosmos25.base_model ${base_model} \
  --datasets.vla_data.per_device_batch_size ${per_device_batch_size} \
  --datasets.vla_data.data_mix ${data_mix} \
  --trainer.max_train_steps ${max_train_steps} \
  --run_root_dir ${run_root_dir} \
  --run_id ${run_id} \
  2>&1 | tee ${LOG_FILE}

使用方式:

bash train_8p.sh --num_processes=8 --max_train_steps=100 --per_device_batch_size=2

五、训练配置(Robotwin)

新建配置文件 DiT4DiT/config/robotwin/dit4dit_robotwin.yaml,关键参数:

  • base_model: Cosmos-Predict2.5-2B
  • data_mix: robotwin
  • action_dim/state_dim: 14(双臂 7 关节 × 2)
  • training: action(仅动作训练,不含视频 loss)
  • future_loss_type: flow_matching
  • future_action_window_size: 7

六、踩坑记录

6.2 mx_driving.patcher API 变更

旧版 StarVLA 的 starvla.patch 使用 Patcher().add(TransformersNPU).apply(),新版 mx_driving 改为了 default_patcher_builder.build().__enter__()TransformersNPU 类已移除,功能整合进 default_patcher_builder 预设的 mmcv/torch/numpy/mmdet 全套 patch。

6.3 libc10.so / libtorch_npu.so 未自动发现

需手动设置 LD_LIBRARY_PATH 或写入 conda activate 钩子(见 2.5 节)。

6.4 Cosmos-Predict2.5-2B 显存占用

Cosmos2.5 包含 VAE + Transformer + Text Encoder 三部分。VAE 和 Text Encoder 冻结后(requires_grad=False),实际训练参数约 2.2B(Transformer + Action Head)。

6.5 VAE / Text Encoder 混合 dtype → ZeRO-3 defragment 报错

VAE 和 Text Encoder 权重为 fp32,Transformer 为 bf16。ZeRO-3 初始化时要求所有 trainable 参数 dtype 一致。冻结后解决。

6.6 eval_action_model 维度不匹配 + bf16 → numpy

eval 时 actions 和 action_mask 未截断到 action_horizon(8 步),需要 slice。同时 predict_action 的输出是 bf16,numpy 不支持,需 .float().numpy()

6.7 @torch.inference_mode() 与 ZeRO-3 冲突

ZeRO-3 的 LinearFunctionForZeroStage3 内部保存 tensor 用于反向,inference_mode 创建的 tensor 无法被保存。改用 @torch.no_grad()

6.8 ZeRO-3 模型保存卡死

accelerator.get_state_dict() 把 8 卡分片参数 gather 到 rank 0,2.3B 参数耗时数分钟。改用 model.save_checkpoint()(DeepSpeed 原生保存,每卡各自写 shard)。

七、ZeRO-2 vs ZeRO-3 配置与性能对比

ZeRO-3ZeRO-2
训练速度~3.14s/it~0.74s/it(实际训练不含保存)
per_device_batch_size162
gradient_accumulation18
等效 batch128128
数据吞吐量更高(大batch x 大步时间)较低
显存压力较高(参数分片,可开大batch)较低(参数完整保留)
模型保存save_checkpoint(分片)正常
配置复杂度较高(多个冲突点)较低
推荐场景正式训练(吞吐优先)调试/显存不足

八、适配 Patch

基于 commit 1ae6efd 生成的两个 patch 文件,已覆盖全部修改:

  • dit4dit_zero3.patch:ZeRO-3 适配版(含 NPU 移植 + 全部 Bug 修复)
  • dit4dit_zero2.patch:ZeRO-2 适配版(同上 + train.py 去硬编码 + ZeRO-2 配置)

使用方法:

git clone https://github.com/Mondo-Robotics/DiT4DiT.git
cd DiT4DiT
git checkout 1ae6efd
git apply dit4dit_zero3.patch  # 或 dit4dit_zero2.patch
pip install -r requirements.txt
pip install -e .

两个 patch 下载:


scanz个人博客