@ -46,7 +46,7 @@ openMind Library目前支持的特性如下:
|
||||
|
||||
| 模型蒸馏 | DeepSeek-R1-Distill系列LLM模型微调 | Open-R1复现 |
|
||||
|:-----------------------------------------------------|:-----------------------------------------------------------------------------|:----------------------------------------------------|
|
||||
| 在研中,详情请见[模型蒸馏](./docs/zh/best_practice/deepseek_r1.md#模型蒸馏)章节 | 在研中,详情请见[DeepSeek-R1-Distill模型微调](./docs/zh/best_practice/deepseek_r1.md#deepseek-r1-distill模型微调)章节 | 在研中,详情请见[基于昇腾NPU复现open-r1](./examples/research/open_r1/README.md)文档 |
|
||||
| 在研中,详情请见[模型蒸馏](./docs/zh/best_practice/deepseek_r1.md#模型蒸馏)章节 | 在研中,详情请见[DeepSeek-R1-Distill模型微调](./docs/zh/best_practice/deepseek_r1.md#deepseek-r1-distill模型微调)章节 | 在研中,详情请见[基于昇腾NPU复现open-r1](examples/research/open_r1/README.md)文档 |
|
||||
|
||||
---
|
||||
|
||||
|
@ -1,174 +1,108 @@
|
||||
# 基于昇腾NPU复现open-r1
|
||||
|
||||
open-r1项目是Hugging Face官方开源的对DeepSeek-R1模型进行完全开放式复现的项目,是当前的主流复现项目,其目的是构建DeepSeek-R1训练流程缺失的部分,以便每个人都能在此基础上构建复现R1,当前已经有23k+star数。
|
||||
open-r1项目是huggingface官方开源的对DeepSeek-R1模型进行完全开放式复现的项目,是当前的主流复现项目,其目的是构建DeepSeek-R1训练流程缺失的部分,以便每个人都能在此基础上构建复现R1,当前已经有24k+star数。
|
||||
|
||||
本项目的目的为基于昇腾NPU进行open-r1项目的适配和验证。
|
||||
昇腾已适配完成open-r1项目的重要步骤:打通R1-Zero的GRPO流程,同时支持通过VLLM等生态库实现训练过程中的数据生产,从而验证了通过昇腾训练出DeepSeek-R1-Zero以及DeepSeek-R1模型的可行性。
|
||||
|
||||

|
||||
|
||||
上图所示为open-r1项目中呈现的3个step,我们对其进行了适配复现:
|
||||
|
||||
step1:蒸馏复刻,使用DeepSeek-R1构造推理思维链数据,并使用小模型进行SFT,我们基于Qwen2.5-7B-Instruct模型和开源的Sky-T1_data_17k在昇腾NPU验证了step1的有效性。具体实验步骤可以参考文档:[在NPU上进行模型蒸馏和微调DeepSeek-R1-Distill系列模型](../../../docs/zh/best_practice/deepseek_r1.md)。
|
||||
|
||||
step2:通过GRPO算法复现R1-Zero流程。我们基于Qwen2.5-7B-Instrct模型在昇腾NPU上进行了验证,可以观察到reward在少数迭代之后快速上升的现象,并且观察到了Aha Moment。
|
||||
|
||||
step3:多阶段训练,从基础模型到RL调优,我们基于Qwen2.5-7B模型和`OpenR1-Math-220k`处理后的数据集进行了SFT与GRPO,在MATH-500上评测结果为:54.8->75.2->79.6。
|
||||
|
||||
下文为具体的环境依赖、执行过程和实验结果。
|
||||
|
||||
**注意:当前版本仍为在研版本,将会持续更新。**
|
||||
|
||||
|
||||
## 1、版本依赖
|
||||
## 环境配置
|
||||
|
||||
### 支持的设备
|
||||
- Atlas A2 训练系列 (Atlas 800T A2, Atlas 900 A2 PoD)
|
||||
|
||||
### 版本要求
|
||||
| 依赖 | 推荐版本 |
|
||||
|-----------|-------------------------------------------------------------------|
|
||||
| python | [3.10](https://www.python.org/downloads/) |
|
||||
| CANN | 在研版本* |
|
||||
| torch-npu | 在研版本* | |
|
||||
| torch | [2.5.1](https://github.com/pytorch/pytorch/releases/tag/v2.5.1) |
|
||||
### 环境依赖
|
||||
| 依赖 | 推荐版本 |
|
||||
|-----------|----------------------------------------------------------------------------------------------------------|
|
||||
| Python | [3.10](https://www.python.org/downloads/) |
|
||||
| CANN | 在研版本* |
|
||||
| torch-npu | 在研版本* |
|
||||
| torch | [2.6.0](https://github.com/pytorch/pytorch/releases/tag/v2.6.0) |
|
||||
|
||||
* *在研版本请联系相关人员获取,获得当前较优的性能。如果使用社区版本,可以参考文档[通过社区版本执行open-r1复现使用说明](./README_RC3.md)。
|
||||
* *在研版本请联系相关人员获取,获得当前较优的性能。
|
||||
|
||||
## 2、环境配置
|
||||
|
||||
### 步骤一、安装vLLM
|
||||
### 安装vLLM
|
||||
|
||||
```shell
|
||||
git clone -b v0.7.1 https://github.com/vllm-project/vllm.git
|
||||
git clone https://github.com/vllm-project/vllm.git
|
||||
cd vllm
|
||||
git checkout 83f9ce4932247647c7ff91b457a5d72682d68f0d
|
||||
pip install -r requirements-build.txt
|
||||
VLLM_TARGET_DEVICE=empty pip install -e .
|
||||
```
|
||||
|
||||
### 步骤二、安装vllm-ascend
|
||||
### 安装vllm-ascend
|
||||
|
||||
```shell
|
||||
git clone -b v0.7.1-dev https://github.com/vllm-project/vllm-ascend.git
|
||||
git clone https://github.com/vllm-project/vllm-ascend.git
|
||||
cd vllm-ascend
|
||||
git checkout e8131b99cf199f50a304e6e6fb125a1b95bcc92b
|
||||
pip install -e .
|
||||
git checkout c805d9a0c318bb42dec64263686be5066facf02f
|
||||
COMPILE_CUSTOM_KERNELS=0 pip install -e .
|
||||
```
|
||||
|
||||
### 步骤三、安装TRL
|
||||
### 安装trl
|
||||
|
||||
在openmind/examples/research/open_r1目录执行以下命令:
|
||||
```shell
|
||||
git clone https://github.com/huggingface/trl.git
|
||||
cd trl
|
||||
git checkout ad3b24b0fdfb28a3089a5d8cd88eedf667cf32c3
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### 步骤四、安装open-r1
|
||||
### 安装open-r1
|
||||
|
||||
在openmind/examples/research/open_r1目录执行以下命令:
|
||||
在当前目录执行以下命令:
|
||||
```shell
|
||||
git clone https://github.com/huggingface/open-r1.git
|
||||
cd open-r1
|
||||
git checkout e128cd5edcdcb86d577250b14848357e3af807f1
|
||||
cp -r ../recipes/Qwen2.5-7B-Instruct ./recipes/Qwen2.5-7B-Instruct
|
||||
cp ../setup.py ./set_up.py
|
||||
pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
## 3、执行open-r1中的step2:GRPO算法
|
||||
## 执行GRPO训练
|
||||
|
||||
### 单机
|
||||
|
||||
在openmind/examples/research/open_r1目录执行以下命令:
|
||||
```shell
|
||||
cd open-r1
|
||||
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml --num_processes 7\
|
||||
#启动推理server
|
||||
trl vllm-serve --model path/to/Qwen2.5-7B-Instruct --tensor_parallel_size 1
|
||||
|
||||
# 启动训练
|
||||
ASCEND_RT_VISIBLE_DEVICES=1,2,3,4,5,6,7 ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml --num_processes 7\
|
||||
src/open_r1/grpo.py \
|
||||
--config recipes/Qwen2.5-7B-Instruct/grpo/config_demo.yaml
|
||||
--config recipes/Qwen2.5-7B-Instruct/grpo/config_demo.yaml --vllm_server_host 127.0.0.1
|
||||
```
|
||||
|
||||
基于Qwen2.5-7B-Instrct模型和MATH-lighteval数据集训练的相关结果图如下:
|
||||
### 多机
|
||||
|
||||

|
||||
|
||||
训练现象:
|
||||
|
||||
1、在30次迭代之后,accuracy_reward稳定到0.6以上,峰值约为0.8
|
||||
|
||||
2、在10次迭代之后,模型已基本完全学习到正确的格式`<think>...</think>\n<answer>...</answer>`,我们采集了部分结果如下:
|
||||
```shell
|
||||
Aha moment
|
||||
<think>...So,the number of band menbers is\(12\).To verify,we can check:- Initially,with 12 band members and 6rows of 2 members each,there is no remainder. ...</think><answer>The largest number of members the band could have is\(\boxed{78}\).</answer>
|
||||
```
|
||||
|
||||
评测结果:
|
||||
|
||||
我们基于MATH-500数据集对Qwen2.5-7B-Instruct和GRPO 30steps后的权重进行评测,得分从41.8提升至73:
|
||||
|
||||
| **模型** | **MATH-500得分** |
|
||||
|------------------------------------|----------------|
|
||||
| Qwen2.5-7B-Instruct | 41.8 |
|
||||
| Qwen2.5-7B-Instruct + GRPO 30steps | 73 |
|
||||
|
||||
## 4、执行open-r1中的step3:SFT+GRPO算法
|
||||
|
||||
我们基于Qwen2.5-7B模型复现step3,实验结果和启动方式如下:
|
||||
|
||||
**步骤一 SFT**
|
||||
|
||||
我们使用openMind进行SFT过程。
|
||||
|
||||
1、准备数据集
|
||||
|
||||
SFT阶段使用的数据集为从`OpenR1-Math-220k`处理得到的数据集:[openmind/OpenR1-Math-220k_filtered_step3_SFT](https://modelers.cn/datasets/openmind/OpenR1-Math-220k_filtered_step3_SFT)
|
||||
|
||||
2、更新微调配置
|
||||
|
||||
- 微调配置为`examples/qwen2.5/train_sft_qwen2_5_7b_openr1.yaml`。
|
||||
- 若模型在本地,可将`model_id`改为`model_name_or_path`,并将对应值改为模型本地路径, 同时请在yaml文件中增加template字段,值可参见[此处](../../../docs/zh/basic_tutorial/train/train_params.md#模型数据配置模板)设定
|
||||
- 微调后的模型保存在`output_dir`下。
|
||||
- 若需要按照step保存checkpoint,可在yaml文件中添加参数`save_strategy: steps`。
|
||||
|
||||
3、启动微调
|
||||
```shell
|
||||
openmind-cli train openmind/examples/qwen2.5/train_sft_qwen2_5_7b_openr1.yaml
|
||||
```
|
||||
|
||||
4、评测结果
|
||||
|
||||
我们基于MATH-500对比了sft前后的评估数值(base模型加上few-shot1进行评估),结果如下:
|
||||
|
||||
| **模型**| **MATH-500得分**|
|
||||
|---------|----------------|
|
||||
| Qwen2.5-7B | 54.8|
|
||||
| Qwen2.5-7B + SFT | 75.2|
|
||||
|
||||
**步骤二 GRPO**
|
||||
|
||||
1、准备数据集
|
||||
|
||||
GRPO使用的数据集为从`OpenR1-Math-220k`过滤得到的数据集:[openmind/OpenR1-Math-220k_filtered_step3_GRPO](https://modelers.cn/datasets/openmind/OpenR1-Math-220k_filtered_step3_GRPO),通过以下命令将数据集下载到本地。
|
||||
```shell
|
||||
git clone https://modelers.cn/datasets/openmind/OpenR1-Math-220k_filtered_step3_GRPO.git
|
||||
```
|
||||
|
||||
2、更新微调配置
|
||||
|
||||
- 微调配置为`recipes/Qwen2.5-7B-step3/GRPO/config_demo.yaml`。
|
||||
- 需要将`model_name_or_path`和`dataset_name`改为模型和数据集的本地路径。
|
||||
- 模型保存在`output_dir`下。
|
||||
|
||||
3、启动GRPO训练
|
||||
在主节点执行:
|
||||
|
||||
```shell
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml --num_processes 7\
|
||||
cd open-r1
|
||||
|
||||
#启动推理server
|
||||
trl vllm-serve --model path/to/Qwen2.5-7B-Instruct --tensor_parallel_size 1
|
||||
|
||||
# 启动训练
|
||||
ASCEND_RT_VISIBLE_DEVICES=1,2,3,4,5,6,7 ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml\
|
||||
--num_processes 14 --num_machines 2 --main_process_ip x.x.x.x(主节点ip) --main_process_port 12345 --machine_rank 0 \
|
||||
src/open_r1/grpo.py \
|
||||
--config recipes/Qwen2.5-1.5B-step3/GRPO/config_demo.yaml
|
||||
--config recipes/Qwen2.5-7B-Instruct/grpo/config_demo.yaml --vllm_server_host x.x.x.x(主节点ip)
|
||||
```
|
||||
|
||||
4、评测结果
|
||||
在次节点执行:
|
||||
|
||||
| **模型** | **MATH-500得分** |
|
||||
|-------------------------|----------------|
|
||||
| Qwen2.5-7B | 54.8 |
|
||||
| Qwen2.5-7B + SFT | 75.2 |
|
||||
| Qwen2.5-7B + SFT + GRPO | 79.6 |
|
||||
```shell
|
||||
cd open-r1
|
||||
|
||||
整个流程在MATH-500上的评分提升了24.8
|
||||
# 启动训练
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml \
|
||||
--num_processes 14 --num_machines 2 --main_process_ip x.x.x.x(主节点ip) --main_process_port 12345 --machine_rank 1 \
|
||||
src/open_r1/grpo.py \
|
||||
--config recipes/Qwen2.5-7B-Instruct/grpo/config_demo.yaml --vllm_server_host x.x.x.x(主节点ip)
|
||||
```
|
||||
|
||||
## FQA
|
||||
- 如果出现 numpy 版本冲突,请安装 1.26.0 版本
|
||||
|
139
examples/research/open_r1/setup.py
Normal file
139
examples/research/open_r1/setup.py
Normal file
@ -0,0 +1,139 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Adapted from huggingface/transformers: https://github.com/huggingface/transformers/blob/21a2d900eceeded7be9edc445b56877b95eda4ca/setup.py
|
||||
|
||||
# This project includes modifications to the original codebase:
|
||||
# All email addresses and personal identifiers have been removed.
|
||||
|
||||
import re
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
stale_egg_info = Path(__file__).parent / "open_r1.egg-info"
|
||||
if stale_egg_info.exists():
|
||||
print(
|
||||
(
|
||||
"Warning: {} exists.\n\n"
|
||||
"If you recently updated open_r1, this is expected,\n"
|
||||
"but it may prevent open_r1 from installing in editable mode.\n\n"
|
||||
"This directory is automatically generated by Python's packaging tools.\n"
|
||||
"I will remove it now.\n\n"
|
||||
).format(stale_egg_info)
|
||||
)
|
||||
shutil.rmtree(stale_egg_info)
|
||||
|
||||
|
||||
# IMPORTANT: all dependencies should be listed here with their version requirements, if any.
|
||||
# * If a dependency is fast-moving (e.g. transformers), pin to the exact version
|
||||
_deps = [
|
||||
"accelerate>=1.2.1",
|
||||
# "bitsandbytes>=0.43.0",
|
||||
"datasets>=3.2.0",
|
||||
"deepspeed==0.15.4",
|
||||
"distilabel[vllm,ray,openai]>=1.5.2",
|
||||
"e2b-code-interpreter>=1.0.5",
|
||||
"einops>=0.8.0",
|
||||
"flake8>=6.0.0",
|
||||
# "flash_attn>=2.7.4.post1",
|
||||
"hf_transfer>=0.1.4",
|
||||
"huggingface-hub[cli]>=0.19.2,<1.0",
|
||||
"isort>=5.12.0",
|
||||
"latex2sympy2_extended>=1.0.6",
|
||||
"liger_kernel==0.5.2",
|
||||
"lighteval @ git+https://github.com/huggingface/lighteval.git@86f62259f105ae164f655e0b91c92a823a742724#egg=lighteval[math]",
|
||||
"math-verify==0.5.2", # Used for math verification in grpo
|
||||
"packaging>=23.0",
|
||||
"parameterized>=0.9.0",
|
||||
"peft>=0.14.0",
|
||||
"pytest",
|
||||
"python-dotenv",
|
||||
"ruff>=0.9.0",
|
||||
"safetensors>=0.3.3",
|
||||
"sentencepiece>=0.1.99",
|
||||
"torch==2.6.0",
|
||||
"wandb>=0.19.1",
|
||||
]
|
||||
|
||||
# this is a lookup table with items like:
|
||||
#
|
||||
# tokenizers: "tokenizers==0.9.4"
|
||||
# packaging: "packaging"
|
||||
#
|
||||
# some of the values are versioned whereas others aren't.
|
||||
deps = {b: a for a, b in (re.findall(r"^(([^!=<>~ \[\]]+)(?:\[[^\]]+\])?(?:[!=<>~ ].*)?$)", x)[0] for x in _deps)}
|
||||
|
||||
|
||||
def deps_list(*pkgs):
|
||||
return [deps[pkg] for pkg in pkgs]
|
||||
|
||||
|
||||
extras = {}
|
||||
extras["tests"] = deps_list("pytest", "parameterized", "math-verify")
|
||||
extras["torch"] = deps_list("torch")
|
||||
extras["quality"] = deps_list("ruff", "isort", "flake8")
|
||||
# extras["train"] = deps_list("flash_attn")
|
||||
extras["code"] = deps_list("e2b-code-interpreter", "python-dotenv")
|
||||
extras["eval"] = deps_list("lighteval", "math-verify")
|
||||
extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"] #+ extras["train"]
|
||||
|
||||
# core dependencies shared across the whole project - keep this to a bare minimum :)
|
||||
install_requires = [
|
||||
deps["accelerate"],
|
||||
# deps["bitsandbytes"],
|
||||
deps["einops"],
|
||||
deps["datasets"],
|
||||
deps["deepspeed"],
|
||||
deps["hf_transfer"],
|
||||
deps["huggingface-hub"],
|
||||
deps["latex2sympy2_extended"],
|
||||
deps["math-verify"],
|
||||
# deps["liger_kernel"],
|
||||
deps["packaging"], # utilities from PyPA to e.g., compare versions
|
||||
deps["safetensors"],
|
||||
deps["sentencepiece"],
|
||||
# deps["transformers"],
|
||||
# deps["trl"],
|
||||
]
|
||||
|
||||
setup(
|
||||
name="open-r1",
|
||||
version="0.1.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="The Hugging Face team (past and future)",
|
||||
description="Open R1",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
keywords="llm inference-time compute reasoning",
|
||||
license="Apache",
|
||||
package_dir={"": "src"},
|
||||
packages=find_packages("src"),
|
||||
zip_safe=False,
|
||||
extras_require=extras,
|
||||
python_requires=">=3.10.9",
|
||||
install_requires=install_requires,
|
||||
classifiers=[
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: Education",
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
],
|
||||
)
|
174
examples/research/open_r1_legacy/README.md
Normal file
174
examples/research/open_r1_legacy/README.md
Normal file
@ -0,0 +1,174 @@
|
||||
# 基于昇腾NPU复现open-r1
|
||||
|
||||
open-r1项目是Hugging Face官方开源的对DeepSeek-R1模型进行完全开放式复现的项目,是当前的主流复现项目,其目的是构建DeepSeek-R1训练流程缺失的部分,以便每个人都能在此基础上构建复现R1,当前已经有23k+star数。
|
||||
|
||||
本项目的目的为基于昇腾NPU进行open-r1项目的适配和验证。
|
||||
|
||||

|
||||
|
||||
上图所示为open-r1项目中呈现的3个step,我们对其进行了适配复现:
|
||||
|
||||
step1:蒸馏复刻,使用DeepSeek-R1构造推理思维链数据,并使用小模型进行SFT,我们基于Qwen2.5-7B-Instruct模型和开源的Sky-T1_data_17k在昇腾NPU验证了step1的有效性。具体实验步骤可以参考文档:[在NPU上进行模型蒸馏和微调DeepSeek-R1-Distill系列模型](../../../docs/zh/best_practice/deepseek_r1.md)。
|
||||
|
||||
step2:通过GRPO算法复现R1-Zero流程。我们基于Qwen2.5-7B-Instrct模型在昇腾NPU上进行了验证,可以观察到reward在少数迭代之后快速上升的现象,并且观察到了Aha Moment。
|
||||
|
||||
step3:多阶段训练,从基础模型到RL调优,我们基于Qwen2.5-7B模型和`OpenR1-Math-220k`处理后的数据集进行了SFT与GRPO,在MATH-500上评测结果为:54.8->75.2->79.6。
|
||||
|
||||
下文为具体的环境依赖、执行过程和实验结果。
|
||||
|
||||
**注意:当前版本仍为在研版本,将会持续更新。**
|
||||
|
||||
|
||||
## 1、版本依赖
|
||||
|
||||
### 支持的设备
|
||||
- Atlas A2 训练系列 (Atlas 800T A2, Atlas 900 A2 PoD)
|
||||
|
||||
### 版本要求
|
||||
| 依赖 | 推荐版本 |
|
||||
|-----------|-------------------------------------------------------------------|
|
||||
| python | [3.10](https://www.python.org/downloads/) |
|
||||
| CANN | 在研版本* |
|
||||
| torch-npu | 在研版本* | |
|
||||
| torch | [2.5.1](https://github.com/pytorch/pytorch/releases/tag/v2.5.1) |
|
||||
|
||||
* *在研版本请联系相关人员获取,获得当前较优的性能。如果使用社区版本,可以参考文档[通过社区版本执行open-r1复现使用说明](./README_RC3.md)。
|
||||
|
||||
## 2、环境配置
|
||||
|
||||
### 步骤一、安装vLLM
|
||||
|
||||
```shell
|
||||
git clone -b v0.7.1 https://github.com/vllm-project/vllm.git
|
||||
cd vllm
|
||||
pip install -r requirements-build.txt
|
||||
VLLM_TARGET_DEVICE=empty pip install -e .
|
||||
```
|
||||
|
||||
### 步骤二、安装vllm-ascend
|
||||
|
||||
```shell
|
||||
git clone -b v0.7.1-dev https://github.com/vllm-project/vllm-ascend.git
|
||||
cd vllm-ascend
|
||||
git checkout e8131b99cf199f50a304e6e6fb125a1b95bcc92b
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### 步骤三、安装TRL
|
||||
|
||||
在openmind/examples/research/open_r1目录执行以下命令:
|
||||
```shell
|
||||
cd trl
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### 步骤四、安装open-r1
|
||||
|
||||
在openmind/examples/research/open_r1目录执行以下命令:
|
||||
```shell
|
||||
cd open-r1
|
||||
pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
## 3、执行open-r1中的step2:GRPO算法
|
||||
|
||||
在openmind/examples/research/open_r1目录执行以下命令:
|
||||
```shell
|
||||
cd open-r1
|
||||
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml --num_processes 7\
|
||||
src/open_r1/grpo.py \
|
||||
--config recipes/Qwen2.5-7B-Instruct/grpo/config_demo.yaml
|
||||
```
|
||||
|
||||
基于Qwen2.5-7B-Instrct模型和MATH-lighteval数据集训练的相关结果图如下:
|
||||
|
||||

|
||||
|
||||
训练现象:
|
||||
|
||||
1、在30次迭代之后,accuracy_reward稳定到0.6以上,峰值约为0.8
|
||||
|
||||
2、在10次迭代之后,模型已基本完全学习到正确的格式`<think>...</think>\n<answer>...</answer>`,我们采集了部分结果如下:
|
||||
```shell
|
||||
Aha moment
|
||||
<think>...So,the number of band menbers is\(12\).To verify,we can check:- Initially,with 12 band members and 6rows of 2 members each,there is no remainder. ...</think><answer>The largest number of members the band could have is\(\boxed{78}\).</answer>
|
||||
```
|
||||
|
||||
评测结果:
|
||||
|
||||
我们基于MATH-500数据集对Qwen2.5-7B-Instruct和GRPO 30steps后的权重进行评测,得分从41.8提升至73:
|
||||
|
||||
| **模型** | **MATH-500得分** |
|
||||
|------------------------------------|----------------|
|
||||
| Qwen2.5-7B-Instruct | 41.8 |
|
||||
| Qwen2.5-7B-Instruct + GRPO 30steps | 73 |
|
||||
|
||||
## 4、执行open-r1中的step3:SFT+GRPO算法
|
||||
|
||||
我们基于Qwen2.5-7B模型复现step3,实验结果和启动方式如下:
|
||||
|
||||
**步骤一 SFT**
|
||||
|
||||
我们使用openMind进行SFT过程。
|
||||
|
||||
1、准备数据集
|
||||
|
||||
SFT阶段使用的数据集为从`OpenR1-Math-220k`处理得到的数据集:[openmind/OpenR1-Math-220k_filtered_step3_SFT](https://modelers.cn/datasets/openmind/OpenR1-Math-220k_filtered_step3_SFT)
|
||||
|
||||
2、更新微调配置
|
||||
|
||||
- 微调配置为`examples/qwen2.5/train_sft_qwen2_5_7b_openr1.yaml`。
|
||||
- 若模型在本地,可将`model_id`改为`model_name_or_path`,并将对应值改为模型本地路径, 同时请在yaml文件中增加template字段,值可参见[此处](../../../docs/zh/basic_tutorial/train/train_params.md#模型数据配置模板)设定
|
||||
- 微调后的模型保存在`output_dir`下。
|
||||
- 若需要按照step保存checkpoint,可在yaml文件中添加参数`save_strategy: steps`。
|
||||
|
||||
3、启动微调
|
||||
```shell
|
||||
openmind-cli train openmind/examples/qwen2.5/train_sft_qwen2_5_7b_openr1.yaml
|
||||
```
|
||||
|
||||
4、评测结果
|
||||
|
||||
我们基于MATH-500对比了sft前后的评估数值(base模型加上few-shot1进行评估),结果如下:
|
||||
|
||||
| **模型**| **MATH-500得分**|
|
||||
|---------|----------------|
|
||||
| Qwen2.5-7B | 54.8|
|
||||
| Qwen2.5-7B + SFT | 75.2|
|
||||
|
||||
**步骤二 GRPO**
|
||||
|
||||
1、准备数据集
|
||||
|
||||
GRPO使用的数据集为从`OpenR1-Math-220k`过滤得到的数据集:[openmind/OpenR1-Math-220k_filtered_step3_GRPO](https://modelers.cn/datasets/openmind/OpenR1-Math-220k_filtered_step3_GRPO),通过以下命令将数据集下载到本地。
|
||||
```shell
|
||||
git clone https://modelers.cn/datasets/openmind/OpenR1-Math-220k_filtered_step3_GRPO.git
|
||||
```
|
||||
|
||||
2、更新微调配置
|
||||
|
||||
- 微调配置为`recipes/Qwen2.5-7B-step3/GRPO/config_demo.yaml`。
|
||||
- 需要将`model_name_or_path`和`dataset_name`改为模型和数据集的本地路径。
|
||||
- 模型保存在`output_dir`下。
|
||||
|
||||
3、启动GRPO训练
|
||||
|
||||
```shell
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml --num_processes 7\
|
||||
src/open_r1/grpo.py \
|
||||
--config recipes/Qwen2.5-1.5B-step3/GRPO/config_demo.yaml
|
||||
```
|
||||
|
||||
4、评测结果
|
||||
|
||||
| **模型** | **MATH-500得分** |
|
||||
|-------------------------|----------------|
|
||||
| Qwen2.5-7B | 54.8 |
|
||||
| Qwen2.5-7B + SFT | 75.2 |
|
||||
| Qwen2.5-7B + SFT + GRPO | 79.6 |
|
||||
|
||||
整个流程在MATH-500上的评分提升了24.8
|
||||
|
||||
## FQA
|
||||
- 如果出现 numpy 版本冲突,请安装 1.26.0 版本
|
Before Width: | Height: | Size: 80 KiB After Width: | Height: | Size: 80 KiB |
Before Width: | Height: | Size: 687 KiB After Width: | Height: | Size: 687 KiB |
Before Width: | Height: | Size: 371 KiB After Width: | Height: | Size: 371 KiB |
@ -0,0 +1,50 @@
|
||||
# Model arguments
|
||||
model_name_or_path: Qwen/Qwen2.5-7B-Instruct
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: eager
|
||||
|
||||
# Data training arguments
|
||||
dataset_name: DigitalLearningGmbH/MATH-lighteval
|
||||
dataset_configs:
|
||||
- train
|
||||
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
|
||||
|
||||
# GRPO trainer config
|
||||
bf16: true
|
||||
use_vllm: true
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.8
|
||||
do_eval: false
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
learning_rate: 3.0e-06
|
||||
log_completions: true
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 1024
|
||||
max_steps: -1
|
||||
num_generations: 7
|
||||
num_train_epochs: 1
|
||||
output_dir: data/Qwen2.5-7B-Open-R1-GRPO
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 4
|
||||
# push_to_hub: true
|
||||
report_to:
|
||||
- none
|
||||
reward_funcs:
|
||||
- accuracy
|
||||
- format
|
||||
reward_weights:
|
||||
- 1.0
|
||||
- 1.0
|
||||
save_strategy: "steps"
|
||||
save_steps: 10
|
||||
seed: 42
|
||||
warmup_ratio: 0.1
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user