本文记录将 NVIDIA GR00T N1.7(3B VLA 模型)迁移到华为 Ascend 910B 进行分布式训练的全过程。从环境搭建、mx_driving 补丁、代码适配到成功跑通 DeepSpeed ZeRO-2 8 卡训练,逐一记录。
一、环境概览
| 组件 | 版本 |
|---|---|
| 硬件 | 8 × Ascend 910B (aarch64) |
| Python | 3.10 (conda torch2.7.1) |
| PyTorch | 2.7.1 + torch_npu 2.7.1.post2 |
| CANN | 9.0.0 |
| transformers | 4.57.3 |
| diffusers | 0.35.1 |
| DeepSpeed | 0.18.4 (ZeRO-2) |
| 数据集 | LIBERO-10 (379 episodes, 95K steps) |
| 步时 | ~8s/step (global_batch_size=640) |
二、系统依赖
apt install -y \
libavcodec-dev libavformat-dev libavutil-dev \
libswscale-dev libswresample-dev libavfilter-dev \
libavdevice-dev pkg-config ffmpeg
三、mx_driving 补丁(最关键的一步)
pip 安装的 mx_driving v1.0.20260421 是旧版 API,缺少 GR00T N1.7 需要的 TransformersNPU 和 DiffusersNPU 类。需从 DrivingSDK 源码中补齐 8 个文件并合并新旧 patcher API。
- 从
/data/DrivingSDK/mx_driving/patcher/拷贝业务文件:transformers_patch.py、diffusers_patch.py - 拷贝新版框架文件:
patch.py、version.py、reporting.py、patcher_logger.py - 替换
patcher.py为新版 fluent API(Patcher().add().apply()) - 重写
__init__.py,桥接新旧导出,添加default_patcher_builder = None
最终打成 wheel 包 mx_driving-1.0.20260421.1-py3-none-any.whl。注意必须包含原版 _C.cpython-310-aarch64-linux-gnu.so(从同版本 conda 环境如 bevfusion 拷贝),否则会报 ModuleNotFoundError: _C。
四、GR00T 代码适配(5 个文件 patch)
4.1 分布式后端(3 处)
原代码仅检查 backend == "nccl",NPU 上后端为 hccl。需添加 hccl 分支返回 NPU 设备。
| 文件 | 函数 | 修改 |
|---|---|---|
| dist_utils.py | _collective_device() | 加 hccl → npu device |
| utils.py | device 选择 | backend in ("nccl", "hccl") |
| sharded_mixture_dataset.py | _get_default_pg_tensor_device() | 加 hccl → npu device |
4.2 pin_memory(trainer.py)
pin_memory 是 CUDA 特性,NPU 会报 cannot pin 'npuFloatType'。硬编码为 False。
4.3 get_device_capability 容错(dit.py)
_is_spark_sm121() 在每次前向中调用,NPU 返回 None 导致解包异常。用 warnings.catch_warnings() 静默 + try-except 容错。
try:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", ".*get_device_capability.*")
major, minor = torch.cuda.get_device_capability()
return (major, minor) == (12, 1)
except (TypeError, RuntimeError):
return False
五、训练脚本
两个脚本:train_8p.sh(完整训练 20000 步)、train_performance_8p.sh(性能测试 1000 步 + FPS 计算)。关键参数:
dataloader_num_workers=0— NPU 不支持 fork 子进程embodiment_tag=LIBERO_PANDA— 预注册标签,自带 modality configLD_LIBRARY_PATH含 torch/lib 和 torch_npu/lib
六、排错记录
| # | 错误 | 根因 | 修复 |
|---|---|---|---|
| 1 | ImportError: TransformersNPU | mx_driving pip 包版本过旧 | 从源码补 8 个 patcher 文件,打 wheel |
| 2 | Two accelerators: npu and npu | mx_driving._C 和 torch_npu 双重注册 | transfer_to_npu 在 import torch_npu 前生效 |
| 3 | Unknown tag: fourier_gr1 | N1.7 tag 体系变了 | 改用 LIBERO_PANDA |
| 4 | Cosmos-Reason2-2B 403 | Gated HF 模型需授权 | 申请授权后下载到本地 |
| 5 | No backend type for cpu tensor | hccl 不认 CPU tensor | 3 处 nccl→nccl+hccl 适配 |
| 6 | Cannot re-init NPU in forked subprocess | DataLoader fork 限制 | num_workers=0 |
| 7 | cannot pin 'npuFloatType' | pin_memory CUDA 特有不兼容 | trainer.py 硬编码 False |
| 8 | get_device_capability returns None | NPU API 未实现 | catch_warnings + try-except |
七、训练结果
8 卡 910B,DeepSpeed ZeRO-2,每卡约 9GB HBM。LIBERO-10 数据集,94 shards,95K steps。
Loss 下降曲线健康:1.28 → 1.18 → 0.90 → 0.53 → 0.28 → 0.16 → 0.12,无震荡。Grad Norm 在 0.3~2.0 区间。~8s/step (bs=640)。MX-DRIVING PATCHER 正确应用了 Qwen3RMSNorm、RoPE 等 9 个 NPU 适配。
八、输出物
| 文件 | 说明 |
|---|---|
| gr00t_n1.7_ascend.patch | 5 文件精确补丁 |
| mx_driving-1.0.20260421.1-py3-none-any.whl | mx_driving 补丁 wheel |
| deploy.sh | 一键部署脚本 |
| pyproject_ascend.toml | 适配后依赖声明 |
| train_8p.sh | 完整训练脚本 |
| train_performance_8p.sh | 性能测试脚本(含 FPS 计算) |
以上文件均在 ascend_patch/ 目录下。
简记。







