Files
mamba_chen 042384994a !244 删除open-r1-legacy
Merge pull request !244 from mamba_chen/master
2025-06-24 12:16:49 +00:00

4.8 KiB
Raw Permalink Blame History

基于昇腾NPU复现open-r1

open-r1项目是huggingface官方开源的对DeepSeek-R1模型进行完全开放式复现的项目是当前的主流复现项目其目的是构建DeepSeek-R1训练流程缺失的部分以便每个人都能在此基础上构建复现R1当前已经有24k+star数。

昇腾已适配完成open-r1项目的重要步骤打通R1-Zero的GRPO流程同时支持通过VLLM等生态库实现训练过程中的数据生产从而验证了通过昇腾训练出DeepSeek-R1-Zero以及DeepSeek-R1模型的可行性。

环境配置

支持的设备

  • Atlas A2 训练系列 (Atlas 800T A2, Atlas 900 A2 PoD)

环境依赖

依赖 推荐版本
Python 3.10
CANN 在研版本*
NNAL 在研版本*
torch-npu 在研版本*
torch 2.6.0
torchvision 0.21.0
  • *在研版本请联系相关人员获取,获得当前较优的性能。

安装vLLM

git clone https://github.com/vllm-project/vllm.git
cd vllm
git checkout 68bb122eb
pip install -r requirements/build.txt
VLLM_TARGET_DEVICE=empty pip install -e .

安装vllm-ascend

git clone https://github.com/vllm-project/vllm-ascend.git
cd vllm-ascend
git checkout c3d1a3782
COMPILE_CUSTOM_KERNELS=0 pip install -e .

安装trl

git clone https://github.com/huggingface/trl.git
cd trl
git checkout 27adc3016
pip install -e .

安装open-r1

在当前目录执行以下命令:

git clone https://github.com/huggingface/open-r1.git
cd open-r1
git checkout e128cd5edcdcb86d577250b14848357e3af807f1
# 从本项目中拷贝部分内容至本地open-rl代码仓中
cp -r ../recipes/Qwen2.5-7B-Instruct ./recipes/Qwen2.5-7B-Instruct
cp ../setup.py ./setup.py
pip install -e ".[dev]"

执行GRPO训练

单机


# 在trl路径下执行
# 启动推理server
trl vllm-serve --model path/to/Qwen2.5-7B-Instruct --tensor_parallel_size 1

# 在open-r1路径下执行
# 启动训练
ASCEND_RT_VISIBLE_DEVICES=1,2,3,4,5,6,7 ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml --num_processes 7\
    src/open_r1/grpo.py \
    --config recipes/Qwen2.5-7B-Instruct/grpo/config_demo.yaml --vllm_server_host 127.0.0.1

多机

在主节点执行:

cd trl

# 在trl路径下执行
# 启动推理server
trl vllm-serve --model path/to/Qwen2.5-7B-Instruct --tensor_parallel_size 1

# 在open-r1路径下执行
# 启动训练
ASCEND_RT_VISIBLE_DEVICES=1,2,3,4,5,6,7 ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml\ 
    --num_processes 14 --num_machines 2 --main_process_ip x.x.x.x(主节点ip) --main_process_port 12345 --machine_rank 0 \
    src/open_r1/grpo.py \
    --config recipes/Qwen2.5-7B-Instruct/grpo/config_demo.yaml --vllm_server_host x.x.x.x(主节点ip)

在次节点执行:


# 在open-r1路径下执行
# 启动训练
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml \
    --num_processes 14 --num_machines 2 --main_process_ip x.x.x.x(主节点ip) --main_process_port 12345 --machine_rank 1 \
    src/open_r1/grpo.py \
    --config recipes/Qwen2.5-7B-Instruct/grpo/config_demo.yaml --vllm_server_host x.x.x.x(主节点ip)

基于Qwen2.5-7B-Instrct模型和MATH-lighteval数据集训练的相关结果图如下

img_1.png

训练现象:

1、在30次迭代之后accuracy_reward稳定到0.6以上峰值约为0.8

2、在10次迭代之后模型已基本完全学习到正确的格式<think>...</think>\n<answer>...</answer>,我们采集了部分结果如下:

Aha moment
<think>...So,the number of band menbers is\(12\).To verify,we can check:- Initially,with 12 band members and 6rows of 2 members each,there is no remainder. ...</think><answer>The largest number of members the band could have is\(\boxed{78}\).</answer>

评测结果:

我们基于MATH-500数据集对Qwen2.5-7B-Instruct和GRPO 30steps后的权重进行评测得分从41.8提升至73

模型 MATH-500得分
Qwen2.5-7B-Instruct 41.8
Qwen2.5-7B-Instruct + GRPO 30steps 73

FQA

  • 如果出现 numpy 版本冲突,请安装 1.26.0 版本