@ -110,5 +110,28 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con
|
||||
--config recipes/Qwen2.5-7B-Instruct/grpo/config_demo.yaml --vllm_server_host x.x.x.x(主节点ip)
|
||||
```
|
||||
|
||||
基于Qwen2.5-7B-Instrct模型和MATH-lighteval数据集训练的相关结果图如下:
|
||||
|
||||

|
||||
|
||||
训练现象:
|
||||
|
||||
1、在30次迭代之后,accuracy_reward稳定到0.6以上,峰值约为0.8
|
||||
|
||||
2、在10次迭代之后,模型已基本完全学习到正确的格式`<think>...</think>\n<answer>...</answer>`,我们采集了部分结果如下:
|
||||
```shell
|
||||
Aha moment
|
||||
<think>...So,the number of band menbers is\(12\).To verify,we can check:- Initially,with 12 band members and 6rows of 2 members each,there is no remainder. ...</think><answer>The largest number of members the band could have is\(\boxed{78}\).</answer>
|
||||
```
|
||||
|
||||
评测结果:
|
||||
|
||||
我们基于MATH-500数据集对Qwen2.5-7B-Instruct和GRPO 30steps后的权重进行评测,得分从41.8提升至73:
|
||||
|
||||
| **模型** | **MATH-500得分** |
|
||||
|------------------------------------|----------------|
|
||||
| Qwen2.5-7B-Instruct | 41.8 |
|
||||
| Qwen2.5-7B-Instruct + GRPO 30steps | 73 |
|
||||
|
||||
## FQA
|
||||
- 如果出现 numpy 版本冲突,请安装 1.26.0 版本
|
||||
|
Before Width: | Height: | Size: 687 KiB After Width: | Height: | Size: 687 KiB |
@ -1,174 +0,0 @@
|
||||
# 基于昇腾NPU复现open-r1
|
||||
|
||||
open-r1项目是Hugging Face官方开源的对DeepSeek-R1模型进行完全开放式复现的项目,是当前的主流复现项目,其目的是构建DeepSeek-R1训练流程缺失的部分,以便每个人都能在此基础上构建复现R1,当前已经有23k+star数。
|
||||
|
||||
本项目的目的为基于昇腾NPU进行open-r1项目的适配和验证。
|
||||
|
||||

|
||||
|
||||
上图所示为open-r1项目中呈现的3个step,我们对其进行了适配复现:
|
||||
|
||||
step1:蒸馏复刻,使用DeepSeek-R1构造推理思维链数据,并使用小模型进行SFT,我们基于Qwen2.5-7B-Instruct模型和开源的Sky-T1_data_17k在昇腾NPU验证了step1的有效性。具体实验步骤可以参考文档:[在NPU上进行模型蒸馏和微调DeepSeek-R1-Distill系列模型](../../../docs/zh/best_practice/deepseek_r1.md)。
|
||||
|
||||
step2:通过GRPO算法复现R1-Zero流程。我们基于Qwen2.5-7B-Instrct模型在昇腾NPU上进行了验证,可以观察到reward在少数迭代之后快速上升的现象,并且观察到了Aha Moment。
|
||||
|
||||
step3:多阶段训练,从基础模型到RL调优,我们基于Qwen2.5-7B模型和`OpenR1-Math-220k`处理后的数据集进行了SFT与GRPO,在MATH-500上评测结果为:54.8->75.2->79.6。
|
||||
|
||||
下文为具体的环境依赖、执行过程和实验结果。
|
||||
|
||||
**注意:当前版本仍为在研版本,将会持续更新。**
|
||||
|
||||
|
||||
## 1、版本依赖
|
||||
|
||||
### 支持的设备
|
||||
- Atlas A2 训练系列 (Atlas 800T A2, Atlas 900 A2 PoD)
|
||||
|
||||
### 版本要求
|
||||
| 依赖 | 推荐版本 |
|
||||
|-----------|-------------------------------------------------------------------|
|
||||
| python | [3.10](https://www.python.org/downloads/) |
|
||||
| CANN | 在研版本* |
|
||||
| torch-npu | 在研版本* | |
|
||||
| torch | [2.5.1](https://github.com/pytorch/pytorch/releases/tag/v2.5.1) |
|
||||
|
||||
* *在研版本请联系相关人员获取,获得当前较优的性能。如果使用社区版本,可以参考文档[通过社区版本执行open-r1复现使用说明](./README_RC3.md)。
|
||||
|
||||
## 2、环境配置
|
||||
|
||||
### 步骤一、安装vLLM
|
||||
|
||||
```shell
|
||||
git clone -b v0.7.1 https://github.com/vllm-project/vllm.git
|
||||
cd vllm
|
||||
pip install -r requirements-build.txt
|
||||
VLLM_TARGET_DEVICE=empty pip install -e .
|
||||
```
|
||||
|
||||
### 步骤二、安装vllm-ascend
|
||||
|
||||
```shell
|
||||
git clone -b v0.7.1-dev https://github.com/vllm-project/vllm-ascend.git
|
||||
cd vllm-ascend
|
||||
git checkout e8131b99cf199f50a304e6e6fb125a1b95bcc92b
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### 步骤三、安装TRL
|
||||
|
||||
在openmind/examples/research/open_r1目录执行以下命令:
|
||||
```shell
|
||||
cd trl
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### 步骤四、安装open-r1
|
||||
|
||||
在openmind/examples/research/open_r1目录执行以下命令:
|
||||
```shell
|
||||
cd open-r1
|
||||
pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
## 3、执行open-r1中的step2:GRPO算法
|
||||
|
||||
在openmind/examples/research/open_r1目录执行以下命令:
|
||||
```shell
|
||||
cd open-r1
|
||||
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml --num_processes 7\
|
||||
src/open_r1/grpo.py \
|
||||
--config recipes/Qwen2.5-7B-Instruct/grpo/config_demo.yaml
|
||||
```
|
||||
|
||||
基于Qwen2.5-7B-Instrct模型和MATH-lighteval数据集训练的相关结果图如下:
|
||||
|
||||

|
||||
|
||||
训练现象:
|
||||
|
||||
1、在30次迭代之后,accuracy_reward稳定到0.6以上,峰值约为0.8
|
||||
|
||||
2、在10次迭代之后,模型已基本完全学习到正确的格式`<think>...</think>\n<answer>...</answer>`,我们采集了部分结果如下:
|
||||
```shell
|
||||
Aha moment
|
||||
<think>...So,the number of band menbers is\(12\).To verify,we can check:- Initially,with 12 band members and 6rows of 2 members each,there is no remainder. ...</think><answer>The largest number of members the band could have is\(\boxed{78}\).</answer>
|
||||
```
|
||||
|
||||
评测结果:
|
||||
|
||||
我们基于MATH-500数据集对Qwen2.5-7B-Instruct和GRPO 30steps后的权重进行评测,得分从41.8提升至73:
|
||||
|
||||
| **模型** | **MATH-500得分** |
|
||||
|------------------------------------|----------------|
|
||||
| Qwen2.5-7B-Instruct | 41.8 |
|
||||
| Qwen2.5-7B-Instruct + GRPO 30steps | 73 |
|
||||
|
||||
## 4、执行open-r1中的step3:SFT+GRPO算法
|
||||
|
||||
我们基于Qwen2.5-7B模型复现step3,实验结果和启动方式如下:
|
||||
|
||||
**步骤一 SFT**
|
||||
|
||||
我们使用openMind进行SFT过程。
|
||||
|
||||
1、准备数据集
|
||||
|
||||
SFT阶段使用的数据集为从`OpenR1-Math-220k`处理得到的数据集:[openmind/OpenR1-Math-220k_filtered_step3_SFT](https://modelers.cn/datasets/openmind/OpenR1-Math-220k_filtered_step3_SFT)
|
||||
|
||||
2、更新微调配置
|
||||
|
||||
- 微调配置为`examples/qwen2.5/train_sft_qwen2_5_7b_openr1.yaml`。
|
||||
- 若模型在本地,可将`model_id`改为`model_name_or_path`,并将对应值改为模型本地路径, 同时请在yaml文件中增加template字段,值可参见[此处](../../../docs/zh/basic_tutorial/train/train_params.md#模型数据配置模板)设定
|
||||
- 微调后的模型保存在`output_dir`下。
|
||||
- 若需要按照step保存checkpoint,可在yaml文件中添加参数`save_strategy: steps`。
|
||||
|
||||
3、启动微调
|
||||
```shell
|
||||
openmind-cli train openmind/examples/qwen2.5/train_sft_qwen2_5_7b_openr1.yaml
|
||||
```
|
||||
|
||||
4、评测结果
|
||||
|
||||
我们基于MATH-500对比了sft前后的评估数值(base模型加上few-shot1进行评估),结果如下:
|
||||
|
||||
| **模型**| **MATH-500得分**|
|
||||
|---------|----------------|
|
||||
| Qwen2.5-7B | 54.8|
|
||||
| Qwen2.5-7B + SFT | 75.2|
|
||||
|
||||
**步骤二 GRPO**
|
||||
|
||||
1、准备数据集
|
||||
|
||||
GRPO使用的数据集为从`OpenR1-Math-220k`过滤得到的数据集:[openmind/OpenR1-Math-220k_filtered_step3_GRPO](https://modelers.cn/datasets/openmind/OpenR1-Math-220k_filtered_step3_GRPO),通过以下命令将数据集下载到本地。
|
||||
```shell
|
||||
git clone https://modelers.cn/datasets/openmind/OpenR1-Math-220k_filtered_step3_GRPO.git
|
||||
```
|
||||
|
||||
2、更新微调配置
|
||||
|
||||
- 微调配置为`recipes/Qwen2.5-7B-step3/GRPO/config_demo.yaml`。
|
||||
- 需要将`model_name_or_path`和`dataset_name`改为模型和数据集的本地路径。
|
||||
- 模型保存在`output_dir`下。
|
||||
|
||||
3、启动GRPO训练
|
||||
|
||||
```shell
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml --num_processes 7\
|
||||
src/open_r1/grpo.py \
|
||||
--config recipes/Qwen2.5-1.5B-step3/GRPO/config_demo.yaml
|
||||
```
|
||||
|
||||
4、评测结果
|
||||
|
||||
| **模型** | **MATH-500得分** |
|
||||
|-------------------------|----------------|
|
||||
| Qwen2.5-7B | 54.8 |
|
||||
| Qwen2.5-7B + SFT | 75.2 |
|
||||
| Qwen2.5-7B + SFT + GRPO | 79.6 |
|
||||
|
||||
整个流程在MATH-500上的评分提升了24.8
|
||||
|
||||
## FQA
|
||||
- 如果出现 numpy 版本冲突,请安装 1.26.0 版本
|
@ -1,71 +0,0 @@
|
||||
# 通过社区版本执行open-r1复现
|
||||
|
||||
open-r1项目是huggingface官方开源的对DeepSeek-R1模型进行完全开放式复现的项目,是当前的主流复现项目,其目的是构建DeepSeek-R1训练流程缺失的部分,以便每个人都能在此基础上构建复现R1,当前已经有20k+star数。
|
||||
|
||||
昇腾已适配完成open-r1项目的重要步骤:打通R1-Zero的GRPO流程,同时支持通过VLLM等生态库实现训练过程中的数据生产,从而验证了通过昇腾训练出DeepSeek-R1-Zero以及DeepSeek-R1模型的可行性。
|
||||
|
||||
**注意**:当前版本仍为在研版本,将会持续快速更新
|
||||
|
||||
|
||||
## 环境配置
|
||||
|
||||
### 支持的设备
|
||||
- Atlas A2 训练系列 (Atlas 800T A2, Atlas 900 A2 PoD)
|
||||
|
||||
### 环境依赖
|
||||
| 依赖 | 推荐版本 |
|
||||
|-----------|----------------------------------------------------------------------------------------------------------|
|
||||
| Python | [3.10](https://www.python.org/downloads/) |
|
||||
| CANN | [8.0.beta1](https://www.hiascend.com/developer/download/community/result?module=cann&cann=8.0.0.beta1) |
|
||||
| torch-npu | [2.5.1rc1](https://gitee.com/ascend/pytorch/releases/tag/v6.0.0.alpha001-pytorch2.5.1) |
|
||||
| torch | [2.5.1](https://github.com/pytorch/pytorch/releases/tag/v2.5.1) |
|
||||
|
||||
### 安装vLLM
|
||||
|
||||
```shell
|
||||
git clone https://github.com/vllm-project/vllm.git -b v0.7.1
|
||||
cd vllm
|
||||
pip install -r requirements-build.txt
|
||||
VLLM_TARGET_DEVICE=empty pip install -e .
|
||||
```
|
||||
|
||||
### 安装vllm-ascend
|
||||
|
||||
```shell
|
||||
git clone https://github.com/vllm-project/vllm-ascend.git
|
||||
cd vllm-ascend
|
||||
git checkout 36991b2052db0b33c0f2b84021768a588360b735
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### 安装trl
|
||||
|
||||
在当前目录执行以下命令:
|
||||
```shell
|
||||
cd trl
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### 安装open-r1
|
||||
|
||||
在当前目录执行以下命令:
|
||||
```shell
|
||||
cd open-r1
|
||||
pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
## 执行GRPO训练
|
||||
|
||||
```shell
|
||||
cd open-r1
|
||||
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml --num_processes 7\
|
||||
src/open_r1/grpo.py \
|
||||
--config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo.yaml
|
||||
```
|
||||
|
||||
具体实验效果将在后续持续补充,同时我们也将持续进行性能调优并构建open-r1 step3流程。我们将在本文持续更新,欢迎关注并star。
|
||||
|
||||
|
||||
## FQA
|
||||
- 如果出现 numpy 版本冲突,请安装 1.26.0 版本
|
Binary file not shown.
Before Width: | Height: | Size: 80 KiB |
@ -1,201 +0,0 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
@ -1,44 +0,0 @@
|
||||
.PHONY: style quality
|
||||
|
||||
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
|
||||
export PYTHONPATH = src
|
||||
|
||||
check_dirs := src tests
|
||||
|
||||
style:
|
||||
ruff format --line-length 119 --target-version py310 $(check_dirs) setup.py
|
||||
isort $(check_dirs) setup.py
|
||||
|
||||
quality:
|
||||
ruff check --line-length 119 --target-version py310 $(check_dirs) setup.py
|
||||
isort --check-only $(check_dirs) setup.py
|
||||
flake8 --max-line-length 119 $(check_dirs) setup.py
|
||||
|
||||
test:
|
||||
pytest -sv tests/
|
||||
|
||||
# Evaluation
|
||||
|
||||
evaluate:
|
||||
$(eval PARALLEL_ARGS := $(if $(PARALLEL),$(shell \
|
||||
if [ "$(PARALLEL)" = "data" ]; then \
|
||||
echo "data_parallel_size=$(NUM_GPUS)"; \
|
||||
elif [ "$(PARALLEL)" = "tensor" ]; then \
|
||||
echo "tensor_parallel_size=$(NUM_GPUS)"; \
|
||||
fi \
|
||||
),))
|
||||
$(if $(filter tensor,$(PARALLEL)),export VLLM_WORKER_MULTIPROC_METHOD=spawn &&,) \
|
||||
MODEL_ARGS="pretrained=$(MODEL),dtype=bfloat16,$(PARALLEL_ARGS),max_model_length=32768,gpu_memory_utilisation=0.8" && \
|
||||
lighteval vllm $$MODEL_ARGS "custom|$(TASK)|0|0" \
|
||||
--custom-tasks src/open_r1/evaluate.py \
|
||||
--use-chat-template \
|
||||
--system-prompt="Please reason step by step, and put your final answer within \boxed{}." \
|
||||
--output-dir data/evals/$(MODEL)
|
||||
|
||||
# Example usage:
|
||||
# Single GPU:
|
||||
# make evaluate MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-32B TASK=aime24
|
||||
# Data parallel:
|
||||
# make evaluate MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-32B TASK=aime24 PARALLEL=data NUM_GPUS=8
|
||||
# Tensor parallel:
|
||||
# make evaluate MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-32B TASK=aime24 PARALLEL=tensor NUM_GPUS=8
|
@ -1,503 +0,0 @@
|
||||
# Open R1
|
||||
|
||||
*A fully open reproduction of DeepSeek-R1. This repo is a work in progress, let's build it together!*
|
||||
|
||||
**Table of Contents**
|
||||
1. [Overview](#overview)
|
||||
2. [Plan of attack](#plan-of-attack)
|
||||
3. [Installation](#installation)
|
||||
4. [Training models](#training-models)
|
||||
- [SFT](#sft)
|
||||
- [GRPO](#grpo)
|
||||
5. [Evaluating models](#evaluating-models)
|
||||
6. [Reproducing Deepseek's evaluation results](#reproducing-deepseeks-evaluation-results)
|
||||
7. [Data generation](#data-generation)
|
||||
- [Generate data from a smol distilled R1 model](#generate-data-from-a-smol-distilled-r1-model)
|
||||
- [Generate data from DeepSeek-R1](#generate-data-from-deepseek-r1)
|
||||
8. [Contributing](#contributing)
|
||||
|
||||
## Overview
|
||||
|
||||
The goal of this repo is to build the missing pieces of the R1 pipeline such that everybody can reproduce and build on top of it. The project is simple by design and mostly consists of:
|
||||
|
||||
|
||||
- `src/open_r1`: contains the scripts to train and evaluate models as well as generate synthetic data:
|
||||
- `grpo.py`: trains a model with GRPO on a given dataset.
|
||||
- `sft.py`: performs a simple SFT of a model on a dataset.
|
||||
- `evaluate.py`: evaluates a model on the R1 benchmarks.
|
||||
- `generate.py`: generates synthetic data from a model using [Distilabel](https://github.com/argilla-io/distilabel).
|
||||
- `Makefile`: contains easy-to-run commands for each step in the R1 pipeline leveraging the scripts above.
|
||||
|
||||
### Plan of attack
|
||||
|
||||
We will use the DeepSeek-R1 [tech report](https://github.com/deepseek-ai/DeepSeek-R1) as a guide, which can roughly be broken down into three main steps:
|
||||
|
||||
* Step 1: replicate the R1-Distill models by distilling a high-quality corpus from DeepSeek-R1.
|
||||
* Step 2: replicate the pure RL pipeline that DeepSeek used to create R1-Zero. This will likely involve curating new, large-scale datasets for math, reasoning, and code.
|
||||
* Step 3: show we can go from base model to RL-tuned via multi-stage training.
|
||||
|
||||
<center>
|
||||
<img src="assets/plan-of-attack.png" width="500">
|
||||
</center>
|
||||
|
||||
|
||||
## Installation
|
||||
|
||||
> [!CAUTION]
|
||||
> Libraries rely on CUDA 12.4. If you see errors related to segmentation faults, double check the version your system is running with `nvcc --version`.
|
||||
|
||||
To run the code in this project, first, create a Python virtual environment using e.g. `uv`.
|
||||
To install `uv`, follow the [UV Installation Guide](https://docs.astral.sh/uv/getting-started/installation/).
|
||||
|
||||
|
||||
```shell
|
||||
uv venv openr1 --python 3.11 && source openr1/bin/activate && uv pip install --upgrade pip --link-mode=copy
|
||||
```
|
||||
|
||||
Next, install vLLM:
|
||||
|
||||
```shell
|
||||
uv pip install vllm==0.7.2 --link-mode=copy
|
||||
```
|
||||
|
||||
This will also install PyTorch `v2.5.1` and it is **very important** to use this version since the vLLM binaries are compiled for it. You can then install the remaining dependencies for your specific use case via `pip install -e .[LIST OF MODES]`. For most contributors, we recommend:
|
||||
|
||||
```shell
|
||||
GIT_LFS_SKIP_SMUDGE=1 uv pip install -e ".[dev]" --link-mode=copy
|
||||
```
|
||||
|
||||
Next, log into your Hugging Face and Weights and Biases accounts as follows:
|
||||
|
||||
```shell
|
||||
huggingface-cli login
|
||||
wandb login
|
||||
```
|
||||
|
||||
Finally, check whether your system has Git LFS installed so that you can load and push models/datasets to the Hugging Face Hub:
|
||||
|
||||
```shell
|
||||
git-lfs --version
|
||||
```
|
||||
|
||||
If it isn't installed, run:
|
||||
|
||||
```shell
|
||||
sudo apt-get install git-lfs
|
||||
```
|
||||
|
||||
## Training models
|
||||
|
||||
We support training models with either DDP or DeepSpeed (ZeRO-2 and ZeRO-3). For example, to run SFT on a dataset distilled from DeepSeek-R1 with reasoning traces such as [open-r1/OpenR1-Math-220k](https://huggingface.co/datasets/open-r1/OpenR1-Math-220k), run:
|
||||
|
||||
```shell
|
||||
# Train via command line
|
||||
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
|
||||
--model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
|
||||
--dataset_name open-r1/OpenR1-Math-220k \
|
||||
--learning_rate 1.0e-5 \
|
||||
--num_train_epochs 1 \
|
||||
--packing \
|
||||
--max_seq_length 16384 \
|
||||
--per_device_train_batch_size 16 \
|
||||
--gradient_checkpointing \
|
||||
--bf16 \
|
||||
--output_dir data/Qwen2.5-1.5B-Open-R1-Distill
|
||||
|
||||
# Train via YAML config
|
||||
accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
|
||||
--config recipes/Qwen2.5-1.5B-Instruct/sft/config_demo.yaml
|
||||
```
|
||||
|
||||
Currently, the following tasks are supported:
|
||||
|
||||
* Supervised Fine-Tuning `sft`
|
||||
* Group Relative Policy Optimization `grpo`
|
||||
|
||||
> [!TIP]
|
||||
> If you scale up/down the number of GPUs, we recommend also scaling up the per-device batch size or number of gradient accumulation steps to keep the global batch size constant.
|
||||
|
||||
By default, these scripts will push each model to your Hugging Face Hub username, i.e. `{username}/{model_name}-{task}`. You can override the parameters in each YAML config by appending them to the command as follows:
|
||||
|
||||
```shell
|
||||
# Change batch size, number of epochs etc
|
||||
accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
|
||||
--config recipes/Qwen2.5-1.5B-Instruct/sft/config_demo.yaml
|
||||
--per_device_train_batch_size=1 --num_train_epochs=5
|
||||
```
|
||||
|
||||
If you also wish to override the Weights and Biases default settings, you can do so as follows:
|
||||
|
||||
```shell
|
||||
accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
|
||||
--config recipes/Qwen2.5-1.5B-Instruct/sft/config_demo.yaml
|
||||
--wandb_entity huggingface --wandb_project open-r1 --run_name Qwen2.5-1.5B-GRPO
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> The training commands below are configured for a node of 8 x H100s (80GB). For different hardware and topologies, you may need to tune the batch size and number of gradient accumulation steps.
|
||||
|
||||
### SFT
|
||||
|
||||
To run SFT on a dataset distilled from DeepSeek-R1 with reasoning traces such as [open-r1/OpenR1-Math-220k](https://huggingface.co/datasets/open-r1/OpenR1-Math-220k), run:
|
||||
|
||||
```shell
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero3.yaml \
|
||||
src/open_r1/sft.py \
|
||||
--config recipes/Qwen2.5-1.5B-Instruct/sft/config_demo.yaml
|
||||
```
|
||||
|
||||
### GRPO
|
||||
|
||||
To train via the GRPO trainer, we use one GPU to run vLLM for faster generation and the remaining GPUs for training. For example, one a node with 8 GPUs, set `--num_processes` to override the default value in the `accelerate` configs:
|
||||
|
||||
```shell
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml \
|
||||
--num_processes=7 src/open_r1/grpo.py \
|
||||
--config recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> The chat template used in the distilled DeepSeek models omits the contents of the reasoning block within the `<think>` and `</think>` tags. It also prefills the assistant response with `<think>` which interferes with the format reward function. To handle that, it is important to override the chat template as done in e.g. [recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml](./recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml).
|
||||
|
||||
|
||||
We provide a minimal reproducible experiment using GRPO for mathematical reasoning, referencing the approach from [SimpleRL-Reason](https://hkust-nlp.notion.site/simplerl-reason) which uses a 7B model trained on 8K examples. Running this on 8 H100 80G GPU takes about 3 hours:
|
||||
|
||||
```shell
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml \
|
||||
--num_processes=7 src/open_r1/grpo.py \
|
||||
--config recipes/Qwen2.5-Math-7B/grpo/config_simple_rl.yaml
|
||||
```
|
||||
|
||||
Our final [model](https://huggingface.co/Dongwei/Qwen-2.5-7B_Base_Math_smalllr), while using different learning rates, loss functions and reward structures, achieves 69.4% accuracy on MATH-500, demonstrating a 17%+ improvement over the base model.
|
||||
|
||||
#### 👨💻 Training with a code interpreter
|
||||
|
||||
We provide a `code` reward function for executing code generated by the policy during training. Currently, this reward function targets code contests like [Codeforces](https://codeforces.com), where solutions are executed against a set of test cases and the overall success rate is returned as the final reward. To ensure safe execution, we use [E2B](https://e2b.dev) sandboxes, which are fast and cheap to run. To use this reward function, first install the necessary dependencies:
|
||||
|
||||
```shell
|
||||
uv pip install -e '.[code]'
|
||||
```
|
||||
|
||||
Then create a `.env` file and place an API token from E2B within it:
|
||||
|
||||
```
|
||||
E2B_API_KEY="e2b_xxx"
|
||||
```
|
||||
|
||||
Then make sure your dataset contains a `verification_info` column with the following schema (adopted from PrimeIntellect's excellent [datasets](https://huggingface.co/collections/PrimeIntellect/synthetic-1-67a2c399cfdd6c9f7fae0c37) of verifiable problems):
|
||||
|
||||
```python
|
||||
{
|
||||
"language": "python",
|
||||
"test_cases": [
|
||||
{
|
||||
"input": "4\n4\n0001\n1000\n0011\n0111\n3\n010\n101\n0\n2\n00000\n00001\n4\n01\n001\n0001\n00001\n",
|
||||
"output": "1\n3 \n-1\n0\n\n2\n1 2 \n",
|
||||
"type": "stdin_stdout",
|
||||
}
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
For example, to train a smol model on Python problems, run:
|
||||
|
||||
```shell
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml \
|
||||
--num_processes=7 src/open_r1/grpo.py \
|
||||
--config recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code.yaml
|
||||
```
|
||||
|
||||
### Launching jobs on a Slurm cluster
|
||||
|
||||
If you have access to a Slurm cluster, we provide a `slurm/train.slurm` script that will automatically queue training jobs for you. Here's how you can use it:
|
||||
|
||||
```shell
|
||||
sbatch --job-name=open_r1 --nodes=1 slurm/train.slurm {model_name} {task} {config_suffix} {accelerator}
|
||||
```
|
||||
|
||||
Here `{model_name}` and `{task}` are defined as above, while `{config_suffix}` refers to the specific config and `{accelerator}` refers to the choice of 🤗 Accelerate config in `recipes/accelerate_configs`. If you wish to override the default config parameters, you can provide them by appending a space-separated string like `'--arg1=value1 --arg2=value2'`. Here's a concrete example to run SFT on 1 node of 8 GPUs:
|
||||
|
||||
```shell
|
||||
# Launch on Slurm and override default hyperparameters
|
||||
sbatch --job-name=open_r1 --nodes=1 slurm/train.slurm Qwen2.5-1.5B-Instruct sft demo zero3 '--per_device_train_batch_size=1 --num_train_epochs=5'
|
||||
```
|
||||
|
||||
You can scale the number of nodes by increasing the `--nodes` flag.
|
||||
|
||||
> [!NOTE]
|
||||
> The configuration in `slurm/train.slurm` is optimised for the Hugging Face Compute Cluster and may require tweaking to be adapted to your own compute nodes.
|
||||
|
||||
## Evaluating models
|
||||
|
||||
We use `lighteval` to evaluate models, with custom tasks defined in `src/open_r1/evaluate.py`. For models which fit on a single GPU, run:
|
||||
|
||||
```shell
|
||||
MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
|
||||
MODEL_ARGS="pretrained=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilisation=0.8"
|
||||
OUTPUT_DIR=data/evals/$MODEL
|
||||
|
||||
# AIME 2024
|
||||
TASK=aime24
|
||||
lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \
|
||||
--custom-tasks src/open_r1/evaluate.py \
|
||||
--use-chat-template \
|
||||
--output-dir $OUTPUT_DIR
|
||||
|
||||
# MATH-500
|
||||
TASK=math_500
|
||||
lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \
|
||||
--custom-tasks src/open_r1/evaluate.py \
|
||||
--use-chat-template \
|
||||
--output-dir $OUTPUT_DIR
|
||||
|
||||
# GPQA Diamond
|
||||
TASK=gpqa:diamond
|
||||
lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \
|
||||
--custom-tasks src/open_r1/evaluate.py \
|
||||
--use-chat-template \
|
||||
--output-dir $OUTPUT_DIR
|
||||
```
|
||||
|
||||
> [!IMPORTANT]
|
||||
> You must set `max_model_length=32768` in the `vllm` command to align with the `generation_size` we define per eval. Without this, `lighteval` will throw an error.
|
||||
|
||||
To increase throughput across multiple GPUs, use _data parallel_ as follows:
|
||||
|
||||
```shell
|
||||
NUM_GPUS=8
|
||||
MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
|
||||
MODEL_ARGS="pretrained=$MODEL,dtype=bfloat16,data_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilisation=0.8"
|
||||
TASK=aime24
|
||||
OUTPUT_DIR=data/evals/$MODEL
|
||||
|
||||
lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \
|
||||
--custom-tasks src/open_r1/evaluate.py \
|
||||
--use-chat-template \
|
||||
--output-dir $OUTPUT_DIR
|
||||
```
|
||||
|
||||
For large models which require sharding across GPUs, use _tensor parallel_ and run:
|
||||
|
||||
```shell
|
||||
NUM_GPUS=8
|
||||
MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-32B
|
||||
MODEL_ARGS="pretrained=$MODEL,dtype=bfloat16,tensor_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilisation=0.8"
|
||||
TASK=aime24
|
||||
OUTPUT_DIR=data/evals/$MODEL
|
||||
|
||||
export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
lighteval vllm $MODEL_ARGS "custom|$TASK|0|0" \
|
||||
--custom-tasks src/open_r1/evaluate.py \
|
||||
--use-chat-template \
|
||||
--output-dir $OUTPUT_DIR
|
||||
```
|
||||
|
||||
You can also launch an evaluation with `make evaluate`, specifying the model, task, and optionally the parallelism technique and number of GPUs.
|
||||
|
||||
To evaluate on a single GPU:
|
||||
|
||||
```shell
|
||||
make evaluate MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-32B TASK=aime24
|
||||
```
|
||||
|
||||
To use Data Parallelism:
|
||||
|
||||
```shell
|
||||
make evaluate MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-32B TASK=aime24 PARALLEL=data NUM_GPUS=8
|
||||
```
|
||||
|
||||
To use Tensor Parallelism:
|
||||
|
||||
```shell
|
||||
make evaluate MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-32B TASK=aime24 PARALLEL=tensor NUM_GPUS=8
|
||||
```
|
||||
|
||||
## Reproducing Deepseek's evaluation results
|
||||
|
||||
> [!NOTE]
|
||||
> The DeepSeek-R1 paper uses sampling with a temperature of 0.6, a top-p value of 0.95, and 64 responses per query to estimate `pass@1`. Below, we report the results from greedy decoding, which likely explains the small 1-3σ discrepancies between our results and theirs.
|
||||
|
||||
### MATH-500
|
||||
|
||||
We are able to reproduce Deepseek's reported results on the MATH-500 benchmark within ~1-3 standard deviations:
|
||||
|
||||
| Model | MATH-500 (🤗 LightEval) | MATH-500 (DeepSeek Reported) |
|
||||
|:------------------------------|:-----------------------:|:----------------------------:|
|
||||
| DeepSeek-R1-Distill-Qwen-1.5B | 81.2 | 83.9 |
|
||||
| DeepSeek-R1-Distill-Qwen-7B | 91.8 | 92.8 |
|
||||
| DeepSeek-R1-Distill-Qwen-14B | 94.2 | 93.9 |
|
||||
| DeepSeek-R1-Distill-Qwen-32B | 95.0 | 94.3 |
|
||||
| DeepSeek-R1-Distill-Llama-8B | 85.4 | 89.1 |
|
||||
| DeepSeek-R1-Distill-Llama-70B | 93.4 | 94.5 |
|
||||
|
||||
To reproduce these results use the following command:
|
||||
|
||||
```shell
|
||||
NUM_GPUS=1 # Set to 8 for 32B and 70B models
|
||||
MODEL=deepseek-ai/{model_name}
|
||||
MODEL_ARGS="pretrained=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilisation=0.8,tensor_parallel_size=$NUM_GPUS"
|
||||
OUTPUT_DIR=data/evals/$MODEL
|
||||
|
||||
lighteval vllm $MODEL_ARGS "custom|math_500|0|0" \
|
||||
--custom-tasks src/open_r1/evaluate.py \
|
||||
--use-chat-template \
|
||||
--output-dir $OUTPUT_DIR
|
||||
```
|
||||
|
||||
Alternatively, you can launch Slurm jobs as follows:
|
||||
|
||||
```shell
|
||||
python scripts/run_benchmarks.py --model-id={model_id} --benchmarks math_500
|
||||
```
|
||||
|
||||
### GPQA Diamond
|
||||
|
||||
We are able to reproduce Deepseek's reported results on the GPQA Diamond benchmark within ~1-3 standard deviations:
|
||||
|
||||
| Model | GPQA Diamond (🤗 LightEval) | GPQA Diamond (DeepSeek Reported) |
|
||||
|:------------------------------|:---------------------------:|:--------------------------------:|
|
||||
| DeepSeek-R1-Distill-Qwen-1.5B | 33.3 | 33.8 |
|
||||
| DeepSeek-R1-Distill-Qwen-7B | 48.4 | 49.1 |
|
||||
| DeepSeek-R1-Distill-Qwen-14B | 55.6 | 59.1 |
|
||||
| DeepSeek-R1-Distill-Qwen-32B | 58.6 | 62.1 |
|
||||
| DeepSeek-R1-Distill-Llama-8B | 51.0 | 49.0 |
|
||||
| DeepSeek-R1-Distill-Llama-70B | 65.2 | 65.2 |
|
||||
|
||||
To reproduce these results use the following command:
|
||||
|
||||
```shell
|
||||
NUM_GPUS=1 # Set to 8 for 32B and 70B models
|
||||
MODEL=deepseek-ai/{model_name}
|
||||
MODEL_ARGS="pretrained=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilisation=0.8,tensor_parallel_size=$NUM_GPUS"
|
||||
OUTPUT_DIR=data/evals/$MODEL
|
||||
|
||||
lighteval vllm $MODEL_ARGS "custom|gpqa:diamond|0|0" \
|
||||
--custom-tasks src/open_r1/evaluate.py \
|
||||
--use-chat-template \
|
||||
--output-dir $OUTPUT_DIR
|
||||
```
|
||||
|
||||
```shell
|
||||
python scripts/run_benchmarks.py --model-id={model_id} --benchmarks gpqa
|
||||
```
|
||||
|
||||
### LiveCodeBench
|
||||
|
||||
We are able to reproduce Deepseek's reported results on the LiveCodeBench code generation benchmark within ~1-3 standard deviations:
|
||||
|
||||
| Model | LiveCodeBench (🤗 LightEval) | GPQA Diamond (DeepSeek Reported) |
|
||||
|:------------------------------|:---------------------------:|:--------------------------------:|
|
||||
| DeepSeek-R1-Distill-Qwen-1.5B | 16.3 | 16.9 |
|
||||
| DeepSeek-R1-Distill-Qwen-7B | 36.6 | 37.6 |
|
||||
| DeepSeek-R1-Distill-Qwen-14B | 51.5 | 53.1 |
|
||||
| DeepSeek-R1-Distill-Qwen-32B | 56.6 | 57.2 |
|
||||
| DeepSeek-R1-Distill-Llama-8B | 37.0 | 39.6 |
|
||||
| DeepSeek-R1-Distill-Llama-70B | 54.5 | 57.5 |
|
||||
|
||||
To reproduce these results use the following command:
|
||||
|
||||
```shell
|
||||
NUM_GPUS=1 # Set to 8 for 32B and 70B models, or data_parallel_size=8 with the smaller models for speed
|
||||
MODEL=deepseek-ai/{model_name}
|
||||
MODEL_ARGS="pretrained=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilisation=0.8,tensor_parallel_size=$NUM_GPUS,generation_parameters={temperature:0.6,top_p:0.95}"
|
||||
OUTPUT_DIR=data/evals/$MODEL
|
||||
|
||||
lighteval vllm $MODEL_ARGS "extended|lcb:codegeneration|0|0" \
|
||||
--use-chat-template \
|
||||
--output-dir $OUTPUT_DIR
|
||||
```
|
||||
|
||||
```shell
|
||||
python scripts/run_benchmarks.py --model-id={model_id} --benchmarks lcb
|
||||
```
|
||||
|
||||
## Data generation
|
||||
|
||||
### Generate data from a smol distilled R1 model
|
||||
|
||||
The following example can be run in 1xH100.
|
||||
First install the following dependencies:
|
||||
|
||||
```shell
|
||||
uv pip install "distilabel[vllm]>=1.5.2"
|
||||
```
|
||||
|
||||
Now save the following snippet into a file named `pipeline.py` and run it with `python pipeline.py`. It will generate 4 outputs for each of the 10 examples (change the username for the repository to your org/user name):
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from distilabel.models import vLLM
|
||||
from distilabel.pipeline import Pipeline
|
||||
from distilabel.steps.tasks import TextGeneration
|
||||
|
||||
|
||||
prompt_template = """\
|
||||
You will be given a problem. Please reason step by step, and put your final answer within \boxed{}:
|
||||
{{ instruction }}"""
|
||||
|
||||
dataset = load_dataset("AI-MO/NuminaMath-TIR", split="train").select(range(10))
|
||||
|
||||
model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" # Exchange with another smol distilled r1
|
||||
|
||||
with Pipeline(
|
||||
name="distill-qwen-7b-r1",
|
||||
description="A pipeline to generate data from a distilled r1 model",
|
||||
) as pipeline:
|
||||
|
||||
llm = vLLM(
|
||||
model=model_id,
|
||||
tokenizer=model_id,
|
||||
extra_kwargs={
|
||||
"tensor_parallel_size": 1,
|
||||
"max_model_len": 8192,
|
||||
},
|
||||
generation_kwargs={
|
||||
"temperature": 0.6,
|
||||
"max_new_tokens": 8192,
|
||||
},
|
||||
)
|
||||
prompt_column = "problem"
|
||||
text_generation = TextGeneration(
|
||||
llm=llm,
|
||||
template=prompt_template,
|
||||
num_generations=4,
|
||||
input_mappings={"instruction": prompt_column} if prompt_column is not None else {}
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
distiset = pipeline.run(dataset=dataset)
|
||||
distiset.push_to_hub(repo_id="username/numina-deepseek-r1-qwen-7b")
|
||||
```
|
||||
|
||||
Take a look at the sample dataset at [HuggingFaceH4/numina-deepseek-r1-qwen-7b](https://huggingface.co/datasets/HuggingFaceH4/numina-deepseek-r1-qwen-7b).
|
||||
|
||||
|
||||
### Generate data from DeepSeek-R1
|
||||
|
||||
To run the bigger DeepSeek-R1, we used 2 nodes, each with 8×H100 GPUs using the slurm file present in this repo at `slurm/generate.slurm`. First, install the dependencies:
|
||||
|
||||
(for now we need to install the vllm dev wheel that [fixes the R1 cuda graph capture](https://github.com/vllm-project/vllm/commits/221d388cc5a836fa189305785ed7e887cea8b510/csrc/moe/moe_align_sum_kernels.cu))
|
||||
```shell
|
||||
pip install https://wheels.vllm.ai/221d388cc5a836fa189305785ed7e887cea8b510/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
uv pip install "distilabel[vllm,ray,openai]>=1.5.2"
|
||||
```
|
||||
|
||||
And then run the following command:
|
||||
|
||||
```shell
|
||||
sbatch slurm/generate.slurm \
|
||||
--hf-dataset AI-MO/NuminaMath-TIR \
|
||||
--temperature 0.6 \
|
||||
--prompt-column problem \
|
||||
--model deepseek-ai/DeepSeek-R1 \
|
||||
--hf-output-dataset username/r1-dataset
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> While the job is running, you can setup an SSH tunnel through the cluster login node to access the Ray dashboard from your computer running `ssh -L 8265:ray_ip_head_node:8265 <login_node>`, then browsing `http://localhost:8265`
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are welcome. Please refer to https://github.com/huggingface/open-r1/issues/23.
|
Binary file not shown.
Before Width: | Height: | Size: 371 KiB |
@ -1,58 +0,0 @@
|
||||
# Model arguments
|
||||
model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
# Data training arguments
|
||||
# We edit the DeepSeek chat template to ensure (a) the reasoning block within <think> and </think> is included in the completion and (b) the <think> tag is not part of the prefill so that the format reward works
|
||||
chat_template: "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %}"
|
||||
dataset_name: open-r1/OpenR1-Math-220k
|
||||
dataset_configs:
|
||||
- default
|
||||
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
|
||||
|
||||
# GRPO trainer config
|
||||
bf16: true
|
||||
use_vllm: true
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.7
|
||||
do_eval: false
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: DeepSeek-R1-Distill-Qwen-1.5B-GRPO
|
||||
hub_strategy: every_save
|
||||
learning_rate: 1.0e-06
|
||||
log_completions: true
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine_with_min_lr
|
||||
lr_scheduler_kwargs:
|
||||
min_lr_rate: 0.1
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 2048
|
||||
max_steps: -1
|
||||
num_generations: 16
|
||||
num_train_epochs: 1
|
||||
output_dir: data/DeepSeek-R1-Distill-Qwen-1.5B-GRPO
|
||||
overwrite_output_dir: true
|
||||
per_device_eval_batch_size: 16
|
||||
per_device_train_batch_size: 16
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- wandb
|
||||
reward_funcs:
|
||||
- accuracy
|
||||
- format
|
||||
reward_weights:
|
||||
- 1.0
|
||||
- 1.0
|
||||
save_strategy: "epoch"
|
||||
save_total_limit: 1
|
||||
seed: 42
|
||||
temperature: 0.7
|
||||
warmup_ratio: 0.1
|
@ -1,44 +0,0 @@
|
||||
# To start the training, run the following command:
|
||||
# sbatch -N 4 --job-name=mistral_sft slurm/train.slurm Mistral-Small-24B-Instruct-2501 sft numina zero3
|
||||
|
||||
model_name_or_path: mistralai/Mistral-Small-24B-Instruct-2501
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
# Data training arguments
|
||||
# dataset_name: yentinglin/s1K-1.1-trl-format
|
||||
dataset_name: yentinglin/OpenR1-Math-220k-trl-format
|
||||
dataset_configs:
|
||||
- all
|
||||
preprocessing_num_workers: 8
|
||||
|
||||
# SFT trainer config
|
||||
bf16: true
|
||||
do_eval: true
|
||||
eval_strategy: no
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: Mistral-Small-24B-Instruct-2501-Open-R1-Distill
|
||||
hub_strategy: every_save
|
||||
learning_rate: 2.0e-05
|
||||
log_level: info
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine
|
||||
packing: true
|
||||
max_seq_length: 32768
|
||||
max_steps: -1
|
||||
num_train_epochs: 5
|
||||
output_dir: data/Mistral-Small-24B-Instruct-2501-Open-R1-Distill
|
||||
overwrite_output_dir: true
|
||||
per_device_eval_batch_size: 1
|
||||
per_device_train_batch_size: 1
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- wandb
|
||||
save_strategy: epoch
|
||||
seed: 42
|
||||
warmup_ratio: 0.1
|
@ -1,53 +0,0 @@
|
||||
# Model arguments
|
||||
model_name_or_path: Qwen/Qwen2.5-1.5B-Instruct
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: eager
|
||||
|
||||
# Data training arguments
|
||||
dataset_name: AI-MO/NuminaMath-TIR
|
||||
dataset_configs:
|
||||
- default
|
||||
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
|
||||
|
||||
# GRPO trainer config
|
||||
bf16: true
|
||||
use_vllm: true
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.7
|
||||
do_eval: false
|
||||
gradient_accumulation_steps: 16
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
# hub_model_id: Qwen2.5-1.5B-Open-R1-GRPO
|
||||
# hub_strategy: every_save
|
||||
learning_rate: 2.0e-05
|
||||
log_completions: true
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 1024
|
||||
max_steps: -1
|
||||
num_generations: 7
|
||||
num_train_epochs: 1
|
||||
output_dir: data/Qwen2.5-1.5B-Open-R1-GRPO
|
||||
overwrite_output_dir: true
|
||||
per_device_eval_batch_size: 4
|
||||
per_device_train_batch_size: 2
|
||||
push_to_hub: false
|
||||
report_to:
|
||||
- none
|
||||
reward_funcs:
|
||||
- accuracy
|
||||
- format
|
||||
reward_weights:
|
||||
- 1.0
|
||||
- 1.0
|
||||
save_strategy: "epoch"
|
||||
save_total_limit: 1
|
||||
seed: 42
|
||||
warmup_ratio: 0.1
|
@ -1,57 +0,0 @@
|
||||
# Model arguments
|
||||
model_name_or_path: Qwen/Qwen2.5-1.5B-Instruct
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
# Data training arguments
|
||||
dataset_name: open-r1/verifiable-coding-problems-python-10k
|
||||
dataset_configs:
|
||||
- default
|
||||
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
|
||||
|
||||
# GRPO trainer config
|
||||
beta: 0.01
|
||||
bf16: true
|
||||
use_vllm: true
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.9
|
||||
do_eval: false
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: Qwen2.5-1.5B-Open-R1-Code-GRPO
|
||||
hub_strategy: every_save
|
||||
learning_rate: 5.0e-06
|
||||
log_completions: true
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine_with_min_lr
|
||||
lr_scheduler_kwargs:
|
||||
min_lr_rate: 0.1
|
||||
max_prompt_length: 1024
|
||||
max_completion_length: 2048
|
||||
max_steps: 500
|
||||
num_generations: 14
|
||||
num_train_epochs: 1
|
||||
output_dir: data/Qwen2.5-1.5B-Open-R1-Code-GRPO
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 16
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- wandb
|
||||
reward_funcs:
|
||||
- code
|
||||
- format
|
||||
reward_weights:
|
||||
- 1.0
|
||||
- 0.1
|
||||
save_strategy: "steps"
|
||||
save_steps: 50
|
||||
save_total_limit: 1
|
||||
seed: 42
|
||||
temperature: 1.0
|
||||
warmup_ratio: 0.03
|
@ -1,46 +0,0 @@
|
||||
# Model arguments
|
||||
model_name_or_path: Qwen/Qwen2.5-1.5B-Instruct
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: flash_attention_2
|
||||
|
||||
# Data training arguments
|
||||
dataset_name: open-r1/OpenR1-Math-220k
|
||||
dataset_configs:
|
||||
- default
|
||||
dataset_num_proc: 48
|
||||
|
||||
# SFT trainer config
|
||||
bf16: true
|
||||
do_eval: false
|
||||
eval_strategy: 'no'
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
hub_model_id: Qwen2.5-1.5B-Open-R1-Distill
|
||||
hub_strategy: every_save
|
||||
learning_rate: 5.0e-05
|
||||
log_level: info
|
||||
logging_steps: 5
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine_with_min_lr
|
||||
lr_scheduler_kwargs:
|
||||
min_lr_rate: 0.1
|
||||
packing: true
|
||||
max_seq_length: 16384
|
||||
max_steps: -1
|
||||
num_train_epochs: 1
|
||||
output_dir: data/Qwen2.5-1.5B-Open-R1-Distill
|
||||
overwrite_output_dir: true
|
||||
per_device_eval_batch_size: 16
|
||||
per_device_train_batch_size: 16
|
||||
push_to_hub: true
|
||||
report_to:
|
||||
- wandb
|
||||
save_strategy: "steps"
|
||||
save_steps: 100
|
||||
save_total_limit: 1
|
||||
seed: 42
|
||||
use_liger: true
|
||||
warmup_ratio: 0.05
|
@ -1,41 +0,0 @@
|
||||
# Model arguments
|
||||
model_name_or_path: Qwen/Qwen2.5-1.5B-Instruct
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: eager
|
||||
|
||||
# Data training arguments
|
||||
dataset_name: HuggingFaceH4/Bespoke-Stratos-17k
|
||||
dataset_configs:
|
||||
- all
|
||||
preprocessing_num_workers: 8
|
||||
|
||||
# SFT trainer config
|
||||
bf16: true
|
||||
do_eval: true
|
||||
eval_strategy: steps
|
||||
eval_steps: 100
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
learning_rate: 2.0e-05
|
||||
log_level: info
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine
|
||||
packing: true
|
||||
max_length: 4096
|
||||
max_steps: -1
|
||||
num_train_epochs: 6
|
||||
output_dir: data/Qwen2.5-1.5B-Open-R1-Distill
|
||||
overwrite_output_dir: true
|
||||
packing: true
|
||||
per_device_eval_batch_size: 4
|
||||
per_device_train_batch_size: 2
|
||||
push_to_hub: false
|
||||
report_to:
|
||||
- none
|
||||
save_strategy: "no"
|
||||
seed: 42
|
||||
warmup_ratio: 0.1
|
@ -1,50 +0,0 @@
|
||||
# Model arguments
|
||||
model_name_or_path: Qwen/Qwen2.5-7B-Instruct
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: eager
|
||||
|
||||
# Data training arguments
|
||||
dataset_name: DigitalLearningGmbH/MATH-lighteval
|
||||
dataset_configs:
|
||||
- train
|
||||
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
|
||||
|
||||
# GRPO trainer config
|
||||
bf16: true
|
||||
use_vllm: true
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.8
|
||||
do_eval: false
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
learning_rate: 3.0e-06
|
||||
log_completions: true
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 1024
|
||||
max_steps: -1
|
||||
num_generations: 7
|
||||
num_train_epochs: 1
|
||||
output_dir: data/Qwen2.5-7B-Open-R1-GRPO
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 4
|
||||
# push_to_hub: true
|
||||
report_to:
|
||||
- none
|
||||
reward_funcs:
|
||||
- accuracy
|
||||
- format
|
||||
reward_weights:
|
||||
- 1.0
|
||||
- 1.0
|
||||
save_strategy: "steps"
|
||||
save_steps: 10
|
||||
seed: 42
|
||||
warmup_ratio: 0.1
|
@ -1,50 +0,0 @@
|
||||
# Model arguments
|
||||
model_name_or_path: path/to/model_sfted
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: eager
|
||||
|
||||
# Data training arguments
|
||||
dataset_name: path/to/OpenR1-Math-220k_filtered_step3
|
||||
dataset_configs:
|
||||
- train
|
||||
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"
|
||||
|
||||
# GRPO trainer config
|
||||
bf16: true
|
||||
use_vllm: true
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.8
|
||||
do_eval: false
|
||||
gradient_accumulation_steps: 8
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
learning_rate: 3.0e-06
|
||||
log_completions: true
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 4096
|
||||
max_steps: -1
|
||||
num_generations: 7
|
||||
num_train_epochs: 1
|
||||
output_dir: data/Qwen2.5-7B-Open-R1-step3-GRPO
|
||||
overwrite_output_dir: true
|
||||
per_device_train_batch_size: 1
|
||||
# push_to_hub: true
|
||||
report_to:
|
||||
- none
|
||||
reward_funcs:
|
||||
- accuracy
|
||||
- format
|
||||
reward_weights:
|
||||
- 1.0
|
||||
- 1.0
|
||||
save_strategy: "steps"
|
||||
save_steps: 10
|
||||
seed: 42
|
||||
warmup_ratio: 0.1
|
@ -1,55 +0,0 @@
|
||||
# Model arguments
|
||||
model_name_or_path: Qwen/Qwen2.5-Math-7B
|
||||
model_revision: main
|
||||
torch_dtype: bfloat16
|
||||
attn_implementation: eager
|
||||
|
||||
# Data training arguments
|
||||
dataset_name: DigitalLearningGmbH/MATH-lighteval
|
||||
dataset_configs:
|
||||
- train
|
||||
system_prompt: "You are a helpful AI Assistant, designed to provided well-reasoned and detailed responses. You FIRST think about the reasoning process as an internal monologue and then provide the user with the answer. The reasoning process MUST BE enclosed within <think> and </think> tags."
|
||||
|
||||
# GRPO trainer config
|
||||
bf16: true
|
||||
use_vllm: true
|
||||
vllm_device: auto
|
||||
vllm_gpu_memory_utilization: 0.7
|
||||
do_eval: true
|
||||
eval_strategy: steps
|
||||
eval_steps: 100
|
||||
gradient_accumulation_steps: 8
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
# hub_model_id: Qwen-2.5-7B-Simple-RL
|
||||
# hub_strategy: every_save
|
||||
learning_rate: 3.0e-06
|
||||
log_completions: true
|
||||
log_level: info
|
||||
logging_first_step: true
|
||||
logging_steps: 1
|
||||
logging_strategy: steps
|
||||
lr_scheduler_type: cosine
|
||||
max_prompt_length: 512
|
||||
max_completion_length: 1024
|
||||
max_steps: -1
|
||||
num_generations: 7
|
||||
num_train_epochs: 1
|
||||
output_dir: data/Qwen-2.5-7B-Simple-RL
|
||||
overwrite_output_dir: true
|
||||
per_device_eval_batch_size: 4
|
||||
per_device_train_batch_size: 4
|
||||
# push_to_hub: true
|
||||
report_to:
|
||||
- none
|
||||
reward_funcs:
|
||||
- accuracy
|
||||
- format
|
||||
reward_weights:
|
||||
- 1.0
|
||||
- 1.0
|
||||
save_strategy: "steps"
|
||||
seed: 42
|
||||
warmup_ratio: 0.1
|
||||
save_steps: 10
|
@ -1 +0,0 @@
|
||||
**TODO:** we will add more recipes in the future, just like alignment-handbook, this is the purpose of adding recipes to this project.
|
@ -1,16 +0,0 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: MULTI_GPU
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: all
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
@ -1,21 +0,0 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
deepspeed_multinode_launcher: standard
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: false
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
@ -1,22 +0,0 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
deepspeed_multinode_launcher: standard
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: true
|
||||
zero3_save_16bit_model: true
|
||||
zero_stage: 3
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
@ -1,174 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from asyncio import Lock
|
||||
from typing import Set
|
||||
|
||||
from datasets import load_dataset
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
import aiofiles
|
||||
import aiohttp
|
||||
import uvloop
|
||||
|
||||
|
||||
file_lock = Lock()
|
||||
|
||||
|
||||
async def generate_completion(session, prompt, args):
|
||||
retry_budget = 10
|
||||
while retry_budget > 0:
|
||||
try:
|
||||
await asyncio.sleep(random.uniform(0.0, 0.1))
|
||||
async with session.post(
|
||||
f"http://{args.api_addr}/v1/chat/completions",
|
||||
json={
|
||||
"model": "default",
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": args.max_tokens,
|
||||
"temperature": args.temperature,
|
||||
"top_p": args.top_p,
|
||||
},
|
||||
headers={"Authorization": "Bearer EMPTY"},
|
||||
) as response:
|
||||
return await response.json(content_type=None)
|
||||
except Exception as e:
|
||||
print(f"API error (will retry): {e}")
|
||||
retry_budget -= 1
|
||||
await asyncio.sleep(10)
|
||||
return None
|
||||
|
||||
|
||||
async def process_example(example, session, args, output_file, pbar):
|
||||
prompt = args.prompt_template.format(prompt=example[args.prompt_column])
|
||||
|
||||
try:
|
||||
tasks = [generate_completion(session, prompt, args) for _ in range(args.num_generations)]
|
||||
|
||||
completions = await asyncio.gather(*tasks)
|
||||
|
||||
if any(completion is None for completion in completions):
|
||||
print(f"Error processing example")
|
||||
pbar.update(1)
|
||||
return None
|
||||
|
||||
generations = []
|
||||
finish_reasons = []
|
||||
api_metadata = []
|
||||
|
||||
for completion in completions:
|
||||
generations.append(completion["choices"][0]["message"]["content"])
|
||||
finish_reasons.append(completion["choices"][0]["finish_reason"])
|
||||
api_metadata.append(completion["usage"])
|
||||
|
||||
# Combine original dataset fields with generations
|
||||
result = {
|
||||
**example, # Preserve all original dataset fields
|
||||
"generations": generations,
|
||||
"finish_reasons": finish_reasons,
|
||||
"api_metadata": api_metadata,
|
||||
}
|
||||
|
||||
# Write to file with lock
|
||||
async with file_lock:
|
||||
async with aiofiles.open(output_file, mode="a") as f:
|
||||
await f.write(json.dumps(result) + "\n")
|
||||
await f.flush()
|
||||
|
||||
pbar.set_postfix(active=len(pbar.active_tasks), refresh=False)
|
||||
pbar.update(1)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"Error processing example: {e}")
|
||||
pbar.update(1)
|
||||
return None
|
||||
|
||||
|
||||
async def load_processed_uuids(output_file, uuid_column):
|
||||
processed_uuids = set()
|
||||
if os.path.exists(output_file):
|
||||
async with aiofiles.open(output_file, mode="r") as f:
|
||||
async for line in f:
|
||||
try:
|
||||
data = json.loads(line)
|
||||
processed_uuids.add(hashlib.md5(str(data[uuid_column]).encode()).hexdigest())
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
return processed_uuids
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dataset-name", type=str, required=True)
|
||||
parser.add_argument("--output-file", type=str, required=True)
|
||||
parser.add_argument("--prompt-column", type=str, required=True)
|
||||
parser.add_argument("--uuid-column", type=str, required=True)
|
||||
parser.add_argument("--api-addr", type=str, default="localhost:39876")
|
||||
parser.add_argument("--num-generations", type=int, default=4)
|
||||
parser.add_argument(
|
||||
"--prompt-template",
|
||||
type=str,
|
||||
default="You will be given a problem. Please reason step by step, and put your final answer within \\boxed{{}}:\n{prompt}",
|
||||
)
|
||||
parser.add_argument("--temperature", type=float, default=0.6)
|
||||
parser.add_argument("--top-p", type=float, default=0.95)
|
||||
parser.add_argument("--max-tokens", type=int, default=16384)
|
||||
parser.add_argument("--max-concurrent", type=int, default=1000)
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset = load_dataset(args.dataset_name, split="train").shuffle()
|
||||
processed_uuids = await load_processed_uuids(args.output_file, args.uuid_column)
|
||||
if processed_uuids:
|
||||
print(f"Found {len(processed_uuids)} already processed examples, resuming from there...")
|
||||
|
||||
if not os.path.exists(args.output_file):
|
||||
async with aiofiles.open(args.output_file, mode="w") as f:
|
||||
await f.write("")
|
||||
|
||||
active_tasks: Set[asyncio.Task] = set()
|
||||
|
||||
pbar = tqdm(
|
||||
total=len(dataset) - len(processed_uuids),
|
||||
desc="Generating responses",
|
||||
unit="row",
|
||||
mininterval=2,
|
||||
smoothing=0.0001,
|
||||
)
|
||||
pbar.active_tasks = active_tasks
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=60 * 60),
|
||||
connector=aiohttp.TCPConnector(limit=args.max_concurrent, ttl_dns_cache=300, keepalive_timeout=60 * 60),
|
||||
) as session:
|
||||
for example in dataset:
|
||||
uuid = hashlib.md5(str(example[args.uuid_column]).encode()).hexdigest()
|
||||
if uuid not in processed_uuids:
|
||||
# Wait if we've hit the concurrency limit
|
||||
while len(active_tasks) >= args.max_concurrent:
|
||||
done, active_tasks = await asyncio.wait(active_tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
for task in done:
|
||||
try:
|
||||
await task
|
||||
except Exception as e:
|
||||
print(f"Task failed: {e}")
|
||||
|
||||
task = asyncio.create_task(process_example(example, session, args, args.output_file, pbar))
|
||||
active_tasks.add(task)
|
||||
task.add_done_callback(active_tasks.discard)
|
||||
|
||||
pbar.set_postfix(active=len(active_tasks), refresh=True)
|
||||
|
||||
# Wait for remaining tasks
|
||||
if active_tasks:
|
||||
await asyncio.gather(*active_tasks, return_exceptions=True)
|
||||
|
||||
pbar.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvloop.install()
|
||||
asyncio.run(main())
|
@ -1,61 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional
|
||||
|
||||
from open_r1.utils.evaluation import SUPPORTED_BENCHMARKS, run_benchmark_jobs
|
||||
from open_r1.configs import SFTConfig
|
||||
from trl import ModelConfig, TrlParser
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_id: str = field(
|
||||
default="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
||||
metadata={"help": "The Hub model id to push the model to."},
|
||||
)
|
||||
model_revision: str = field(default="main", metadata={"help": "The Hub model branch to push the model to."})
|
||||
trust_remote_code: bool = field(default=False, metadata={"help": "Trust the remote code."})
|
||||
benchmarks: List[str] = field(
|
||||
default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
|
||||
)
|
||||
list_benchmarks: bool = field(default=False, metadata={"help": "List all supported benchmarks."})
|
||||
system_prompt: Optional[str] = field(
|
||||
default=None, metadata={"help": "The system prompt to use for the benchmark."}
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = TrlParser(ScriptArguments)
|
||||
args = parser.parse_args_and_config()[0]
|
||||
if args.list_benchmarks:
|
||||
print("Supported benchmarks:")
|
||||
for benchmark in SUPPORTED_BENCHMARKS:
|
||||
print(f" - {benchmark}")
|
||||
return
|
||||
benchmark_args = SFTConfig(
|
||||
output_dir="",
|
||||
hub_model_id=args.model_id,
|
||||
hub_model_revision=args.model_revision,
|
||||
benchmarks=args.benchmarks,
|
||||
system_prompt=args.system_prompt,
|
||||
)
|
||||
run_benchmark_jobs(
|
||||
benchmark_args,
|
||||
ModelConfig(model_name_or_path="", model_revision="", trust_remote_code=args.trust_remote_code),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,55 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Push the details from a LightEval run to the Hub.
|
||||
|
||||
Usage:
|
||||
|
||||
python src/open_r1/utils/upload_details.py \
|
||||
--data_files {path_to_parquet_file} \
|
||||
--hub_repo_id {hub_repo_id} \
|
||||
--config_name {config_name}
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
data_files: List[str] = field(default_factory=list)
|
||||
hub_repo_id: str = None
|
||||
config_name: str = None
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
if all(file.endswith(".json") for file in args.data_files):
|
||||
ds = load_dataset("json", data_files=args.data_files)
|
||||
elif all(file.endswith(".jsonl") for file in args.data_files):
|
||||
ds = load_dataset("json", data_files=args.data_files)
|
||||
else:
|
||||
ds = load_dataset("parquet", data_files=args.data_files)
|
||||
url = ds.push_to_hub(args.hub_repo_id, config_name=args.config_name, private=True)
|
||||
print(f"Dataset available at: {url}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,41 +0,0 @@
|
||||
[isort]
|
||||
default_section = FIRSTPARTY
|
||||
ensure_newline_before_comments = True
|
||||
force_grid_wrap = 0
|
||||
include_trailing_comma = True
|
||||
known_first_party = open_r1
|
||||
known_third_party =
|
||||
transformers
|
||||
datasets
|
||||
fugashi
|
||||
git
|
||||
h5py
|
||||
matplotlib
|
||||
nltk
|
||||
numpy
|
||||
packaging
|
||||
pandas
|
||||
psutil
|
||||
pytest
|
||||
rouge_score
|
||||
sacrebleu
|
||||
seqeval
|
||||
sklearn
|
||||
streamlit
|
||||
torch
|
||||
tqdm
|
||||
|
||||
line_length = 119
|
||||
lines_after_imports = 2
|
||||
multi_line_output = 3
|
||||
use_parentheses = True
|
||||
|
||||
[flake8]
|
||||
ignore = E203, E501, E741, W503, W605
|
||||
max-line-length = 119
|
||||
per-file-ignores =
|
||||
# imported but unused
|
||||
__init__.py: F401
|
||||
|
||||
[tool:pytest]
|
||||
doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS
|
@ -1,139 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Adapted from huggingface/transformers: https://github.com/huggingface/transformers/blob/21a2d900eceeded7be9edc445b56877b95eda4ca/setup.py
|
||||
|
||||
# This project includes modifications to the original codebase:
|
||||
# All email addresses and personal identifiers have been removed.
|
||||
|
||||
import re
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
stale_egg_info = Path(__file__).parent / "open_r1.egg-info"
|
||||
if stale_egg_info.exists():
|
||||
print(
|
||||
(
|
||||
"Warning: {} exists.\n\n"
|
||||
"If you recently updated open_r1, this is expected,\n"
|
||||
"but it may prevent open_r1 from installing in editable mode.\n\n"
|
||||
"This directory is automatically generated by Python's packaging tools.\n"
|
||||
"I will remove it now.\n\n"
|
||||
).format(stale_egg_info)
|
||||
)
|
||||
shutil.rmtree(stale_egg_info)
|
||||
|
||||
|
||||
# IMPORTANT: all dependencies should be listed here with their version requirements, if any.
|
||||
# * If a dependency is fast-moving (e.g. transformers), pin to the exact version
|
||||
_deps = [
|
||||
"accelerate>=1.2.1",
|
||||
# "bitsandbytes>=0.43.0",
|
||||
"datasets>=3.2.0",
|
||||
"deepspeed==0.15.4",
|
||||
"distilabel[vllm,ray,openai]>=1.5.2",
|
||||
"e2b-code-interpreter>=1.0.5",
|
||||
"einops>=0.8.0",
|
||||
"flake8>=6.0.0",
|
||||
# "flash_attn>=2.7.4.post1",
|
||||
"hf_transfer>=0.1.4",
|
||||
"huggingface-hub[cli]>=0.19.2,<1.0",
|
||||
"isort>=5.12.0",
|
||||
"latex2sympy2_extended>=1.0.6",
|
||||
"liger_kernel==0.5.2",
|
||||
"lighteval @ git+https://github.com/huggingface/lighteval.git@86f62259f105ae164f655e0b91c92a823a742724#egg=lighteval[math]",
|
||||
"math-verify==0.5.2", # Used for math verification in grpo
|
||||
"packaging>=23.0",
|
||||
"parameterized>=0.9.0",
|
||||
"peft>=0.14.0",
|
||||
"pytest",
|
||||
"python-dotenv",
|
||||
"ruff>=0.9.0",
|
||||
"safetensors>=0.3.3",
|
||||
"sentencepiece>=0.1.99",
|
||||
"torch==2.5.1",
|
||||
"wandb>=0.19.1",
|
||||
]
|
||||
|
||||
# this is a lookup table with items like:
|
||||
#
|
||||
# tokenizers: "tokenizers==0.9.4"
|
||||
# packaging: "packaging"
|
||||
#
|
||||
# some of the values are versioned whereas others aren't.
|
||||
deps = {b: a for a, b in (re.findall(r"^(([^!=<>~ \[\]]+)(?:\[[^\]]+\])?(?:[!=<>~ ].*)?$)", x)[0] for x in _deps)}
|
||||
|
||||
|
||||
def deps_list(*pkgs):
|
||||
return [deps[pkg] for pkg in pkgs]
|
||||
|
||||
|
||||
extras = {}
|
||||
extras["tests"] = deps_list("pytest", "parameterized", "math-verify")
|
||||
extras["torch"] = deps_list("torch")
|
||||
extras["quality"] = deps_list("ruff", "isort", "flake8")
|
||||
# extras["train"] = deps_list("flash_attn")
|
||||
extras["code"] = deps_list("e2b-code-interpreter", "python-dotenv")
|
||||
extras["eval"] = deps_list("lighteval", "math-verify")
|
||||
extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"] #+ extras["train"]
|
||||
|
||||
# core dependencies shared across the whole project - keep this to a bare minimum :)
|
||||
install_requires = [
|
||||
deps["accelerate"],
|
||||
# deps["bitsandbytes"],
|
||||
deps["einops"],
|
||||
deps["datasets"],
|
||||
deps["deepspeed"],
|
||||
deps["hf_transfer"],
|
||||
deps["huggingface-hub"],
|
||||
deps["latex2sympy2_extended"],
|
||||
deps["math-verify"],
|
||||
# deps["liger_kernel"],
|
||||
deps["packaging"], # utilities from PyPA to e.g., compare versions
|
||||
deps["safetensors"],
|
||||
deps["sentencepiece"],
|
||||
# deps["transformers"],
|
||||
# deps["trl"],
|
||||
]
|
||||
|
||||
setup(
|
||||
name="open-r1",
|
||||
version="0.1.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="The Hugging Face team (past and future)",
|
||||
description="Open R1",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
keywords="llm inference-time compute reasoning",
|
||||
license="Apache",
|
||||
package_dir={"": "src"},
|
||||
packages=find_packages("src"),
|
||||
zip_safe=False,
|
||||
extras_require=extras,
|
||||
python_requires=">=3.10.9",
|
||||
install_requires=install_requires,
|
||||
classifiers=[
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: Education",
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
],
|
||||
)
|
@ -1,30 +0,0 @@
|
||||
## Serving DeepSeek-R1 on 2x8 H100 SLURM nodes with SGLang
|
||||
|
||||
1. Set up the environment (adjust for your cuda version):
|
||||
```bash
|
||||
conda create -n sglang124 python=3.11
|
||||
conda activate sglang124
|
||||
|
||||
pip install torch=2.5.1 --index-url https://download.pytorch.org/whl/cu124
|
||||
|
||||
pip install sgl-kernel --force-reinstall --no-deps
|
||||
pip install "sglang[all]>=0.4.2.post4" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/
|
||||
```
|
||||
|
||||
2. Run the server and wait for the model to load:
|
||||
```bash
|
||||
sbatch slurm/serve_r1.slurm -m "/fsx/deepseek-r1-checkpoint" -e "sglang124"
|
||||
```
|
||||
|
||||
3. Run the data generation script:
|
||||
```bash
|
||||
python scripts/generate_reasoning.py \
|
||||
--dataset-name "AI-MO/NuminaMath-1.5" \
|
||||
--output-file "numinamath_r1_generations.jsonl" \
|
||||
--prompt-column "problem" \
|
||||
--uuid-column "problem" \
|
||||
--api-addr "<SGLANG_SERVER_ADDRESS>:39877" \
|
||||
--num-generations 2 \
|
||||
--max-tokens 16384 \
|
||||
--max-concurrent 200
|
||||
```
|
@ -1,89 +0,0 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --gres=gpu:8
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --output=./logs/%x-%j.out
|
||||
#SBATCH --err=./logs/%x-%j.err
|
||||
#SBATCH --requeue
|
||||
|
||||
# Specific configuration optimized for the Hugging Face Compute Cluster
|
||||
# Be ye warned this may not work on other clusters!
|
||||
module load cuda/12.4
|
||||
|
||||
set -x -e
|
||||
|
||||
source ~/.bashrc
|
||||
source openr1/bin/activate
|
||||
|
||||
TASK_NAME=$1
|
||||
TASKS=$2
|
||||
MODEL_ID=$3
|
||||
MODEL_REVISION=$4
|
||||
# Optional args
|
||||
[ -z "$5"] && TENSOR_PARALLEL=False || TENSOR_PARALLEL=$5
|
||||
[ -z "$6"] && TRUST_REMOTE_CODE=False || TRUST_REMOTE_CODE=$6
|
||||
# $7 is reserved for system_prompt, see line 51
|
||||
NUM_GPUS=$(nvidia-smi -L | wc -l)
|
||||
|
||||
# Set Whether to use tensor parallelism or data parallelism
|
||||
if [ "$TENSOR_PARALLEL" = "True" ]; then
|
||||
# use TP to shard model across NUM_GPUS
|
||||
export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
MODEL_ARGS="pretrained=$MODEL_ID,revision=$MODEL_REVISION,trust_remote_code=$TRUST_REMOTE_CODE,dtype=bfloat16,tensor_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilisation=0.8"
|
||||
else
|
||||
MODEL_ARGS="pretrained=$MODEL_ID,revision=$MODEL_REVISION,trust_remote_code=$TRUST_REMOTE_CODE,dtype=bfloat16,data_parallel_size=$NUM_GPUS,max_model_length=32768,gpu_memory_utilisation=0.8"
|
||||
fi
|
||||
|
||||
LM_EVAL_REPO_ID="open-r1/open-r1-eval-leaderboard"
|
||||
MODEL_NAME=$(echo $MODEL_ID | sed 's/\//_/g') # replaces / with _
|
||||
DETAILS_REPO_ID="open-r1/details-$MODEL_NAME"
|
||||
OUTPUT_DIR="eval_results/$MODEL_ID/$MODEL_REVISION/$TASK_NAME"
|
||||
# We need this flag since we run this script from training jobs that use DeepSpeed and the env vars get progated which causes errors during evaluation
|
||||
ACCELERATE_USE_DEEPSPEED=false
|
||||
# Enable fast downloads
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1
|
||||
|
||||
echo "Running lighteval script ..."
|
||||
echo "Eval results will be saved to $OUTPUT_DIR"
|
||||
# Check if "custom" is a substring of TASKS
|
||||
if [[ $TASKS == *"custom"* ]]; then
|
||||
echo "Custom task detected. Running custom task evaluation script ..."
|
||||
lighteval vllm $MODEL_ARGS $TASKS \
|
||||
--custom-tasks "src/open_r1/evaluate.py" \
|
||||
--use-chat-template \
|
||||
--output-dir $OUTPUT_DIR \
|
||||
--save-details \
|
||||
${7:+--system-prompt "$7"}
|
||||
else
|
||||
lighteval vllm $MODEL_ARGS $TASKS \
|
||||
--use-chat-template \
|
||||
--output-dir $OUTPUT_DIR \
|
||||
--save-details \
|
||||
${7:+--system-prompt "$7"}
|
||||
fi
|
||||
|
||||
OUTPUT_FILEPATHS=$(find $OUTPUT_DIR/results/ -type f \( -name "*.json" \))
|
||||
for filepath in $OUTPUT_FILEPATHS; do
|
||||
echo "Uploading $filepath to Hugging Face Hub..."
|
||||
filename=$(basename -- "$filepath")
|
||||
for attempt in {1..20}; do
|
||||
if huggingface-cli upload --repo-type space --private $LM_EVAL_REPO_ID $filepath $OUTPUT_DIR/$filename; then
|
||||
echo "Upload succeeded for $filepath"
|
||||
break
|
||||
else
|
||||
echo "Upload failed for $filepath. Attempt $attempt of 20. Retrying in 5 seconds..."
|
||||
sleep 5
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
echo "Uploading details to Hugging Face Hub..."
|
||||
DETAILS_FILEPATHS=$(find $OUTPUT_DIR/details/ -type f \( -name "*.parquet" \))
|
||||
echo "DETAILS_FILEPATHS: $DETAILS_FILEPATHS"
|
||||
TIMESTAMP=$(date +"%Y-%m-%dT%H-%M-%S")
|
||||
python scripts/upload_details.py --data_files $DETAILS_FILEPATHS --hub_repo_id $DETAILS_REPO_ID --config_name $MODEL_REVISION.$TASK_NAME.$TIMESTAMP
|
||||
|
||||
echo "Cleaning up ..."
|
||||
rm -rf $OUTPUT_DIR
|
||||
|
||||
echo "Done!"
|
@ -1,132 +0,0 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=r1-vllm
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --qos=normal
|
||||
#SBATCH --nodes=4
|
||||
#SBATCH --gpus-per-node=8
|
||||
#SBATCH --exclusive
|
||||
#SBATCH --output=./logs/%x_%j_%n.out
|
||||
#SBATCH --error=./logs/%x_%j_%n.err
|
||||
#SBATCH --time=7-00:00:00
|
||||
#SBATCH --ntasks-per-node=1
|
||||
|
||||
set -exuo pipefail
|
||||
|
||||
MODEL_PATH="deepseek-ai/DeepSeek-R1"
|
||||
CONDA_ENV="vllm7"
|
||||
SERVER_PORT=8000
|
||||
RAY_PORT=6379
|
||||
RAY_DASHBOARD_PORT=8265
|
||||
|
||||
while getopts "m:e:h" opt; do
|
||||
case $opt in
|
||||
m) MODEL_PATH="$OPTARG" ;;
|
||||
e) CONDA_ENV="$OPTARG" ;;
|
||||
h|?) echo "Usage: sbatch $0 [-m MODEL_PATH] [-e CONDA_ENV]"; exit 1 ;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Environment setup
|
||||
module load cuda/12.1
|
||||
source ~/.bashrc
|
||||
source "$CONDA_PREFIX/etc/profile.d/conda.sh"
|
||||
conda activate "$CONDA_ENV" || { echo "Failed to activate conda env $CONDA_ENV"; exit 1; }
|
||||
|
||||
# Get nodes information
|
||||
NODES=($(scontrol show hostnames "$SLURM_JOB_NODELIST"))
|
||||
HEAD_NODE="${NODES[0]}"
|
||||
HEAD_NODE_IP=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address)
|
||||
|
||||
echo "SLURM_JOB_ID: $SLURM_JOB_ID"
|
||||
echo "SLURM_JOB_NODELIST: $SLURM_JOB_NODELIST"
|
||||
echo "Head node: $HEAD_NODE ($HEAD_NODE_IP)"
|
||||
|
||||
# Start Ray head node
|
||||
echo "Starting Ray head node at $HEAD_NODE"
|
||||
srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" \
|
||||
ray start --head \
|
||||
--node-ip-address="$HEAD_NODE_IP" \
|
||||
--port=$RAY_PORT \
|
||||
--dashboard-host=0.0.0.0 \
|
||||
--dashboard-port=$RAY_DASHBOARD_PORT \
|
||||
--block &
|
||||
|
||||
sleep 10
|
||||
|
||||
# Start Ray worker nodes
|
||||
WORKER_COUNT=$((SLURM_JOB_NUM_NODES - 1))
|
||||
for ((i = 1; i <= WORKER_COUNT; i++)); do
|
||||
WORKER_NODE="${NODES[$i]}"
|
||||
echo "Starting Ray worker $i at $WORKER_NODE"
|
||||
srun --nodes=1 --ntasks=1 -w "$WORKER_NODE" \
|
||||
ray start --address "$HEAD_NODE_IP:$RAY_PORT" \
|
||||
--block &
|
||||
sleep 5
|
||||
done
|
||||
|
||||
echo "Waiting for Ray cluster to initialize..."
|
||||
sleep 60
|
||||
|
||||
# Start vLLM server
|
||||
echo "Starting vLLM server..."
|
||||
RAY_ADDRESS="http://$HEAD_NODE_IP:$RAY_DASHBOARD_PORT" ray job submit \
|
||||
--working-dir src/open_r1 \
|
||||
--no-wait \
|
||||
--job-id vllm-server \
|
||||
-- vllm serve "$MODEL_PATH" \
|
||||
--tensor-parallel-size 8 \
|
||||
--pipeline-parallel-size 4 \
|
||||
--gpu-memory-utilization 0.90 \
|
||||
--max-model-len 32768 \
|
||||
--max-num-batched-tokens 262144 \
|
||||
--max-num-seqs 128 \
|
||||
--max-seq-len-to-capture 32768 \
|
||||
--enable-chunked-prefill true \
|
||||
--preemption-mode recompute \
|
||||
--swap-space 128 \
|
||||
--trust-remote-code \
|
||||
--distributed-executor-backend ray
|
||||
|
||||
# Wait for server with timeout
|
||||
TIMEOUT=3600 # 1h
|
||||
START_TIME=$(date +%s)
|
||||
echo "Waiting for vLLM server (http://$HEAD_NODE_IP:$SERVER_PORT)..."
|
||||
|
||||
while true; do
|
||||
if curl -s -o /dev/null -w "%{http_code}" "http://$HEAD_NODE_IP:$SERVER_PORT/health" >/dev/null 2>&1; then
|
||||
echo "Server is ready at http://$HEAD_NODE_IP:$SERVER_PORT"
|
||||
break
|
||||
fi
|
||||
|
||||
CURRENT_TIME=$(date +%s)
|
||||
if [ $((CURRENT_TIME - START_TIME)) -gt $TIMEOUT ]; then
|
||||
echo "Error: Server failed to start within $TIMEOUT seconds"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Still waiting... ($(($CURRENT_TIME - $START_TIME)) seconds elapsed)"
|
||||
sleep 60
|
||||
done
|
||||
|
||||
echo "Checking available models..."
|
||||
curl "http://$HEAD_NODE_IP:$SERVER_PORT/v1/models"
|
||||
sleep 10
|
||||
|
||||
echo "Executing sanity check..."
|
||||
curl "http://$HEAD_NODE_IP:$SERVER_PORT/v1/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{
|
||||
\"model\": \"default\",
|
||||
\"prompt\": \"<|begin▁of▁sentence|><|User|>hi, how are you?<|Assistant|>\",
|
||||
\"max_tokens\": 2048,
|
||||
\"temperature\": 0.6
|
||||
}"
|
||||
|
||||
# Keep the job running with health checks
|
||||
while true; do
|
||||
if ! curl -s -o /dev/null "http://$HEAD_NODE_IP:$SERVER_PORT/health"; then
|
||||
echo "Error: Server health check failed"
|
||||
exit 1
|
||||
fi
|
||||
sleep 300
|
||||
done
|
@ -1,244 +0,0 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=deepseek-r1-generation
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --qos=normal
|
||||
#SBATCH --nodes=2
|
||||
#SBATCH --exclusive
|
||||
#SBATCH --gpus-per-node=8
|
||||
#SBATCH --output=./logs/%x-%j.out
|
||||
#SBATCH --err=./logs/%x-%j.err
|
||||
#SBATCH --time=04-00:00:00
|
||||
|
||||
# Parse command line arguments
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--hf-dataset)
|
||||
HF_DATASET="$2"
|
||||
shift 2
|
||||
;;
|
||||
--hf-dataset-config)
|
||||
HF_DATASET_CONFIG="$2"
|
||||
shift 2
|
||||
;;
|
||||
--hf-dataset-split)
|
||||
HF_DATASET_SPLIT="$2"
|
||||
shift 2
|
||||
;;
|
||||
--prompt-column)
|
||||
PROMPT_COLUMN="$2"
|
||||
shift 2
|
||||
;;
|
||||
--prompt-template)
|
||||
PROMPT_TEMPLATE="$2"
|
||||
shift 2
|
||||
;;
|
||||
--model)
|
||||
MODEL="$2"
|
||||
shift 2
|
||||
;;
|
||||
--temperature)
|
||||
TEMPERATURE="$2"
|
||||
shift 2
|
||||
;;
|
||||
--top-p)
|
||||
TOP_P="$2"
|
||||
shift 2
|
||||
;;
|
||||
--max-new-tokens)
|
||||
MAX_NEW_TOKENS="$2"
|
||||
shift 2
|
||||
;;
|
||||
--num-generations)
|
||||
NUM_GENERATIONS="$2"
|
||||
shift 2
|
||||
;;
|
||||
--input-batch-size)
|
||||
INPUT_BATCH_SIZE="$2"
|
||||
shift 2
|
||||
;;
|
||||
--client-replicas)
|
||||
CLIENT_REPLICAS="$2"
|
||||
shift 2
|
||||
;;
|
||||
--timeout)
|
||||
TIMEOUT="$2"
|
||||
shift 2
|
||||
;;
|
||||
--retries)
|
||||
RETRIES="$2"
|
||||
shift 2
|
||||
;;
|
||||
--hf-output-dataset)
|
||||
HF_OUTPUT_DATASET="$2"
|
||||
shift 2
|
||||
;;
|
||||
--private)
|
||||
PRIVATE="true"
|
||||
shift
|
||||
;;
|
||||
*)
|
||||
echo "Unknown parameter: $1"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [ -z "$MODEL" ] || [ -z "$HF_DATASET" ]; then
|
||||
echo "Error: --model and --hf-dataset are required parameters"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Set default values for optional parameters
|
||||
HF_DATASET_SPLIT=${HF_DATASET_SPLIT:-"train"}
|
||||
PROMPT_COLUMN=${PROMPT_COLUMN:-"prompt"}
|
||||
PROMPT_TEMPLATE=${PROMPT_TEMPLATE:-"{{ instruction }}"}
|
||||
MAX_NEW_TOKENS=${MAX_NEW_TOKENS:-8192}
|
||||
NUM_GENERATIONS=${NUM_GENERATIONS:-1}
|
||||
INPUT_BATCH_SIZE=${INPUT_BATCH_SIZE:-64}
|
||||
CLIENT_REPLICAS=${CLIENT_REPLICAS:-1}
|
||||
TIMEOUT=${TIMEOUT:-900}
|
||||
RETRIES=${RETRIES:-0}
|
||||
PRIVATE=${PRIVATE:-"false"}
|
||||
|
||||
# Print all input arguments
|
||||
echo "Input arguments:"
|
||||
echo "MODEL: $MODEL"
|
||||
echo "HF_DATASET: $HF_DATASET"
|
||||
echo "HF_DATASET_CONFIG: $HF_DATASET_CONFIG"
|
||||
echo "HF_DATASET_SPLIT: $HF_DATASET_SPLIT"
|
||||
echo "PROMPT_COLUMN: $PROMPT_COLUMN"
|
||||
echo "PROMPT_TEMPLATE: $PROMPT_TEMPLATE"
|
||||
echo "TEMPERATURE: $TEMPERATURE"
|
||||
echo "TOP_P: $TOP_P"
|
||||
echo "MAX_NEW_TOKENS: $MAX_NEW_TOKENS"
|
||||
echo "NUM_GENERATIONS: $NUM_GENERATIONS"
|
||||
echo "INPUT_BATCH_SIZE: $INPUT_BATCH_SIZE"
|
||||
echo "CLIENT_REPLICAS: $CLIENT_REPLICAS"
|
||||
echo "TIMEOUT: $TIMEOUT"
|
||||
echo "RETRIES: $RETRIES"
|
||||
echo "HF_OUTPUT_DATASET: $HF_OUTPUT_DATASET"
|
||||
echo "PRIVATE: $PRIVATE"
|
||||
echo "-------------------"
|
||||
|
||||
set -ex
|
||||
|
||||
module load cuda/12.4
|
||||
|
||||
export LD_LIBRARY_PATH=.venv/lib/python3.11/site-packages/nvidia/nvjitlink/lib
|
||||
|
||||
echo "SLURM_JOB_ID: $SLURM_JOB_ID"
|
||||
echo "SLURM_JOB_NODELIST: $SLURM_JOB_NODELIST"
|
||||
|
||||
source openr1/bin/activate
|
||||
|
||||
# Getting the node names
|
||||
nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST")
|
||||
nodes_array=($nodes)
|
||||
|
||||
# Get the IP address of the head node
|
||||
head_node=${nodes_array[0]}
|
||||
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
|
||||
|
||||
# Start Ray head node
|
||||
port=6379
|
||||
ip_head=$head_node_ip:$port
|
||||
export ip_head
|
||||
echo "IP Head: $ip_head"
|
||||
|
||||
echo "Starting HEAD at $head_node"
|
||||
srun --nodes=1 --ntasks=1 -w "$head_node" \
|
||||
ray start --head --node-ip-address="$head_node_ip" --port=$port \
|
||||
--dashboard-host=0.0.0.0 \
|
||||
--dashboard-port=8265 \
|
||||
--block &
|
||||
|
||||
# Give some time to head node to start...
|
||||
sleep 10
|
||||
|
||||
# Start Ray worker nodes
|
||||
worker_num=$((SLURM_JOB_NUM_NODES - 1))
|
||||
|
||||
# Start from 1 (0 is head node)
|
||||
for ((i = 1; i <= worker_num; i++)); do
|
||||
node_i=${nodes_array[$i]}
|
||||
echo "Starting WORKER $i at $node_i"
|
||||
srun --nodes=1 --ntasks=1 -w "$node_i" \
|
||||
ray start --address "$ip_head" \
|
||||
--block &
|
||||
sleep 5
|
||||
done
|
||||
|
||||
# Give some time to the Ray cluster to gather info
|
||||
echo "Waiting a bit for Ray cluster to gather node info..."
|
||||
sleep 60
|
||||
|
||||
# Run vllm
|
||||
RAY_ADDRESS="http://$head_node_ip:8265" ray job submit \
|
||||
--working-dir src/open_r1 \
|
||||
--no-wait \
|
||||
--job-id vllm-server \
|
||||
-- vllm serve $MODEL \
|
||||
--tensor-parallel-size $SLURM_GPUS_PER_NODE \
|
||||
--pipeline-parallel-size $SLURM_JOB_NUM_NODES \
|
||||
--gpu-memory-utilization=0.85 \
|
||||
--max-model-len 16384 \
|
||||
--enable-chunked-prefill \
|
||||
--trust-remote-code \
|
||||
--distributed-executor-backend ray
|
||||
|
||||
# wait for vllm to load the model
|
||||
echo "Waiting for vLLM (http://$head_node_ip:8000) server to be up..."
|
||||
|
||||
# wait for vllm to load and serve the model
|
||||
while true; do
|
||||
if curl -s -o /dev/null -w "%{http_code}" http://$head_node_ip:8000 >/dev/null 2>&1; then
|
||||
echo "Received response from http://$head_node_ip:8000"
|
||||
break
|
||||
else
|
||||
echo "Still waiting... (Press Ctrl+C to cancel)"
|
||||
sleep 60
|
||||
fi
|
||||
done
|
||||
|
||||
echo "Checking available models..."
|
||||
curl http://$head_node_ip:8000/v1/models
|
||||
|
||||
echo "Executing sanity check..."
|
||||
curl http://$head_node_ip:8000/v1/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{
|
||||
\"model\": \"$MODEL\",
|
||||
\"prompt\": \"<|begin▁of▁sentence|><|User|>hi, how are you?<|Assistant|>\",
|
||||
\"max_tokens\": 2048,
|
||||
\"temperature\": 0.6
|
||||
}"
|
||||
|
||||
# Finally submit the job to the cluster
|
||||
echo "Submitting job to ray cluster..."
|
||||
RAY_ADDRESS="http://$head_node_ip:8265" ray job submit \
|
||||
--working-dir src/open_r1 \
|
||||
--job-id generate \
|
||||
-- python -u generate.py \
|
||||
--model "$MODEL" \
|
||||
--hf-dataset "$HF_DATASET" \
|
||||
${HF_DATASET_CONFIG:+--hf-dataset-config "$HF_DATASET_CONFIG"} \
|
||||
--hf-dataset-split "$HF_DATASET_SPLIT" \
|
||||
--prompt-column "$PROMPT_COLUMN" \
|
||||
--prompt-template "$PROMPT_TEMPLATE" \
|
||||
${TEMPERATURE:+--temperature "$TEMPERATURE"} \
|
||||
${TOP_P:+--top-p "$TOP_P"} \
|
||||
--max-new-tokens "$MAX_NEW_TOKENS" \
|
||||
--num-generations "$NUM_GENERATIONS" \
|
||||
--input-batch-size "$INPUT_BATCH_SIZE" \
|
||||
--client-replicas "$CLIENT_REPLICAS" \
|
||||
--timeout "$TIMEOUT" \
|
||||
--retries "$RETRIES" \
|
||||
${HF_OUTPUT_DATASET:+--hf-output-dataset "$HF_OUTPUT_DATASET"} \
|
||||
${PRIVATE:+--private} \
|
||||
--vllm-server-url "http://$head_node_ip:8000/v1"
|
||||
|
||||
mkdir -p ray_logs
|
||||
|
||||
echo "Downloading Ray job logs..."
|
||||
RAY_ADDRESS="http://$head_node_ip:8265" ray job logs --job-id vllm-server > ray_logs/vllm-server-${SLURM_JOB_ID}.log
|
||||
RAY_ADDRESS="http://$head_node_ip:8265" ray job logs --job-id generate > ray_logs/generate-${SLURM_JOB_ID}.log
|
@ -1,109 +0,0 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=r1-server
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --qos=normal
|
||||
#SBATCH --nodes=2
|
||||
#SBATCH --gpus-per-node=8
|
||||
#SBATCH --exclusive
|
||||
#SBATCH --output=./logs/%x_%j_%n.out
|
||||
#SBATCH --error=./logs/%x_%j_%n.err
|
||||
#SBATCH --time=7-00:00:00
|
||||
#SBATCH --ntasks-per-node=1
|
||||
|
||||
set -exuo pipefail
|
||||
|
||||
MODEL_PATH="deepseek-ai/DeepSeek-R1"
|
||||
CONDA_ENV="sglang124"
|
||||
ROUTER_ADDRESS=""
|
||||
SERVER_PORT=39877
|
||||
DIST_PORT=45000
|
||||
|
||||
# TODO: Adjust these variables to your cluster configuration
|
||||
export OUTLINES_CACHE_DIR=/scratch/serve_r1/ocache/
|
||||
export TRITON_HOME=/scratch/serve_r1/triton/
|
||||
export GLOO_SOCKET_IFNAME="enp71s0"
|
||||
export NCCL_SOCKET_IFNAME="enp71s0"
|
||||
|
||||
while getopts "m:e:r:h" opt; do
|
||||
case $opt in
|
||||
m) MODEL_PATH="$OPTARG" ;;
|
||||
e) CONDA_ENV="$OPTARG" ;;
|
||||
r) ROUTER_ADDRESS="$OPTARG" ;;
|
||||
h|?) echo "Usage: sbatch $0 [-m MODEL_PATH] [-e CONDA_ENV] [-r ROUTER_ADDRESS]"; exit 1 ;;
|
||||
esac
|
||||
done
|
||||
|
||||
# TODO: Environment setup, adjust to your cluster configuration
|
||||
module load cuda/12.4
|
||||
source ~/.bashrc
|
||||
source "$CONDA_PREFIX/etc/profile.d/conda.sh"
|
||||
conda activate "$CONDA_ENV" || { echo "Failed to activate conda env $CONDA_ENV"; exit 1; }
|
||||
|
||||
FIRST_NODE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n1)
|
||||
FIRST_NODE_IP=$(srun --nodes=1 --ntasks=1 -w "$FIRST_NODE" hostname --ip-address)
|
||||
|
||||
# Launch servers synchronously across all nodes
|
||||
# (--max-running-requests=56 is rough estimate to avoid too many evicted/preempted 16k-long requests)
|
||||
srun --nodes=2 --ntasks=2 --ntasks-per-node=1 \
|
||||
bash -c "python -m sglang.launch_server \
|
||||
--model-path '$MODEL_PATH' \
|
||||
--tp 16 \
|
||||
--dist-init-addr '$FIRST_NODE_IP:$DIST_PORT' \
|
||||
--nnodes 2 \
|
||||
--node-rank \$SLURM_PROCID \
|
||||
--port '$SERVER_PORT' \
|
||||
--host 0.0.0.0 \
|
||||
--trust-remote-code \
|
||||
--max-running-requests 56 \
|
||||
--context-length 32768" &
|
||||
|
||||
# Wait for server with timeout
|
||||
TIMEOUT=3600 # 1h, but model loading should take ~30min
|
||||
START_TIME=$(date +%s)
|
||||
echo "Waiting for SGLang server (http://$FIRST_NODE_IP:$SERVER_PORT)..."
|
||||
|
||||
while true; do
|
||||
if curl -s -o /dev/null -w "%{http_code}" "http://$FIRST_NODE_IP:$SERVER_PORT/health" >/dev/null 2>&1; then
|
||||
echo "Server is ready at http://$FIRST_NODE_IP:$SERVER_PORT"
|
||||
break
|
||||
fi
|
||||
|
||||
CURRENT_TIME=$(date +%s)
|
||||
if [ $((CURRENT_TIME - START_TIME)) -gt $TIMEOUT ]; then
|
||||
echo "Error: Server failed to start within $TIMEOUT seconds"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Still waiting... ($(($CURRENT_TIME - $START_TIME)) seconds elapsed)"
|
||||
sleep 60
|
||||
done
|
||||
|
||||
# Register with router only if address was provided
|
||||
if [ -n "$ROUTER_ADDRESS" ]; then
|
||||
echo "Registering with router at $ROUTER_ADDRESS..."
|
||||
curl -X POST "http://$ROUTER_ADDRESS/add_worker?url=http://$FIRST_NODE_IP:$SERVER_PORT" || true
|
||||
sleep 10
|
||||
fi
|
||||
|
||||
echo "Checking available models..."
|
||||
curl "http://$FIRST_NODE_IP:$SERVER_PORT/v1/models"
|
||||
sleep 10
|
||||
|
||||
echo "Executing sanity check..."
|
||||
curl "http://$FIRST_NODE_IP:$SERVER_PORT/v1/completions" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "{
|
||||
\"model\": \"default\",
|
||||
\"prompt\": \"<|begin▁of▁sentence|><|User|>hi, how are you?<|Assistant|>\",
|
||||
\"max_tokens\": 2048,
|
||||
\"temperature\": 0.6
|
||||
}"
|
||||
|
||||
# Keep the job running with health checks
|
||||
while true; do
|
||||
if ! curl -s -o /dev/null "http://$FIRST_NODE_IP:$SERVER_PORT/health"; then
|
||||
echo "Error: Server health check failed"
|
||||
exit 1
|
||||
fi
|
||||
sleep 300
|
||||
done
|
@ -1,45 +0,0 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=r1-router
|
||||
#SBATCH --partition=hopper-cpu
|
||||
#SBATCH --qos=high
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --cpus-per-task=8
|
||||
#SBATCH --mem-per-cpu=1875m
|
||||
#SBATCH --output=./logs/%x_%j_%n.out
|
||||
#SBATCH --error=./logs/%x_%j_%n.err
|
||||
#SBATCH --time=30-00:00:00
|
||||
#SBATCH --requeue
|
||||
|
||||
set -exuo pipefail
|
||||
|
||||
# TODO: Adjust these variables to your cluster configuration
|
||||
CONDA_ENV="sglang124"
|
||||
ROUTER_PORT=39876
|
||||
|
||||
trap 'scontrol requeue ${SLURM_JOB_ID}; exit 15' SIGUSR1
|
||||
|
||||
while getopts "e:h" opt; do
|
||||
case $opt in
|
||||
e) CONDA_ENV="$OPTARG" ;;
|
||||
h|?) echo "Usage: sbatch $0 [-e CONDA_ENV]"; exit 1 ;;
|
||||
esac
|
||||
done
|
||||
|
||||
# TODO: Environment setup, adjust to your cluster configuration
|
||||
source ~/.bashrc
|
||||
source "$CONDA_PREFIX/etc/profile.d/conda.sh"
|
||||
conda activate "$CONDA_ENV" || { echo "Failed to activate conda env $CONDA_ENV"; exit 1; }
|
||||
|
||||
python -m sglang_router.launch_router \
|
||||
--port "$ROUTER_PORT" \
|
||||
--host 0.0.0.0 \
|
||||
--worker-startup-timeout-secs 300
|
||||
|
||||
# Keep the job running with health checks
|
||||
while true; do
|
||||
if ! curl -s -o /dev/null "http://localhost:$ROUTER_PORT/health"; then
|
||||
echo "Error: Router health check failed"
|
||||
exit 1
|
||||
fi
|
||||
sleep 300
|
||||
done
|
@ -1,94 +0,0 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=open-r1-sft
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --exclusive
|
||||
#SBATCH --gres=gpu:8
|
||||
#SBATCH --partition=hopper-prod # Adjust this for your cluster
|
||||
#SBATCH --output=./logs/%x-%j.out
|
||||
#SBATCH --err=./logs/%x-%j.err
|
||||
#SBATCH --requeue
|
||||
|
||||
# Specific configuration optimized for the Hugging Face Compute Cluster
|
||||
# Be ye warned this may not work on other clusters!
|
||||
module load cuda/12.4
|
||||
|
||||
|
||||
set -x -e
|
||||
|
||||
source ~/.bashrc
|
||||
source openr1/bin/activate
|
||||
echo "START TIME: $(date)"
|
||||
|
||||
MODEL=$1
|
||||
TASK=$2
|
||||
CONFIG_SUFFIX=$3
|
||||
ACCELERATOR=$4
|
||||
OPTIONAL_ARGS=$5
|
||||
|
||||
# Training setup
|
||||
NUM_NODES=$SLURM_NNODES
|
||||
GPUS_PER_NODE=8
|
||||
WORLD_SIZE=$(($NUM_NODES*$GPUS_PER_NODE))
|
||||
# Due to conflicts between Accelerate's DeepSpeed configs and Transformers' TrainingArguments, we need to parse the gradient accumulation steps from the config file to ensure they match
|
||||
CONFIG_FILE=recipes/$MODEL/$TASK/config_$CONFIG_SUFFIX.yaml
|
||||
GRAD_ACC_STEPS=$(grep 'gradient_accumulation_steps' $CONFIG_FILE | awk '{print $2}')
|
||||
USE_VLLM=$(grep 'use_vllm:\s*true' $CONFIG_FILE) # Match "use_vllm: true" (with optional whitespace)
|
||||
|
||||
if [ -n "$USE_VLLM" ]; then # Check if USE_VLLM is *not* empty (found)
|
||||
WORLD_SIZE=$(($WORLD_SIZE-1))
|
||||
fi
|
||||
|
||||
# Split the string into individual arguments
|
||||
IFS=' ' read -ra ARGS <<< "$OPTIONAL_ARGS"
|
||||
|
||||
# Loop through the arguments and find the one with "--gradient_accumulation_steps"
|
||||
for arg in "${ARGS[@]}"; do
|
||||
if [[ "$arg" == "--gradient_accumulation_steps="* ]]; then
|
||||
# Extract the value after the equals sign
|
||||
GRAD_ACC_STEPS="${arg#*=}"
|
||||
break # Exit the loop once we find the desired argument
|
||||
fi
|
||||
done
|
||||
|
||||
echo "Gradient accumulation steps: $GRAD_ACC_STEPS"
|
||||
# so processes know who to talk to
|
||||
MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
|
||||
MASTER_PORT=6000
|
||||
|
||||
export CMD=" \
|
||||
src/open_r1/$TASK.py --config $CONFIG_FILE $OPTIONAL_ARGS
|
||||
"
|
||||
|
||||
export LAUNCHER="HF_HUB_ENABLE_HF_TRANSFER=1 ACCELERATE_LOG_LEVEL=info TRANSFORMERS_VERBOSITY=info accelerate launch \
|
||||
--config_file recipes/accelerate_configs/$ACCELERATOR.yaml \
|
||||
--gradient_accumulation_steps $GRAD_ACC_STEPS \
|
||||
--num_machines $NUM_NODES \
|
||||
--num_processes $WORLD_SIZE \
|
||||
--main_process_ip $MASTER_ADDR \
|
||||
--main_process_port $MASTER_PORT \
|
||||
--machine_rank \$SLURM_PROCID \
|
||||
--rdzv_conf "rdzv_backend=c10d,rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT" \
|
||||
--max_restarts 1 \
|
||||
--role \$(hostname -s): \
|
||||
--tee 3 \
|
||||
"
|
||||
|
||||
# force crashing on nccl issues like hanging broadcast
|
||||
export NCCL_ASYNC_ERROR_HANDLING=1
|
||||
# export NCCL_DEBUG=INFO
|
||||
# export NCCL_DEBUG_SUBSYS=COLL
|
||||
# export NCCL_SOCKET_NTHREADS=1
|
||||
# export NCCL_NSOCKS_PERTHREAD=1
|
||||
# export CUDA_LAUNCH_BLOCKING=1
|
||||
|
||||
# srun error handling:
|
||||
# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks
|
||||
# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code
|
||||
SRUN_ARGS=" \
|
||||
--wait=60 \
|
||||
--kill-on-bad-exit=1 \
|
||||
"
|
||||
|
||||
clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --role \$SLURMD_NODENAME: $CMD" 2>&1
|
||||
|
||||
echo "END TIME: $(date)"
|
@ -1,13 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -1,85 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import trl
|
||||
|
||||
|
||||
# TODO: add the shared options with a mixin to reduce code duplication
|
||||
@dataclass
|
||||
class GRPOConfig(trl.GRPOConfig):
|
||||
"""
|
||||
args for callbacks, benchmarks etc
|
||||
"""
|
||||
|
||||
benchmarks: list[str] = field(
|
||||
default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
|
||||
)
|
||||
callbacks: list[str] = field(
|
||||
default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
|
||||
)
|
||||
chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."})
|
||||
system_prompt: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The optional system prompt to use."},
|
||||
)
|
||||
hub_model_revision: Optional[str] = field(
|
||||
default="main", metadata={"help": "The Hub model branch to push the model to."}
|
||||
)
|
||||
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
|
||||
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
|
||||
wandb_entity: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": ("The entity to store runs under.")},
|
||||
)
|
||||
wandb_project: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": ("The project to store runs under.")},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SFTConfig(trl.SFTConfig):
|
||||
"""
|
||||
args for callbacks, benchmarks etc
|
||||
"""
|
||||
|
||||
benchmarks: list[str] = field(
|
||||
default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
|
||||
)
|
||||
callbacks: list[str] = field(
|
||||
default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
|
||||
)
|
||||
chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."})
|
||||
system_prompt: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The optional system prompt to use for benchmarking."},
|
||||
)
|
||||
hub_model_revision: Optional[str] = field(
|
||||
default="main",
|
||||
metadata={"help": "The Hub model branch to push the model to."},
|
||||
)
|
||||
overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
|
||||
push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
|
||||
wandb_entity: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": ("The entity to store runs under.")},
|
||||
)
|
||||
wandb_project: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": ("The project to store runs under.")},
|
||||
)
|
@ -1,165 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Custom evaluation tasks for LightEval."""
|
||||
|
||||
import random
|
||||
|
||||
from lighteval.metrics.dynamic_metrics import (
|
||||
ExprExtractionConfig,
|
||||
IndicesExtractionConfig,
|
||||
LatexExtractionConfig,
|
||||
multilingual_extractive_match_metric,
|
||||
)
|
||||
from lighteval.tasks.lighteval_task import LightevalTaskConfig
|
||||
from lighteval.tasks.requests import Doc
|
||||
from lighteval.utils.language import Language
|
||||
|
||||
|
||||
latex_gold_metric = multilingual_extractive_match_metric(
|
||||
language=Language.ENGLISH,
|
||||
fallback_mode="first_match",
|
||||
precision=5,
|
||||
gold_extraction_target=(LatexExtractionConfig(),),
|
||||
# Match boxed first before trying other regexes
|
||||
pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig(boxed_match_priority=0)),
|
||||
aggregation_function=max,
|
||||
)
|
||||
|
||||
expr_gold_metric = multilingual_extractive_match_metric(
|
||||
language=Language.ENGLISH,
|
||||
fallback_mode="first_match",
|
||||
precision=5,
|
||||
gold_extraction_target=(ExprExtractionConfig(),),
|
||||
# Match boxed first before trying other regexes
|
||||
pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig(boxed_match_priority=0)),
|
||||
aggregation_function=max,
|
||||
)
|
||||
|
||||
gpqa_metric = multilingual_extractive_match_metric(
|
||||
language=Language.ENGLISH,
|
||||
gold_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")],
|
||||
pred_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")],
|
||||
precision=5,
|
||||
)
|
||||
|
||||
|
||||
def prompt_fn(line, task_name: str = None):
|
||||
"""Assumes the model is either prompted to emit \\boxed{answer} or does so automatically"""
|
||||
return Doc(
|
||||
task_name=task_name,
|
||||
query=line["problem"],
|
||||
choices=[line["solution"]],
|
||||
gold_index=0,
|
||||
)
|
||||
|
||||
|
||||
def aime_prompt_fn(line, task_name: str = None):
|
||||
return Doc(
|
||||
task_name=task_name,
|
||||
query=line["problem"],
|
||||
choices=[line["answer"]],
|
||||
gold_index=0,
|
||||
)
|
||||
|
||||
|
||||
def gpqa_prompt_fn(line, task_name: str = None):
|
||||
"""Prompt template adapted from simple-evals: https://github.com/openai/simple-evals/blob/83ed7640a7d9cd26849bcb3340125002ef14abbe/common.py#L14"""
|
||||
gold_index = random.randint(0, 3)
|
||||
choices = [line["Incorrect Answer 1"], line["Incorrect Answer 2"], line["Incorrect Answer 3"]]
|
||||
choices.insert(gold_index, line["Correct Answer"])
|
||||
query_template = "Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.\n\n{Question}\n\nA) {A}\nB) {B}\nC) {C}\nD) {D}"
|
||||
query = query_template.format(A=choices[0], B=choices[1], C=choices[2], D=choices[3], Question=line["Question"])
|
||||
|
||||
return Doc(
|
||||
task_name=task_name,
|
||||
query=query,
|
||||
choices=["A", "B", "C", "D"],
|
||||
gold_index=gold_index,
|
||||
instruction=query,
|
||||
)
|
||||
|
||||
|
||||
# Define tasks
|
||||
aime24 = LightevalTaskConfig(
|
||||
name="aime24",
|
||||
suite=["custom"],
|
||||
prompt_function=aime_prompt_fn,
|
||||
hf_repo="HuggingFaceH4/aime_2024",
|
||||
hf_subset="default",
|
||||
hf_avail_splits=["train"],
|
||||
evaluation_splits=["train"],
|
||||
few_shots_split=None,
|
||||
few_shots_select=None,
|
||||
generation_size=32768,
|
||||
metric=[expr_gold_metric],
|
||||
version=1,
|
||||
)
|
||||
aime25 = LightevalTaskConfig(
|
||||
name="aime25",
|
||||
suite=["custom"],
|
||||
prompt_function=aime_prompt_fn,
|
||||
hf_repo="yentinglin/aime_2025",
|
||||
hf_subset="default",
|
||||
hf_avail_splits=["train"],
|
||||
evaluation_splits=["train"],
|
||||
few_shots_split=None,
|
||||
few_shots_select=None,
|
||||
generation_size=32768,
|
||||
metric=[expr_gold_metric],
|
||||
version=1,
|
||||
)
|
||||
math_500 = LightevalTaskConfig(
|
||||
name="math_500",
|
||||
suite=["custom"],
|
||||
prompt_function=prompt_fn,
|
||||
hf_repo="HuggingFaceH4/MATH-500",
|
||||
hf_subset="default",
|
||||
hf_avail_splits=["test"],
|
||||
evaluation_splits=["test"],
|
||||
few_shots_split=None,
|
||||
few_shots_select=None,
|
||||
generation_size=32768,
|
||||
metric=[latex_gold_metric],
|
||||
version=1,
|
||||
)
|
||||
gpqa_diamond = LightevalTaskConfig(
|
||||
name="gpqa:diamond",
|
||||
suite=["custom"],
|
||||
prompt_function=gpqa_prompt_fn,
|
||||
hf_repo="Idavidrein/gpqa",
|
||||
hf_subset="gpqa_diamond",
|
||||
hf_avail_splits=["train"],
|
||||
evaluation_splits=["train"],
|
||||
few_shots_split=None,
|
||||
few_shots_select=None,
|
||||
generation_size=32768, # needed for reasoning models like R1
|
||||
metric=[gpqa_metric],
|
||||
stop_sequence=[], # no stop sequence, will use eos token
|
||||
trust_dataset=True,
|
||||
version=1,
|
||||
)
|
||||
|
||||
|
||||
# Add tasks to the table
|
||||
TASKS_TABLE = []
|
||||
TASKS_TABLE.append(aime24)
|
||||
TASKS_TABLE.append(aime25)
|
||||
TASKS_TABLE.append(math_500)
|
||||
TASKS_TABLE.append(gpqa_diamond)
|
||||
|
||||
# MODULE LOGIC
|
||||
if __name__ == "__main__":
|
||||
print([t["name"] for t in TASKS_TABLE])
|
||||
print(len(TASKS_TABLE))
|
@ -1,208 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from distilabel.llms import OpenAILLM
|
||||
from distilabel.pipeline import Pipeline
|
||||
from distilabel.steps import StepResources
|
||||
from distilabel.steps.tasks import TextGeneration
|
||||
|
||||
|
||||
def build_distilabel_pipeline(
|
||||
model: str,
|
||||
base_url: str = "http://localhost:8000/v1",
|
||||
prompt_column: Optional[str] = None,
|
||||
prompt_template: str = "{{ instruction }}",
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
max_new_tokens: int = 8192,
|
||||
num_generations: int = 1,
|
||||
input_batch_size: int = 64,
|
||||
client_replicas: int = 1,
|
||||
timeout: int = 900,
|
||||
retries: int = 0,
|
||||
) -> Pipeline:
|
||||
generation_kwargs = {"max_new_tokens": max_new_tokens}
|
||||
|
||||
if temperature is not None:
|
||||
generation_kwargs["temperature"] = temperature
|
||||
|
||||
if top_p is not None:
|
||||
generation_kwargs["top_p"] = top_p
|
||||
|
||||
with Pipeline().ray() as pipeline:
|
||||
TextGeneration(
|
||||
llm=OpenAILLM(
|
||||
base_url=base_url,
|
||||
api_key="something",
|
||||
model=model,
|
||||
timeout=timeout,
|
||||
max_retries=retries,
|
||||
generation_kwargs=generation_kwargs,
|
||||
),
|
||||
template=prompt_template,
|
||||
input_mappings={"instruction": prompt_column} if prompt_column is not None else {},
|
||||
input_batch_size=input_batch_size,
|
||||
num_generations=num_generations,
|
||||
group_generations=True,
|
||||
resources=StepResources(replicas=client_replicas),
|
||||
)
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1")
|
||||
parser.add_argument(
|
||||
"--hf-dataset",
|
||||
type=str,
|
||||
required=True,
|
||||
help="HuggingFace dataset to load",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-dataset-config",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Dataset config to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-dataset-split",
|
||||
type=str,
|
||||
default="train",
|
||||
help="Dataset split to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt-column",
|
||||
type=str,
|
||||
default="prompt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt-template",
|
||||
type=str,
|
||||
default="{{ instruction }}",
|
||||
help="Template string for formatting prompts.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model name to use for generation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vllm-server-url",
|
||||
type=str,
|
||||
default="http://localhost:8000/v1",
|
||||
help="URL of the vLLM server",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
help="Temperature for generation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-p",
|
||||
type=float,
|
||||
help="Top-p value for generation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-new-tokens",
|
||||
type=int,
|
||||
default=8192,
|
||||
help="Maximum number of new tokens to generate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-generations",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of generations per problem",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-batch-size",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Batch size for input processing",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--client-replicas",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of client replicas for parallel processing",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=int,
|
||||
default=600,
|
||||
help="Request timeout in seconds (default: 600)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--retries",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of retries for failed requests (default: 0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf-output-dataset",
|
||||
type=str,
|
||||
required=False,
|
||||
help="HuggingFace repo to push results to",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--private",
|
||||
action="store_true",
|
||||
help="Whether to make the output dataset private when pushing to HF Hub",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("\nRunning with arguments:")
|
||||
for arg, value in vars(args).items():
|
||||
print(f" {arg}: {value}")
|
||||
print()
|
||||
|
||||
print(f"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset...")
|
||||
dataset = load_dataset(args.hf_dataset, args.hf_dataset_config, split=args.hf_dataset_split)
|
||||
print("Dataset loaded!")
|
||||
|
||||
pipeline = build_distilabel_pipeline(
|
||||
model=args.model,
|
||||
base_url=args.vllm_server_url,
|
||||
prompt_template=args.prompt_template,
|
||||
prompt_column=args.prompt_column,
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
num_generations=args.num_generations,
|
||||
input_batch_size=args.input_batch_size,
|
||||
client_replicas=args.client_replicas,
|
||||
timeout=args.timeout,
|
||||
retries=args.retries,
|
||||
)
|
||||
|
||||
print("Running generation pipeline...")
|
||||
distiset = pipeline.run(
|
||||
dataset=dataset,
|
||||
dataset_batch_size=args.input_batch_size * 1000,
|
||||
use_cache=False,
|
||||
)
|
||||
print("Generation pipeline finished!")
|
||||
|
||||
if args.hf_output_dataset:
|
||||
print(f"Pushing resulting dataset to '{args.hf_output_dataset}'...")
|
||||
distiset.push_to_hub(args.hf_output_dataset, private=args.private)
|
||||
print("Dataset pushed!")
|
@ -1,267 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
import transformers
|
||||
from datasets import load_dataset
|
||||
from transformers import set_seed
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
from open_r1.configs import GRPOConfig
|
||||
from open_r1.rewards import (
|
||||
accuracy_reward,
|
||||
code_reward,
|
||||
format_reward,
|
||||
get_cosine_scaled_reward,
|
||||
get_repetition_penalty_reward,
|
||||
len_reward,
|
||||
reasoning_steps_reward,
|
||||
)
|
||||
from open_r1.utils import get_tokenizer
|
||||
from open_r1.utils.callbacks import get_callbacks
|
||||
from open_r1.utils.wandb_logging import init_wandb_training
|
||||
from trl import GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GRPOScriptArguments(ScriptArguments):
|
||||
"""
|
||||
Script arguments for the GRPO training script.
|
||||
|
||||
Args:
|
||||
reward_funcs (`list[str]`):
|
||||
List of reward functions. Possible values: 'accuracy', 'format', 'format_deepseek', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'.
|
||||
cosine_min_value_wrong (`float`):
|
||||
Minimum reward for cosine scaling for wrong answers.
|
||||
cosine_max_value_wrong (`float`):
|
||||
Maximum reward for cosine scaling for wrong answers.
|
||||
cosine_min_value_correct (`float`):
|
||||
Minimum reward for cosine scaling for correct answers.
|
||||
cosine_max_value_correct (`float`):
|
||||
Maximum reward for cosine scaling for correct answers.
|
||||
cosine_max_len (`int`):
|
||||
Maximum length for cosine scaling.
|
||||
"""
|
||||
|
||||
reward_funcs: list[str] = field(
|
||||
default_factory=lambda: ["accuracy", "format"],
|
||||
metadata={
|
||||
"help": "List of reward functions. Possible values: 'accuracy', 'format', 'format_deepseek', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'"
|
||||
},
|
||||
)
|
||||
cosine_min_value_wrong: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "Minimum reward for wrong answers"},
|
||||
)
|
||||
cosine_max_value_wrong: float = field(
|
||||
default=-0.5,
|
||||
metadata={"help": "Maximum reward for wrong answers"},
|
||||
)
|
||||
cosine_min_value_correct: float = field(
|
||||
default=0.5,
|
||||
metadata={"help": "Minimum reward for correct answers"},
|
||||
)
|
||||
cosine_max_value_correct: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Maximum reward for correct answers"},
|
||||
)
|
||||
cosine_max_len: int = field(
|
||||
default=1000,
|
||||
metadata={"help": "Maximum length for scaling"},
|
||||
)
|
||||
repetition_n_grams: int = field(
|
||||
default=3,
|
||||
metadata={"help": "Number of n-grams for repetition penalty reward"},
|
||||
)
|
||||
repetition_max_penalty: float = field(
|
||||
default=-1.0,
|
||||
metadata={"help": "Maximum (negative) penalty for for repetition penalty reward"},
|
||||
)
|
||||
|
||||
|
||||
def main(script_args, training_args, model_args):
|
||||
# Set seed for reproducibility
|
||||
set_seed(training_args.seed)
|
||||
|
||||
###############
|
||||
# Setup logging
|
||||
###############
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level)
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
# Log on each process a small summary
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
||||
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
logger.info(f"Model parameters {model_args}")
|
||||
logger.info(f"Script parameters {script_args}")
|
||||
logger.info(f"Training parameters {training_args}")
|
||||
|
||||
# Check for last checkpoint
|
||||
last_checkpoint = None
|
||||
if os.path.isdir(training_args.output_dir):
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
||||
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
|
||||
|
||||
if "wandb" in training_args.report_to:
|
||||
init_wandb_training(training_args)
|
||||
|
||||
# Load the dataset
|
||||
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
||||
|
||||
################
|
||||
# Load tokenizer
|
||||
################
|
||||
tokenizer = get_tokenizer(model_args, training_args)
|
||||
|
||||
# Get reward functions
|
||||
REWARD_FUNCS_REGISTRY = {
|
||||
"accuracy": accuracy_reward,
|
||||
"format": format_reward,
|
||||
"reasoning_steps": reasoning_steps_reward,
|
||||
"cosine": get_cosine_scaled_reward(
|
||||
min_value_wrong=script_args.cosine_min_value_wrong,
|
||||
max_value_wrong=script_args.cosine_max_value_wrong,
|
||||
min_value_correct=script_args.cosine_min_value_correct,
|
||||
max_value_correct=script_args.cosine_max_value_correct,
|
||||
max_len=script_args.cosine_max_len,
|
||||
),
|
||||
"repetition_penalty": get_repetition_penalty_reward(
|
||||
ngram_size=script_args.repetition_n_grams,
|
||||
max_penalty=script_args.repetition_max_penalty,
|
||||
),
|
||||
"length": len_reward,
|
||||
"code": code_reward,
|
||||
}
|
||||
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
|
||||
|
||||
# Format into conversation
|
||||
def make_conversation(example):
|
||||
prompt = []
|
||||
|
||||
if training_args.system_prompt is not None:
|
||||
prompt.append({"role": "system", "content": training_args.system_prompt})
|
||||
|
||||
prompt.append({"role": "user", "content": example["problem"]})
|
||||
return {"prompt": prompt}
|
||||
|
||||
dataset = dataset.map(make_conversation)
|
||||
|
||||
for split in dataset:
|
||||
if "messages" in dataset[split].column_names:
|
||||
dataset[split] = dataset[split].remove_columns("messages")
|
||||
|
||||
logger.info("*** Initializing model kwargs ***")
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
)
|
||||
training_args.model_init_kwargs = model_kwargs
|
||||
|
||||
#############################
|
||||
# Initialize the GRPO trainer
|
||||
#############################
|
||||
trainer = GRPOTrainer(
|
||||
model=model_args.model_name_or_path,
|
||||
reward_funcs=reward_funcs,
|
||||
args=training_args,
|
||||
train_dataset=dataset[script_args.dataset_train_split],
|
||||
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
||||
peft_config=get_peft_config(model_args),
|
||||
callbacks=get_callbacks(training_args, model_args),
|
||||
processing_class=tokenizer,
|
||||
)
|
||||
|
||||
###############
|
||||
# Training loop
|
||||
###############
|
||||
logger.info("*** Train ***")
|
||||
checkpoint = None
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
checkpoint = training_args.resume_from_checkpoint
|
||||
elif last_checkpoint is not None:
|
||||
checkpoint = last_checkpoint
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
metrics = train_result.metrics
|
||||
metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
##################################
|
||||
# Save model and create model card
|
||||
##################################
|
||||
logger.info("*** Save model ***")
|
||||
trainer.save_model(training_args.output_dir)
|
||||
logger.info(f"Model saved to {training_args.output_dir}")
|
||||
|
||||
# Save everything else on main process
|
||||
kwargs = {
|
||||
"dataset_name": script_args.dataset_name,
|
||||
"tags": ["open-r1"],
|
||||
}
|
||||
if trainer.accelerator.is_main_process:
|
||||
trainer.create_model_card(**kwargs)
|
||||
# Restore k,v cache for fast inference
|
||||
trainer.model.config.use_cache = True
|
||||
trainer.model.config.save_pretrained(training_args.output_dir)
|
||||
|
||||
##########
|
||||
# Evaluate
|
||||
##########
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
metrics = trainer.evaluate()
|
||||
metrics["eval_samples"] = len(dataset[script_args.dataset_test_split])
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
#############
|
||||
# push to hub
|
||||
#############
|
||||
if training_args.push_to_hub:
|
||||
logger.info("Pushing to hub...")
|
||||
trainer.push_to_hub(**kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
main(script_args, training_args, model_args)
|
@ -1,353 +0,0 @@
|
||||
"""Reward functions for GRPO training."""
|
||||
# This project includes modifications to the original codebase:
|
||||
# All email addresses and personal identifiers have been removed.
|
||||
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
from typing import Dict
|
||||
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
|
||||
from .utils import is_e2b_available
|
||||
|
||||
|
||||
if is_e2b_available():
|
||||
from dotenv import load_dotenv
|
||||
from e2b_code_interpreter import Sandbox
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def accuracy_reward(completions, solution, **kwargs):
|
||||
"""Reward function that checks if the completion is the same as the ground truth."""
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
rewards = []
|
||||
for content, sol in zip(contents, solution):
|
||||
gold_parsed = parse(
|
||||
sol,
|
||||
extraction_mode="first_match",
|
||||
extraction_config=[LatexExtractionConfig()],
|
||||
)
|
||||
if len(gold_parsed) != 0:
|
||||
# We require the answer to be provided in correct latex (no malformed operators)
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
equations=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
# Ensures that boxed is tried first
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
# Reward 1 if the content is the same as the ground truth, 0 otherwise
|
||||
reward = float(verify(answer_parsed, gold_parsed))
|
||||
else:
|
||||
# If the gold solution is not parseable, we reward 1 to skip this example
|
||||
reward = 1.0
|
||||
print("Failed to parse gold solution: ", sol)
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
|
||||
def format_reward(completions, **kwargs):
|
||||
"""Reward function that checks if the reasoning process is enclosed within <think> and </think> tags, while the final answer is enclosed within <answer> and </answer> tags."""
|
||||
pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"
|
||||
completion_contents = [completion[0]["content"] for completion in completions]
|
||||
matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
|
||||
return [1.0 if match else 0.0 for match in matches]
|
||||
|
||||
|
||||
def reasoning_steps_reward(completions, **kwargs):
|
||||
r"""Reward function that checks for clear step-by-step reasoning.
|
||||
Regex pattern:
|
||||
Step \d+: - matches "Step 1:", "Step 2:", etc.
|
||||
^\d+\. - matches numbered lists like "1.", "2.", etc. at start of line
|
||||
\n- - matches bullet points with hyphens
|
||||
\n\* - matches bullet points with asterisks
|
||||
First,|Second,|Next,|Finally, - matches transition words
|
||||
"""
|
||||
pattern = r"(Step \d+:|^\d+\.|\n-|\n\*|First,|Second,|Next,|Finally,)"
|
||||
completion_contents = [completion[0]["content"] for completion in completions]
|
||||
matches = [len(re.findall(pattern, content)) for content in completion_contents]
|
||||
|
||||
# Magic nubmer 3 to encourage 3 steps and more, otherwise partial reward
|
||||
return [min(1.0, count / 3) for count in matches]
|
||||
|
||||
|
||||
def len_reward(completions: list[Dict[str, str]], solutions: list[str], **kwargs) -> float:
|
||||
"""Compute length-based rewards to discourage overthinking and promote token efficiency.
|
||||
|
||||
Args:
|
||||
completions: List of model completions
|
||||
solutions: List of ground truth solutions
|
||||
|
||||
Returns:
|
||||
List of rewards where:
|
||||
- For correct answers: reward = 0.5 - (len - min_len)/(max_len - min_len)
|
||||
- For incorrect answers: reward = min(0, 0.5 - (len - min_len)/(max_len - min_len))
|
||||
"""
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
|
||||
# First check correctness of answers
|
||||
correctness = []
|
||||
for content, sol in zip(contents, solutions):
|
||||
gold_parsed = parse(
|
||||
sol,
|
||||
extraction_mode="first_match",
|
||||
extraction_config=[LatexExtractionConfig()],
|
||||
)
|
||||
if len(gold_parsed) == 0:
|
||||
# Skip unparseable examples
|
||||
correctness.append(True) # Treat as correct to avoid penalizing
|
||||
print("Failed to parse gold solution: ", sol)
|
||||
continue
|
||||
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
equations=True,
|
||||
boxed=True,
|
||||
units=True,
|
||||
),
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
correctness.append(verify(answer_parsed, gold_parsed))
|
||||
|
||||
# Calculate lengths
|
||||
lengths = [len(content) for content in contents]
|
||||
min_len = min(lengths)
|
||||
max_len = max(lengths)
|
||||
|
||||
# If all responses have the same length, return zero rewards
|
||||
if max_len == min_len:
|
||||
return [0.0] * len(completions)
|
||||
|
||||
rewards = []
|
||||
for length, is_correct in zip(lengths, correctness):
|
||||
lambda_val = 0.5 - (length - min_len) / (max_len - min_len)
|
||||
|
||||
if is_correct:
|
||||
reward = lambda_val
|
||||
else:
|
||||
reward = min(0, lambda_val)
|
||||
|
||||
rewards.append(float(reward))
|
||||
|
||||
return rewards
|
||||
|
||||
|
||||
def get_cosine_scaled_reward(
|
||||
min_value_wrong: float = -1.0,
|
||||
max_value_wrong: float = -0.5,
|
||||
min_value_correct: float = 0.5,
|
||||
max_value_correct: float = 1.0,
|
||||
max_len: int = 1000,
|
||||
):
|
||||
def cosine_scaled_reward(completions, solution, **kwargs):
|
||||
"""Reward function that scales based on completion length using a cosine schedule.
|
||||
|
||||
Shorter correct solutions are rewarded more than longer ones.
|
||||
Longer incorrect solutions are penalized less than shorter ones.
|
||||
|
||||
Args:
|
||||
completions: List of model completions
|
||||
solution: List of ground truth solutions
|
||||
|
||||
This function is parameterized by the following arguments:
|
||||
min_value_wrong: Minimum reward for wrong answers
|
||||
max_value_wrong: Maximum reward for wrong answers
|
||||
min_value_correct: Minimum reward for correct answers
|
||||
max_value_correct: Maximum reward for correct answers
|
||||
max_len: Maximum length for scaling
|
||||
"""
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
rewards = []
|
||||
|
||||
for content, sol in zip(contents, solution):
|
||||
gold_parsed = parse(sol, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
|
||||
if len(gold_parsed) == 0:
|
||||
rewards.append(1.0) # Skip unparseable examples
|
||||
print("Failed to parse gold solution: ", sol)
|
||||
continue
|
||||
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
equations=True,
|
||||
boxed=True,
|
||||
units=True,
|
||||
),
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
|
||||
is_correct = verify(answer_parsed, gold_parsed)
|
||||
gen_len = len(content)
|
||||
|
||||
# Apply cosine scaling based on length
|
||||
progress = gen_len / max_len
|
||||
cosine = math.cos(progress * math.pi)
|
||||
|
||||
if is_correct:
|
||||
min_value = min_value_correct
|
||||
max_value = max_value_correct
|
||||
else:
|
||||
# Swap min/max for incorrect answers
|
||||
min_value = max_value_wrong
|
||||
max_value = min_value_wrong
|
||||
|
||||
reward = min_value + 0.5 * (max_value - min_value) * (1.0 + cosine)
|
||||
rewards.append(float(reward))
|
||||
|
||||
return rewards
|
||||
|
||||
return cosine_scaled_reward
|
||||
|
||||
|
||||
def get_repetition_penalty_reward(ngram_size: int, max_penalty: float):
|
||||
"""
|
||||
Args:
|
||||
ngram_size: size of the n-grams
|
||||
max_penalty: Maximum (negative) penalty for wrong answers
|
||||
"""
|
||||
if max_penalty > 0:
|
||||
raise ValueError(f"max_penalty {max_penalty} should not be positive")
|
||||
|
||||
def zipngram(text: str, ngram_size: int):
|
||||
words = text.lower().split()
|
||||
return zip(*[words[i:] for i in range(ngram_size)])
|
||||
|
||||
def repetition_penalty_reward(completions, **kwargs) -> float:
|
||||
"""
|
||||
Args:
|
||||
completions: List of model completions
|
||||
"""
|
||||
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
rewards = []
|
||||
for completion in contents:
|
||||
if completion == "":
|
||||
rewards.append(0.0)
|
||||
continue
|
||||
if len(completion.split()) < ngram_size:
|
||||
rewards.append(0.0)
|
||||
continue
|
||||
|
||||
ngrams = set()
|
||||
total = 0
|
||||
for ng in zipngram(completion, ngram_size):
|
||||
ngrams.add(ng)
|
||||
total += 1
|
||||
|
||||
scaling = 1 - len(ngrams) / total
|
||||
reward = scaling * max_penalty
|
||||
rewards.append(reward)
|
||||
return rewards
|
||||
|
||||
return repetition_penalty_reward
|
||||
|
||||
|
||||
def extract_code(completion: str) -> str:
|
||||
pattern = re.compile(r"```python\n(.*?)```", re.DOTALL)
|
||||
matches = pattern.findall(completion)
|
||||
extracted_answer = matches[-1] if len(matches) >= 1 else ""
|
||||
return extracted_answer
|
||||
|
||||
|
||||
def code_reward(completions, **kwargs) -> list[float]:
|
||||
"""Reward function that evaluates code snippets using the E2B code interpreter.
|
||||
|
||||
Assumes the dataset contains a `verification_info` column with test cases.
|
||||
"""
|
||||
if not is_e2b_available():
|
||||
raise ImportError(
|
||||
"E2B is not available and required for this reward function. Please install E2B with "
|
||||
"`pip install e2b-code-interpreter` and add an API key to a `.env` file."
|
||||
)
|
||||
|
||||
rewards = []
|
||||
try:
|
||||
"""Returns a reward function that evaluates code snippets in a sandbox."""
|
||||
evaluation_script_template = """
|
||||
import subprocess
|
||||
import json
|
||||
|
||||
def evaluate_code(code, test_cases):
|
||||
passed = 0
|
||||
total = len(test_cases)
|
||||
exec_timeout = 5
|
||||
|
||||
for case in test_cases:
|
||||
process = subprocess.run(
|
||||
["python3", "-c", code],
|
||||
input=case["input"],
|
||||
text=True,
|
||||
capture_output=True,
|
||||
timeout=exec_timeout
|
||||
)
|
||||
|
||||
if process.returncode != 0: # Error in execution
|
||||
continue
|
||||
|
||||
output = process.stdout.strip()
|
||||
if output.strip() == case["output"].strip():
|
||||
passed += 1
|
||||
|
||||
success_rate = (passed / total)
|
||||
return success_rate
|
||||
|
||||
code_snippet = {code}
|
||||
test_cases = json.loads({test_cases})
|
||||
|
||||
evaluate_code(code_snippet, test_cases)
|
||||
"""
|
||||
code_snippets = [extract_code(completion[-1]["content"]) for completion in completions]
|
||||
verification_info = kwargs["verification_info"]
|
||||
scripts = [
|
||||
evaluation_script_template.format(
|
||||
code=json.dumps(code), test_cases=json.dumps(json.dumps(info["test_cases"]))
|
||||
)
|
||||
for code, info in zip(code_snippets, verification_info)
|
||||
]
|
||||
with Sandbox(timeout=30, request_timeout=3) as sbx:
|
||||
for script in scripts:
|
||||
execution = sbx.run_code(script, language=verification_info["language"])
|
||||
try:
|
||||
output = float(execution.text)
|
||||
except (TypeError, ValueError):
|
||||
output = 0.0
|
||||
rewards.append(output)
|
||||
except Exception as e:
|
||||
print(f"Error from E2B executor: {e}")
|
||||
rewards = [0.0] * len(completions)
|
||||
return rewards
|
@ -1,198 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Supervised fine-tuning script for decoder language models.
|
||||
|
||||
Usage:
|
||||
|
||||
# One 1 node of 8 x H100s
|
||||
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
|
||||
--model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
|
||||
--dataset_name HuggingFaceH4/Bespoke-Stratos-17k \
|
||||
--learning_rate 2.0e-5 \
|
||||
--num_train_epochs 1 \
|
||||
--packing \
|
||||
--max_seq_length 4096 \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--gradient_checkpointing \
|
||||
--bf16 \
|
||||
--logging_steps 5 \
|
||||
--eval_strategy steps \
|
||||
--eval_steps 100 \
|
||||
--output_dir data/Qwen2.5-1.5B-Open-R1-Distill
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
import transformers
|
||||
from datasets import load_dataset
|
||||
from transformers import set_seed
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
from open_r1.configs import SFTConfig
|
||||
from open_r1.utils import get_tokenizer
|
||||
from open_r1.utils.callbacks import get_callbacks
|
||||
from open_r1.utils.wandb_logging import init_wandb_training
|
||||
from trl import (
|
||||
ModelConfig,
|
||||
ScriptArguments,
|
||||
SFTTrainer,
|
||||
TrlParser,
|
||||
get_kbit_device_map,
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main(script_args, training_args, model_args):
|
||||
# Set seed for reproducibility
|
||||
set_seed(training_args.seed)
|
||||
|
||||
###############
|
||||
# Setup logging
|
||||
###############
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
handlers=[logging.StreamHandler(sys.stdout)],
|
||||
)
|
||||
log_level = training_args.get_process_log_level()
|
||||
logger.setLevel(log_level)
|
||||
datasets.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.set_verbosity(log_level)
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
logger.info(f"Model parameters {model_args}")
|
||||
logger.info(f"Script parameters {script_args}")
|
||||
logger.info(f"Training parameters {training_args}")
|
||||
|
||||
# Check for last checkpoint
|
||||
last_checkpoint = None
|
||||
if os.path.isdir(training_args.output_dir):
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
||||
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
|
||||
|
||||
if "wandb" in training_args.report_to:
|
||||
init_wandb_training(training_args)
|
||||
|
||||
################
|
||||
# Load datasets
|
||||
################
|
||||
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
||||
|
||||
################
|
||||
# Load tokenizer
|
||||
################
|
||||
tokenizer = get_tokenizer(model_args, training_args)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
###################
|
||||
# Model init kwargs
|
||||
###################
|
||||
logger.info("*** Initializing model kwargs ***")
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
training_args.model_init_kwargs = model_kwargs
|
||||
|
||||
############################
|
||||
# Initialize the SFT Trainer
|
||||
############################
|
||||
trainer = SFTTrainer(
|
||||
model=model_args.model_name_or_path,
|
||||
args=training_args,
|
||||
train_dataset=dataset[script_args.dataset_train_split],
|
||||
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
||||
processing_class=tokenizer,
|
||||
peft_config=get_peft_config(model_args),
|
||||
callbacks=get_callbacks(training_args, model_args),
|
||||
)
|
||||
|
||||
###############
|
||||
# Training loop
|
||||
###############
|
||||
logger.info("*** Train ***")
|
||||
checkpoint = None
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
checkpoint = training_args.resume_from_checkpoint
|
||||
elif last_checkpoint is not None:
|
||||
checkpoint = last_checkpoint
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
metrics = train_result.metrics
|
||||
metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
##################################
|
||||
# Save model and create model card
|
||||
##################################
|
||||
logger.info("*** Save model ***")
|
||||
trainer.save_model(training_args.output_dir)
|
||||
logger.info(f"Model saved to {training_args.output_dir}")
|
||||
|
||||
# Save everything else on main process
|
||||
kwargs = {
|
||||
"dataset_name": script_args.dataset_name,
|
||||
"tags": ["open-r1"],
|
||||
}
|
||||
if trainer.accelerator.is_main_process:
|
||||
trainer.create_model_card(**kwargs)
|
||||
# Restore k,v cache for fast inference
|
||||
trainer.model.config.use_cache = True
|
||||
trainer.model.config.save_pretrained(training_args.output_dir)
|
||||
|
||||
##########
|
||||
# Evaluate
|
||||
##########
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
metrics = trainer.evaluate()
|
||||
metrics["eval_samples"] = len(dataset[script_args.dataset_test_split])
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
#############
|
||||
# push to hub
|
||||
#############
|
||||
if training_args.push_to_hub:
|
||||
logger.info("Pushing to hub...")
|
||||
trainer.push_to_hub(**kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
main(script_args, training_args, model_args)
|
@ -1,5 +0,0 @@
|
||||
from .import_utils import is_e2b_available
|
||||
from .model_utils import get_tokenizer
|
||||
|
||||
|
||||
__all__ = ["get_tokenizer", "is_e2b_available"]
|
@ -1,86 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import subprocess
|
||||
from typing import List
|
||||
|
||||
from transformers import TrainerCallback
|
||||
from transformers.trainer_callback import TrainerControl, TrainerState
|
||||
from transformers.training_args import TrainingArguments
|
||||
|
||||
from .evaluation import run_benchmark_jobs
|
||||
from .hub import push_to_hub_revision
|
||||
|
||||
|
||||
def is_slurm_available() -> bool:
|
||||
# returns true if a slurm queueing system is available
|
||||
try:
|
||||
subprocess.run(["sinfo"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
|
||||
|
||||
class DummyConfig:
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
class PushToHubRevisionCallback(TrainerCallback):
|
||||
def __init__(self, model_config) -> None:
|
||||
self.model_config = model_config
|
||||
|
||||
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
||||
if state.is_world_process_zero:
|
||||
global_step = state.global_step
|
||||
|
||||
# WARNING: if you use dataclasses.replace(args, ...) the accelerator dist state will be broken, so I do this workaround
|
||||
# Also if you instantiate a new SFTConfig, the accelerator dist state will be broken
|
||||
dummy_config = DummyConfig(
|
||||
hub_model_id=args.hub_model_id,
|
||||
hub_model_revision=f"{args.hub_model_revision}-step-{global_step:09d}",
|
||||
output_dir=f"{args.output_dir}/checkpoint-{global_step}",
|
||||
system_prompt=args.system_prompt,
|
||||
)
|
||||
|
||||
future = push_to_hub_revision(
|
||||
dummy_config, extra_ignore_patterns=["*.pt"]
|
||||
) # don't push the optimizer states
|
||||
|
||||
if is_slurm_available():
|
||||
dummy_config.benchmarks = args.benchmarks
|
||||
|
||||
def run_benchmark_callback(_):
|
||||
print(f"Checkpoint {global_step} pushed to hub.")
|
||||
run_benchmark_jobs(dummy_config, self.model_config)
|
||||
|
||||
future.add_done_callback(run_benchmark_callback)
|
||||
|
||||
|
||||
CALLBACKS = {
|
||||
"push_to_hub_revision": PushToHubRevisionCallback,
|
||||
}
|
||||
|
||||
|
||||
def get_callbacks(train_config, model_config) -> List[TrainerCallback]:
|
||||
callbacks = []
|
||||
for callback_name in train_config.callbacks:
|
||||
if callback_name not in CALLBACKS:
|
||||
raise ValueError(f"Callback {callback_name} not found in CALLBACKS.")
|
||||
callbacks.append(CALLBACKS[callback_name](model_config))
|
||||
|
||||
return callbacks
|
@ -1,106 +0,0 @@
|
||||
import subprocess
|
||||
from typing import TYPE_CHECKING, Dict, Union
|
||||
|
||||
from .hub import get_gpu_count_for_vllm, get_param_count_from_repo_id
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trl import GRPOConfig, SFTConfig, ModelConfig
|
||||
|
||||
import os
|
||||
|
||||
|
||||
# We need a special environment setup to launch vLLM from within Slurm training jobs.
|
||||
# - Reference code: https://github.com/huggingface/brrr/blob/c55ba3505686d690de24c7ace6487a5c1426c0fd/brrr/lighteval/one_job_runner.py#L105
|
||||
# - Slack thread: https://huggingface.slack.com/archives/C043JTYE1MJ/p1726566494958269
|
||||
user_home_directory = os.path.expanduser("~")
|
||||
VLLM_SLURM_PREFIX = [
|
||||
"env",
|
||||
"-i",
|
||||
"bash",
|
||||
"-c",
|
||||
f"for f in /etc/profile.d/*.sh; do source $f; done; export HOME={user_home_directory}; sbatch ",
|
||||
]
|
||||
|
||||
|
||||
def register_lighteval_task(
|
||||
configs: Dict[str, str], eval_suite: str, task_name: str, task_list: str, num_fewshot: int = 0
|
||||
):
|
||||
"""Registers a LightEval task configuration.
|
||||
|
||||
- Core tasks can be added from this table: https://github.com/huggingface/lighteval/blob/main/src/lighteval/tasks/tasks_table.jsonl
|
||||
- Custom tasks that require their own metrics / scripts, should be stored in scripts/evaluation/extended_lighteval_tasks
|
||||
|
||||
Args:
|
||||
configs (Dict[str, str]): The dictionary to store the task configuration.
|
||||
eval_suite (str, optional): The evaluation suite.
|
||||
task_name (str): The name of the task.
|
||||
task_list (str): The comma-separated list of tasks in the format "extended|{task_name}|{num_fewshot}|0" or "lighteval|{task_name}|{num_fewshot}|0".
|
||||
num_fewshot (int, optional): The number of few-shot examples. Defaults to 0.
|
||||
is_custom_task (bool, optional): Whether the task is a custom task. Defaults to False.
|
||||
"""
|
||||
# Format task list in lighteval format
|
||||
task_list = ",".join(f"{eval_suite}|{task}|{num_fewshot}|0" for task in task_list.split(","))
|
||||
configs[task_name] = task_list
|
||||
|
||||
|
||||
LIGHTEVAL_TASKS = {}
|
||||
|
||||
register_lighteval_task(LIGHTEVAL_TASKS, "custom", "math_500", "math_500", 0)
|
||||
register_lighteval_task(LIGHTEVAL_TASKS, "custom", "aime24", "aime24", 0)
|
||||
register_lighteval_task(LIGHTEVAL_TASKS, "custom", "aime25", "aime25", 0)
|
||||
register_lighteval_task(LIGHTEVAL_TASKS, "custom", "gpqa", "gpqa:diamond", 0)
|
||||
register_lighteval_task(LIGHTEVAL_TASKS, "extended", "lcb", "lcb:codegeneration", 0)
|
||||
|
||||
|
||||
def get_lighteval_tasks():
|
||||
return list(LIGHTEVAL_TASKS.keys())
|
||||
|
||||
|
||||
SUPPORTED_BENCHMARKS = get_lighteval_tasks()
|
||||
|
||||
|
||||
def run_lighteval_job(
|
||||
benchmark: str, training_args: Union["SFTConfig", "GRPOConfig"], model_args: "ModelConfig"
|
||||
) -> None:
|
||||
task_list = LIGHTEVAL_TASKS[benchmark]
|
||||
model_name = training_args.hub_model_id
|
||||
model_revision = training_args.hub_model_revision
|
||||
# For large models >= 30b params or those running the MATH benchmark, we need to shard them across the GPUs to avoid OOM
|
||||
num_gpus = get_gpu_count_for_vllm(model_name, model_revision)
|
||||
if get_param_count_from_repo_id(model_name) >= 30_000_000_000:
|
||||
tensor_parallel = True
|
||||
else:
|
||||
tensor_parallel = False
|
||||
|
||||
cmd = VLLM_SLURM_PREFIX.copy()
|
||||
cmd_args = [
|
||||
f"--gres=gpu:{num_gpus}",
|
||||
f"--job-name=or1_{benchmark}_{model_name.split('/')[-1]}_{model_revision}",
|
||||
"slurm/evaluate.slurm",
|
||||
benchmark,
|
||||
f'"{task_list}"',
|
||||
model_name,
|
||||
model_revision,
|
||||
f"{tensor_parallel}",
|
||||
f"{model_args.trust_remote_code}",
|
||||
]
|
||||
if training_args.system_prompt is not None:
|
||||
cmd_args.append(f"--system_prompt={training_args.system_prompt}")
|
||||
cmd[-1] += " " + " ".join(cmd_args)
|
||||
subprocess.run(cmd, check=True)
|
||||
|
||||
|
||||
def run_benchmark_jobs(training_args: Union["SFTConfig", "GRPOConfig"], model_args: "ModelConfig") -> None:
|
||||
benchmarks = training_args.benchmarks
|
||||
if len(benchmarks) == 1 and benchmarks[0] == "all":
|
||||
benchmarks = get_lighteval_tasks()
|
||||
# Evaluate on all supported benchmarks. Later we may want to include a `chat` option
|
||||
# that just evaluates on `ifeval` and `mt_bench` etc.
|
||||
|
||||
for benchmark in benchmarks:
|
||||
print(f"Launching benchmark `{benchmark}`")
|
||||
if benchmark in get_lighteval_tasks():
|
||||
run_lighteval_job(benchmark, training_args, model_args)
|
||||
else:
|
||||
raise ValueError(f"Unknown benchmark {benchmark}")
|
@ -1,131 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import re
|
||||
from concurrent.futures import Future
|
||||
|
||||
from transformers import AutoConfig
|
||||
|
||||
from huggingface_hub import (
|
||||
create_branch,
|
||||
create_repo,
|
||||
get_safetensors_metadata,
|
||||
list_repo_commits,
|
||||
list_repo_files,
|
||||
list_repo_refs,
|
||||
repo_exists,
|
||||
upload_folder,
|
||||
)
|
||||
from trl import GRPOConfig, SFTConfig
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def push_to_hub_revision(training_args: SFTConfig | GRPOConfig, extra_ignore_patterns=[]) -> Future:
|
||||
"""Pushes the model to branch on a Hub repo."""
|
||||
|
||||
# Create a repo if it doesn't exist yet
|
||||
repo_url = create_repo(repo_id=training_args.hub_model_id, private=True, exist_ok=True)
|
||||
# Get initial commit to branch from
|
||||
initial_commit = list_repo_commits(training_args.hub_model_id)[-1]
|
||||
# Now create the branch we'll be pushing to
|
||||
create_branch(
|
||||
repo_id=training_args.hub_model_id,
|
||||
branch=training_args.hub_model_revision,
|
||||
revision=initial_commit.commit_id,
|
||||
exist_ok=True,
|
||||
)
|
||||
logger.info(f"Created target repo at {repo_url}")
|
||||
logger.info(f"Pushing to the Hub revision {training_args.hub_model_revision}...")
|
||||
ignore_patterns = ["checkpoint-*", "*.pth"]
|
||||
ignore_patterns.extend(extra_ignore_patterns)
|
||||
future = upload_folder(
|
||||
repo_id=training_args.hub_model_id,
|
||||
folder_path=training_args.output_dir,
|
||||
revision=training_args.hub_model_revision,
|
||||
commit_message=f"Add {training_args.hub_model_revision} checkpoint",
|
||||
ignore_patterns=ignore_patterns,
|
||||
run_as_future=True,
|
||||
)
|
||||
logger.info(f"Pushed to {repo_url} revision {training_args.hub_model_revision} successfully!")
|
||||
|
||||
return future
|
||||
|
||||
|
||||
def check_hub_revision_exists(training_args: SFTConfig | GRPOConfig):
|
||||
"""Checks if a given Hub revision exists."""
|
||||
if repo_exists(training_args.hub_model_id):
|
||||
if training_args.push_to_hub_revision is True:
|
||||
# First check if the revision exists
|
||||
revisions = [rev.name for rev in list_repo_refs(training_args.hub_model_id).branches]
|
||||
# If the revision exists, we next check it has a README file
|
||||
if training_args.hub_model_revision in revisions:
|
||||
repo_files = list_repo_files(
|
||||
repo_id=training_args.hub_model_id, revision=training_args.hub_model_revision
|
||||
)
|
||||
if "README.md" in repo_files and training_args.overwrite_hub_revision is False:
|
||||
raise ValueError(
|
||||
f"Revision {training_args.hub_model_revision} already exists. "
|
||||
"Use --overwrite_hub_revision to overwrite it."
|
||||
)
|
||||
|
||||
|
||||
def get_param_count_from_repo_id(repo_id: str) -> int:
|
||||
"""Function to get model param counts from safetensors metadata or find patterns like 42m, 1.5b, 0.5m or products like 8x7b in a repo ID."""
|
||||
try:
|
||||
metadata = get_safetensors_metadata(repo_id)
|
||||
return list(metadata.parameter_count.values())[0]
|
||||
except Exception:
|
||||
# Pattern to match products (like 8x7b) and single values (like 42m)
|
||||
pattern = r"((\d+(\.\d+)?)(x(\d+(\.\d+)?))?)([bm])"
|
||||
matches = re.findall(pattern, repo_id.lower())
|
||||
|
||||
param_counts = []
|
||||
for full_match, number1, _, _, number2, _, unit in matches:
|
||||
if number2: # If there's a second number, it's a product
|
||||
number = float(number1) * float(number2)
|
||||
else: # Otherwise, it's a single value
|
||||
number = float(number1)
|
||||
|
||||
if unit == "b":
|
||||
number *= 1_000_000_000 # Convert to billion
|
||||
elif unit == "m":
|
||||
number *= 1_000_000 # Convert to million
|
||||
|
||||
param_counts.append(number)
|
||||
|
||||
if len(param_counts) > 0:
|
||||
# Return the largest number
|
||||
return int(max(param_counts))
|
||||
else:
|
||||
# Return -1 if no match found
|
||||
return -1
|
||||
|
||||
|
||||
def get_gpu_count_for_vllm(model_name: str, revision: str = "main", num_gpus: int = 8) -> int:
|
||||
"""vLLM enforces a constraint that the number of attention heads must be divisible by the number of GPUs and 64 must be divisible by the number of GPUs.
|
||||
This function calculates the number of GPUs to use for decoding based on the number of attention heads in the model.
|
||||
"""
|
||||
config = AutoConfig.from_pretrained(model_name, revision=revision, trust_remote_code=True)
|
||||
# Get number of attention heads
|
||||
num_heads = config.num_attention_heads
|
||||
# Reduce num_gpus so that num_heads is divisible by num_gpus and 64 is divisible by num_gpus
|
||||
while num_heads % num_gpus != 0 or 64 % num_gpus != 0:
|
||||
logger.info(f"Reducing num_gpus from {num_gpus} to {num_gpus - 1} to make num_heads divisible by num_gpus")
|
||||
num_gpus -= 1
|
||||
return num_gpus
|
@ -1,23 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from transformers.utils.import_utils import _is_package_available
|
||||
|
||||
|
||||
# Use same as transformers.utils.import_utils
|
||||
_e2b_available = _is_package_available("e2b")
|
||||
|
||||
|
||||
def is_e2b_available() -> bool:
|
||||
return _e2b_available
|
@ -1,26 +0,0 @@
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
from trl import ModelConfig
|
||||
|
||||
from ..configs import GRPOConfig, SFTConfig
|
||||
|
||||
|
||||
DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
|
||||
|
||||
|
||||
def get_tokenizer(
|
||||
model_args: ModelConfig, training_args: SFTConfig | GRPOConfig, auto_set_chat_template: bool = True
|
||||
) -> PreTrainedTokenizer:
|
||||
"""Get the tokenizer for the model."""
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
revision=model_args.model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
)
|
||||
|
||||
if training_args.chat_template is not None:
|
||||
tokenizer.chat_template = training_args.chat_template
|
||||
elif auto_set_chat_template and tokenizer.get_chat_template() is None:
|
||||
tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
|
||||
|
||||
return tokenizer
|
@ -1,11 +0,0 @@
|
||||
import os
|
||||
|
||||
|
||||
def init_wandb_training(training_args):
|
||||
"""
|
||||
Helper function for setting up Weights & Biases logging tools.
|
||||
"""
|
||||
if training_args.wandb_entity is not None:
|
||||
os.environ["WANDB_ENTITY"] = training_args.wandb_entity
|
||||
if training_args.wandb_project is not None:
|
||||
os.environ["WANDB_PROJECT"] = training_args.wandb_project
|
@ -1,317 +0,0 @@
|
||||
import unittest
|
||||
|
||||
from open_r1.rewards import (
|
||||
accuracy_reward,
|
||||
format_reward,
|
||||
get_cosine_scaled_reward,
|
||||
get_repetition_penalty_reward,
|
||||
len_reward,
|
||||
reasoning_steps_reward,
|
||||
)
|
||||
|
||||
|
||||
class TestRewards(unittest.TestCase):
|
||||
def test_accuracy_reward_correct_answer(self):
|
||||
"""Test accuracy_reward with a correct answer."""
|
||||
completion = [[{"content": r"\boxed{\frac{63}{400}}"}]]
|
||||
solution = [r"\frac{63}{400}"]
|
||||
|
||||
rewards = accuracy_reward(completion, solution)
|
||||
self.assertEqual(rewards[0], 1.0)
|
||||
|
||||
def test_accuracy_reward_wrong_answer(self):
|
||||
"""Test accuracy_reward with an incorrect answer."""
|
||||
completion = [[{"content": r"\boxed{\frac{64}{400}}"}]]
|
||||
solution = [r"\frac{63}{400}"]
|
||||
|
||||
rewards = accuracy_reward(completion, solution)
|
||||
self.assertEqual(rewards[0], 0.0)
|
||||
|
||||
def test_format_reward_correct(self):
|
||||
"""Test format_reward with correct format."""
|
||||
completion = [[{"content": "<think>Some reasoning</think><answer>The answer</answer>"}]]
|
||||
rewards = format_reward(completion)
|
||||
self.assertEqual(rewards[0], 1.0)
|
||||
|
||||
def test_format_reward_incorrect(self):
|
||||
"""Test format_reward with incorrect format."""
|
||||
incorrect_formats = [
|
||||
"<think>Only thinking</think>",
|
||||
"<answer>Only answer</answer>",
|
||||
"No tags at all",
|
||||
"<think>Missing closing</think><answer>Missing closing",
|
||||
"<think>Wrong order</answer><answer>Wrong order</think>",
|
||||
]
|
||||
|
||||
for fmt in incorrect_formats:
|
||||
completion = [[{"content": fmt}]]
|
||||
rewards = format_reward(completion)
|
||||
self.assertEqual(rewards[0], 0.0)
|
||||
|
||||
def test_reasoning_steps_reward(self):
|
||||
"""Test reasoning_steps_reward with various formats."""
|
||||
test_cases = [
|
||||
# Full credit cases (3 or more steps)
|
||||
("Step 1: First step\nStep 2: Second step\nStep 3: Third step", 1.0),
|
||||
("First, we do this.\nSecond, we do that.\nFinally, we conclude.", 1.0),
|
||||
# Partial credit cases (less than 3 steps)
|
||||
("Step 1: Only step", 1 / 3),
|
||||
("First, we do this.\nFinally, we conclude.", 2 / 3),
|
||||
# No credit case
|
||||
("Just plain text without any clear steps", 0.0),
|
||||
]
|
||||
|
||||
for content, expected_reward in test_cases:
|
||||
completion = [[{"content": content}]]
|
||||
rewards = reasoning_steps_reward(completion)
|
||||
self.assertAlmostEqual(rewards[0], expected_reward)
|
||||
|
||||
def test_multiple_completions(self):
|
||||
"""Test handling multiple completions at once."""
|
||||
completions = [[{"content": r"\boxed{\frac{63}{400}}"}], [{"content": r"\boxed{\frac{64}{400}}"}]]
|
||||
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]
|
||||
|
||||
rewards = accuracy_reward(completions, solutions)
|
||||
self.assertEqual(len(rewards), 2)
|
||||
self.assertEqual(rewards[0], 1.0)
|
||||
self.assertEqual(rewards[1], 0.0)
|
||||
|
||||
def test_cosine_scaled_reward(self):
|
||||
"""Test cosine_scaled_reward with various cases."""
|
||||
# Test parameters
|
||||
test_params = {
|
||||
"min_value_wrong": -1.0,
|
||||
"max_value_wrong": -0.5,
|
||||
"min_value_correct": 0.5,
|
||||
"max_value_correct": 1.0,
|
||||
"max_len": 100,
|
||||
}
|
||||
|
||||
test_cases = [
|
||||
# Correct answers with different lengths
|
||||
(r"\boxed{\frac{63}{400}}", r"\frac{63}{400}", 20, 0.943), # Short correct answer
|
||||
(r"\boxed{\frac{63}{400}}", r"\frac{63}{400}", 80, 0.547), # Long correct answer
|
||||
# Wrong answers with different lengths
|
||||
(r"\boxed{\frac{64}{400}}", r"\frac{63}{400}", 20, -0.942), # Short wrong answer
|
||||
(r"\boxed{\frac{64}{400}}", r"\frac{63}{400}", 80, -0.547), # Long wrong answer
|
||||
]
|
||||
|
||||
for content, solution, content_len, expected_reward in test_cases:
|
||||
# Pad content to desired length
|
||||
padded_content = content + " " * (content_len - len(content))
|
||||
completion = [[{"content": padded_content}]]
|
||||
|
||||
rewards = get_cosine_scaled_reward(**test_params)(completion, [solution])
|
||||
self.assertAlmostEqual(rewards[0], expected_reward, places=2)
|
||||
|
||||
def test_format_reward_specific_multiline(self):
|
||||
"""Test format_reward with a specific multiline input."""
|
||||
inputs = "<think>\nI will count each distinct object in the image:\n1. Purple scooter\n2. Red bicycle\n3. Green motorcycle\n4. Gray sedan\n5. Yellow school bus\n6. Small green double-decker bus\n7. Small red car\n8. Small purple car\n9. Small gray dirt bike\n\nThere are 9 distinct objects in total.\n</think>\n<answer>9</answer>"
|
||||
completion = [[{"content": inputs}]]
|
||||
rewards = format_reward(completion)
|
||||
self.assertEqual(rewards[0], 1.0)
|
||||
|
||||
def test_same_length_responses(self):
|
||||
"""Test len_reward when all responses have the same length."""
|
||||
completions = [[{"content": r"\boxed{\frac{63}{400}}"}], [{"content": r"\boxed{\frac{64}{400}}"}]]
|
||||
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]
|
||||
|
||||
rewards = len_reward(completions, solutions)
|
||||
self.assertEqual(rewards, [0.0, 0.0])
|
||||
|
||||
def test_different_lengths_correct_answers(self):
|
||||
"""Test len_reward with different length correct answers."""
|
||||
completions = [
|
||||
[{"content": r"\boxed{\frac{63}{400}}"}], # shorter
|
||||
[{"content": r"\boxed{\frac{63}{400}} " + "x" * 10}], # longer
|
||||
]
|
||||
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]
|
||||
|
||||
rewards = len_reward(completions, solutions)
|
||||
self.assertGreater(rewards[0], rewards[1]) # shorter answer should get higher reward
|
||||
self.assertAlmostEqual(rewards[0], 0.5) # shortest correct answer gets maximum reward
|
||||
|
||||
def test_different_lengths_incorrect_answers(self):
|
||||
"""Test len_reward with different length incorrect answers."""
|
||||
completions = [
|
||||
[{"content": r"\boxed{\frac{64}{400}}"}], # shorter
|
||||
[{"content": r"\boxed{\frac{64}{400}} " + "x" * 10}], # longer
|
||||
]
|
||||
solutions = [r"\frac{63}{400}", r"\frac{63}{400}"]
|
||||
|
||||
rewards = len_reward(completions, solutions)
|
||||
self.assertLessEqual(rewards[0], 0.0) # incorrect answers should get non-positive rewards
|
||||
self.assertLessEqual(rewards[1], 0.0)
|
||||
self.assertGreater(rewards[0], rewards[1]) # shorter answer should still be penalized less
|
||||
|
||||
def test_mixed_correctness(self):
|
||||
"""Test len_reward with mix of correct and incorrect answers of different lengths."""
|
||||
completions = [
|
||||
[{"content": r"\boxed{\frac{63}{400}}"}], # correct, shorter
|
||||
[{"content": r"\boxed{\frac{63}{400}} " + "x" * 10}], # correct, longer
|
||||
[{"content": r"\boxed{\frac{64}{400}}"}], # incorrect, shorter
|
||||
[{"content": r"\boxed{\frac{64}{400}} " + "x" * 10}], # incorrect, longer
|
||||
]
|
||||
solutions = [r"\frac{63}{400}"] * 4
|
||||
|
||||
rewards = len_reward(completions, solutions)
|
||||
|
||||
# Shortest correct answer should get positive reward
|
||||
self.assertGreater(rewards[0], 0.0)
|
||||
|
||||
# Longer correct answer might get negative reward:
|
||||
self.assertGreater(rewards[2], rewards[1])
|
||||
self.assertGreaterEqual(rewards[1], rewards[3])
|
||||
|
||||
# Incorrect answers should get non-positive rewards
|
||||
self.assertLessEqual(rewards[2], 0.0)
|
||||
self.assertLessEqual(rewards[3], 0.0)
|
||||
|
||||
# Shorter answers should get better rewards within their correctness category
|
||||
self.assertGreater(rewards[0], rewards[1]) # correct answers
|
||||
self.assertGreater(rewards[2], rewards[3]) # incorrect answers
|
||||
|
||||
def test_unparseable_solution(self):
|
||||
"""Test len_reward with unparseable solution."""
|
||||
completions = [[{"content": r"\boxed{answer}"}], [{"content": r"\boxed{answer} " + "x" * 10}]]
|
||||
solutions = ["unparseable_latex", "unparseable_latex"]
|
||||
|
||||
rewards = len_reward(completions, solutions)
|
||||
self.assertGreater(rewards[0], rewards[1]) # shorter answer should still get better reward
|
||||
self.assertAlmostEqual(rewards[0], 0.5) # treated as correct, shortest gets maximum reward
|
||||
|
||||
|
||||
class TestRepetitionPenaltyReward(unittest.TestCase):
|
||||
def test_positive_max_penalty_raises_value_error(self):
|
||||
with self.assertRaises(ValueError):
|
||||
get_repetition_penalty_reward(ngram_size=2, max_penalty=1.0)
|
||||
with self.assertRaisesRegex(ValueError, "max_penalty 1.5 should not be positive"):
|
||||
get_repetition_penalty_reward(ngram_size=2, max_penalty=1.5)
|
||||
|
||||
def test_no_repetition(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
|
||||
completions = [[{"content": "this is a test sentence"}]]
|
||||
rewards = reward_fn(completions)
|
||||
self.assertEqual(rewards, [0.0])
|
||||
|
||||
def test_full_repetition(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
|
||||
completions = [[{"content": "this this this this this"}]]
|
||||
|
||||
rewards = reward_fn(completions)
|
||||
# (1 - 1/4) * -1 = -0.75
|
||||
self.assertEqual(rewards, [-0.75])
|
||||
|
||||
def test_partial_repetition(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
|
||||
completions = [[{"content": "this is a this is a test"}]]
|
||||
|
||||
rewards = reward_fn(completions)
|
||||
# Unique 2-grams: (this, is), (is, a), (a, this), (a, test). 4 unique out of 6 total
|
||||
# (1 - 4/6) * -1 = -1/3 = -0.3333...
|
||||
self.assertAlmostEqual(rewards[0], -1 / 3)
|
||||
|
||||
def test_multiple_completions(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-0.5)
|
||||
completions = [
|
||||
[{"content": "this is a test"}],
|
||||
[{"content": "test test test test"}],
|
||||
]
|
||||
|
||||
rewards = reward_fn(completions)
|
||||
# Completion 1: (this, is, a), (is, a, test) -> 2 unique / 2 total -> (1 - 2/2) * -0.5 = 0
|
||||
# Completion 2: (test, test, test) -> 1 unique / 2 total -> (1 - 1/2) * -0.5 = -0.25
|
||||
self.assertAlmostEqual(rewards[0], 0.0)
|
||||
self.assertAlmostEqual(rewards[1], -0.25)
|
||||
|
||||
def test_empty_completion(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
|
||||
completions = [[{"content": ""}]]
|
||||
rewards = reward_fn(completions)
|
||||
self.assertEqual(rewards, [0.0])
|
||||
|
||||
def test_different_ngram_size(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-2.0)
|
||||
completions = [[{"content": "this is a this is a test"}]]
|
||||
|
||||
rewards = reward_fn(completions)
|
||||
self.assertAlmostEqual(rewards[0], -0.4)
|
||||
|
||||
def test_mixed_case(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=2, max_penalty=-1.0)
|
||||
completions = [
|
||||
[{"content": "This is A Test"}],
|
||||
[{"content": "this IS a test"}],
|
||||
]
|
||||
|
||||
rewards = reward_fn(completions)
|
||||
# both completions should produce the same reward, because the text gets lowercased
|
||||
self.assertAlmostEqual(rewards[0], rewards[1])
|
||||
|
||||
def test_one_word_completion(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
|
||||
completions = [[{"content": "word"}]]
|
||||
|
||||
rewards = reward_fn(completions)
|
||||
self.assertEqual(rewards, [0.0])
|
||||
|
||||
def test_two_word_completion(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
|
||||
completions = [[{"content": "two words"}]]
|
||||
|
||||
rewards = reward_fn(completions)
|
||||
self.assertEqual(rewards, [0.0])
|
||||
|
||||
def test_three_word_completion(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
|
||||
completions = [[{"content": "three different words"}]]
|
||||
|
||||
rewards = reward_fn(completions)
|
||||
self.assertEqual(rewards, [0.0])
|
||||
|
||||
def test_three_word_repetition_completion(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
|
||||
completions = [[{"content": "word word word word"}]]
|
||||
|
||||
rewards = reward_fn(completions)
|
||||
self.assertEqual(rewards, [-0.5])
|
||||
|
||||
def test_four_word_completion_with_repetition(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
|
||||
completions = [[{"content": "one two one two"}]]
|
||||
|
||||
rewards = reward_fn(completions)
|
||||
# ngrams are (one two one) (two one two). unique is 2 and count is 2, therefore (1-1) * -1.
|
||||
self.assertEqual(rewards, [0.0])
|
||||
|
||||
def test_five_word_completion_with_repetition(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-0.5)
|
||||
completions = [[{"content": "A B C A B"}]]
|
||||
|
||||
rewards = reward_fn(completions)
|
||||
# (A B C) (B C A) (C A B). unique is 3. count is 3 (1-1) * -.5 = 0
|
||||
self.assertEqual(rewards, [0.0])
|
||||
|
||||
def test_six_word_completion_with_repetition(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
|
||||
completions = [[{"content": "A B C A B C"}]]
|
||||
|
||||
rewards = reward_fn(completions)
|
||||
self.assertEqual(rewards, [-0.25])
|
||||
|
||||
def test_long_completion_with_repetition(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
|
||||
completions = [[{"content": "A B C A B C E F G A B C A B C"}]]
|
||||
rewards = reward_fn(completions)
|
||||
self.assertAlmostEqual(rewards[0], -0.3846, places=4)
|
||||
|
||||
def test_long_completion_without_repetition(self):
|
||||
reward_fn = get_repetition_penalty_reward(ngram_size=3, max_penalty=-1.0)
|
||||
completions = [[{"content": "A B C D E F G H I J K L"}]]
|
||||
|
||||
rewards = reward_fn(completions)
|
||||
self.assertEqual(rewards, [0.0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -1,17 +0,0 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.6.3
|
||||
hooks:
|
||||
- id: ruff
|
||||
types_or: [ python, pyi ]
|
||||
args: [ --fix ]
|
||||
- id: ruff-format
|
||||
types_or: [ python, pyi ]
|
||||
|
||||
# - repo: https://github.com/codespell-project/codespell
|
||||
# rev: v2.1.0
|
||||
# hooks:
|
||||
# - id: codespell
|
||||
# args:
|
||||
# - --ignore-words-list=nd,reacher,thist,ths,magent,ba
|
||||
# - --skip=docs/css/termynal.css,docs/js/termynal.js
|
@ -1,34 +0,0 @@
|
||||
cff-version: 1.2.0
|
||||
title: 'TRL: Transformer Reinforcement Learning'
|
||||
message: >-
|
||||
If you use this software, please cite it using the
|
||||
metadata from this file.
|
||||
type: software
|
||||
authors:
|
||||
- given-names: Leandro
|
||||
family-names: von Werra
|
||||
- given-names: Younes
|
||||
family-names: Belkada
|
||||
- given-names: Lewis
|
||||
family-names: Tunstall
|
||||
- given-names: Edward
|
||||
family-names: Beeching
|
||||
- given-names: Tristan
|
||||
family-names: Thrush
|
||||
- given-names: Nathan
|
||||
family-names: Lambert
|
||||
- given-names: Shengyi
|
||||
family-names: Huang
|
||||
- given-names: Kashif
|
||||
family-names: Rasul
|
||||
- given-names: Quentin
|
||||
family-names: Gallouédec
|
||||
repository-code: 'https://github.com/huggingface/trl'
|
||||
abstract: "With trl you can train transformer language models with Proximal Policy Optimization (PPO). The library is built on top of the transformers library by \U0001F917 Hugging Face. Therefore, pre-trained language models can be directly loaded via transformers. At this point, most decoder and encoder-decoder architectures are supported."
|
||||
keywords:
|
||||
- rlhf
|
||||
- deep-learning
|
||||
- pytorch
|
||||
- transformers
|
||||
license: Apache-2.0
|
||||
version: 0.15
|
@ -1,133 +0,0 @@
|
||||
|
||||
# Contributor Covenant Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
We as members, contributors, and leaders pledge to make participation in our
|
||||
community a harassment-free experience for everyone, regardless of age, body
|
||||
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
||||
identity and expression, level of experience, education, socio-economic status,
|
||||
nationality, personal appearance, race, caste, color, religion, or sexual
|
||||
identity and orientation.
|
||||
|
||||
We pledge to act and interact in ways that contribute to an open, welcoming,
|
||||
diverse, inclusive, and healthy community.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contributes to a positive environment for our
|
||||
community include:
|
||||
|
||||
* Demonstrating empathy and kindness toward other people
|
||||
* Being respectful of differing opinions, viewpoints, and experiences
|
||||
* Giving and gracefully accepting constructive feedback
|
||||
* Accepting responsibility and apologizing to those affected by our mistakes,
|
||||
and learning from the experience
|
||||
* Focusing on what is best not just for us as individuals, but for the overall
|
||||
community
|
||||
|
||||
Examples of unacceptable behavior include:
|
||||
|
||||
* The use of sexualized language or imagery, and sexual attention or advances of
|
||||
any kind
|
||||
* Trolling, insulting or derogatory comments, and personal or political attacks
|
||||
* Public or private harassment
|
||||
* Publishing others' private information, such as a physical or email address,
|
||||
without their explicit permission
|
||||
* Other conduct which could reasonably be considered inappropriate in a
|
||||
professional setting
|
||||
|
||||
## Enforcement Responsibilities
|
||||
|
||||
Community leaders are responsible for clarifying and enforcing our standards of
|
||||
acceptable behavior and will take appropriate and fair corrective action in
|
||||
response to any behavior that they deem inappropriate, threatening, offensive,
|
||||
or harmful.
|
||||
|
||||
Community leaders have the right and responsibility to remove, edit, or reject
|
||||
comments, commits, code, wiki edits, issues, and other contributions that are
|
||||
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
||||
decisions when appropriate.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies within all community spaces, and also applies when
|
||||
an individual is officially representing the community in public spaces.
|
||||
Examples of representing our community include using an official e-mail address,
|
||||
posting via an official social media account, or acting as an appointed
|
||||
representative at an online or offline event.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported to the community leaders responsible for enforcement at
|
||||
feedback@huggingface.co.
|
||||
All complaints will be reviewed and investigated promptly and fairly.
|
||||
|
||||
All community leaders are obligated to respect the privacy and security of the
|
||||
reporter of any incident.
|
||||
|
||||
## Enforcement Guidelines
|
||||
|
||||
Community leaders will follow these Community Impact Guidelines in determining
|
||||
the consequences for any action they deem in violation of this Code of Conduct:
|
||||
|
||||
### 1. Correction
|
||||
|
||||
**Community Impact**: Use of inappropriate language or other behavior deemed
|
||||
unprofessional or unwelcome in the community.
|
||||
|
||||
**Consequence**: A private, written warning from community leaders, providing
|
||||
clarity around the nature of the violation and an explanation of why the
|
||||
behavior was inappropriate. A public apology may be requested.
|
||||
|
||||
### 2. Warning
|
||||
|
||||
**Community Impact**: A violation through a single incident or series of
|
||||
actions.
|
||||
|
||||
**Consequence**: A warning with consequences for continued behavior. No
|
||||
interaction with the people involved, including unsolicited interaction with
|
||||
those enforcing the Code of Conduct, for a specified period of time. This
|
||||
includes avoiding interactions in community spaces as well as external channels
|
||||
like social media. Violating these terms may lead to a temporary or permanent
|
||||
ban.
|
||||
|
||||
### 3. Temporary Ban
|
||||
|
||||
**Community Impact**: A serious violation of community standards, including
|
||||
sustained inappropriate behavior.
|
||||
|
||||
**Consequence**: A temporary ban from any sort of interaction or public
|
||||
communication with the community for a specified period of time. No public or
|
||||
private interaction with the people involved, including unsolicited interaction
|
||||
with those enforcing the Code of Conduct, is allowed during this period.
|
||||
Violating these terms may lead to a permanent ban.
|
||||
|
||||
### 4. Permanent Ban
|
||||
|
||||
**Community Impact**: Demonstrating a pattern of violation of community
|
||||
standards, including sustained inappropriate behavior, harassment of an
|
||||
individual, or aggression toward or disparagement of classes of individuals.
|
||||
|
||||
**Consequence**: A permanent ban from any sort of public interaction within the
|
||||
community.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
||||
version 2.1, available at
|
||||
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
|
||||
|
||||
Community Impact Guidelines were inspired by
|
||||
[Mozilla's code of conduct enforcement ladder][Mozilla CoC].
|
||||
|
||||
For answers to common questions about this code of conduct, see the FAQ at
|
||||
[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
|
||||
[https://www.contributor-covenant.org/translations][translations].
|
||||
|
||||
[homepage]: https://www.contributor-covenant.org
|
||||
[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
|
||||
[Mozilla CoC]: https://github.com/mozilla/diversity
|
||||
[FAQ]: https://www.contributor-covenant.org/faq
|
||||
[translations]: https://www.contributor-covenant.org/translations
|
@ -1,459 +0,0 @@
|
||||
# How to contribute to TRL?
|
||||
|
||||
Everyone is welcome to contribute, and we value everybody's contribution. Code
|
||||
contributions are not the only way to help the community. Answering questions, helping
|
||||
others, and improving the documentation are also immensely valuable.
|
||||
|
||||
It also helps us if you spread the word! Reference the library in blog posts
|
||||
about the awesome projects it made possible, shout out on Twitter every time it has
|
||||
helped you, or simply ⭐️ the repository to say thank you.
|
||||
|
||||
However you choose to contribute, please be mindful and respect our
|
||||
[code of conduct](https://github.com/huggingface/trl/blob/main/CODE_OF_CONDUCT.md).
|
||||
|
||||
**This guide was heavily inspired by the awesome [scikit-learn guide to contributing](https://github.com/scikit-learn/scikit-learn/blob/main/CONTRIBUTING.md).**
|
||||
|
||||
## Ways to contribute
|
||||
|
||||
There are several ways you can contribute to TRL:
|
||||
|
||||
* Fix outstanding issues with the existing code.
|
||||
* Submit issues related to bugs or desired new features.
|
||||
* Implement trainers for new post-training algorithms.
|
||||
* Contribute to the examples or the documentation.
|
||||
|
||||
If you don't know where to start, there is a special [Good First
|
||||
Issue](https://github.com/huggingface/trl/labels/%F0%9F%91%B6%20good%20first%20issue) listing. It will give you a list of
|
||||
open issues that are beginner-friendly and help you start contributing to open-source. The best way to do that is to open a Pull Request and link it to the issue that you'd like to work on. We try to give priority to opened PRs as we can easily track the progress of the fix, and if the contributor does not have time anymore, someone else can take the PR over.
|
||||
|
||||
For something slightly more challenging, you can also take a look at the [Good Second Issue](https://github.com/huggingface/trl/labels/Good%20Second%20Issue) list. In general though, if you feel like you know what you're doing, go for it and we'll help you get there! 🚀
|
||||
|
||||
> All contributions are equally valuable to the community. 🥰
|
||||
|
||||
Before you start contributing make sure you have installed all the dev tools:
|
||||
|
||||
```bash
|
||||
pip install -e .[dev]
|
||||
```
|
||||
|
||||
## Fixing outstanding issues
|
||||
|
||||
If you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](#submitting-a-pull-request-pr) and open a Pull Request!
|
||||
|
||||
## Submitting a bug-related issue or feature request
|
||||
|
||||
Do your best to follow these guidelines when submitting a bug-related issue or a feature request. It will make it easier for us to come back to you quickly and with good feedback.
|
||||
|
||||
### Did you find a bug?
|
||||
|
||||
The TRL library is robust and reliable thanks to users who report the problems they encounter.
|
||||
|
||||
Before you report an issue, we would really appreciate it if you could **make sure the bug was not
|
||||
already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the library itself, and not your code.
|
||||
|
||||
Once you've confirmed the bug hasn't already been reported, please include the following information in your issue so we can quickly resolve it:
|
||||
|
||||
* Your **OS type and version**, **Python**, **PyTorch**, **TRL** and **Transformers** versions.
|
||||
* A short, self-contained, code snippet that allows us to reproduce the bug in
|
||||
less than 30s.
|
||||
* The *full* traceback if an exception is raised.
|
||||
* Attach any other additional information, like screenshots, you think may help.
|
||||
|
||||
To get the OS and software versions automatically, run the following command:
|
||||
|
||||
```bash
|
||||
trl env
|
||||
```
|
||||
|
||||
### Do you want a new feature?
|
||||
|
||||
If there is a new feature you'd like to see in TRL, please open an issue and describe:
|
||||
|
||||
1. What is the *motivation* behind this feature? Is it related to a problem or frustration with the library? Is it a feature related to something you need for a project? Is it something you worked on and think it could benefit the community?
|
||||
|
||||
Whatever it is, we'd love to hear about it!
|
||||
|
||||
2. Describe your requested feature in as much detail as possible. The more you can tell us about it, the better we'll be able to help you.
|
||||
3. Provide a *code snippet* that demonstrates the feature's usage.
|
||||
4. If the feature is related to a paper, please include a link.
|
||||
|
||||
If your issue is well written we're already 80% of the way there by the time you create it.
|
||||
|
||||
## Do you want to implement a new trainer?
|
||||
|
||||
New post-training methods are published frequently and those that satisfy the following criteria are good candidates to be integrated into TRL:
|
||||
|
||||
* **Simplicity:** Does the new method achieve similar performance as prior methods, but with less complexity? A good example is Direct Preference Optimization (DPO) [[Rafailov et al, 2023]](https://huggingface.co/papers/2305.18290), which provided a simpler and compelling alternative to RLHF methods.
|
||||
* **Efficiency:** Does the new method provide a significant improvement in training efficiency? A good example is Odds Ratio Preference Optimization (ORPO) [[Hong et al, 2023]](https://huggingface.co/papers/2403.07691), which utilizes a similar objective as DPO but requires half the GPU VRAM.
|
||||
|
||||
Methods that only provide incremental improvements at the expense of added complexity or compute costs are unlikely to be included in TRL.
|
||||
|
||||
If you want to implement a trainer for a new post-training method, first open an issue and provide the following information:
|
||||
|
||||
* A short description of the method and a link to the paper.
|
||||
* Link to the implementation if it is open-sourced.
|
||||
* Link to model weights trained with the method if they are available.
|
||||
|
||||
Based on the community and maintainer feedback, the next step will be to implement the trainer and config classes. See the following examples for inspiration:
|
||||
|
||||
* Paired preference optimisation: [`dpo_trainer.py`](./trl/trainer/dpo_trainer.py) and [`dpo_config.py`](./trl/trainer/dpo_config.py)
|
||||
* RL-based optimisation: [`rloo_trainer.py](./trl/trainer/rloo_trainer.py) and [`rloo_config.py](./trl/trainer/rloo_config.py)
|
||||
* Online optimisation: [`online_dpo_trainer.py`](./trl/trainer/online_dpo_trainer.py) and [`online_dpo_config.py`](./trl/trainer/online_dpo_config.py)
|
||||
|
||||
## Do you want to add documentation?
|
||||
|
||||
We're always looking for improvements to the documentation that make it more clear and accurate. Please let us know how the documentation can be improved, such as typos, dead links, and any missing, unclear, or inaccurate content... We'll be happy to make the changes or help you contribute if you're interested!
|
||||
|
||||
## Submitting a pull request (PR)
|
||||
|
||||
Before writing code, we strongly advise you to search through the existing PRs or
|
||||
issues to make sure that nobody is already working on the same thing. If you are
|
||||
unsure, it is always a good idea to open an issue to get some feedback.
|
||||
|
||||
You will need basic `git` proficiency to be able to contribute to
|
||||
TRL. `git` is not the easiest tool to use but it has the greatest
|
||||
manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
|
||||
Git](https://git-scm.com/book/en/v2) is a very good reference.
|
||||
|
||||
Follow these steps to start contributing:
|
||||
|
||||
1. Fork the [repository](https://github.com/huggingface/trl) by
|
||||
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
|
||||
under your GitHub user account.
|
||||
|
||||
2. Clone your fork to your local disk, and add the base repository as a remote. The following command
|
||||
assumes you have your public SSH key uploaded to GitHub. See the following guide for more
|
||||
[information](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository).
|
||||
|
||||
```bash
|
||||
$ git clone git@github.com:<your Github handle>/trl.git
|
||||
$ cd trl
|
||||
$ git remote add upstream https://github.com/huggingface/trl.git
|
||||
```
|
||||
|
||||
3. Create a new branch to hold your development changes, and do this for every new PR you work on.
|
||||
|
||||
Start by synchronizing your `main` branch with the `upstream/main` branch (more details in the [GitHub Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/syncing-a-fork)):
|
||||
|
||||
```bash
|
||||
$ git checkout main
|
||||
$ git fetch upstream
|
||||
$ git merge upstream/main
|
||||
```
|
||||
|
||||
Once your `main` branch is synchronized, create a new branch from it:
|
||||
|
||||
```bash
|
||||
$ git checkout -b a-descriptive-name-for-my-changes
|
||||
```
|
||||
|
||||
**Do not** work on the `main` branch.
|
||||
|
||||
4. Set up a development environment by running the following command in a conda or a virtual environment you've created for working on this library:
|
||||
|
||||
```bash
|
||||
$ pip install -e .[dev]
|
||||
```
|
||||
|
||||
(If TRL was already installed in the virtual environment, remove
|
||||
it with `pip uninstall trl` before reinstalling it.)
|
||||
|
||||
Alternatively, if you are using [Visual Studio Code](https://code.visualstudio.com/Download), the fastest way to get set up is by using
|
||||
the provided Dev Container. Documentation on how to get started with dev containers is available [here](https://code.visualstudio.com/docs/remote/containers).
|
||||
|
||||
5. Develop the features on your branch.
|
||||
|
||||
As you work on the features, you should make sure that the test suite
|
||||
passes. You should run the tests impacted by your changes like this (see
|
||||
below an explanation regarding the environment variable):
|
||||
|
||||
```bash
|
||||
$ pytest tests/<TEST_TO_RUN>.py
|
||||
```
|
||||
|
||||
> For the following commands leveraging the `make` utility, we recommend using the WSL system when running on
|
||||
> Windows. More information [here](https://docs.microsoft.com/en-us/windows/wsl/about).
|
||||
|
||||
You can also run the full suite with the following command.
|
||||
|
||||
```bash
|
||||
$ make test
|
||||
```
|
||||
|
||||
TRL relies on `ruff` for maintaining consistent code formatting across its source files. Before submitting any PR, you should apply automatic style corrections and run code verification checks.
|
||||
|
||||
We provide a `precommit` target in the `Makefile` that simplifies this process by running all required checks and optimizations on only the files modified by your PR.
|
||||
|
||||
To apply these checks and corrections in one step, use:
|
||||
|
||||
```bash
|
||||
$ make precommit
|
||||
```
|
||||
|
||||
This command runs the following:
|
||||
- Executes `pre-commit` hooks to automatically fix style issues with `ruff` and other tools.
|
||||
- Runs additional scripts such as adding copyright information.
|
||||
|
||||
If you prefer to apply the style corrections separately or review them individually, the `pre-commit` hook will handle the formatting for the files in question.
|
||||
|
||||
Once you're happy with your changes, add changed files using `git add` and
|
||||
make a commit with `git commit` to record your changes locally:
|
||||
|
||||
```bash
|
||||
$ git add modified_file.py
|
||||
$ git commit
|
||||
```
|
||||
|
||||
Please write [good commit messages](https://chris.beams.io/posts/git-commit/).
|
||||
|
||||
It is a good idea to sync your copy of the code with the original
|
||||
repository regularly. This way you can quickly account for changes:
|
||||
|
||||
```bash
|
||||
$ git fetch upstream
|
||||
$ git rebase upstream/main
|
||||
```
|
||||
|
||||
Push the changes to your account using:
|
||||
|
||||
```bash
|
||||
$ git push -u origin a-descriptive-name-for-my-changes
|
||||
```
|
||||
|
||||
6. Once you are satisfied (**and the checklist below is happy too**), go to the
|
||||
webpage of your fork on GitHub. Click on 'Pull request' to send your changes
|
||||
to the project maintainers for review.
|
||||
|
||||
7. It's ok if maintainers ask you for changes. It happens to core contributors too! To ensure everyone can review your changes in the pull request, work on your local branch and push the updates to your fork. They will automatically appear in the pull request.
|
||||
|
||||
|
||||
### Checklist
|
||||
|
||||
1. The title of your pull request should be a summary of its contribution;
|
||||
2. If your pull request addresses an issue, please mention the issue number in
|
||||
the pull request description to make sure they are linked (and people
|
||||
consulting the issue know you are working on it);
|
||||
3. To indicate a work in progress please prefix the title with `[WIP]`, or mark
|
||||
the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate
|
||||
it from PRs ready to be merged;
|
||||
4. Make sure existing tests pass;
|
||||
5. Add high-coverage tests. No quality testing = no merge.
|
||||
|
||||
|
||||
### Tests
|
||||
|
||||
An extensive test suite is included to test the library behavior and several examples. Library tests can be found in
|
||||
the [tests folder](https://github.com/huggingface/trl/tree/main/tests).
|
||||
|
||||
We use `pytest` to run the tests. From the root of the
|
||||
repository here's how to run tests with `pytest` for the library:
|
||||
|
||||
```bash
|
||||
$ python -m pytest -sv ./tests
|
||||
```
|
||||
|
||||
That's how `make test` is implemented (without the `pip install` line)!
|
||||
|
||||
You can specify a smaller set of tests to test only the feature
|
||||
you're working on.
|
||||
|
||||
### Default values guidelines
|
||||
|
||||
1. **Use defaults when appropriate**:
|
||||
|
||||
Provide default values unless the parameter's value varies significantly by use case. For example, datasets or models should not have defaults, but parameters like `learning_rate` should.
|
||||
|
||||
2. **Prioritize proven defaults**:
|
||||
|
||||
Default values should align with those recommended in the original paper or method. Alternatives require strong evidence of superior performance in most cases.
|
||||
|
||||
3. **Ensure safety and predictability**:
|
||||
|
||||
Defaults must be safe, expected and reliable. Avoid settings that could lead to surprising outcomes, such as excessive memory usage or poor performance in edge cases.
|
||||
|
||||
4. **Balance consistency and flexibility**:
|
||||
|
||||
Aim for consistent defaults across similar functions or methods. However, consistency should not be preferred to point 2 or 3.
|
||||
|
||||
5. **Opt-in for new features**:
|
||||
|
||||
Do not enable new features or improvements (e.g., novel loss functions) by default. Users should explicitly opt-in to use these.
|
||||
|
||||
### Writing documentation
|
||||
|
||||
High-quality documentation is crucial for maintaining a project that is easy to use, understand, and extend. When adding new features, ensure they are thoroughly documented to maintain consistency and clarity throughout the project.
|
||||
|
||||
To illustrate what good documentation looks like, here’s an example of a well-documented function:
|
||||
|
||||
````python
|
||||
def replicate_str(string: str, n: int, sep: str = " ") -> str:
|
||||
r"""
|
||||
Replicate a string `n` times with a separator.
|
||||
|
||||
Args:
|
||||
string (`str`):
|
||||
String to replicate.
|
||||
n (`int`):
|
||||
Number of times to replicate the string.
|
||||
sep (`str`, *optional*, defaults to `" "`):
|
||||
Separator to use between each replication.
|
||||
|
||||
Returns:
|
||||
`str`: The replicated string.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
>>> replicate_str("hello", 3)
|
||||
"hello hello hello"
|
||||
>>> replicate_str("hello", 3, sep=", ")
|
||||
"hello, hello, hello"
|
||||
```
|
||||
"""
|
||||
return sep.join([string] * n)
|
||||
````
|
||||
|
||||
* **Line Wrapping:** Applied a consistent line wrap at column 120 to improve readability.
|
||||
* **Definite Articles:** Removed definite articles where possible to streamline language. (Eg: Changed "The string to replicate" to "String to replicate")
|
||||
* **Type Annotations:**
|
||||
* Always include type definitions, indicating if a parameter is optional and specifying the default value.
|
||||
* Note that `Optional` means that the value can be `None`, and `*optional*` means that it is not required for the user to pass a value.
|
||||
E.g., for arguments that can't be `None` and aren't required:
|
||||
|
||||
```python
|
||||
foo (`int`, *optional*, defaults to `4`):
|
||||
```
|
||||
|
||||
For arguments that can be `None` and are required:
|
||||
|
||||
```python
|
||||
foo (`Optional[int]`):
|
||||
```
|
||||
|
||||
for arguments that can be `None` and aren't required:
|
||||
|
||||
```python
|
||||
foo (`Optional[int]`, *optional*, defaults to `None`):
|
||||
```
|
||||
|
||||
* **String Defaults:**
|
||||
* Ensured that default string values are wrapped in double quotes:
|
||||
|
||||
```python
|
||||
defaults to `"foo"`
|
||||
```
|
||||
|
||||
* **Dictionary Typing:**
|
||||
* Replaced generic `dict` type hints with more explicit `dict[str, Any]` to clarify expected key-value pairs.
|
||||
* **Default Value Formatting:**
|
||||
* Consistently surrounded default values with backticks for improved formatting:
|
||||
|
||||
```python
|
||||
defaults to `4`
|
||||
```
|
||||
|
||||
* **Sub-sectioning:** When the number of arguments is large, consider breaking them into sub-sections for better readability.
|
||||
|
||||
```python
|
||||
def calculate_statistics(data: list[float], precision: int = 2, include_variance: bool = False) -> dict[str, float]:
|
||||
r"""
|
||||
Calculates basic statistics for a given dataset.
|
||||
|
||||
Args:
|
||||
> Data inputs
|
||||
|
||||
data (`list[float]`):
|
||||
A list of numerical values to analyze.
|
||||
|
||||
> Configuration parameters
|
||||
|
||||
precision (`int`, *optional*, defaults to `2`):
|
||||
Number of decimal places to round the results.
|
||||
include_variance (`bool`, *optional*, defaults to `False`):
|
||||
Whether to include the variance of the dataset in the results.
|
||||
|
||||
Returns:
|
||||
`dict[str, float]`:
|
||||
A dictionary containing calculated statistics such as mean, median, and optionally variance.
|
||||
"""
|
||||
...
|
||||
```
|
||||
|
||||
### Deprecation and backward compatibility
|
||||
|
||||
Our approach to deprecation and backward compatibility is flexible and based on the feature’s usage and impact. Each deprecation is carefully evaluated, aiming to balance innovation with user needs.
|
||||
|
||||
When a feature or component is marked for deprecation, its use will emit a warning message. This warning will include:
|
||||
|
||||
- **Transition Guidance**: Instructions on how to migrate to the alternative solution or replacement.
|
||||
- **Removal Version**: The target version when the feature will be removed, providing users with a clear timeframe to transition.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
warnings.warn(
|
||||
"The `Trainer.foo` method is deprecated and will be removed in version 0.14.0. "
|
||||
"Please use the `Trainer.bar` class instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
```
|
||||
|
||||
The deprecation and removal schedule is based on each feature's usage and impact, with examples at two extremes:
|
||||
|
||||
- **Experimental or Low-Use Features**: For a feature that is experimental or has limited usage, backward compatibility may not be maintained between releases. Users should therefore anticipate potential breaking changes from one version to the next.
|
||||
|
||||
- **Widely-Used Components**: For a feature with high usage, we aim for a more gradual transition period of approximately **5 months**, generally scheduling deprecation around **5 minor releases** after the initial warning.
|
||||
|
||||
These examples represent the two ends of a continuum. The specific timeline for each feature will be determined individually, balancing innovation with user stability needs.
|
||||
|
||||
### Working with warnings
|
||||
|
||||
Warnings play a critical role in guiding users toward resolving potential issues, but they should be used thoughtfully to avoid unnecessary noise. Unlike logging, which provides informational context or operational details, warnings signal conditions that require attention and action. Overusing warnings can dilute their importance, leading users to ignore them entirely.
|
||||
|
||||
#### Definitions
|
||||
|
||||
- **Correct**: An operation is correct if it is valid, follows the intended approach, and aligns with the current best practices or guidelines within the codebase. This is the recommended or intended way to perform the operation.
|
||||
- **Supported**: An operation is supported if it is technically valid and works within the current codebase, but it may not be the most efficient, optimal, or recommended way to perform the task. This includes deprecated features or legacy approaches that still work but may be phased out in the future.
|
||||
|
||||
#### Choosing the right message
|
||||
|
||||
- **Correct → No warning**:
|
||||
If the operation is fully valid and expected, no message should be issued. The system is working as intended, so no warning is necessary.
|
||||
|
||||
- **Correct but deserves attention → No warning, possibly a log message**:
|
||||
When an operation is correct but uncommon or requires special attention, providing an informational message can be helpful. This keeps users informed without implying any issue. If available, use the logger to output this message. Example:
|
||||
|
||||
```python
|
||||
logger.info("This is an informational message about a rare but correct operation.")
|
||||
```
|
||||
|
||||
- **Correct but very likely a mistake → Warning with option to disable**:
|
||||
In rare cases, you may want to issue a warning for a correct operation that’s very likely a mistake. In such cases, you must provide an option to suppress the warning. This can be done with a flag in the function. Example:
|
||||
|
||||
```python
|
||||
def my_function(foo, bar, _warn=True):
|
||||
if foo == bar:
|
||||
if _warn:
|
||||
warnings.warn("foo and bar are the same, this is likely a mistake. Ignore this warning by setting `_warn=False`.")
|
||||
# Do something
|
||||
```
|
||||
|
||||
- **Supported but not correct → Warning**:
|
||||
If the operation is technically supported but is deprecated, suboptimal, or could cause future issues (e.g., conflicting arguments), a warning should be raised. This message should be actionable, meaning it must explain how to resolve the issue. Example:
|
||||
|
||||
```python
|
||||
def my_function(foo, bar):
|
||||
if foo and bar:
|
||||
warnings.warn("Both `foo` and `bar` were provided, but only one is allowed. Ignoring `foo`. Please pass only one of these arguments.")
|
||||
# Do something
|
||||
```
|
||||
|
||||
- **Not supported → Exception**:
|
||||
If the operation is invalid or unsupported, raise an exception. This indicates that the operation cannot be performed and requires immediate attention. Example:
|
||||
|
||||
```python
|
||||
def my_function(foo, bar):
|
||||
if foo and bar:
|
||||
raise ValueError("Both `foo` and `bar` were provided, but only one is allowed. Please pass only one of these arguments.")
|
||||
```
|
||||
|
||||
By following this classification, you ensure that warnings, information, and exceptions are used appropriately, providing clear guidance to the user without cluttering the system with unnecessary messages.
|
@ -1,201 +0,0 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
@ -1,6 +0,0 @@
|
||||
include settings.ini
|
||||
include LICENSE
|
||||
include CONTRIBUTING.md
|
||||
include README.md
|
||||
recursive-exclude * __pycache__
|
||||
include trl/templates/*.md
|
@ -1,32 +0,0 @@
|
||||
.PHONY: test precommit common_tests slow_tests test_examples tests_gpu
|
||||
|
||||
check_dirs := examples tests trl
|
||||
|
||||
ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs
|
||||
COMMAND_FILES_PATH = `pwd`/commands
|
||||
|
||||
test:
|
||||
python -m pytest -n auto --dist=loadfile -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' ./tests/
|
||||
|
||||
precommit:
|
||||
pre-commit run --all-files
|
||||
python scripts/add_copyrights.py
|
||||
|
||||
tests_gpu:
|
||||
python -m pytest tests/test_* $(if $(IS_GITHUB_CI),--report-log "common_tests.log",)
|
||||
|
||||
slow_tests:
|
||||
python -m pytest tests/slow/test_* $(if $(IS_GITHUB_CI),--report-log "slow_tests.log",)
|
||||
|
||||
test_examples:
|
||||
touch temp_results_sft_tests.txt
|
||||
for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \
|
||||
TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_sft.sh; \
|
||||
echo $$?','$${file} >> temp_results_sft_tests.txt; \
|
||||
done
|
||||
|
||||
touch temp_results_dpo_tests.txt
|
||||
for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \
|
||||
TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_dpo.sh; \
|
||||
echo $$?','$${file} >> temp_results_dpo_tests.txt; \
|
||||
done
|
@ -1,206 +0,0 @@
|
||||
# TRL - Transformer Reinforcement Learning
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png" alt="TRL Banner">
|
||||
</div>
|
||||
|
||||
<hr> <br>
|
||||
|
||||
<h3 align="center">
|
||||
<p>A comprehensive library to post-train foundation models</p>
|
||||
</h3>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/huggingface/trl/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/github/license/huggingface/trl.svg?color=blue"></a>
|
||||
<a href="https://huggingface.co/docs/trl/index"><img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/docs/trl/index.svg?down_color=red&down_message=offline&up_color=blue&up_message=online"></a>
|
||||
<a href="https://github.com/huggingface/trl/releases"><img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/trl.svg"></a>
|
||||
</p>
|
||||
|
||||
## Overview
|
||||
|
||||
TRL is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO). Built on top of the [🤗 Transformers](https://github.com/huggingface/transformers) ecosystem, TRL supports a variety of model architectures and modalities, and can be scaled-up across various hardware setups.
|
||||
|
||||
## Highlights
|
||||
|
||||
- **Efficient and scalable**:
|
||||
- Leverages [🤗 Accelerate](https://github.com/huggingface/accelerate) to scale from single GPU to multi-node clusters using methods like DDP and DeepSpeed.
|
||||
- Full integration with [`PEFT`](https://github.com/huggingface/peft) enables training on large models with modest hardware via quantization and LoRA/QLoRA.
|
||||
- Integrates [Unsloth](https://github.com/unslothai/unsloth) for accelerating training using optimized kernels.
|
||||
|
||||
- **Command Line Interface (CLI)**: A simple interface lets you fine-tune and interact with models without needing to write code.
|
||||
|
||||
- **Trainers**: Various fine-tuning methods are easily accessible via trainers like [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer), [`ORPOTrainer`](https://huggingface.co/docs/trl/orpo_trainer) and more.
|
||||
|
||||
- **AutoModels**: Use pre-defined model classes like [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) to simplify reinforcement learning (RL) with LLMs.
|
||||
|
||||
## Installation
|
||||
|
||||
### Python Package
|
||||
|
||||
Install the library using `pip`:
|
||||
|
||||
```bash
|
||||
pip install trl
|
||||
```
|
||||
|
||||
### From source
|
||||
|
||||
If you want to use the latest features before an official release, you can install TRL from source:
|
||||
|
||||
```bash
|
||||
pip install git+https://github.com/huggingface/trl.git
|
||||
```
|
||||
|
||||
### Repository
|
||||
|
||||
If you want to use the examples you can clone the repository with the following command:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/trl.git
|
||||
```
|
||||
|
||||
## Command Line Interface (CLI)
|
||||
|
||||
You can use the TRL Command Line Interface (CLI) to quickly get started with Supervised Fine-tuning (SFT) and Direct Preference Optimization (DPO), or vibe check your model with the chat CLI:
|
||||
|
||||
**SFT:**
|
||||
|
||||
```bash
|
||||
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name trl-lib/Capybara \
|
||||
--output_dir Qwen2.5-0.5B-SFT
|
||||
```
|
||||
|
||||
**DPO:**
|
||||
|
||||
```bash
|
||||
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--dataset_name argilla/Capybara-Preferences \
|
||||
--output_dir Qwen2.5-0.5B-DPO
|
||||
```
|
||||
|
||||
**Chat:**
|
||||
|
||||
```bash
|
||||
trl chat --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
|
||||
```
|
||||
|
||||
Read more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/main/en/clis) or use `--help` for more details.
|
||||
|
||||
## How to use
|
||||
|
||||
For more flexibility and control over training, TRL provides dedicated trainer classes to post-train language models or PEFT adapters on a custom dataset. Each trainer in TRL is a light wrapper around the 🤗 Transformers trainer and natively supports distributed training methods like DDP, DeepSpeed ZeRO, and FSDP.
|
||||
|
||||
### `SFTTrainer`
|
||||
|
||||
Here is a basic example of how to use the `SFTTrainer`:
|
||||
|
||||
```python
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("trl-lib/Capybara", split="train")
|
||||
|
||||
training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT")
|
||||
trainer = SFTTrainer(
|
||||
args=training_args,
|
||||
model="Qwen/Qwen2.5-0.5B",
|
||||
train_dataset=dataset,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### `RewardTrainer`
|
||||
|
||||
Here is a basic example of how to use the `RewardTrainer`:
|
||||
|
||||
```python
|
||||
from trl import RewardConfig, RewardTrainer
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"Qwen/Qwen2.5-0.5B-Instruct", num_labels=1
|
||||
)
|
||||
model.config.pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
|
||||
training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward", per_device_train_batch_size=2)
|
||||
trainer = RewardTrainer(
|
||||
args=training_args,
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### `GRPOTrainer`
|
||||
|
||||
`GRPOTrainer` implements the [Group Relative Policy Optimization (GRPO) algorithm](https://huggingface.co/papers/2402.03300) that is more memory-efficient than PPO and was used to train [Deepseek AI's R1](https://huggingface.co/deepseek-ai/DeepSeek-R1).
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
|
||||
dataset = load_dataset("trl-lib/tldr", split="train")
|
||||
|
||||
# Dummy reward function: rewards completions that are close to 20 characters
|
||||
def reward_len(completions, **kwargs):
|
||||
return [-abs(20 - len(completion)) for completion in completions]
|
||||
|
||||
training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", logging_steps=10)
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2-0.5B-Instruct",
|
||||
reward_funcs=reward_len,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### `DPOTrainer`
|
||||
|
||||
`DPOTrainer` implements the popular [Direct Preference Optimization (DPO) algorithm](https://huggingface.co/papers/2305.18290) that was used to post-train Llama 3 and many other models. Here is a basic example of how to use the `DPOTrainer`:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
|
||||
trainer = DPOTrainer(model=model, args=training_args, train_dataset=dataset, processing_class=tokenizer)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
If you want to contribute to `trl` or customize it to your needs make sure to read the [contribution guide](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) and make sure you make a dev install:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/trl.git
|
||||
cd trl/
|
||||
pip install -e .[dev]
|
||||
```
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{vonwerra2022trl,
|
||||
author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec},
|
||||
title = {TRL: Transformer Reinforcement Learning},
|
||||
year = {2020},
|
||||
publisher = {GitHub},
|
||||
journal = {GitHub repository},
|
||||
howpublished = {\url{https://github.com/huggingface/trl}}
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
This repository's source code is available under the [Apache-2.0 License](LICENSE).
|
@ -1,58 +0,0 @@
|
||||
#!/bin/bash
|
||||
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
|
||||
# but defaults to QLoRA + PEFT
|
||||
OUTPUT_DIR="test_dpo/"
|
||||
MODEL_NAME="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
||||
DATASET_NAME="trl-internal-testing/hh-rlhf-helpful-base-trl-style"
|
||||
MAX_STEPS=5
|
||||
BATCH_SIZE=2
|
||||
SEQ_LEN=128
|
||||
|
||||
# Handle extra arguments in case one passes accelerate configs.
|
||||
EXTRA_ACCELERATE_ARGS=""
|
||||
EXTRA_TRAINING_ARGS="""--use_peft \
|
||||
--load_in_4bit
|
||||
"""
|
||||
|
||||
# This is a hack to get the number of available GPUs
|
||||
NUM_GPUS=2
|
||||
|
||||
if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then
|
||||
EXTRA_ACCELERATE_ARGS=""
|
||||
else
|
||||
EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG"
|
||||
# For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed
|
||||
# on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training.
|
||||
if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then
|
||||
EXTRA_TRAINING_ARGS="--fp16"
|
||||
else
|
||||
echo "Keeping QLoRA + PEFT"
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
CMD="""
|
||||
accelerate launch $EXTRA_ACCELERATE_ARGS \
|
||||
--num_processes $NUM_GPUS \
|
||||
--mixed_precision 'fp16' \
|
||||
`pwd`/trl/scripts/dpo.py \
|
||||
--model_name_or_path $MODEL_NAME \
|
||||
--dataset_name $DATASET_NAME \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--max_steps $MAX_STEPS \
|
||||
--per_device_train_batch_size $BATCH_SIZE \
|
||||
--max_length $SEQ_LEN \
|
||||
$EXTRA_TRAINING_ARGS
|
||||
"""
|
||||
|
||||
echo "Starting program..."
|
||||
|
||||
{ # try
|
||||
echo $CMD
|
||||
eval "$CMD"
|
||||
} || { # catch
|
||||
# save log for exception
|
||||
echo "Operation Failed!"
|
||||
exit 1
|
||||
}
|
||||
exit 0
|
@ -1,59 +0,0 @@
|
||||
#!/bin/bash
|
||||
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
|
||||
# but defaults to QLoRA + PEFT
|
||||
OUTPUT_DIR="test_sft/"
|
||||
MODEL_NAME="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
||||
DATASET_NAME="stanfordnlp/imdb"
|
||||
MAX_STEPS=5
|
||||
BATCH_SIZE=2
|
||||
SEQ_LEN=128
|
||||
|
||||
|
||||
# Handle extra arguments in case one passes accelerate configs.
|
||||
EXTRA_ACCELERATE_ARGS=""
|
||||
EXTRA_TRAINING_ARGS="""--use_peft \
|
||||
--load_in_4bit
|
||||
"""
|
||||
|
||||
# Set your number of GPUs here
|
||||
NUM_GPUS=2
|
||||
|
||||
if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then
|
||||
EXTRA_ACCELERATE_ARGS=""
|
||||
else
|
||||
EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG"
|
||||
# For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed
|
||||
# on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training.
|
||||
if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then
|
||||
EXTRA_TRAINING_ARGS="--fp16"
|
||||
else
|
||||
echo "Keeping QLoRA + PEFT"
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
CMD="""
|
||||
accelerate launch $EXTRA_ACCELERATE_ARGS \
|
||||
--num_processes $NUM_GPUS \
|
||||
--mixed_precision 'fp16' \
|
||||
`pwd`/trl/scripts/sft.py \
|
||||
--model_name $MODEL_NAME \
|
||||
--dataset_name $DATASET_NAME \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--max_steps $MAX_STEPS \
|
||||
--per_device_train_batch_size $BATCH_SIZE \
|
||||
--max_length $SEQ_LEN \
|
||||
$EXTRA_TRAINING_ARGS
|
||||
"""
|
||||
|
||||
echo "Starting program..."
|
||||
|
||||
{ # try
|
||||
echo $CMD
|
||||
eval "$CMD"
|
||||
} || { # catch
|
||||
# save log for exception
|
||||
echo "Operation Failed!"
|
||||
exit 1
|
||||
}
|
||||
exit 0
|
@ -1,66 +0,0 @@
|
||||
# Builds GPU docker image of PyTorch
|
||||
# Uses multi-staged approach to reduce size
|
||||
# Stage 1
|
||||
# Use base conda image to reduce time
|
||||
FROM continuumio/miniconda3:latest AS compile-image
|
||||
# Specify py version
|
||||
ENV PYTHON_VERSION=3.10
|
||||
# Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl git wget software-properties-common git-lfs && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists*
|
||||
|
||||
# Install audio-related libraries
|
||||
RUN apt-get update && \
|
||||
apt install -y ffmpeg
|
||||
|
||||
RUN apt install -y libsndfile1-dev
|
||||
RUN git lfs install
|
||||
|
||||
# Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
||||
RUN conda create --name trl python=${PYTHON_VERSION} ipython jupyter pip
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip
|
||||
|
||||
# Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
||||
# We don't install pytorch here yet since CUDA isn't available
|
||||
# instead we use the direct torch wheel
|
||||
ENV PATH /opt/conda/envs/trl/bin:$PATH
|
||||
# Activate our bash shell
|
||||
RUN chsh -s /bin/bash
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
# Stage 2
|
||||
FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 AS build-image
|
||||
COPY --from=compile-image /opt/conda /opt/conda
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
|
||||
RUN chsh -s /bin/bash
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
RUN source activate trl && \
|
||||
python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq
|
||||
|
||||
# Install apt libs
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl git wget && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists*
|
||||
|
||||
# Activate the conda env and install transformers + accelerate from source
|
||||
RUN source activate trl && \
|
||||
python3 -m pip install -U --no-cache-dir \
|
||||
librosa \
|
||||
"soundfile>=0.12.1" \
|
||||
scipy \
|
||||
transformers \
|
||||
accelerate \
|
||||
peft \
|
||||
trl[test]@git+https://github.com/huggingface/trl
|
||||
|
||||
RUN source activate trl && \
|
||||
pip freeze | grep trl
|
||||
|
||||
RUN echo "source activate trl" >> ~/.profile
|
||||
|
||||
# Activate the virtualenv
|
||||
CMD ["/bin/bash"]
|
@ -1,66 +0,0 @@
|
||||
# Builds GPU docker image of PyTorch
|
||||
# Uses multi-staged approach to reduce size
|
||||
# Stage 1
|
||||
# Use base conda image to reduce time
|
||||
FROM continuumio/miniconda3:latest AS compile-image
|
||||
# Specify py version
|
||||
ENV PYTHON_VERSION=3.10
|
||||
# Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl git wget software-properties-common git-lfs && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists*
|
||||
|
||||
# Install audio-related libraries
|
||||
RUN apt-get update && \
|
||||
apt install -y ffmpeg
|
||||
|
||||
RUN apt install -y libsndfile1-dev
|
||||
RUN git lfs install
|
||||
|
||||
# Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
||||
RUN conda create --name trl python=${PYTHON_VERSION} ipython jupyter pip
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip
|
||||
|
||||
# Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
||||
# We don't install pytorch here yet since CUDA isn't available
|
||||
# instead we use the direct torch wheel
|
||||
ENV PATH /opt/conda/envs/trl/bin:$PATH
|
||||
# Activate our bash shell
|
||||
RUN chsh -s /bin/bash
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
# Stage 2
|
||||
FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 AS build-image
|
||||
COPY --from=compile-image /opt/conda /opt/conda
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
|
||||
RUN chsh -s /bin/bash
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
RUN source activate trl && \
|
||||
python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq
|
||||
|
||||
# Install apt libs
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl git wget && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists*
|
||||
|
||||
# Activate the conda env and install transformers + accelerate from source
|
||||
RUN source activate trl && \
|
||||
python3 -m pip install -U --no-cache-dir \
|
||||
librosa \
|
||||
"soundfile>=0.12.1" \
|
||||
scipy \
|
||||
git+https://github.com/huggingface/transformers \
|
||||
git+https://github.com/huggingface/accelerate \
|
||||
git+https://github.com/huggingface/peft \
|
||||
trl[test]@git+https://github.com/huggingface/trl
|
||||
|
||||
RUN source activate trl && \
|
||||
pip freeze | grep transformers
|
||||
|
||||
RUN echo "source activate trl" >> ~/.profile
|
||||
|
||||
# Activate the virtualenv
|
||||
CMD ["/bin/bash"]
|
@ -1,108 +0,0 @@
|
||||
- sections:
|
||||
- local: index
|
||||
title: TRL
|
||||
- local: installation
|
||||
title: Installation
|
||||
- local: quickstart
|
||||
title: Quickstart
|
||||
title: Getting started
|
||||
- sections:
|
||||
- local: dataset_formats
|
||||
title: Dataset Formats
|
||||
- local: how_to_train
|
||||
title: Training FAQ
|
||||
- local: logging
|
||||
title: Understanding Logs
|
||||
title: Conceptual Guides
|
||||
- sections:
|
||||
- local: clis
|
||||
title: Command Line Interface (CLI)
|
||||
- local: customization
|
||||
title: Customizing the Training
|
||||
- local: reducing_memory_usage
|
||||
title: Reducing Memory Usage
|
||||
- local: speeding_up_training
|
||||
title: Speeding Up Training
|
||||
- local: use_model
|
||||
title: Using Trained Models
|
||||
title: How-to guides
|
||||
- sections:
|
||||
- local: deepspeed_integration
|
||||
title: DeepSpeed
|
||||
- local: liger_kernel_integration
|
||||
title: Liger Kernel
|
||||
- local: peft_integration
|
||||
title: PEFT
|
||||
- local: unsloth_integration
|
||||
title: Unsloth
|
||||
title: Integrations
|
||||
- sections:
|
||||
- local: example_overview
|
||||
title: Example Overview
|
||||
- local: community_tutorials
|
||||
title: Community Tutorials
|
||||
- local: sentiment_tuning
|
||||
title: Sentiment Tuning
|
||||
- local: using_llama_models
|
||||
title: Training StackLlama
|
||||
- local: detoxifying_a_lm
|
||||
title: Detoxifying a Language Model
|
||||
- local: learning_tools
|
||||
title: Learning to Use Tools
|
||||
- local: multi_adapter_rl
|
||||
title: Multi Adapter RLHF
|
||||
title: Examples
|
||||
- sections:
|
||||
- sections: # Sorted alphabetically
|
||||
- local: alignprop_trainer
|
||||
title: AlignProp
|
||||
- local: bco_trainer
|
||||
title: BCO
|
||||
- local: cpo_trainer
|
||||
title: CPO
|
||||
- local: ddpo_trainer
|
||||
title: DDPO
|
||||
- local: dpo_trainer
|
||||
title: DPO
|
||||
- local: online_dpo_trainer
|
||||
title: Online DPO
|
||||
- local: gkd_trainer
|
||||
title: GKD
|
||||
- local: grpo_trainer
|
||||
title: GRPO
|
||||
- local: kto_trainer
|
||||
title: KTO
|
||||
- local: nash_md_trainer
|
||||
title: Nash-MD
|
||||
- local: orpo_trainer
|
||||
title: ORPO
|
||||
- local: ppo_trainer
|
||||
title: PPO
|
||||
- local: prm_trainer
|
||||
title: PRM
|
||||
- local: reward_trainer
|
||||
title: Reward
|
||||
- local: rloo_trainer
|
||||
title: RLOO
|
||||
- local: sft_trainer
|
||||
title: SFT
|
||||
- local: iterative_sft_trainer
|
||||
title: Iterative SFT
|
||||
- local: xpo_trainer
|
||||
title: XPO
|
||||
title: Trainers
|
||||
- local: models
|
||||
title: Model Classes
|
||||
- local: best_of_n
|
||||
title: Best of N Sampling
|
||||
- local: judges
|
||||
title: Judges
|
||||
- local: callbacks
|
||||
title: Callbacks
|
||||
- local: data_utils
|
||||
title: Data Utilities
|
||||
- local: text_environments
|
||||
title: Text Environments
|
||||
- local: script_utils
|
||||
title: Script Utilities
|
||||
title: API
|
@ -1,93 +0,0 @@
|
||||
# Aligning Text-to-Image Diffusion Models with Reward Backpropagation
|
||||
|
||||
[](https://huggingface.co/models?other=alignprop,trl)
|
||||
|
||||
## The why
|
||||
|
||||
If your reward function is differentiable, directly backpropagating gradients from the reward models to the diffusion model is significantly more sample and compute efficient (25x) than doing policy gradient algorithm like DDPO.
|
||||
AlignProp does full backpropagation through time, which allows updating the earlier steps of denoising via reward backpropagation.
|
||||
|
||||
<div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/reward_tuning.png"/></div>
|
||||
|
||||
|
||||
## Getting started with `examples/scripts/alignprop.py`
|
||||
|
||||
The `alignprop.py` script is a working example of using the `AlignProp` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`AlignPropConfig`).
|
||||
|
||||
**Note:** one A100 GPU is recommended to get this running. For lower memory setting, consider setting truncated_backprop_rand to False. With default settings this will do truncated backpropagation with K=1.
|
||||
|
||||
Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post-finetuning to HuggingFace hub. The following bash command is to be entered to get things running
|
||||
|
||||
```batch
|
||||
python alignprop.py --hf_user_access_token <token>
|
||||
```
|
||||
|
||||
To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help`
|
||||
|
||||
The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)
|
||||
|
||||
- The configurable randomized truncation range (`--alignprop_config.truncated_rand_backprop_minmax=(0,50)`) the first number should be equal and greater than 0, while the second number should equal or less to the number of diffusion timesteps (sample_num_steps)
|
||||
- The configurable truncation backprop absolute step (`--alignprop_config.truncated_backprop_timestep=49`) the number should be less than the number of diffusion timesteps (sample_num_steps), it only matters when truncated_backprop_rand is set to False
|
||||
|
||||
## Setting up the image logging hook function
|
||||
|
||||
Expect the function to be given a dictionary with keys
|
||||
```python
|
||||
['image', 'prompt', 'prompt_metadata', 'rewards']
|
||||
|
||||
```
|
||||
and `image`, `prompt`, `prompt_metadata`, `rewards`are batched.
|
||||
You are free to log however you want the use of `wandb` or `tensorboard` is recommended.
|
||||
|
||||
### Key terms
|
||||
|
||||
- `rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process
|
||||
- `prompt` : The prompt is the text that is used to generate the image
|
||||
- `prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)
|
||||
- `image` : The image generated by the Stable Diffusion model
|
||||
|
||||
Example code for logging sampled images with `wandb` is given below.
|
||||
|
||||
```python
|
||||
# for logging these images to wandb
|
||||
|
||||
def image_outputs_hook(image_data, global_step, accelerate_logger):
|
||||
# For the sake of this example, we only care about the last batch
|
||||
# hence we extract the last element of the list
|
||||
result = {}
|
||||
images, prompts, rewards = [image_data['images'],image_data['prompts'],image_data['rewards']]
|
||||
for i, image in enumerate(images):
|
||||
pil = Image.fromarray(
|
||||
(image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
|
||||
)
|
||||
pil = pil.resize((256, 256))
|
||||
result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
|
||||
accelerate_logger.log_images(
|
||||
result,
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
### Using the finetuned model
|
||||
|
||||
Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
pipeline.to("cuda")
|
||||
|
||||
pipeline.load_lora_weights('mihirpd/alignprop-trl-aesthetics')
|
||||
|
||||
prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
|
||||
results = pipeline(prompts)
|
||||
|
||||
for prompt, image in zip(prompts,results.images):
|
||||
image.save(f"dump/{prompt}.png")
|
||||
```
|
||||
|
||||
## Credits
|
||||
|
||||
This work is heavily influenced by the repo [here](https://github.com/mihirp1998/AlignProp/) and the associated paper [Aligning Text-to-Image Diffusion Models with Reward Backpropagation
|
||||
by Mihir Prabhudesai, Anirudh Goyal, Deepak Pathak, Katerina Fragkiadaki](https://huggingface.co/papers/2310.03739).
|
@ -1,100 +0,0 @@
|
||||
# BCO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=bco,trl)
|
||||
|
||||
TRL supports the Binary Classifier Optimization (BCO).
|
||||
The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0.
|
||||
For a full example have a look at [`examples/scripts/bco.py`].
|
||||
|
||||
## Expected dataset type
|
||||
|
||||
The [`BCOTrainer`] requires an [unpaired preference dataset](dataset_formats#unpaired-preference).
|
||||
The [`BCOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
|
||||
|
||||
## Expected model format
|
||||
The BCO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
|
||||
|
||||
## Using the `BCOTrainer`
|
||||
|
||||
For a detailed example have a look at the `examples/scripts/bco.py` script. At a high level we need to initialize the `BCOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response.
|
||||
|
||||
The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
|
||||
|
||||
|
||||
|
||||
```py
|
||||
training_args = BCOConfig(
|
||||
beta=0.1,
|
||||
)
|
||||
|
||||
bco_trainer = BCOTrainer(
|
||||
model,
|
||||
model_ref,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
processing_class=tokenizer,
|
||||
)
|
||||
```
|
||||
After this one can then call:
|
||||
|
||||
```py
|
||||
bco_trainer.train()
|
||||
```
|
||||
|
||||
## Underlying Distribution matching (UDM)
|
||||
|
||||
In practical scenarios, the thumbs-up and thumbs-down datasets are likely to have divergent underlying distributions of prompts.
|
||||
Consider an LLM deployed for user feedback: if the model excels in writing tasks but underperforms in coding, the thumbs-up dataset will be dominated by writing-related prompts, while the thumbs-down dataset will contain mostly coding-related prompts.
|
||||
If the prompts in your desired and undesired datasets differ a lot, it is useful to enable UDM.
|
||||
|
||||
Choose an embedding model and tokenizer:
|
||||
|
||||
```py
|
||||
embedding_model = AutoModel.from_pretrained(your_model_id)
|
||||
embedding_tokenizer = AutoTokenizer.from_pretrained(your_model_id)
|
||||
|
||||
# customize this function depending on your embedding model
|
||||
def embed_prompt(input_ids, attention_mask, model):
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
return outputs.last_hidden_state.mean(dim=1)
|
||||
|
||||
embedding_model = Accelerator().prepare_model(self.embedding_model)
|
||||
embedding_func = partial(embed_prompt, model=embedding_model)
|
||||
```
|
||||
|
||||
Set `prompt_sample_size` to define how many prompts are selected to train the UDM classifier and start the training with the provided embedding function:
|
||||
|
||||
```py
|
||||
training_args = BCOConfig(
|
||||
beta=0.1,
|
||||
prompt_sample_size=512,
|
||||
)
|
||||
|
||||
bco_trainer = BCOTrainer(
|
||||
model,
|
||||
model_ref,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
processing_class=tokenizer,
|
||||
embedding_func=embedding_func,
|
||||
embedding_tokenizer=self.embedding_tokenizer,
|
||||
)
|
||||
|
||||
bco_trainer.train()
|
||||
```
|
||||
|
||||
### For Mixture of Experts Models: Enabling the auxiliary loss
|
||||
|
||||
MOEs are the most efficient if the load is about equally distributed between experts.
|
||||
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
|
||||
|
||||
This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig).
|
||||
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001).
|
||||
|
||||
## BCOTrainer
|
||||
|
||||
[[autodoc]] BCOTrainer
|
||||
|
||||
## BCOConfig
|
||||
|
||||
[[autodoc]] BCOConfig
|
@ -1,72 +0,0 @@
|
||||
# Best of N sampling: Alternative ways to get better model output without RL based fine-tuning
|
||||
|
||||
Within the extras module is the `best-of-n` sampler class that serves as an alternative method of generating better model output.
|
||||
As to how it fares against the RL based fine-tuning, please look in the `examples` directory for a comparison example
|
||||
|
||||
## Usage
|
||||
|
||||
To get started quickly, instantiate an instance of the class with a model, a length sampler, a tokenizer and a callable that serves as a proxy reward pipeline that outputs reward scores for input queries
|
||||
|
||||
```python
|
||||
|
||||
from transformers import pipeline, AutoTokenizer
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
from trl.core import LengthSampler
|
||||
from trl.extras import BestOfNSampler
|
||||
|
||||
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name)
|
||||
reward_pipe = pipeline("sentiment-analysis", model=reward_model, device=device)
|
||||
tokenizer = AutoTokenizer.from_pretrained(ref_model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
|
||||
# callable that takes a list of raw text and returns a list of corresponding reward scores
|
||||
def queries_to_scores(list_of_strings):
|
||||
return [output["score"] for output in reward_pipe(list_of_strings)]
|
||||
|
||||
best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler)
|
||||
|
||||
|
||||
```
|
||||
|
||||
And assuming you have a list/tensor of tokenized queries, you can generate better output by calling the `generate` method
|
||||
|
||||
```python
|
||||
|
||||
best_of_n.generate(query_tensors, device=device, **gen_kwargs)
|
||||
|
||||
```
|
||||
The default sample size is 4, but you can change it at the time of instance initialization like so
|
||||
|
||||
```python
|
||||
|
||||
best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, sample_size=8)
|
||||
|
||||
```
|
||||
|
||||
The default output is the result of taking the top scored output for each query, but you can change it to top 2 and so on by passing the `n_candidates` argument at the time of instance initialization
|
||||
|
||||
```python
|
||||
|
||||
best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, n_candidates=2)
|
||||
|
||||
```
|
||||
|
||||
There is the option of setting the generation settings (like `temperature`, `pad_token_id`) at the time of instance creation as opposed to when calling the `generate` method.
|
||||
This is done by passing a `GenerationConfig` from the `transformers` library at the time of initialization
|
||||
|
||||
```python
|
||||
|
||||
from transformers import GenerationConfig
|
||||
|
||||
generation_config = GenerationConfig(min_length= -1, top_k=0.0, top_p= 1.0, do_sample= True, pad_token_id=tokenizer.eos_token_id)
|
||||
|
||||
best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, generation_config=generation_config)
|
||||
|
||||
best_of_n.generate(query_tensors, device=device)
|
||||
|
||||
```
|
||||
|
||||
Furthermore, at the time of initialization you can set the seed to control the repeatability of the generation process and the number of samples to generate for each query
|
||||
|
||||
|
@ -1,21 +0,0 @@
|
||||
# Callbacks
|
||||
|
||||
## SyncRefModelCallback
|
||||
|
||||
[[autodoc]] SyncRefModelCallback
|
||||
|
||||
## RichProgressCallback
|
||||
|
||||
[[autodoc]] RichProgressCallback
|
||||
|
||||
## WinRateCallback
|
||||
|
||||
[[autodoc]] WinRateCallback
|
||||
|
||||
## LogCompletionsCallback
|
||||
|
||||
[[autodoc]] LogCompletionsCallback
|
||||
|
||||
## MergeModelCallback
|
||||
|
||||
[[autodoc]] MergeModelCallback
|
@ -1,176 +0,0 @@
|
||||
# Command Line Interfaces (CLIs)
|
||||
|
||||
You can use TRL to fine-tune your Language Model with Supervised Fine-Tuning (SFT) or Direct Policy Optimization (DPO) or even chat with your model using the TRL CLIs.
|
||||
|
||||
Currently supported CLIs are:
|
||||
|
||||
#### Training commands
|
||||
|
||||
- `trl dpo`: fine-tune a LLM with DPO
|
||||
- `trl grpo`: fine-tune a LLM with GRPO
|
||||
- `trl kto`: fine-tune a LLM with KTO
|
||||
- `trl sft`: fine-tune a LLM with SFT
|
||||
|
||||
#### Other commands
|
||||
|
||||
- `trl chat`: quickly spin up an LLM fine-tuned for chatting
|
||||
- `trl env`: get the system information
|
||||
|
||||
## Fine-tuning with the CLI
|
||||
|
||||
Before getting started, pick up a Language Model from Hugging Face Hub. Supported models can be found with the filter "text-generation" within models. Also make sure to pick up a relevant dataset for your task.
|
||||
|
||||
Before using the `sft` or `dpo` commands make sure to run:
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
and pick up the right configuration for your training setup (single / multi-GPU, DeepSpeed, etc.). Make sure to complete all steps of `accelerate config` before running any CLI command.
|
||||
|
||||
We also recommend you passing a YAML config file to configure your training protocol. Below is a simple example of a YAML file that you can use for training your models with `trl sft` command.
|
||||
|
||||
```yaml
|
||||
model_name_or_path:
|
||||
Qwen/Qwen2.5-0.5B
|
||||
dataset_name:
|
||||
stanfordnlp/imdb
|
||||
report_to:
|
||||
none
|
||||
learning_rate:
|
||||
0.0001
|
||||
lr_scheduler_type:
|
||||
cosine
|
||||
```
|
||||
|
||||
Save that config in a `.yaml` and get started immediately! An example CLI config is available as `examples/cli_configs/example_config.yaml`. Note you can overwrite the arguments from the config file by explicitly passing them to the CLI, e.g. from the root folder:
|
||||
|
||||
```bash
|
||||
trl sft --config examples/cli_configs/example_config.yaml --output_dir test-trl-cli --lr_scheduler_type cosine_with_restarts
|
||||
```
|
||||
|
||||
Will force-use `cosine_with_restarts` for `lr_scheduler_type`.
|
||||
|
||||
### Supported Arguments
|
||||
|
||||
We do support all arguments from `transformers.TrainingArguments`, for loading your model, we support all arguments from `~trl.ModelConfig`:
|
||||
|
||||
[[autodoc]] ModelConfig
|
||||
|
||||
You can pass any of these arguments either to the CLI or the YAML file.
|
||||
|
||||
### Supervised Fine-tuning (SFT)
|
||||
|
||||
Follow the basic instructions above and run `trl sft --output_dir <output_dir> <*args>`:
|
||||
|
||||
```bash
|
||||
trl sft --model_name_or_path facebook/opt-125m --dataset_name stanfordnlp/imdb --output_dir opt-sft-imdb
|
||||
```
|
||||
|
||||
The SFT CLI is based on the `trl/scripts/sft.py` script.
|
||||
|
||||
### Direct Policy Optimization (DPO)
|
||||
|
||||
To use the DPO CLI, you need to have a dataset in the TRL format such as
|
||||
|
||||
* TRL's Anthropic HH dataset: https://huggingface.co/datasets/trl-internal-testing/hh-rlhf-helpful-base-trl-style
|
||||
* TRL's OpenAI TL;DR summarization dataset: https://huggingface.co/datasets/trl-internal-testing/tldr-preference-trl-style
|
||||
|
||||
These datasets always have at least three columns `prompt, chosen, rejected`:
|
||||
|
||||
* `prompt` is a list of strings.
|
||||
* `chosen` is the chosen response in [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating)
|
||||
* `rejected` is the rejected response [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating)
|
||||
|
||||
|
||||
To do a quick start, you can run the following command:
|
||||
|
||||
```bash
|
||||
trl dpo --model_name_or_path facebook/opt-125m --output_dir trl-hh-rlhf --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style
|
||||
```
|
||||
|
||||
|
||||
The DPO CLI is based on the `trl/scripts/dpo.py` script.
|
||||
|
||||
|
||||
#### Custom preference dataset
|
||||
|
||||
Format the dataset into TRL format (you can adapt the `examples/datasets/anthropic_hh.py`):
|
||||
|
||||
```bash
|
||||
python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity your-hf-org
|
||||
```
|
||||
|
||||
## Chat interface
|
||||
|
||||
The chat CLI lets you quickly load the model and talk to it. Simply run the following:
|
||||
|
||||
<pre><code>$ trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat
|
||||
<strong><span style="color: red;"><quentin_gallouedec>:</span></strong>
|
||||
What is the best programming language?
|
||||
|
||||
<strong><span style="color: blue;"><Qwen/Qwen1.5-0.5B-Chat>:</span></strong>
|
||||
There isn't a "best" programming language, as everyone has different style preferences, needs, and preferences. However, some people commonly use
|
||||
languages like Python, Java, C++, and JavaScript, which are popular among developers for a variety of reasons, including readability, flexibility,
|
||||
and scalability. Ultimately, it depends on personal preference, needs, and goals.
|
||||
</code></pre>
|
||||
|
||||
Note that the chat interface relies on the tokenizer's [chat template](https://huggingface.co/docs/transformers/chat_templating) to format the inputs for the model. Make sure your tokenizer has a chat template defined.
|
||||
|
||||
Besides talking to the model there are a few commands you can use:
|
||||
|
||||
- `clear`: clears the current conversation and start a new one
|
||||
- `example {NAME}`: load example named `{NAME}` from the config and use it as the user input
|
||||
- `set {SETTING_NAME}={SETTING_VALUE};`: change the system prompt or generation settings (multiple settings are separated by a `;`).
|
||||
- `reset`: same as clear but also resets the generation configs to defaults if they have been changed by `set`
|
||||
- `save` or `save {SAVE_NAME}`: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
|
||||
- `exit`: closes the interface
|
||||
|
||||
## Getting the system information
|
||||
|
||||
You can get the system information by running the following command:
|
||||
|
||||
```bash
|
||||
trl env
|
||||
```
|
||||
|
||||
This will print out the system information including the GPU information, the CUDA version, the PyTorch version, the transformers version, and the TRL version, and any optional dependencies that are installed.
|
||||
|
||||
```txt
|
||||
Copy-paste the following information when reporting an issue:
|
||||
|
||||
- Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
|
||||
- Python version: 3.11.9
|
||||
- PyTorch version: 2.4.1
|
||||
- CUDA device: NVIDIA H100 80GB HBM3
|
||||
- Transformers version: 4.45.0.dev0
|
||||
- Accelerate version: 0.34.2
|
||||
- Accelerate config:
|
||||
- compute_environment: LOCAL_MACHINE
|
||||
- distributed_type: DEEPSPEED
|
||||
- mixed_precision: no
|
||||
- use_cpu: False
|
||||
- debug: False
|
||||
- num_processes: 4
|
||||
- machine_rank: 0
|
||||
- num_machines: 1
|
||||
- rdzv_backend: static
|
||||
- same_network: True
|
||||
- main_training_function: main
|
||||
- enable_cpu_affinity: False
|
||||
- deepspeed_config: {'gradient_accumulation_steps': 4, 'offload_optimizer_device': 'none', 'offload_param_device': 'none', 'zero3_init_flag': False, 'zero_stage': 2}
|
||||
- downcast_bf16: no
|
||||
- tpu_use_cluster: False
|
||||
- tpu_use_sudo: False
|
||||
- tpu_env: []
|
||||
- Datasets version: 3.0.0
|
||||
- HF Hub version: 0.24.7
|
||||
- TRL version: 0.12.0.dev0+acb4d70
|
||||
- bitsandbytes version: 0.41.1
|
||||
- DeepSpeed version: 0.15.1
|
||||
- Diffusers version: 0.30.3
|
||||
- Liger-Kernel version: 0.3.0
|
||||
- LLM-Blender version: 0.0.2
|
||||
- OpenAI version: 1.46.0
|
||||
- PEFT version: 0.12.0
|
||||
```
|
||||
|
||||
This information are required when reporting an issue.
|
@ -1,31 +0,0 @@
|
||||
# Community Tutorials
|
||||
|
||||
Community tutorials are made by active members of the Hugging Face community who want to share their knowledge and expertise with others. They are a great way to learn about the library and its features, and to get started with core classes and modalities.
|
||||
|
||||
# Language Models
|
||||
|
||||
| Task | Class | Description | Author | Tutorial | Colab |
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
| Reinforcement Learning | [`GRPOTrainer`] | Post training an LLM for reasoning with GRPO in TRL | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_llm_grpo_trl) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_grpo_trl.ipynb) |
|
||||
| Reinforcement Learning | [`GRPOTrainer`] | Mini-R1: Reproduce Deepseek R1 „aha moment“ a RL tutorial | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/mini-deepseek-r1) | [](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/mini-deepseek-r1-aha-grpo.ipynb) |
|
||||
| Instruction tuning | [`SFTTrainer`] | Fine-tuning Google Gemma LLMs using ChatML format with QLoRA | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-google-gemma) | [](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/gemma-lora-example.ipynb) |
|
||||
| Structured Generation | [`SFTTrainer`] | Fine-tuning Llama-2-7B to generate Persian product catalogs in JSON using QLoRA and PEFT | [Mohammadreza Esmaeilian](https://huggingface.co/Mohammadreza) | [Link](https://huggingface.co/learn/cookbook/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format.ipynb) |
|
||||
| Preference Optimization | [`DPOTrainer`] | Align Mistral-7b using Direct Preference Optimization for human preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/Fine_tune_Mistral_7b_with_DPO.html) | [](https://colab.research.google.com/github/mlabonne/llm-course/blob/main/Fine_tune_a_Mistral_7b_model_with_DPO.ipynb) |
|
||||
| Preference Optimization | [`ORPOTrainer`] | Fine-tuning Llama 3 with ORPO combining instruction tuning and preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/2024-04-19_Fine_tune_Llama_3_with_ORPO.html) | [](https://colab.research.google.com/drive/1eHNWg9gnaXErdAa8_mcvjMupbSS6rDvi) |
|
||||
| Instruction tuning | [`SFTTrainer`] | How to fine-tune open LLMs in 2025 with Hugging Face | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-llms-in-2025) | [](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/fine-tune-llms-in-2025.ipynb) |
|
||||
|
||||
<Youtube id="cnGyyM0vOes" />
|
||||
|
||||
# Vision Language Models
|
||||
|
||||
| Task | Class | Description | Author | Tutorial | Colab |
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
| Visual QA | [`SFTTrainer`] | Fine-tuning Qwen2-VL-7B for visual question answering on ChartQA dataset | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_trl) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_trl.ipynb) |
|
||||
| Visual QA | [`SFTTrainer`] | Fine-tuning SmolVLM with TRL on a consumer GPU | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_smol_vlm_sft_trl) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_smol_vlm_sft_trl.ipynb) |
|
||||
| SEO Description | [`SFTTrainer`] | Fine-tuning Qwen2-VL-7B for generating SEO-friendly descriptions from images | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-multimodal-llms-with-trl) | [](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/fine-tune-multimodal-llms-with-trl.ipynb) |
|
||||
| Visual QA | [`DPOTrainer`] | PaliGemma 🤝 Direct Preference Optimization | [Merve Noyan](https://huggingface.co/merve) | [Link](https://github.com/merveenoyan/smol-vision/blob/main/PaliGemma_DPO.ipynb) | [](https://colab.research.google.com/github/merveenoyan/smol-vision/blob/main/PaliGemma_DPO.ipynb) |
|
||||
| Visual QA | [`DPOTrainer`] | Fine-tuning SmolVLM using direct preference optimization (DPO) with TRL on a consumer GPU | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_dpo_smolvlm_instruct) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_dpo_smolvlm_instruct.ipynb) |
|
||||
|
||||
## Contributing
|
||||
|
||||
If you have a tutorial that you would like to add to this list, please open a PR to add it. We will review it and merge it if it is relevant to the community.
|
@ -1,108 +0,0 @@
|
||||
# CPO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=cpo,trl)
|
||||
|
||||
## Overview
|
||||
|
||||
Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://huggingface.co/papers/2401.08417) by [Haoran Xu](https://huggingface.co/haoranxu), [Amr Sharaf](https://huggingface.co/amrsharaf), [Yunmo Chen](https://huggingface.co/yunmochen), Weiting Tan, Lingfeng Shen, Benjamin Van Durme, [Kenton Murray](https://huggingface.co/Kenton), and [Young Jin Kim](https://huggingface.co/ykim362). At a high-level, CPO trains models to avoid generating adequate, but not perfect translations in Machine Translation (MT) tasks. However, CPO is a general approximation of the DPO loss and can be applied to other domains, such as chat.
|
||||
|
||||
CPO aims to mitigate two fundamental shortcomings of SFT. First, SFT’s methodology of minimizing the discrepancy between predicted outputs and gold-standard references inherently caps model performance at the quality level of the training data. Secondly, SFT lacks a mechanism to prevent the model from rejecting mistakes in translations. The CPO objective is derived from the DPO objective.
|
||||
|
||||
## Quick start
|
||||
|
||||
This example demonstrates how to train a model using the CPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the data in the dataset here:
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized/embed/viewer/default/train?row=0"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
Below is the script to train the model:
|
||||
|
||||
```python
|
||||
# train_cpo.py
|
||||
from datasets import load_dataset
|
||||
from trl import CPOConfig, CPOTrainer
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
|
||||
training_args = CPOConfig(output_dir="Qwen2-0.5B-CPO", logging_steps=10)
|
||||
trainer = CPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Execute the script using the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch train_cpo.py
|
||||
```
|
||||
|
||||
## Expected dataset type
|
||||
|
||||
CPO requires a [preference dataset](dataset_formats#preference). The [`CPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
|
||||
|
||||
## Example script
|
||||
|
||||
We provide an example script to train a model using the CPO method. The script is available in [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py)
|
||||
|
||||
To test the CPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch examples/scripts/cpo.py \
|
||||
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
|
||||
--dataset_name trl-lib/ultrafeedback_binarized \
|
||||
--num_train_epochs 1 \
|
||||
--logging_steps 25 \
|
||||
--output_dir Qwen2-0.5B-CPO
|
||||
```
|
||||
|
||||
## Logged metrics
|
||||
|
||||
While training and evaluating we record the following reward metrics:
|
||||
|
||||
* `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
|
||||
* `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
|
||||
* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
|
||||
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
|
||||
* `nll_loss`: the mean negative log likelihood loss of the policy model for the chosen responses
|
||||
|
||||
## CPO variants
|
||||
|
||||
### Simple Preference Optimization (SimPO)
|
||||
|
||||
The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the [`CPOTrainer`]. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, we can use SimPO easily by turning on `loss_type="simpo"` and `cpo_alpha=0.0` in the [`CPOConfig`].
|
||||
|
||||
### CPO-SimPO
|
||||
|
||||
We also offer the combined use of CPO and SimPO, which enables more stable training and improved performance. Learn more details at [CPO-SimPO GitHub](https://github.com/fe1ixxu/CPO_SIMPO). To use this method, simply enable SimPO by setting `loss_type="simpo"` and a non-zero `cpo_alpha` in the [`CPOConfig`].
|
||||
|
||||
## Loss functions
|
||||
|
||||
The CPO algorithm supports several loss functions. The loss function can be set using the `loss_type` parameter in the [`CPOConfig`]. The following loss functions are supported:
|
||||
|
||||
| `loss_type=` | Description |
|
||||
| -------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `"sigmoid"` (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. |
|
||||
| `"hinge"` | The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. In this case, the `beta` is the reciprocal of the margin. |
|
||||
| `"ipo"` | The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike DPO which is summed only). |
|
||||
|
||||
### For Mixture of Experts Models: Enabling the auxiliary loss
|
||||
|
||||
MOEs are the most efficient if the load is about equally distributed between experts.
|
||||
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
|
||||
|
||||
This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]).
|
||||
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.
|
||||
|
||||
## CPOTrainer
|
||||
|
||||
[[autodoc]] CPOTrainer
|
||||
|
||||
## CPOConfig
|
||||
|
||||
[[autodoc]] CPOConfig
|
@ -1,163 +0,0 @@
|
||||
# Training customization
|
||||
|
||||
TRL is designed with modularity in mind so that users to be able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques. Note: Although these examples use the DPOTrainer, the customization applies to most (if not all) trainers.
|
||||
|
||||
## Train on multiple GPUs / nodes
|
||||
|
||||
The trainers in TRL use 🤗 Accelerate to enable distributed training across multiple GPUs or nodes. To do so, first create an 🤗 Accelerate config file by running
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
and answering the questions according to your multi-gpu / multi-node setup. You can then launch distributed training by running:
|
||||
|
||||
```bash
|
||||
accelerate launch your_script.py
|
||||
```
|
||||
|
||||
We also provide config files in the [examples folder](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) that can be used as templates. To use these templates, simply pass the path to the config file when launching a job, e.g.:
|
||||
|
||||
```shell
|
||||
accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script
|
||||
```
|
||||
|
||||
Refer to the [examples page](https://github.com/huggingface/trl/tree/main/examples) for more details.
|
||||
|
||||
### Distributed training with DeepSpeed
|
||||
|
||||
All of the trainers in TRL can be run on multiple GPUs together with DeepSpeed ZeRO-{1,2,3} for efficient sharding of the optimizer states, gradients, and model weights. To do so, run:
|
||||
|
||||
```shell
|
||||
accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero{1,2,3}.yaml --num_processes {NUM_GPUS} path_to_your_script.py --all_arguments_of_the_script
|
||||
```
|
||||
|
||||
Note that for ZeRO-3, a small tweak is needed to initialize your reward model on the correct device via the `zero3_init_context_manager()` context manager. In particular, this is needed to avoid DeepSpeed hanging after a fixed number of training steps. Here is a snippet of what is involved from the [`sentiment_tuning`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) example:
|
||||
|
||||
```python
|
||||
ds_plugin = ppo_trainer.accelerator.state.deepspeed_plugin
|
||||
if ds_plugin is not None and ds_plugin.is_zero3_init_enabled():
|
||||
with ds_plugin.zero3_init_context_manager(enable=False):
|
||||
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)
|
||||
else:
|
||||
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)
|
||||
```
|
||||
|
||||
Consult the 🤗 Accelerate [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more information about the DeepSpeed plugin.
|
||||
|
||||
|
||||
## Use different optimizers and schedulers
|
||||
|
||||
By default, the `DPOTrainer` creates a `torch.optim.AdamW` optimizer. You can create and define a different optimizer and pass it to `DPOTrainer` as follows:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from torch import optim
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
|
||||
|
||||
optimizer = optim.SGD(model.parameters(), lr=training_args.learning_rate)
|
||||
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
optimizers=(optimizer, None),
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### Add a learning rate scheduler
|
||||
|
||||
You can also play with your training by adding learning rate schedulers.
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from torch import optim
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
|
||||
|
||||
optimizer = optim.AdamW(model.parameters(), lr=training_args.learning_rate)
|
||||
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
|
||||
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
optimizers=(optimizer, lr_scheduler),
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Memory efficient fine-tuning by sharing layers
|
||||
|
||||
Another tool you can use for more memory efficient fine-tuning is to share layers between the reference model and the model you want to train.
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import create_reference_model, DPOConfig, DPOTrainer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
ref_model = create_reference_model(model, num_shared_layers=6)
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train[:1%]")
|
||||
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
|
||||
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
ref_model=ref_model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Pass 8-bit reference models
|
||||
|
||||
Since `trl` supports all keyword arguments when loading a model from `transformers` using `from_pretrained`, you can also leverage `load_in_8bit` from `transformers` for more memory efficient fine-tuning.
|
||||
|
||||
Read more about 8-bit model loading in `transformers` [here](https://huggingface.co/docs/transformers/en/peft#load-in-8bit-or-4bit).
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
ref_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", quantization_config= quantization_config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
|
||||
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
ref_model=ref_model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Use the CUDA cache optimizer
|
||||
|
||||
When training large models, you should better handle the CUDA cache by iteratively clearing it. To do so, simply pass `optimize_cuda_cache=True` to `DPOConfig`:
|
||||
|
||||
```python
|
||||
training_args = DPOConfig(..., optimize_cuda_cache=True)
|
||||
```
|
@ -1,37 +0,0 @@
|
||||
# Data Utilities
|
||||
|
||||
## is_conversational
|
||||
|
||||
[[autodoc]] is_conversational
|
||||
|
||||
## apply_chat_template
|
||||
|
||||
[[autodoc]] apply_chat_template
|
||||
|
||||
## maybe_apply_chat_template
|
||||
|
||||
[[autodoc]] maybe_apply_chat_template
|
||||
|
||||
## maybe_convert_to_chatml
|
||||
|
||||
[[autodoc]] maybe_convert_to_chatml
|
||||
|
||||
## extract_prompt
|
||||
|
||||
[[autodoc]] extract_prompt
|
||||
|
||||
## maybe_extract_prompt
|
||||
|
||||
[[autodoc]] maybe_extract_prompt
|
||||
|
||||
## unpair_preference_dataset
|
||||
|
||||
[[autodoc]] unpair_preference_dataset
|
||||
|
||||
## maybe_unpair_preference_dataset
|
||||
|
||||
[[autodoc]] maybe_unpair_preference_dataset
|
||||
|
||||
## pack_examples
|
||||
|
||||
[[autodoc]] pack_examples
|
@ -1,938 +0,0 @@
|
||||
# Dataset formats and types
|
||||
|
||||
This guide provides an overview of the dataset formats and types supported by each trainer in TRL.
|
||||
|
||||
## Overview of the dataset formats and types
|
||||
|
||||
- The *format* of a dataset refers to how the data is structured, typically categorized as either *standard* or *conversational*.
|
||||
- The *type* is associated with the specific task the dataset is designed for, such as *prompt-only* or *preference*. Each type is characterized by its columns, which vary according to the task, as shown in the table.
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th>Type \ Format</th>
|
||||
<th>Standard</th>
|
||||
<th>Conversational</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Language modeling</td>
|
||||
<td>
|
||||
<pre><code>{"text": "The sky is blue."}</code></pre>
|
||||
</td>
|
||||
<td>
|
||||
<pre><code>{"messages": [{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is blue."}]}</code></pre>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Prompt-only</td>
|
||||
<td>
|
||||
<pre><code>{"prompt": "The sky is"}</code></pre>
|
||||
</td>
|
||||
<td>
|
||||
<pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}]}</code></pre>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Prompt-completion</td>
|
||||
<td>
|
||||
<pre><code>{"prompt": "The sky is",
|
||||
"completion": " blue."}</code></pre>
|
||||
</td>
|
||||
<td>
|
||||
<pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
"completion": [{"role": "assistant", "content": "It is blue."}]}</code></pre>
|
||||
</td>
|
||||
</tr>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Preference</td>
|
||||
<td>
|
||||
<pre><code>{"prompt": "The sky is",
|
||||
"chosen": " blue.",
|
||||
"rejected": " green."}</code></pre>
|
||||
or, with implicit prompt:
|
||||
<pre><code>{"chosen": "The sky is blue.",
|
||||
"rejected": "The sky is green."}</code></pre>
|
||||
</td>
|
||||
<td>
|
||||
<pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
"chosen": [{"role": "assistant", "content": "It is blue."}],
|
||||
"rejected": [{"role": "assistant", "content": "It is green."}]}</code></pre>
|
||||
or, with implicit prompt:
|
||||
<pre><code>{"chosen": [{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is blue."}],
|
||||
"rejected": [{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is green."}]}</code></pre>
|
||||
</td>
|
||||
</tr>
|
||||
<td>Unpaired preference</td>
|
||||
<td>
|
||||
<pre><code>{"prompt": "The sky is",
|
||||
"completion": " blue.",
|
||||
"label": True}</code></pre>
|
||||
</td>
|
||||
<td>
|
||||
<pre><code>{"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
"completion": [{"role": "assistant", "content": "It is green."}],
|
||||
"label": False}</code></pre>
|
||||
</td>
|
||||
</tr>
|
||||
</tr>
|
||||
<td>Stepwise supervision</td>
|
||||
<td>
|
||||
<pre><code>{"prompt": "Which number is larger, 9.8 or 9.11?",
|
||||
"completions": ["The fractional part of 9.8 is 0.8.",
|
||||
"The fractional part of 9.11 is 0.11.",
|
||||
"0.11 is greater than 0.8.",
|
||||
"Hence, 9.11 > 9.8."],
|
||||
"labels": [True, True, False, False]}</code></pre>
|
||||
</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
### Formats
|
||||
|
||||
#### Standard
|
||||
|
||||
The standard dataset format typically consists of plain text strings. The columns in the dataset vary depending on the task. This is the format expected by TRL trainers. Below are examples of standard dataset formats for different tasks:
|
||||
|
||||
```python
|
||||
# Language modeling
|
||||
language_modeling_example = {"text": "The sky is blue."}
|
||||
# Preference
|
||||
preference_example = {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."}
|
||||
# Unpaired preference
|
||||
unpaired_preference_example = {"prompt": "The sky is", "completion": " blue.", "label": True}
|
||||
```
|
||||
|
||||
#### Conversational
|
||||
|
||||
Conversational datasets are used for tasks involving dialogues or chat interactions between users and assistants. Unlike standard dataset formats, these contain sequences of messages where each message has a `role` (e.g., `"user"` or `"assistant"`) and `content` (the message text).
|
||||
|
||||
```python
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello, how are you?"},
|
||||
{"role": "assistant", "content": "I'm doing great. How can I help you today?"},
|
||||
{"role": "user", "content": "I'd like to show off how chat templating works!"},
|
||||
]
|
||||
```
|
||||
|
||||
Just like standard datasets, the columns in conversational datasets vary depending on the task. Below are examples of conversational dataset formats for different tasks:
|
||||
|
||||
```python
|
||||
# Prompt-completion
|
||||
prompt_completion_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
"completion": [{"role": "assistant", "content": "It is blue."}]}
|
||||
# Preference
|
||||
preference_example = {
|
||||
"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
"chosen": [{"role": "assistant", "content": "It is blue."}],
|
||||
"rejected": [{"role": "assistant", "content": "It is green."}],
|
||||
}
|
||||
```
|
||||
|
||||
Conversational datasets are useful for training chat models, but must be converted into a standard format before being used with TRL trainers. This is typically done using chat templates specific to the model being used. For more information, refer to the [Working with conversational datasets in TRL](#working-with-conversational-datasets-in-trl) section.
|
||||
|
||||
### Types
|
||||
|
||||
#### Language modeling
|
||||
|
||||
A language modeling dataset consists of a column `"text"` (or `"messages"` for conversational datasets) containing a full sequence of text.
|
||||
|
||||
```python
|
||||
# Standard format
|
||||
language_modeling_example = {"text": "The sky is blue."}
|
||||
# Conversational format
|
||||
language_modeling_example = {"messages": [
|
||||
{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is blue."}
|
||||
]}
|
||||
```
|
||||
|
||||
#### Prompt-only
|
||||
|
||||
In a prompt-only dataset, only the initial prompt (the question or partial sentence) is provided under the key `"prompt"`. The training typically involves generating the completion based on this prompt, where the model learns to continue or complete the given input.
|
||||
|
||||
```python
|
||||
# Standard format
|
||||
prompt_only_example = {"prompt": "The sky is"}
|
||||
# Conversational format
|
||||
prompt_only_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}]}
|
||||
```
|
||||
|
||||
For examples of prompt-only datasets, refer to the [Prompt-only datasets collection](https://huggingface.co/collections/trl-lib/prompt-only-datasets-677ea25245d20252cea00368).
|
||||
|
||||
<Tip>
|
||||
|
||||
While both the prompt-only and language modeling types are similar, they differ in how the input is handled. In the prompt-only type, the prompt represents a partial input that expects the model to complete or continue, while in the language modeling type, the input is treated as a complete sentence or sequence. These two types are processed differently by TRL. Below is an example showing the difference in the output of the `apply_chat_template` function for each type:
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
from trl import apply_chat_template
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
|
||||
|
||||
# Example for prompt-only type
|
||||
prompt_only_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}]}
|
||||
apply_chat_template(prompt_only_example, tokenizer)
|
||||
# Output: {'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n'}
|
||||
|
||||
# Example for language modeling type
|
||||
lm_example = {"messages": [{"role": "user", "content": "What color is the sky?"}]}
|
||||
apply_chat_template(lm_example, tokenizer)
|
||||
# Output: {'text': '<|user|>\nWhat color is the sky?<|end|>\n<|endoftext|>'}
|
||||
```
|
||||
|
||||
- The prompt-only output includes a `'<|assistant|>\n'`, indicating the beginning of the assistant’s turn and expecting the model to generate a completion.
|
||||
- In contrast, the language modeling output treats the input as a complete sequence and terminates it with `'<|endoftext|>'`, signaling the end of the text and not expecting any additional content.
|
||||
|
||||
</Tip>
|
||||
|
||||
#### Prompt-completion
|
||||
|
||||
A prompt-completion dataset includes a `"prompt"` and a `"completion"`.
|
||||
|
||||
```python
|
||||
# Standard format
|
||||
prompt_completion_example = {"prompt": "The sky is", "completion": " blue."}
|
||||
# Conversational format
|
||||
prompt_completion_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
"completion": [{"role": "assistant", "content": "It is blue."}]}
|
||||
```
|
||||
|
||||
For examples of prompt-completion datasets, refer to the [Prompt-completion datasets collection](https://huggingface.co/collections/trl-lib/prompt-completion-datasets-677ea2bb20bbb6bdccada216).
|
||||
|
||||
#### Preference
|
||||
|
||||
A preference dataset is used for tasks where the model is trained to choose between two or more possible completions to the same prompt. This dataset includes a `"prompt"`, a `"chosen"` completion, and a `"rejected"` completion. The model is trained to select the `"chosen"` response over the `"rejected"` response.
|
||||
Some dataset may not include the `"prompt"` column, in which case the prompt is implicit and directly included in the `"chosen"` and `"rejected"` completions. We recommend using explicit prompts whenever possible.
|
||||
|
||||
```python
|
||||
# Standard format
|
||||
## Explicit prompt (recommended)
|
||||
preference_example = {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."}
|
||||
# Implicit prompt
|
||||
preference_example = {"chosen": "The sky is blue.", "rejected": "The sky is green."}
|
||||
|
||||
# Conversational format
|
||||
## Explicit prompt (recommended)
|
||||
preference_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
"chosen": [{"role": "assistant", "content": "It is blue."}],
|
||||
"rejected": [{"role": "assistant", "content": "It is green."}]}
|
||||
## Implicit prompt
|
||||
preference_example = {"chosen": [{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is blue."}],
|
||||
"rejected": [{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is green."}]}
|
||||
```
|
||||
|
||||
For examples of preference datasets, refer to the [Preference datasets collection](https://huggingface.co/collections/trl-lib/preference-datasets-677e99b581018fcad9abd82c).
|
||||
|
||||
Some preference datasets can be found with [the tag `dpo` on Hugging Face Hub](https://huggingface.co/datasets?other=dpo). You can also explore the [librarian-bots' DPO Collections](https://huggingface.co/collections/librarian-bots/direct-preference-optimization-datasets-66964b12835f46289b6ef2fc) to identify preference datasets.
|
||||
|
||||
#### Unpaired preference
|
||||
|
||||
An unpaired preference dataset is similar to a preference dataset but instead of having `"chosen"` and `"rejected"` completions for the same prompt, it includes a single `"completion"` and a `"label"` indicating whether the completion is preferred or not.
|
||||
|
||||
```python
|
||||
# Standard format
|
||||
unpaired_preference_example = {"prompt": "The sky is", "completion": " blue.", "label": True}
|
||||
# Conversational format
|
||||
unpaired_preference_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
"completion": [{"role": "assistant", "content": "It is blue."}],
|
||||
"label": True}
|
||||
```
|
||||
|
||||
For examples of unpaired preference datasets, refer to the [Unpaired preference datasets collection](https://huggingface.co/collections/trl-lib/unpaired-preference-datasets-677ea22bf5f528c125b0bcdf).
|
||||
|
||||
#### Stepwise supervision
|
||||
|
||||
A stepwise (or process) supervision dataset is similar to an [unpaired preference](#unpaired-preference) dataset but includes multiple steps of completions, each with its own label. This structure is useful for tasks that need detailed, step-by-step labeling, such as reasoning tasks. By evaluating each step separately and providing targeted labels, this approach helps identify precisely where the reasoning is correct and where errors occur, allowing for targeted feedback on each part of the reasoning process.
|
||||
|
||||
```python
|
||||
stepwise_example = {
|
||||
"prompt": "Which number is larger, 9.8 or 9.11?",
|
||||
"completions": ["The fractional part of 9.8 is 0.8, while the fractional part of 9.11 is 0.11.", "Since 0.11 is greater than 0.8, the number 9.11 is larger than 9.8."],
|
||||
"labels": [True, False]
|
||||
}
|
||||
```
|
||||
|
||||
For examples of stepwise supervision datasets, refer to the [Stepwise supervision datasets collection](https://huggingface.co/collections/trl-lib/stepwise-supervision-datasets-677ea27fd4c5941beed7a96e).
|
||||
|
||||
## Which dataset type to use?
|
||||
|
||||
Choosing the right dataset type depends on the task you are working on and the specific requirements of the TRL trainer you are using. Below is a brief overview of the dataset types supported by each TRL trainer.
|
||||
|
||||
| Trainer | Expected dataset type |
|
||||
| ----------------------- | ------------------------------------------------------------------------------------------------------ |
|
||||
| [`BCOTrainer`] | [Unpaired preference](#unpaired-preference) |
|
||||
| [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`GKDTrainer`] | [Prompt-completion](#prompt-completion) |
|
||||
| [`GRPOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`IterativeSFTTrainer`] | [Unpaired preference](#unpaired-preference) |
|
||||
| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`NashMDTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`PPOTrainer`] | Tokenized language modeling |
|
||||
| [`PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) |
|
||||
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
|
||||
| [`SFTTrainer`] | [Language modeling](#language-modeling) |
|
||||
| [`XPOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
|
||||
<Tip>
|
||||
|
||||
TRL trainers only support standard dataset formats, [for now](https://github.com/huggingface/trl/issues/2071). If you have a conversational dataset, you must first convert it into a standard format.
|
||||
For more information on how to work with conversational datasets, refer to the [Working with conversational datasets in TRL](#working-with-conversational-datasets-in-trl) section.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Working with conversational datasets in TRL
|
||||
|
||||
Conversational datasets are increasingly common, especially for training chat models. However, some TRL trainers don't support conversational datasets in their raw format. (For more information, see [issue #2071](https://github.com/huggingface/trl/issues/2071).) These datasets must first be converted into a standard format.
|
||||
Fortunately, TRL offers tools to easily handle this conversion, which are detailed below.
|
||||
|
||||
### Converting a conversational dataset into a standard dataset
|
||||
|
||||
To convert a conversational dataset into a standard dataset, you need to _apply a chat template_ to the dataset. A chat template is a predefined structure that typically includes placeholders for user and assistant messages. This template is provided by the tokenizer of the model you use.
|
||||
|
||||
For detailed instructions on using chat templating, refer to the [Chat templating section in the `transformers` documentation](https://huggingface.co/docs/transformers/en/chat_templating).
|
||||
|
||||
In TRL, the method you apply to convert the dataset will vary depending on the task. Fortunately, TRL provides a helper function called [`apply_chat_template`] to simplify this process. Here's an example of how to use it:
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
from trl import apply_chat_template
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
|
||||
|
||||
example = {
|
||||
"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
"completion": [{"role": "assistant", "content": "It is blue."}]
|
||||
}
|
||||
|
||||
apply_chat_template(example, tokenizer)
|
||||
# Output:
|
||||
# {'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n', 'completion': 'It is blue.<|end|>\n<|endoftext|>'}
|
||||
```
|
||||
|
||||
Alternatively, you can use the [`~datasets.Dataset.map`] method to apply the template across an entire dataset:
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
from trl import apply_chat_template
|
||||
|
||||
dataset_dict = {
|
||||
"prompt": [[{"role": "user", "content": "What color is the sky?"}],
|
||||
[{"role": "user", "content": "Where is the sun?"}]],
|
||||
"completion": [[{"role": "assistant", "content": "It is blue."}],
|
||||
[{"role": "assistant", "content": "In the sky."}]]
|
||||
}
|
||||
|
||||
dataset = Dataset.from_dict(dataset_dict)
|
||||
dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
|
||||
# Output:
|
||||
# {'prompt': ['<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n',
|
||||
# '<|user|>\nWhere is the sun?<|end|>\n<|assistant|>\n'],
|
||||
# 'completion': ['It is blue.<|end|>\n<|endoftext|>', 'In the sky.<|end|>\n<|endoftext|>']}
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
We recommend using the [`apply_chat_template`] function instead of calling `tokenizer.apply_chat_template` directly. Handling chat templates for non-language modeling datasets can be tricky and may result in errors, such as mistakenly placing a system prompt in the middle of a conversation.
|
||||
For additional examples, see [#1930 (comment)](https://github.com/huggingface/trl/pull/1930#issuecomment-2292908614). The [`apply_chat_template`] is designed to handle these intricacies and ensure the correct application of chat templates for various tasks.
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
It's important to note that chat templates are model-specific. For example, if you use the chat template from [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) with the above example, you get a different output:
|
||||
|
||||
```python
|
||||
apply_chat_template(example, AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct"))
|
||||
# Output:
|
||||
# {'prompt': '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n',
|
||||
# 'completion': 'It is blue.<|im_end|>\n'}
|
||||
```
|
||||
|
||||
Always use the chat template associated with the model you're working with. Using the wrong template can lead to inaccurate or unexpected results.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Using any dataset with TRL: preprocessing and conversion
|
||||
|
||||
Many datasets come in formats tailored to specific tasks, which might not be directly compatible with TRL. To use such datasets with TRL, you may need to preprocess and convert them into the required format.
|
||||
|
||||
To make this easier, we provide a set of [example scripts](https://github.com/huggingface/trl/tree/main/examples/datasets) that cover common dataset conversions.
|
||||
|
||||
### Example: UltraFeedback dataset
|
||||
|
||||
Let’s take the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback) as an example. Here's a preview of the dataset:
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/openbmb/UltraFeedback/embed/viewer/default/train"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
As shown above, the dataset format does not match the expected structure. It’s not in a conversational format, the column names differ, and the results pertain to different models (e.g., Bard, GPT-4) and aspects (e.g., "helpfulness", "honesty").
|
||||
|
||||
By using the provided conversion script [`examples/datasets/ultrafeedback.py`](https://github.com/huggingface/trl/tree/main/examples/datasets/ultrafeedback.py), you can transform this dataset into an unpaired preference type, and push it to the Hub:
|
||||
|
||||
```sh
|
||||
python examples/datasets/ultrafeedback.py --push_to_hub --repo_id trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness
|
||||
```
|
||||
|
||||
Once converted, the dataset will look like this:
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness/embed/viewer/default/train?row=0"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
Now, you can use this dataset with TRL!
|
||||
|
||||
By adapting the provided scripts or creating your own, you can convert any dataset into a format compatible with TRL.
|
||||
|
||||
## Utilities for converting dataset types
|
||||
|
||||
This section provides example code to help you convert between different dataset types. While some conversions can be performed after applying the chat template (i.e., in the standard format), we recommend performing the conversion before applying the chat template to ensure it works consistently.
|
||||
|
||||
For simplicity, some of the examples below do not follow this recommendation and use the standard format. However, the conversions can be applied directly to the conversational format without modification.
|
||||
|
||||
| From \ To | Language modeling | Prompt-completion | Prompt-only | Preference with implicit prompt | Preference | Unpaired preference | Stepwise supervision |
|
||||
| ------------------------------- | ----------------------------------------------------------------------- | ----------------------------------------------------------------------- | ----------------------------------------------------------------- | --------------------------------------------------------- | --------------------------------------------------------- | ------------------------------------------------------------------------- | -------------------- |
|
||||
| Language modeling | N/A | N/A | N/A | N/A | N/A | N/A | N/A |
|
||||
| Prompt-completion | [🔗](#from-prompt-completion-to-language-modeling-dataset) | N/A | [🔗](#from-prompt-completion-to-prompt-only-dataset) | N/A | N/A | N/A | N/A |
|
||||
| Prompt-only | N/A | N/A | N/A | N/A | N/A | N/A | N/A |
|
||||
| Preference with implicit prompt | [🔗](#from-preference-with-implicit-prompt-to-language-modeling-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-completion-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-only-dataset) | N/A | [🔗](#from-implicit-to-explicit-prompt-preference-dataset) | [🔗](#from-preference-with-implicit-prompt-to-unpaired-preference-dataset) | N/A |
|
||||
| Preference | [🔗](#from-preference-to-language-modeling-dataset) | [🔗](#from-preference-to-prompt-completion-dataset) | [🔗](#from-preference-to-prompt-only-dataset) | [🔗](#from-explicit-to-implicit-prompt-preference-dataset) | N/A | [🔗](#from-preference-to-unpaired-preference-dataset) | N/A |
|
||||
| Unpaired preference | [🔗](#from-unpaired-preference-to-language-modeling-dataset) | [🔗](#from-unpaired-preference-to-prompt-completion-dataset) | [🔗](#from-unpaired-preference-to-prompt-only-dataset) | N/A | N/A | N/A | N/A |
|
||||
| Stepwise supervision | [🔗](#from-stepwise-supervision-to-language-modeling-dataset) | [🔗](#from-stepwise-supervision-to-prompt-completion-dataset) | [🔗](#from-stepwise-supervision-to-prompt-only-dataset) | N/A | N/A | [🔗](#from-stepwise-supervision-to-unpaired-preference-dataset) | N/A |
|
||||
|
||||
### From prompt-completion to language modeling dataset
|
||||
|
||||
To convert a prompt-completion dataset into a language modeling dataset, concatenate the prompt and the completion.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["The sky is", "The sun is"],
|
||||
"completion": [" blue.", " in the sky."],
|
||||
})
|
||||
|
||||
def concat_prompt_completion(example):
|
||||
return {"text": example["prompt"] + example["completion"]}
|
||||
|
||||
dataset = dataset.map(concat_prompt_completion, remove_columns=["prompt", "completion"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'text': 'The sky is blue.'}
|
||||
```
|
||||
|
||||
### From prompt-completion to prompt-only dataset
|
||||
|
||||
To convert a prompt-completion dataset into a prompt-only dataset, remove the completion.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["The sky is", "The sun is"],
|
||||
"completion": [" blue.", " in the sky."],
|
||||
})
|
||||
|
||||
dataset = dataset.remove_columns("completion")
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': 'The sky is'}
|
||||
```
|
||||
|
||||
### From preference with implicit prompt to language modeling dataset
|
||||
|
||||
To convert a preference with implicit prompt dataset into a language modeling dataset, remove the rejected, and rename the column `"chosen"` to `"text"`.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"chosen": ["The sky is blue.", "The sun is in the sky."],
|
||||
"rejected": ["The sky is green.", "The sun is in the sea."],
|
||||
})
|
||||
|
||||
dataset = dataset.rename_column("chosen", "text").remove_columns("rejected")
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'text': 'The sky is blue.'}
|
||||
```
|
||||
|
||||
### From preference with implicit prompt to prompt-completion dataset
|
||||
|
||||
To convert a preference dataset with implicit prompt into a prompt-completion dataset, extract the prompt with [`extract_prompt`], remove the rejected, and rename the column `"chosen"` to `"completion"`.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
from trl import extract_prompt
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"chosen": [
|
||||
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
|
||||
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}],
|
||||
],
|
||||
"rejected": [
|
||||
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}],
|
||||
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}],
|
||||
],
|
||||
})
|
||||
dataset = dataset.map(extract_prompt).remove_columns("rejected").rename_column("chosen", "completion")
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}], 'completion': [{'role': 'assistant', 'content': 'It is blue.'}]}
|
||||
```
|
||||
|
||||
### From preference with implicit prompt to prompt-only dataset
|
||||
|
||||
To convert a preference dataset with implicit prompt into a prompt-only dataset, extract the prompt with [`extract_prompt`], and remove the rejected and the chosen.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
from trl import extract_prompt
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"chosen": [
|
||||
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
|
||||
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}],
|
||||
],
|
||||
"rejected": [
|
||||
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}],
|
||||
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}],
|
||||
],
|
||||
})
|
||||
dataset = dataset.map(extract_prompt).remove_columns(["chosen", "rejected"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}]}
|
||||
```
|
||||
|
||||
### From implicit to explicit prompt preference dataset
|
||||
|
||||
To convert a preference dataset with implicit prompt into a preference dataset with explicit prompt, extract the prompt with [`extract_prompt`].
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
from trl import extract_prompt
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"chosen": [
|
||||
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
|
||||
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}],
|
||||
],
|
||||
"rejected": [
|
||||
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}],
|
||||
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}],
|
||||
],
|
||||
})
|
||||
|
||||
dataset = dataset.map(extract_prompt)
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}],
|
||||
'chosen': [{'role': 'assistant', 'content': 'It is blue.'}],
|
||||
'rejected': [{'role': 'assistant', 'content': 'It is green.'}]}
|
||||
```
|
||||
|
||||
### From preference with implicit prompt to unpaired preference dataset
|
||||
|
||||
To convert a preference dataset with implicit prompt into an unpaired preference dataset, extract the prompt with [`extract_prompt`], and unpair the dataset with [`unpair_preference_dataset`].
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
from trl import extract_prompt, unpair_preference_dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"chosen": [
|
||||
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}],
|
||||
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sky."}],
|
||||
],
|
||||
"rejected": [
|
||||
[{"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}],
|
||||
[{"role": "user", "content": "Where is the sun?"}, {"role": "assistant", "content": "In the sea."}],
|
||||
],
|
||||
})
|
||||
|
||||
dataset = dataset.map(extract_prompt)
|
||||
dataset = unpair_preference_dataset(dataset)
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}],
|
||||
'completion': [{'role': 'assistant', 'content': 'It is blue.'}],
|
||||
'label': True}
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Keep in mind that the `"chosen"` and `"rejected"` completions in a preference dataset can be both good or bad.
|
||||
Before applying [`unpair_preference_dataset`], please ensure that all `"chosen"` completions can be labeled as good and all `"rejected"` completions as bad.
|
||||
This can be ensured by checking absolute rating of each completion, e.g. from a reward model.
|
||||
|
||||
</Tip>
|
||||
|
||||
### From preference to language modeling dataset
|
||||
|
||||
To convert a preference dataset into a language modeling dataset, remove the rejected, concatenate the prompt and the chosen into the `"text"` column.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["The sky is", "The sun is"],
|
||||
"chosen": [" blue.", " in the sky."],
|
||||
"rejected": [" green.", " in the sea."],
|
||||
})
|
||||
|
||||
def concat_prompt_chosen(example):
|
||||
return {"text": example["prompt"] + example["chosen"]}
|
||||
|
||||
dataset = dataset.map(concat_prompt_chosen, remove_columns=["prompt", "chosen", "rejected"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'text': 'The sky is blue.'}
|
||||
```
|
||||
|
||||
### From preference to prompt-completion dataset
|
||||
|
||||
To convert a preference dataset into a prompt-completion dataset, remove the rejected, and rename the column `"chosen"` to `"completion"`.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["The sky is", "The sun is"],
|
||||
"chosen": [" blue.", " in the sky."],
|
||||
"rejected": [" green.", " in the sea."],
|
||||
})
|
||||
|
||||
dataset = dataset.remove_columns("rejected").rename_column("chosen", "completion")
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': 'The sky is', 'completion': ' blue.'}
|
||||
```
|
||||
|
||||
### From preference to prompt-only dataset
|
||||
|
||||
To convert a preference dataset into a prompt-only dataset, remove the rejected and the chosen.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["The sky is", "The sun is"],
|
||||
"chosen": [" blue.", " in the sky."],
|
||||
"rejected": [" green.", " in the sea."],
|
||||
})
|
||||
|
||||
dataset = dataset.remove_columns(["chosen", "rejected"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': 'The sky is'}
|
||||
```
|
||||
|
||||
### From explicit to implicit prompt preference dataset
|
||||
|
||||
To convert a preference dataset with explicit prompt into a preference dataset with implicit prompt, concatenate the prompt to both chosen and rejected, and remove the prompt.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": [
|
||||
[{"role": "user", "content": "What color is the sky?"}],
|
||||
[{"role": "user", "content": "Where is the sun?"}],
|
||||
],
|
||||
"chosen": [
|
||||
[{"role": "assistant", "content": "It is blue."}],
|
||||
[{"role": "assistant", "content": "In the sky."}],
|
||||
],
|
||||
"rejected": [
|
||||
[{"role": "assistant", "content": "It is green."}],
|
||||
[{"role": "assistant", "content": "In the sea."}],
|
||||
],
|
||||
})
|
||||
|
||||
def concat_prompt_to_completions(example):
|
||||
return {"chosen": example["prompt"] + example["chosen"], "rejected": example["prompt"] + example["rejected"]}
|
||||
|
||||
dataset = dataset.map(concat_prompt_to_completions, remove_columns="prompt")
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'chosen': [{'role': 'user', 'content': 'What color is the sky?'}, {'role': 'assistant', 'content': 'It is blue.'}],
|
||||
'rejected': [{'role': 'user', 'content': 'What color is the sky?'}, {'role': 'assistant', 'content': 'It is green.'}]}
|
||||
```
|
||||
|
||||
### From preference to unpaired preference dataset
|
||||
|
||||
To convert dataset into an unpaired preference dataset, unpair the dataset with [`unpair_preference_dataset`].
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
from trl import unpair_preference_dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": [
|
||||
[{"role": "user", "content": "What color is the sky?"}],
|
||||
[{"role": "user", "content": "Where is the sun?"}],
|
||||
],
|
||||
"chosen": [
|
||||
[{"role": "assistant", "content": "It is blue."}],
|
||||
[{"role": "assistant", "content": "In the sky."}],
|
||||
],
|
||||
"rejected": [
|
||||
[{"role": "assistant", "content": "It is green."}],
|
||||
[{"role": "assistant", "content": "In the sea."}],
|
||||
],
|
||||
})
|
||||
|
||||
dataset = unpair_preference_dataset(dataset)
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': [{'role': 'user', 'content': 'What color is the sky?'}],
|
||||
'completion': [{'role': 'assistant', 'content': 'It is blue.'}],
|
||||
'label': True}
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Keep in mind that the `"chosen"` and `"rejected"` completions in a preference dataset can be both good or bad.
|
||||
Before applying [`unpair_preference_dataset`], please ensure that all `"chosen"` completions can be labeled as good and all `"rejected"` completions as bad.
|
||||
This can be ensured by checking absolute rating of each completion, e.g. from a reward model.
|
||||
|
||||
</Tip>
|
||||
|
||||
### From unpaired preference to language modeling dataset
|
||||
|
||||
To convert an unpaired preference dataset into a language modeling dataset, concatenate prompts with good completions into the `"text"` column, and remove the prompt, completion and label columns.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["The sky is", "The sun is", "The sky is", "The sun is"],
|
||||
"completion": [" blue.", " in the sky.", " green.", " in the sea."],
|
||||
"label": [True, True, False, False],
|
||||
})
|
||||
|
||||
def concatenate_prompt_completion(example):
|
||||
return {"text": example["prompt"] + example["completion"]}
|
||||
|
||||
dataset = dataset.filter(lambda x: x["label"]).map(concatenate_prompt_completion).remove_columns(["prompt", "completion", "label"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'text': 'The sky is blue.'}
|
||||
```
|
||||
|
||||
### From unpaired preference to prompt-completion dataset
|
||||
|
||||
To convert an unpaired preference dataset into a prompt-completion dataset, filter for good labels, then remove the label columns.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["The sky is", "The sun is", "The sky is", "The sun is"],
|
||||
"completion": [" blue.", " in the sky.", " green.", " in the sea."],
|
||||
"label": [True, True, False, False],
|
||||
})
|
||||
|
||||
dataset = dataset.filter(lambda x: x["label"]).remove_columns(["label"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': 'The sky is', 'completion': ' blue.'}
|
||||
```
|
||||
|
||||
### From unpaired preference to prompt-only dataset
|
||||
|
||||
To convert an unpaired preference dataset into a prompt-only dataset, remove the completion and the label columns.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["The sky is", "The sun is", "The sky is", "The sun is"],
|
||||
"completion": [" blue.", " in the sky.", " green.", " in the sea."],
|
||||
"label": [True, True, False, False],
|
||||
})
|
||||
|
||||
dataset = dataset.remove_columns(["completion", "label"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': 'The sky is'}
|
||||
```
|
||||
|
||||
### From stepwise supervision to language modeling dataset
|
||||
|
||||
To convert a stepwise supervision dataset into a language modeling dataset, concatenate prompts with good completions into the `"text"` column.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["Blue light", "Water"],
|
||||
"completions": [[" scatters more in the atmosphere,", " so the sky is green."],
|
||||
[" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
|
||||
"labels": [[True, False], [True, True]],
|
||||
})
|
||||
|
||||
def concatenate_prompt_completions(example):
|
||||
completion = "".join(example["completions"])
|
||||
return {"text": example["prompt"] + completion}
|
||||
|
||||
dataset = dataset.filter(lambda x: all(x["labels"])).map(concatenate_prompt_completions, remove_columns=["prompt", "completions", "labels"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'text': 'Blue light scatters more in the atmosphere, so the sky is green.'}
|
||||
```
|
||||
|
||||
### From stepwise supervision to prompt completion dataset
|
||||
|
||||
To convert a stepwise supervision dataset into a prompt-completion dataset, join the good completions and remove the labels.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["Blue light", "Water"],
|
||||
"completions": [[" scatters more in the atmosphere,", " so the sky is green."],
|
||||
[" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
|
||||
"labels": [[True, False], [True, True]],
|
||||
})
|
||||
|
||||
def join_completions(example):
|
||||
completion = "".join(example["completions"])
|
||||
return {"completion": completion}
|
||||
|
||||
dataset = dataset.filter(lambda x: all(x["labels"])).map(join_completions, remove_columns=["completions", "labels"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': 'Blue light', 'completion': ' scatters more in the atmosphere, so the sky is green.'}
|
||||
```
|
||||
|
||||
### From stepwise supervision to prompt only dataset
|
||||
|
||||
To convert a stepwise supervision dataset into a prompt-only dataset, remove the completions and the labels.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["Blue light", "Water"],
|
||||
"completions": [[" scatters more in the atmosphere,", " so the sky is green."],
|
||||
[" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
|
||||
"labels": [[True, False], [True, True]],
|
||||
})
|
||||
|
||||
dataset = dataset.remove_columns(["completions", "labels"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': 'Blue light'}
|
||||
```
|
||||
|
||||
### From stepwise supervision to unpaired preference dataset
|
||||
|
||||
To convert a stepwise supervision dataset into an unpaired preference dataset, join the completions and merge the labels.
|
||||
|
||||
The method for merging the labels depends on the specific task. In this example, we use the logical AND operation. This means that if the step labels indicate the correctness of individual steps, the resulting label will reflect the correctness of the entire sequence.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["Blue light", "Water"],
|
||||
"completions": [[" scatters more in the atmosphere,", " so the sky is green."],
|
||||
[" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
|
||||
"labels": [[True, False], [True, True]],
|
||||
})
|
||||
|
||||
def merge_completions_and_labels(example):
|
||||
return {"prompt": example["prompt"], "completion": "".join(example["completions"]), "label": all(example["labels"])}
|
||||
|
||||
dataset = dataset.map(merge_completions_and_labels, remove_columns=["completions", "labels"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': 'Blue light', 'completion': ' scatters more in the atmosphere, so the sky is green.', 'label': False}
|
||||
```
|
||||
|
||||
## Vision datasets
|
||||
|
||||
Some trainers also support fine-tuning vision-language models (VLMs) using image-text pairs. In this scenario, it's recommended to use a conversational format, as each model handles image placeholders in text differently.
|
||||
|
||||
A conversational vision dataset differs from a standard conversational dataset in two key ways:
|
||||
|
||||
1. The dataset must contain the key `images` with the image data.
|
||||
2. The `"content"` field in messages must be a list of dictionaries, where each dictionary specifies the type of data: `"image"` or `"text"`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
# Textual dataset:
|
||||
"content": "What color is the sky?"
|
||||
|
||||
# Vision dataset:
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What color is the sky in the image?"}
|
||||
]
|
||||
```
|
||||
|
||||
An example of a conversational vision dataset is the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset). Below is an embedded view of the dataset's training data, allowing you to explore it directly:
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/trl-lib/rlaif-v/embed/viewer/default/train"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
@ -1,131 +0,0 @@
|
||||
# Denoising Diffusion Policy Optimization
|
||||
|
||||
[](https://huggingface.co/models?other=ddpo,trl)
|
||||
|
||||
## The why
|
||||
|
||||
| Before | After DDPO finetuning |
|
||||
| --- | --- |
|
||||
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/pre_squirrel.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/post_squirrel.png"/></div> |
|
||||
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/pre_crab.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/post_crab.png"/></div> |
|
||||
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/pre_starfish.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/post_starfish.png"/></div> |
|
||||
|
||||
|
||||
## Getting started with Stable Diffusion finetuning with reinforcement learning
|
||||
|
||||
The machinery for finetuning of Stable Diffusion models with reinforcement learning makes heavy use of HuggingFace's `diffusers`
|
||||
library. A reason for stating this is that getting started requires a bit of familiarity with the `diffusers` library concepts, mainly two of them - pipelines and schedulers.
|
||||
Right out of the box (`diffusers` library), there isn't a `Pipeline` nor a `Scheduler` instance that is suitable for finetuning with reinforcement learning. Some adjustments need to be made.
|
||||
|
||||
There is a pipeline interface that is provided by this library that is required to be implemented to be used with the `DDPOTrainer`, which is the main machinery for fine-tuning Stable Diffusion with reinforcement learning. **Note: Only the StableDiffusion architecture is supported at this point.**
|
||||
There is a default implementation of this interface that you can use out of the box. Assuming the default implementation is sufficient and/or to get things moving, refer to the training example alongside this guide.
|
||||
|
||||
The point of the interface is to fuse the pipeline and the scheduler into one object which allows for minimalness in terms of having the constraints all in one place. The interface was designed in hopes of catering to pipelines and schedulers beyond the examples in this repository and elsewhere at this time of writing. Also the scheduler step is a method of this pipeline interface and this may seem redundant given that the raw scheduler is accessible via the interface but this is the only way to constrain the scheduler step output to an output type befitting of the algorithm at hand (DDPO).
|
||||
|
||||
For a more detailed look into the interface and the associated default implementation, go [here](https://github.com/lvwerra/trl/tree/main/trl/models/modeling_sd_base.py)
|
||||
|
||||
Note that the default implementation has a LoRA implementation path and a non-LoRA based implementation path. The LoRA flag enabled by default and this can be turned off by passing in the flag to do so. LORA based training is faster and the LORA associated model hyperparameters responsible for model convergence aren't as finicky as non-LORA based training.
|
||||
|
||||
Also in addition, there is the expectation of providing a reward function and a prompt function. The reward function is used to evaluate the generated images and the prompt function is used to generate the prompts that are used to generate the images.
|
||||
|
||||
## Getting started with `examples/scripts/ddpo.py`
|
||||
|
||||
The `ddpo.py` script is a working example of using the `DDPO` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`DDPOConfig`).
|
||||
|
||||
**Note:** one A100 GPU is recommended to get this running. Anything below a A100 will not be able to run this example script and even if it does via relatively smaller sized parameters, the results will most likely be poor.
|
||||
|
||||
Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running
|
||||
|
||||
```batch
|
||||
python ddpo.py --hf_user_access_token <token>
|
||||
```
|
||||
|
||||
To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help`
|
||||
|
||||
The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)
|
||||
|
||||
- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) should be greater than or equal to the configurable training batch size (`--ddpo_config.train_batch_size=3`)
|
||||
- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) must be divisible by the configurable train batch size (`--ddpo_config.train_batch_size=3`)
|
||||
- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) must be divisible by both the configurable gradient accumulation steps (`--ddpo_config.train_gradient_accumulation_steps=1`) and the configurable accelerator processes count
|
||||
|
||||
## Setting up the image logging hook function
|
||||
|
||||
Expect the function to be given a list of lists of the form
|
||||
```python
|
||||
[[image, prompt, prompt_metadata, rewards, reward_metadata], ...]
|
||||
|
||||
```
|
||||
and `image`, `prompt`, `prompt_metadata`, `rewards`, `reward_metadata` are batched.
|
||||
The last list in the lists of lists represents the last sample batch. You are likely to want to log this one
|
||||
While you are free to log however you want the use of `wandb` or `tensorboard` is recommended.
|
||||
|
||||
### Key terms
|
||||
|
||||
- `rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process
|
||||
- `reward_metadata` : The reward metadata is the metadata associated with the reward. Think of this as extra information payload delivered alongside the reward
|
||||
- `prompt` : The prompt is the text that is used to generate the image
|
||||
- `prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)
|
||||
- `image` : The image generated by the Stable Diffusion model
|
||||
|
||||
Example code for logging sampled images with `wandb` is given below.
|
||||
|
||||
```python
|
||||
# for logging these images to wandb
|
||||
|
||||
def image_outputs_hook(image_data, global_step, accelerate_logger):
|
||||
# For the sake of this example, we only care about the last batch
|
||||
# hence we extract the last element of the list
|
||||
result = {}
|
||||
images, prompts, _, rewards, _ = image_data[-1]
|
||||
for i, image in enumerate(images):
|
||||
pil = Image.fromarray(
|
||||
(image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
|
||||
)
|
||||
pil = pil.resize((256, 256))
|
||||
result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
|
||||
accelerate_logger.log_images(
|
||||
result,
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
### Using the finetuned model
|
||||
|
||||
Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows
|
||||
|
||||
```python
|
||||
|
||||
import torch
|
||||
from trl import DefaultDDPOStableDiffusionPipeline
|
||||
|
||||
pipeline = DefaultDDPOStableDiffusionPipeline("metric-space/ddpo-finetuned-sd-model")
|
||||
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
# memory optimization
|
||||
pipeline.vae.to(device, torch.float16)
|
||||
pipeline.text_encoder.to(device, torch.float16)
|
||||
pipeline.unet.to(device, torch.float16)
|
||||
|
||||
prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
|
||||
results = pipeline(prompts)
|
||||
|
||||
for prompt, image in zip(prompts,results.images):
|
||||
image.save(f"{prompt}.png")
|
||||
|
||||
```
|
||||
|
||||
## Credits
|
||||
|
||||
This work is heavily influenced by the repo [here](https://github.com/kvablack/ddpo-pytorch) and the associated paper [Training Diffusion Models
|
||||
with Reinforcement Learning by Kevin Black, Michael Janner, Yilan Du, Ilya Kostrikov, Sergey Levine](https://huggingface.co/papers/2305.13301).
|
||||
|
||||
## DDPOTrainer
|
||||
|
||||
[[autodoc]] DDPOTrainer
|
||||
|
||||
## DDPOConfig
|
||||
|
||||
[[autodoc]] DDPOConfig
|
||||
|
@ -1,7 +0,0 @@
|
||||
# DeepSpeed Integration
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Section under construction. Feel free to contribute!
|
||||
|
||||
</Tip>
|
@ -1,187 +0,0 @@
|
||||
# Detoxifying a Language Model using PPO
|
||||
|
||||
Language models (LMs) are known to sometimes generate toxic outputs. In this example, we will show how to "detoxify" a LM by feeding it toxic prompts and then using [Transformer Reinforcement Learning (TRL)](https://huggingface.co/docs/trl/index) and Proximal Policy Optimization (PPO) to "detoxify" it.
|
||||
|
||||
Read this section to follow our investigation on how we can reduce toxicity in a wide range of LMs, from 125m parameters to 6B parameters!
|
||||
|
||||
Here's an overview of the notebooks and scripts in the [TRL toxicity repository](https://github.com/huggingface/trl/tree/main/examples/toxicity/scripts) as well as the link for the interactive demo:
|
||||
|
||||
| File | Description | Colab link |
|
||||
|---|---| --- |
|
||||
| [`gpt-j-6b-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py) | Detoxify `GPT-J-6B` using PPO | x |
|
||||
| [`evaluate-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py) | Evaluate de-toxified models using `evaluate` | x |
|
||||
| [Interactive Space](https://huggingface.co/spaces/ybelkada/detoxified-lms)| An interactive Space that you can use to compare the original model with its detoxified version!| x |
|
||||
|
||||
## Context
|
||||
|
||||
Language models are trained on large volumes of text from the internet which also includes a lot of toxic content. Naturally, language models pick up the toxic patterns during training. Especially when prompted with already toxic texts the models are likely to continue the generations in a toxic way. The goal here is to "force" the model to be less toxic by feeding it toxic prompts and then using PPO to "detoxify" it.
|
||||
|
||||
### Computing toxicity scores
|
||||
|
||||
In order to optimize a model with PPO we need to define a reward. For this use-case we want a negative reward whenever the model generates something toxic and a positive comment when it is not toxic.
|
||||
Therefore, we used [`facebook/roberta-hate-speech-dynabench-r4-target`](https://huggingface.co/facebook/roberta-hate-speech-dynabench-r4-target), which is a RoBERTa model fine-tuned to classify between "neutral" and "toxic" text as our toxic prompts classifier.
|
||||
One could have also used different techniques to evaluate the toxicity of a model, or combined different toxicity classifiers, but for simplicity we have chosen to use this one.
|
||||
|
||||
### Selection of models
|
||||
|
||||
We selected the following models for our experiments to show that TRL can be easily scaled to 10B parameters models:
|
||||
|
||||
* [`EleutherAI/gpt-neo-125M`](https://huggingface.co/EleutherAI/gpt-neo-125M) (125 million parameters)
|
||||
* [`EleutherAI/gpt-neo-2.7B`](https://huggingface.co/EleutherAI/gpt-neo-2.7B) (2.7 billion parameters)
|
||||
* [`EleutherAI/gpt-j-6B`](https://huggingface.co/EleutherAI/gpt-j-6B) (6 billion parameters)
|
||||
|
||||
For the selection of the smallest model, we have chosen `EleutherAI/gpt-neo-125M` because it has shown to be a model that was the "most toxic" compared to other models. We have run toxicity evaluation using `facebook/roberta-hate-speech-dynabench-r4-target` model on 4 different architectures on a subset of `allenai/real-toxicity-prompts` dataset. Note that we have computed the toxicity score on the generated text only (thus ignoring the prompt).
|
||||
|
||||
| Model | Mean toxicity score |
|
||||
|---|---|
|
||||
| `gpt2` | 0.01602 |
|
||||
| `facebook/opt-350m` | 0.01628 |
|
||||
| `bigscience/bloom-560m` | 0.00767 |
|
||||
| `EleutherAI/gpt-neo-125M` | **0.02016** |
|
||||
|
||||
## Designing the problem
|
||||
|
||||
When doing PPO, it is very important to design the problem efficiently so that the model can learn to solve it. Let's cover the topics that were important for the model to converge.
|
||||
|
||||
### Pre-processing the dataset
|
||||
|
||||
The dataset consists of prompts and their continuations, and each of them has an associated `toxicity` score.
|
||||
|
||||
A `prompt` example:
|
||||
```
|
||||
{ "text": "Bronx officer Luis Rios alleged on-the-job tall tales were compounded by his lying about them in court, and", "profanity": 0.08360514, "sexually_explicit": 0.118823394, "identity_attack": 0.09608547, "flirtation": 0.23102224, "threat": 0.13215046, "insult": 0.10130461, "severe_toxicity": 0.04068885, "toxicity": 0.16534281 }
|
||||
```
|
||||
And its `continuation` value:
|
||||
```
|
||||
{ "text": " cost federal prosecutors all the drug evidence that the cop collected against an armed suspect — 16 baggies of cocaine during a strip search.", "severe_toxicity": 0.067997746, "toxicity": 0.1694093, "profanity": 0.11931301, "sexually_explicit": 0.12521537, "identity_attack": 0.09268324, "flirtation": 0.13452998, "threat": 0.31312028, "insult": 0.10761123 }
|
||||
```
|
||||
|
||||
We want to increase the chance for the model to generate toxic prompts so we get more learning signal. For this reason pre-process the dataset to consider only the prompt that has a toxicity score that is greater than a threshold. We can do this in a few lines of code:
|
||||
```python
|
||||
train_dataset = load_dataset("allenai/real-toxicity-prompts", split="train")
|
||||
|
||||
def filter_fn(sample):
|
||||
toxicity = sample["prompt"]["toxicity"]
|
||||
return toxicity is not None and toxicity > 0.3
|
||||
|
||||
train_dataset = train_dataset.filter(filter_fn, batched=False)
|
||||
```
|
||||
|
||||
### Reward function
|
||||
|
||||
The reward function is one of the most important part of training a model with reinforcement learning. It is the function that will tell the model if it is doing well or not.
|
||||
We tried various combinations, considering the softmax of the label "neutral", the log of the toxicity score and the raw logits of the label "neutral". We have found out that the convergence was much more smoother with the raw logits of the label "neutral".
|
||||
```python
|
||||
logits = toxicity_model(**toxicity_inputs).logits.float()
|
||||
rewards = (logits[:, 0]).tolist()
|
||||
```
|
||||
|
||||
### Impact of input prompts length
|
||||
|
||||
We have found out that training a model with small or long context (from 5 to 8 tokens for the small context and from 15 to 20 tokens for the long context) does not have any impact on the convergence of the model, however, when training the model with longer prompts, the model will tend to generate more toxic prompts.
|
||||
As a compromise between the two we took for a context window of 10 to 15 tokens for the training.
|
||||
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-long-vs-short-context.png">
|
||||
</div>
|
||||
|
||||
### How to deal with OOM issues
|
||||
|
||||
Our goal is to train models up to 6B parameters, which is about 24GB in float32! Here are two tricks we use to be able to train a 6B model on a single 40GB-RAM GPU:
|
||||
|
||||
- Use `bfloat16` precision: Simply load your model in `bfloat16` when calling `from_pretrained` and you can reduce the size of the model by 2:
|
||||
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
and the optimizer will take care of computing the gradients in `bfloat16` precision. Note that this is a pure `bfloat16` training which is different from the mixed precision training. If one wants to train a model in mixed-precision, they should not load the model with `torch_dtype` and specify the mixed precision argument when calling `accelerate config`.
|
||||
|
||||
- Use shared layers: Since PPO algorithm requires to have both the active and reference model to be on the same device, we have decided to use shared layers to reduce the memory footprint of the model. This can be achieved by specifying `num_shared_layers` argument when calling the `create_reference_model()` function. For example, if you want to share the first 6 layers of the model, you can do it like this:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-shared-layers.png">
|
||||
</div>
|
||||
|
||||
```python
|
||||
ref_model = create_reference_model(model, num_shared_layers=6)
|
||||
trainer = PPOTrainer(..., ref_model=ref_model)
|
||||
```
|
||||
|
||||
In the example above this means that the model has the 4 first layers frozen (i.e. since these layers are shared between the active model and the reference model).
|
||||
|
||||
- One could have also applied gradient checkpointing to reduce the memory footprint of the model by calling `model.pretrained_model.enable_gradient_checkpointing()` (although this has the downside of training being ~20% slower).
|
||||
|
||||
## Training the model!
|
||||
|
||||
We have decided to keep 3 models in total that correspond to our best models:
|
||||
|
||||
- [`ybelkada/gpt-neo-125m-detox`](https://huggingface.co/ybelkada/gpt-neo-125m-detox)
|
||||
- [`ybelkada/gpt-neo-2.7B-detox`](https://huggingface.co/ybelkada/gpt-neo-2.7B-detox)
|
||||
- [`ybelkada/gpt-j-6b-detox`](https://huggingface.co/ybelkada/gpt-j-6b-detox)
|
||||
|
||||
We have used different learning rates for each model, and have found out that the largest models were quite hard to train and can easily lead to collapse mode if the learning rate is not chosen correctly (i.e. if the learning rate is too high):
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-collapse-mode.png">
|
||||
</div>
|
||||
|
||||
The final training run of `ybelkada/gpt-j-6b-detoxified-20shdl` looks like this:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-gpt-j-final-run-2.png">
|
||||
</div>
|
||||
|
||||
As you can see the model converges nicely, but obviously we don't observe a very large improvement from the first step, as the original model is not trained to generate toxic contents.
|
||||
|
||||
Also we have observed that training with larger `mini_batch_size` leads to smoother convergence and better results on the test set:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-gpt-j-mbs-run.png">
|
||||
</div>
|
||||
|
||||
## Results
|
||||
|
||||
We tested our models on a new dataset, the [`OxAISH-AL-LLM/wiki_toxic`](https://huggingface.co/datasets/OxAISH-AL-LLM/wiki_toxic) dataset. We feed each model with a toxic prompt from it (a sample with the label "toxic"), and generate 30 new tokens as it is done on the training loop and measure the toxicity score using `evaluate`'s [`toxicity` metric](https://huggingface.co/spaces/ybelkada/toxicity).
|
||||
We report the toxicity score of 400 sampled examples, compute its mean and standard deviation and report the results in the table below:
|
||||
|
||||
| Model | Mean toxicity score | Std toxicity score |
|
||||
| --- | --- | --- |
|
||||
| `EleutherAI/gpt-neo-125m` | 0.1627 | 0.2997 |
|
||||
| `ybelkada/gpt-neo-125m-detox` | **0.1148** | **0.2506** |
|
||||
| --- | --- | --- |
|
||||
| `EleutherAI/gpt-neo-2.7B` | 0.1884 | 0.3178 |
|
||||
| `ybelkada/gpt-neo-2.7B-detox` | **0.0916** | **0.2104** |
|
||||
| --- | --- | --- |
|
||||
| `EleutherAI/gpt-j-6B` | 0.1699 | 0.3033 |
|
||||
| `ybelkada/gpt-j-6b-detox` | **0.1510** | **0.2798** |
|
||||
|
||||
<div class="column" style="text-align:center">
|
||||
<figure>
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-final-barplot.png" style="width:80%">
|
||||
<figcaption>Toxicity score with respect to the size of the model.</figcaption>
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
Below are few generation examples of `gpt-j-6b-detox` model:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-toxicity-examples.png">
|
||||
</div>
|
||||
|
||||
The evaluation script can be found [here](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py).
|
||||
|
||||
### Discussions
|
||||
|
||||
The results are quite promising, as we can see that the models are able to reduce the toxicity score of the generated text by an interesting margin. The gap is clear for `gpt-neo-2B` model but we less so for the `gpt-j-6B` model. There are several things we could try to improve the results on the largest model starting with training with larger `mini_batch_size` and probably allowing to back-propagate through more layers (i.e. use less shared layers).
|
||||
|
||||
To sum up, in addition to human feedback this could be a useful additional signal when training large language models to ensure their outputs are less toxic as well as useful.
|
||||
|
||||
### Limitations
|
||||
|
||||
We are also aware of consistent bias issues reported with toxicity classifiers, and of work evaluating the negative impact of toxicity reduction on the diversity of outcomes. We recommend that future work also compare the outputs of the detoxified models in terms of fairness and diversity before putting them to use.
|
||||
|
||||
## What is next?
|
||||
|
||||
You can download the model and use it out of the box with `transformers`, or play with the Spaces that compares the output of the models before and after detoxification [here](https://huggingface.co/spaces/ybelkada/detoxified-lms).
|
@ -1,283 +0,0 @@
|
||||
# DPO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=dpo,trl) [](https://github.com/huggingface/smol-course/tree/main/2_preference_alignment)
|
||||
|
||||
## Overview
|
||||
|
||||
TRL supports the DPO Trainer for training language models from preference data, as described in the paper [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://huggingface.co/papers/2305.18290) by [Rafael Rafailov](https://huggingface.co/rmrafailov), Archit Sharma, Eric Mitchell, [Stefano Ermon](https://huggingface.co/ermonste), [Christopher D. Manning](https://huggingface.co/manning), [Chelsea Finn](https://huggingface.co/cbfinn).
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
> While large-scale unsupervised language models (LMs) learn broad world knowledge and some reasoning skills, achieving precise control of their behavior is difficult due to the completely unsupervised nature of their training. Existing methods for gaining such steerability collect human labels of the relative quality of model generations and fine-tune the unsupervised LM to align with these preferences, often with reinforcement learning from human feedback (RLHF). However, RLHF is a complex and often unstable procedure, first fitting a reward model that reflects the human preferences, and then fine-tuning the large unsupervised LM using reinforcement learning to maximize this estimated reward without drifting too far from the original model. In this paper we introduce a new parameterization of the reward model in RLHF that enables extraction of the corresponding optimal policy in closed form, allowing us to solve the standard RLHF problem with only a simple classification loss. The resulting algorithm, which we call Direct Preference Optimization (DPO), is stable, performant, and computationally lightweight, eliminating the need for sampling from the LM during fine-tuning or performing significant hyperparameter tuning. Our experiments show that DPO can fine-tune LMs to align with human preferences as well as or better than existing methods. Notably, fine-tuning with DPO exceeds PPO-based RLHF in ability to control sentiment of generations, and matches or improves response quality in summarization and single-turn dialogue while being substantially simpler to implement and train.
|
||||
|
||||
The first step is to train an SFT model, to ensure the data we train on is in-distribution for the DPO algorithm.
|
||||
|
||||
Then, fine-tuning a language model via DPO consists of two steps and is easier than [PPO](ppo_trainer):
|
||||
|
||||
1. **Data collection**: Gather a [preference dataset](dataset_formats#preference) with positive and negative selected pairs of generation, given a prompt.
|
||||
2. **Optimization**: Maximize the log-likelihood of the DPO loss directly.
|
||||
|
||||
This process is illustrated in the sketch below (from [Figure 1 of the DPO paper](https://huggingface.co/papers/2305.18290)):
|
||||
|
||||

|
||||
|
||||
Read more about DPO algorithm in the [original paper](https://huggingface.co/papers/2305.18290).
|
||||
|
||||
## Quick start
|
||||
|
||||
This example demonstrates how to train a model using the DPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the data in the dataset here:
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized/embed/viewer/default/train?row=0"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
Below is the script to train the model:
|
||||
|
||||
```python
|
||||
# train_dpo.py
|
||||
from datasets import load_dataset
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
|
||||
training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO", logging_steps=10)
|
||||
trainer = DPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Execute the script using the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch train_dpo.py
|
||||
```
|
||||
|
||||
Distributed across 8 GPUs, the training takes approximately 3 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time.
|
||||
|
||||

|
||||
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-DPO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
|
||||
|
||||
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-DPO
|
||||
<strong><span style="color: red;"><quentin_gallouedec>:</span></strong>
|
||||
What is the best programming language?
|
||||
|
||||
<strong><span style="color: blue;"><trl-lib/Qwen2-0.5B-DPO>:</span></strong>
|
||||
The best programming language for specific applications can vary depending on the use case and knowledge level of the programmer. Here are some general factors that can be used as input to choose the best programming language:
|
||||
|
||||
<strong><span style="color: green;">1</span></strong> Ease of use: Some programming languages are more user-friendly than others, such as Python, Java, or Ruby. Python is popular due to its simplicity and great scalability.
|
||||
<strong><span style="color: green;">2</span></strong> Versatility: The ability to work with a wide range of data structures and frameworks can define the language as versatile.
|
||||
<strong><span style="color: green;">3</span></strong> Ease of learning: Different programming languages have different learning curves, so users must be willing to take some time to master one.
|
||||
<strong><span style="color: green;">4</span></strong> Community support: The broader community of developers and enthusiasts in the selected programming language can provide great support and resources.
|
||||
<strong><span style="color: green;">5</span></strong> Reusability: Languages that emphasize code reuse and can be easily modifiable can be more suitable for software development.
|
||||
|
||||
The best programming language based on these factors is subjective and depends on what the programmer intends to accomplish.
|
||||
</code></pre>
|
||||
|
||||
## Expected dataset type
|
||||
|
||||
DPO requires a [preference dataset](dataset_formats#preference). The [`DPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
|
||||
|
||||
Although the [`DPOTrainer`] supports both explicit and implicit prompts, we recommend using explicit prompts. If provided with an implicit prompt dataset, the trainer will automatically extract the prompt from the `"chosen"` and `"rejected"` columns. For more information, refer to the [preference style](dataset_formats#preference) section.
|
||||
|
||||
### Special considerations for vision-language models
|
||||
|
||||
The [`DPOTrainer`] supports fine-tuning vision-language models (VLMs). For these models, a vision dataset is required. To learn more about the specific format for vision datasets, refer to the [Vision dataset format](dataset_formats#vision-datasets) section.
|
||||
|
||||
Additionally, unlike standard text-based models where a `tokenizer` is used, for VLMs, you should replace the `tokenizer` with a `processor`.
|
||||
|
||||
```diff
|
||||
- model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
+ model = AutoModelForVision2Seq.from_pretrained(model_id)
|
||||
|
||||
- tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
+ processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
trainer = DPOTrainer(
|
||||
model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
- processing_class=tokenizer,
|
||||
+ processing_class=processor,
|
||||
)
|
||||
```
|
||||
|
||||
For a complete example of fine-tuning a vision-language model, refer to the script in [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py).
|
||||
|
||||
|
||||
## Example script
|
||||
|
||||
We provide an example script to train a model using the DPO method. The script is available in [`trl/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py)
|
||||
|
||||
To test the DPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch trl/scripts/dpo.py \
|
||||
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
|
||||
--dataset_name trl-lib/ultrafeedback_binarized \
|
||||
--num_train_epochs 1 \
|
||||
--logging_steps 25 \
|
||||
--output_dir Qwen2-0.5B-DPO
|
||||
```
|
||||
|
||||
## Logged metrics
|
||||
|
||||
While training and evaluating we record the following reward metrics:
|
||||
|
||||
- `rewards/chosen`: the mean difference between the log probabilities of the policy model and the reference model for the chosen responses scaled by beta
|
||||
- `rewards/rejected`: the mean difference between the log probabilities of the policy model and the reference model for the rejected responses scaled by beta
|
||||
- `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
|
||||
- `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
|
||||
|
||||
## Loss functions
|
||||
|
||||
The DPO algorithm supports several loss functions. The loss function can be set using the `loss_type` parameter in the [`DPOConfig`]. The following loss functions are supported:
|
||||
|
||||
| `loss_type=` | Description |
|
||||
| -------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `"sigmoid"` (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. |
|
||||
| `"hinge"` | The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. In this case, the `beta` is the reciprocal of the margin. |
|
||||
| `"ipo"` | The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike DPO which is summed only). |
|
||||
| `"exo_pair"` | The [EXO](https://huggingface.co/papers/2402.00856) authors propose to minimize the reverse KL instead of the negative log-sigmoid loss of DPO which corresponds to forward KL. Setting non-zero `label_smoothing` (default `1e-3`) leads to a simplified version of EXO on pair-wise preferences (see Eqn. (16) of the [EXO paper](https://huggingface.co/papers/2402.00856)). The full version of EXO uses `K>2` completions generated by the SFT policy, which becomes an unbiased estimator of the PPO objective (up to a constant) when `K` is sufficiently large. |
|
||||
| `"nca_pair"` | The [NCA](https://huggingface.co/papers/2402.05369) authors shows that NCA optimizes the absolute likelihood for each response rather than the relative likelihood. |
|
||||
| `"robust"` | The [Robust DPO](https://huggingface.co/papers/2403.00409) authors propose an unbiased estimate of the DPO loss that is robust to preference noise in the data. Like in cDPO, it assumes that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0) |
|
||||
| `"bco_pair"` | The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0. For unpaired data, we recommend the dedicated [`BCOTrainer`]. |
|
||||
| `"sppo_hard"` | The [SPPO](https://huggingface.co/papers/2405.00675) authors claim that SPPO is capable of solving the Nash equilibrium iteratively by pushing the chosen rewards to be as large as 1/2 and the rejected rewards to be as small as -1/2 and can alleviate data sparsity issues. The implementation approximates this algorithm by employing hard label probabilities, assigning 1 to the winner and 0 to the loser. |
|
||||
| `"aot"` or `loss_type="aot_pair"` | The [AOT](https://huggingface.co/papers/2406.05882) authors propose to use Distributional Preference Alignment Via Optimal Transport. Traditionally, the alignment algorithms use paired preferences at a sample level, which does not ensure alignment on the distributional level. AOT, on the other hand, can align LLMs on paired or unpaired preference data by making the reward distribution of the positive samples stochastically dominant in the first order on the distribution of negative samples. Specifically, `loss_type="aot"` is appropriate for paired datasets, where each prompt has both chosen and rejected responses; `loss_type="aot_pair"` is for unpaired datasets. In a nutshell, `loss_type="aot"` ensures that the log-likelihood ratio of chosen to rejected of the aligned model has higher quantiles than that ratio for the reference model. `loss_type="aot_pair"` ensures that the chosen reward is higher on all quantiles than the rejected reward. Note that in both cases quantiles are obtained via sorting. To fully leverage the advantages of the AOT algorithm, it is important to maximize the per-GPU batch size. |
|
||||
| `"apo_zero"` or `loss_type="apo_down"` | The [APO](https://huggingface.co/papers/2408.06266) method introduces an "anchored" version of the alignment objective. There are two variants: `apo_zero` and `apo_down`. The `apo_zero` loss increases the likelihood of winning outputs while decreasing the likelihood of losing outputs, making it suitable when the model is less performant than the winning outputs. On the other hand, `apo_down` decreases the likelihood of both winning and losing outputs, but with a stronger emphasis on reducing the likelihood of losing outputs. This variant is more effective when the model is better than the winning outputs. |
|
||||
| `"discopop"` | The [DiscoPOP](https://huggingface.co/papers/2406.08414) paper uses LLMs to discover more efficient offline preference optimization losses. In the paper the proposed DiscoPOP loss (which is a log-ratio modulated loss) outperformed other optimization losses on different tasks (IMDb positive text generation, Reddit TLDR summarization, and Alpaca Eval 2.0). |
|
||||
|
||||
### Label smoothing
|
||||
|
||||
The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0).
|
||||
|
||||
### Syncing the reference model
|
||||
|
||||
The [TR-DPO](https://huggingface.co/papers/2404.09656) paper suggests syncing the reference model weights after every `ref_model_sync_steps` steps of SGD with weight `ref_model_mixup_alpha` during DPO training. To toggle this callback use the `sync_ref_model=True` in the [`DPOConfig`].
|
||||
|
||||
### RPO loss
|
||||
|
||||
The [RPO](https://huggingface.co/papers/2404.19733) paper implements an iterative preference tuning algorithm using a loss related to the RPO loss in this [paper](https://huggingface.co/papers/2405.16436) that essentially consists of a weighted SFT loss on the chosen preferences together with the DPO loss. To use this loss, set the `rpo_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this weight to `1.0`.
|
||||
|
||||
### WPO loss
|
||||
|
||||
The [WPO](https://huggingface.co/papers/2406.11827) paper adapts off-policy data to resemble on-policy data more closely by reweighting preference pairs according to their probability under the current policy. To use this method, set the `use_weighting` flag to `True` in the [`DPOConfig`].
|
||||
|
||||
### For Mixture of Experts Models: Enabling the auxiliary loss
|
||||
|
||||
MOEs are the most efficient if the load is about equally distributed between experts.
|
||||
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
|
||||
|
||||
This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]).
|
||||
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.
|
||||
|
||||
## Accelerate DPO fine-tuning using `unsloth`
|
||||
|
||||
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks for DPO listed below:
|
||||
|
||||
| GPU | Model | Dataset | 🤗 | 🤗 + Flash Attention 2 | 🦥 Unsloth | 🦥 VRAM saved |
|
||||
| -------- | --------- | ---------- | --- | --------------------- | --------- | ------------ |
|
||||
| A100 40G | Zephyr 7b | Ultra Chat | 1x | 1.24x | **1.88x** | -11.6% |
|
||||
| Tesla T4 | Zephyr 7b | Ultra Chat | 1x | 1.09x | **1.55x** | -18.6% |
|
||||
|
||||
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:
|
||||
|
||||
```diff
|
||||
from datasets import load_dataset
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
- from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
+ from unsloth import FastLanguageModel
|
||||
|
||||
- model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
- tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
+ model, tokenizer = FastLanguageModel.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
+ model = FastLanguageModel.get_peft_model(model)
|
||||
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
|
||||
- training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO", logging_steps=10)
|
||||
+ training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO", logging_steps=10, bf16=True)
|
||||
trainer = DPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
|
||||
trainer.train()
|
||||
|
||||
```
|
||||
|
||||
The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth).
|
||||
|
||||
## Reference model considerations with PEFT
|
||||
|
||||
You have three main options (plus several variants) for how the reference model works when using PEFT, assuming the model that you would like to further enhance with DPO was tuned using (Q)LoRA.
|
||||
|
||||
1. Simply create two instances of the model, each loading your adapter - works fine but is very inefficient.
|
||||
2. Merge the adapter into the base model, create another adapter on top, then leave the `ref_model` param null, in which case DPOTrainer will unload the adapter for reference inference - efficient, but has potential downsides discussed below.
|
||||
3. Load the adapter twice with different names, then use `set_adapter` during training to swap between the adapter being DPO'd and the reference adapter - slightly less efficient compared to 2 (~adapter size VRAM overhead), but avoids the pitfalls.
|
||||
|
||||
### Downsides to merging QLoRA before DPO (approach 2)
|
||||
|
||||
As suggested by [Benjamin Marie](https://medium.com/@bnjmn_marie/dont-merge-your-lora-adapter-into-a-4-bit-llm-65b6da287997), the best option for merging QLoRA adapters is to first dequantize the base model, then merge the adapter. Something similar to [this script](https://github.com/jondurbin/qlora/blob/main/qmerge.py).
|
||||
|
||||
However, after using this approach, you will have an unquantized base model. Therefore, to use QLoRA for DPO, you will need to re-quantize the merged model or use the unquantized merge (resulting in higher memory demand).
|
||||
|
||||
### Using option 3 - load the adapter twice
|
||||
|
||||
To avoid the downsides with option 2, you can load your fine-tuned adapter into the model twice, with different names, and set the model/ref adapter names in [`DPOTrainer`].
|
||||
|
||||
For example:
|
||||
|
||||
```python
|
||||
# Load the base model.
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
llm_int8_threshold=6.0,
|
||||
llm_int8_has_fp16_weight=False,
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"mistralai/mixtral-8x7b-v0.1",
|
||||
load_in_4bit=True,
|
||||
quantization_config=bnb_config,
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
)
|
||||
model.config.use_cache = False
|
||||
|
||||
# Load the adapter.
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
"/path/to/peft",
|
||||
is_trainable=True,
|
||||
adapter_name="train",
|
||||
)
|
||||
# Load the adapter a second time, with a different name, which will be our reference model.
|
||||
model.load_adapter("/path/to/peft", adapter_name="reference")
|
||||
|
||||
# Initialize the trainer, without a ref_model param.
|
||||
training_args = DPOConfig(
|
||||
model_adapter_name="train",
|
||||
ref_adapter_name="reference",
|
||||
)
|
||||
dpo_trainer = DPOTrainer(
|
||||
model,
|
||||
args=training_args,
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
## DPOTrainer
|
||||
|
||||
[[autodoc]] DPOTrainer
|
||||
|
||||
## DPOConfig
|
||||
|
||||
[[autodoc]] DPOConfig
|
||||
|
||||
## DataCollatorForPreference
|
||||
|
||||
[[autodoc]] trainer.dpo_trainer.DataCollatorForPreference
|
@ -1,78 +0,0 @@
|
||||
# Examples
|
||||
|
||||
|
||||
## Introduction
|
||||
|
||||
The examples should work in any of the following settings (with the same script):
|
||||
- single GPU
|
||||
- multi GPUS (using PyTorch distributed mode)
|
||||
- multi GPUS (using DeepSpeed ZeRO-Offload stages 1, 2, & 3)
|
||||
- fp16 (mixed-precision), fp32 (normal precision), or bf16 (bfloat16 precision)
|
||||
|
||||
To run it in each of these various modes, first initialize the accelerate
|
||||
configuration with `accelerate config`
|
||||
|
||||
**NOTE to train with a 4-bit or 8-bit model**, please run
|
||||
|
||||
```bash
|
||||
pip install --upgrade trl[quantization]
|
||||
```
|
||||
|
||||
|
||||
## Accelerate Config
|
||||
For all the examples, you'll need to generate a 🤗 Accelerate config file with:
|
||||
|
||||
```shell
|
||||
accelerate config # will prompt you to define the training configuration
|
||||
```
|
||||
|
||||
Then, it is encouraged to launch jobs with `accelerate launch`!
|
||||
|
||||
|
||||
# Maintained Examples
|
||||
|
||||
Scripts can be used as examples of how to use TRL trainers. They are located in the [`trl/scripts`](https://github.com/huggingface/trl/blob/main/trl/scripts) directory. Additionally, we provide examples in the [`examples/scripts`](https://github.com/huggingface/trl/blob/main/examples/scripts) directory. These examples are maintained and tested regularly.
|
||||
|
||||
| File | Description |
|
||||
| ----------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| [`examples/scripts/alignprop.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/alignprop.py) | This script shows how to use the [`AlignPropTrainer`] to fine-tune a diffusion model. |
|
||||
| [`examples/scripts/bco.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/bco.py) | This script shows how to use the [`KTOTrainer`] with the BCO loss to fine-tune a model to increase instruction-following, truthfulness, honesty and helpfulness using the [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset. |
|
||||
| [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) | This script shows how to use the [`CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
|
||||
| [`examples/scripts/ddpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ddpo.py) | This script shows how to use the [`DDPOTrainer`] to fine-tune a stable diffusion model using reinforcement learning. |
|
||||
| [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a Vision Language Model to reduce hallucinations using the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset) dataset. |
|
||||
| [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
|
||||
| [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language |
|
||||
| [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. |
|
||||
| [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train a reward model on your own dataset. |
|
||||
| [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Vision Language Model in a chat setting. The script has only been tested with [LLaVA 1.5](https://huggingface.co/llava-hf/llava-1.5-7b-hf), [LLaVA 1.6](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf), and [Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) models so users may see unexpected behaviour in other model architectures. |
|
||||
|
||||
Here are also some easier-to-run colab notebooks that you can use to get started with TRL:
|
||||
|
||||
| File | Description |
|
||||
| --------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------- |
|
||||
| [`examples/notebooks/best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb) | This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO. |
|
||||
| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. |
|
||||
| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. |
|
||||
|
||||
|
||||
We also have some other examples that are less maintained but can be used as a reference:
|
||||
1. **[research_projects](https://github.com/huggingface/trl/tree/main/examples/research_projects)**: Check out this folder to find the scripts used for some research projects that used TRL (LM de-toxification, Stack-Llama, etc.)
|
||||
|
||||
|
||||
## Distributed training
|
||||
|
||||
All of the scripts can be run on multiple GPUs by providing the path of an 🤗 Accelerate config file when calling `accelerate launch`. To launch one of them on one or multiple GPUs, run the following command (swapping `{NUM_GPUS}` with the number of GPUs in your machine and `--all_arguments_of_the_script` with your arguments.)
|
||||
|
||||
```shell
|
||||
accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script
|
||||
```
|
||||
|
||||
You can also adjust the parameters of the 🤗 Accelerate config file to suit your needs (e.g. training in mixed precision).
|
||||
|
||||
### Distributed training with DeepSpeed
|
||||
|
||||
Most of the scripts can be run on multiple GPUs together with DeepSpeed ZeRO-{1,2,3} for efficient sharding of the optimizer states, gradients, and model weights. To do so, run following command (swapping `{NUM_GPUS}` with the number of GPUs in your machine, `--all_arguments_of_the_script` with your arguments, and `--deepspeed_config` with the path to the DeepSpeed config file such as `examples/deepspeed_configs/deepspeed_zero1.yaml`):
|
||||
|
||||
```shell
|
||||
accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero{1,2,3}.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script
|
||||
```
|
@ -1,98 +0,0 @@
|
||||
# Generalized Knowledge Distillation Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=gkd,trl)
|
||||
|
||||
## Overview
|
||||
|
||||
Generalized Knowledge Distillation (GKD) was proposed in [On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes](https://huggingface.co/papers/2306.13649) by Rishabh Agarwal, Nino Vieillard, Yongchao Zhou, Piotr Stanczyk, Sabela Ramos, Matthieu Geist, and Olivier Bachem.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
> Knowledge distillation (KD) is widely used for compressing a teacher model to reduce its inference cost and memory footprint, by training a smaller student model. However, current KD methods for auto-regressive sequence models suffer from distribution mismatch between output sequences seen during training and those generated by the student during inference. To address this issue, we introduce Generalized Knowledge Distillation (GKD). Instead of solely relying on a fixed set of output sequences, GKD trains the student on its self-generated output sequences by leveraging feedback from the teacher on such sequences. Unlike supervised KD approaches, GKD also offers the flexibility to employ alternative loss functions between the student and teacher, which can be useful when the student lacks the expressivity to mimic the teacher's distribution. Furthermore, GKD facilitates the seamless integration of distillation with RL fine-tuning (RLHF). We demonstrate the efficacy of GKD for distilling auto-regressive language models on summarization, translation, and arithmetic reasoning tasks, and task-agnostic distillation for instruction-tuning.
|
||||
|
||||
|
||||
The key aspects of GKD are:
|
||||
1. It addresses the train-inference distribution mismatch in auto-regressive sequence models by training the student model on its self-generated output sequences.
|
||||
2. GKD allows flexibility in choosing different divergence measures between student and teacher models via the generalized Jensen-Shannon Divergence (JSD), which can be useful when the student lacks the capacity to fully mimic the teacher.
|
||||
|
||||
This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif) and [Lewis Tunstall](https://huggingface.co/lewtun).
|
||||
|
||||
## Usage tips
|
||||
|
||||
The [`GKDTrainer`] is a wrapper around the [`SFTTrainer`] class that takes in a teacher model argument. It needs three parameters to be set via the [`GKDConfig`] namely:
|
||||
* `lmbda`: controls the student data fraction, i.e., the proportion of on-policy student-generated outputs. When `lmbda=0.0`, the loss reduces to supervised JSD where the student is trained with the token-level probabilities of the teacher. When `lmbda=1.0`, the loss reduces to on-policy JSD, where the student generates output sequences and token-specific feedback on these sequences from the teacher. For values in between [0, 1] it is random between the two based on the `lmbda` value for each batch.
|
||||
* `seq_kd`: controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated out). When `seq_kd=True` and `lmbda=0.0`, the loss reduces to supervised JSD, where the teacher generates output sequences and the student receives token-specific feedback on these sequences from the teacher.
|
||||
* `beta`: controls the interpolation in the generalized Jensen-Shannon Divergence. When `beta=0.0` the loss approximates forward KL divergence, while for `beta=1.0` the loss approximates reverse KL divergence. For values in between [0, 1] it interpolates between the two.
|
||||
|
||||
The authors find that on-policy data (high `lmbda`) performs better and the optimal `beta` varied depending on the task and evaluation method.
|
||||
|
||||
> [!WARNING]
|
||||
> Make sure that `attn_implementation="flash_attention_2"` when training [Gemma models](https://huggingface.co/models?other=gemma2). Otherwise you will encounter NaNs in the logits due to the [soft capping technique](https://huggingface.co/blog/gemma2#soft-capping-and-attention-implementations) adopted by this architecture.
|
||||
|
||||
The basic API is as follows:
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
from trl import GKDConfig, GKDTrainer
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
)
|
||||
|
||||
NUM_DUMMY_SAMPLES = 100
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
# The model to optimise
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
# The teacher model to calculate the KL divergence against
|
||||
teacher_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
|
||||
|
||||
train_dataset = Dataset.from_dict(
|
||||
{
|
||||
"messages": [
|
||||
[
|
||||
{"role": "user", "content": "Hi, how are you?"},
|
||||
{"role": "assistant", "content": "I'm great thanks"},
|
||||
]
|
||||
]
|
||||
* NUM_DUMMY_SAMPLES
|
||||
}
|
||||
)
|
||||
eval_dataset = Dataset.from_dict(
|
||||
{
|
||||
"messages": [
|
||||
[
|
||||
{"role": "user", "content": "What colour is the sky?"},
|
||||
{"role": "assistant", "content": "The sky is blue"},
|
||||
]
|
||||
]
|
||||
* NUM_DUMMY_SAMPLES
|
||||
}
|
||||
)
|
||||
|
||||
training_args = GKDConfig(output_dir="gkd-model", per_device_train_batch_size=1)
|
||||
trainer = GKDTrainer(
|
||||
model=model,
|
||||
teacher_model=teacher_model,
|
||||
args=training_args,
|
||||
processing_class=tokenizer,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### Expected dataset type
|
||||
|
||||
The dataset should be formatted as a list of "messages" where each message is a list of dictionaries with the following keys:
|
||||
* `role`: either `system`, `assistant` or `user`
|
||||
* `content`: the message content
|
||||
|
||||
|
||||
## GKDTrainer
|
||||
|
||||
[[autodoc]] GKDTrainer
|
||||
|
||||
## GKDConfig
|
||||
|
||||
[[autodoc]] GKDConfig
|
@ -1,253 +0,0 @@
|
||||
# GRPO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=grpo,trl)
|
||||
|
||||
## Overview
|
||||
|
||||
TRL supports the GRPO Trainer for training language models, as described in the paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300) by [Zhihong Shao](https://huggingface.co/syhia), [Peiyi Wang](https://huggingface.co/peiyiwang89), [Qihao Zhu](https://huggingface.co/zqh11), Runxin Xu, [Junxiao Song](https://huggingface.co/haha-point), Mingchuan Zhang, Y. K. Li, Y. Wu, [Daya Guo](https://huggingface.co/guoday).
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
> Mathematical reasoning poses a significant challenge for language models due to its complex and structured nature. In this paper, we introduce DeepSeekMath 7B, which continues pre-training DeepSeek-Coder-Base-v1.5 7B with 120B math-related tokens sourced from Common Crawl, together with natural language and code data. DeepSeekMath 7B has achieved an impressive score of 51.7% on the competition-level MATH benchmark without relying on external toolkits and voting techniques, approaching the performance level of Gemini-Ultra and GPT-4. Self-consistency over 64 samples from DeepSeekMath 7B achieves 60.9% on MATH. The mathematical reasoning capability of DeepSeekMath is attributed to two key factors: First, we harness the significant potential of publicly available web data through a meticulously engineered data selection pipeline. Second, we introduce Group Relative Policy Optimization (GRPO), a variant of Proximal Policy Optimization (PPO), that enhances mathematical reasoning abilities while concurrently optimizing the memory usage of PPO.
|
||||
|
||||
This post-training method was contributed by [Quentin Gallouédec](https://huggingface.co/qgallouedec).
|
||||
|
||||
## Quick start
|
||||
|
||||
This example demonstrates how to train a model using the GRPO method. We train a [Qwen 0.5B Instruct model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) with the prompts from the [TLDR dataset](https://huggingface.co/datasets/trl-lib/tldr) (completion column is ignored!). You can view the data in the dataset here:
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/trl-lib/tldr/embed/viewer/default/train?row=0"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
Below is the script to train the model.
|
||||
|
||||
```python
|
||||
# train_grpo.py
|
||||
from datasets import load_dataset
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
|
||||
dataset = load_dataset("trl-lib/tldr", split="train")
|
||||
|
||||
# Define the reward function, which rewards completions that are close to 20 characters
|
||||
def reward_len(completions, **kwargs):
|
||||
return [-abs(20 - len(completion)) for completion in completions]
|
||||
|
||||
training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", logging_steps=10)
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2-0.5B-Instruct",
|
||||
reward_funcs=reward_len,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Execute the script using the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch train_grpo.py
|
||||
```
|
||||
|
||||
Distributed across 8 GPUs, the training takes approximately 1 day.
|
||||
|
||||

|
||||
|
||||
## Looking deeper into the GRPO method
|
||||
|
||||
GRPO is an online learning algorithm, meaning it improves iteratively by using the data generated by the trained model itself during training. The intuition behind GRPO objective is to maximize the advantage of the generated completions, while ensuring that the model remains close to the reference policy. To understand how GRPO works, it can be broken down into four main steps: **Generating completions**, **computing the advantage**, **estimating the KL divergence**, and **computing the loss**.
|
||||
|
||||

|
||||
|
||||
### Generating completions
|
||||
|
||||
At each training step, we sample a batch of prompts and generate a set of \\( G \\) completions for each prompt (denoted as \\( o_i \\)).
|
||||
|
||||
### Computing the advantage
|
||||
|
||||
For each of the \\( G \\) sequences, we compute the reward using a reward model. To align with the comparative nature of reward models—typically trained on datasets of comparisons between outputs for the same question—the advantage is calculated to reflect these relative comparisons. It is normalized as follows:
|
||||
|
||||
$$\hat{A}_{i,t} = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})}$$
|
||||
|
||||
This approach gives the method its name: **Group Relative Policy Optimization (GRPO)**.
|
||||
|
||||
### Estimating the KL divergence
|
||||
|
||||
KL divergence is estimated using the approximator introduced by [Schulman et al. (2020)](http://joschu.net/blog/kl-approx.html). The approximator is defined as follows:
|
||||
|
||||
$$\mathbb{D}_{\text{KL}}\left[\pi_\theta \|\pi_{\text{ref}}\right] = \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i,<t})}{\pi_\theta(o_{i,t} \mid q, o_{i,<t})} - \log \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i,<t})}{\pi_\theta(o_{i,t} \mid q, o_{i,<t})} - 1,
|
||||
$$
|
||||
|
||||
### Computing the loss
|
||||
|
||||
The objective is to maximize the advantage while ensuring that the model remains close to the reference policy. Consequently, the loss is defined as follows:
|
||||
|
||||
$$
|
||||
\mathcal{L}_{\text{GRPO}}(\theta) = -\frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \left[ \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],
|
||||
$$
|
||||
|
||||
where the first term represents the scaled advantage and the second term penalizes deviations from the reference policy through KL divergence.
|
||||
|
||||
In the original paper, this formulation is generalized to account for multiple updates after each generation by leveraging the **clipped surrogate objective**:
|
||||
|
||||
$$
|
||||
\mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \left[ \min \left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})} \hat{A}_{i,t}, \, \text{clip}\left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}, 1 - \epsilon, 1 + \epsilon \right) \hat{A}_{i,t} \right) - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],
|
||||
$$
|
||||
|
||||
where \\(\text{clip}(\cdot, 1 - \epsilon, 1 + \epsilon) \\) ensures that updates do not deviate excessively from the reference policy by bounding the policy ratio between \\( 1 - \epsilon \\) and \\( 1 + \epsilon \\).
|
||||
In TRL though, as in the original paper, we only do one update per generation, so we can simplify the loss to the first form.
|
||||
|
||||
## Logged metrics
|
||||
|
||||
The GRPO Trainer logs the following metrics:
|
||||
|
||||
- `completion_length`: The average completion length.
|
||||
- `reward/{reward_func_name}`: The reward computed by each reward function.
|
||||
- `reward`: The average reward.
|
||||
- `reward_std` : The average standard deviation within reward groups.
|
||||
- `kl` : The average KL divergence between the model and the reference model calculated on completions.
|
||||
|
||||
## Customization
|
||||
|
||||
## Speed up training with vLLM-powered generation
|
||||
|
||||
Generation is often the main bottleneck that makes training slow with online methods. To accelerate generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation. To enable it, pass `use_vllm=True` in the training arguments.
|
||||
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(..., use_vllm=True)
|
||||
```
|
||||
|
||||
For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods).
|
||||
|
||||
### Using a custom reward function
|
||||
|
||||
The [`GRPOTrainer`] supports using custom reward functions instead of dense reward models. To ensure compatibility, your reward function must satisfy the following requirements:
|
||||
|
||||
1. **Input arguments**:
|
||||
- The function must accept the following as keyword arguments:
|
||||
- `prompts` (contains the prompts),
|
||||
- `completions` (contains the generated completions),
|
||||
- All columns names (but `prompt`) that the dataset may have. For example, if the dataset contains a column named `ground_truth`, the function will be called with `ground_truth` as a keyword argument.
|
||||
|
||||
The easiest way to comply with this requirement is to use `**kwargs` in the function signature.
|
||||
- Depending on the dataset format, the input will vary:
|
||||
- For [standard format](dataset_formats#standard), `prompts` and `completions` will be lists of strings.
|
||||
- For [conversational format](dataset_formats#conversational), `prompts` and `completions` will be lists of message dictionaries.
|
||||
|
||||
2. **Return value**: The function must return a list of floats. Each float represents the reward corresponding to a single completion.
|
||||
|
||||
#### Example 1: Reward longer completions
|
||||
|
||||
Below is an example of a reward function for a standard format that rewards longer completions:
|
||||
|
||||
```python
|
||||
def reward_func(completions, **kwargs):
|
||||
"""Reward function that gives higher scores to longer completions."""
|
||||
return [float(len(completion)) for completion in completions]
|
||||
```
|
||||
|
||||
You can test it as follows:
|
||||
|
||||
```python
|
||||
>>> prompts = ["The sky is", "The sun is"]
|
||||
>>> completions = [" blue.", " in the sky."]
|
||||
>>> print(reward_func(prompts=prompts, completions=completions))
|
||||
[6.0, 12.0]
|
||||
```
|
||||
|
||||
#### Example 2: Reward completions with specific format
|
||||
|
||||
Below is an example of a reward function that checks if the completion has a specific format. This example is inspired by the _format reward_ function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948).
|
||||
It is designed for conversational format, where prompts and completions consist of structured messages.
|
||||
|
||||
```python
|
||||
import re
|
||||
|
||||
def format_reward_func(completions, **kwargs):
|
||||
"""Reward function that checks if the completion has a specific format."""
|
||||
pattern = r"^<think>.*?</think><answer>.*?</answer>$"
|
||||
completion_contents = [completion[0]["content"] for completion in completions]
|
||||
matches = [re.match(pattern, content) for content in completion_contents]
|
||||
return [1.0 if match else 0.0 for match in matches]
|
||||
```
|
||||
|
||||
You can test this function as follows:
|
||||
|
||||
```python
|
||||
>>> prompts = [
|
||||
... [{"role": "assistant", "content": "What is the result of (1 + 2) * 4?"}],
|
||||
... [{"role": "assistant", "content": "What is the result of (3 + 1) * 2?"}],
|
||||
... ]
|
||||
>>> completions = [
|
||||
... [{"role": "assistant", "content": "<think>The sum of 1 and 2 is 3, which we multiply by 4 to get 12.</think><answer>(1 + 2) * 4 = 12</answer>"}],
|
||||
... [{"role": "assistant", "content": "The sum of 3 and 1 is 4, which we multiply by 2 to get 8. So (3 + 1) * 2 = 8."}],
|
||||
... ]
|
||||
>>> format_reward_func(prompts=prompts, completions=completions)
|
||||
[1.0, 0.0]
|
||||
```
|
||||
|
||||
#### Example 3: Reward completions based on a reference
|
||||
|
||||
Below is an example of a reward function that checks if the completion is correct. This example is inspired by the _accuracy reward_ function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948).
|
||||
This example is designed for [standard format](dataset_formats#standard), where the dataset contains a column named `ground_truth`.
|
||||
|
||||
```python
|
||||
import re
|
||||
|
||||
def reward_func(completions, ground_truth, **kwargs):
|
||||
# Regular expression to capture content inside \boxed{}
|
||||
matches = [re.search(r"\\boxed\{(.*?)\}", completion) for completion in completions]
|
||||
contents = [match.group(1) if match else "" for match in matches]
|
||||
# Reward 1 if the content is the same as the ground truth, 0 otherwise
|
||||
return [1.0 if c == gt else 0.0 for c, gt in zip(contents, ground_truth)]
|
||||
```
|
||||
|
||||
You can test this function as follows:
|
||||
|
||||
```python
|
||||
>>> prompts = ["Problem: Solve the equation $2x + 3 = 7$. Solution:", "Problem: Solve the equation $3x - 5 = 10$."]
|
||||
>>> completions = [r" The solution is \boxed{2}.", r" The solution is \boxed{6}."]
|
||||
>>> ground_truth = ["2", "5"]
|
||||
>>> reward_func(prompts=prompts, completions=completions, ground_truth=ground_truth)
|
||||
[1.0, 0.0]
|
||||
```
|
||||
|
||||
#### Passing the reward function to the trainer
|
||||
|
||||
To use your custom reward function, pass it to the [`GRPOTrainer`] as follows:
|
||||
|
||||
```python
|
||||
from trl import GRPOTrainer
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
reward_funcs=reward_func,
|
||||
...,
|
||||
)
|
||||
```
|
||||
|
||||
If you have multiple reward functions, you can pass them as a list:
|
||||
|
||||
```python
|
||||
from trl import GRPOTrainer
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
reward_funcs=[reward_func1, reward_func2],
|
||||
...,
|
||||
)
|
||||
```
|
||||
and the reward will be computed as the sum of the rewards from each function, or the weighted sum if `reward_weights` is provided in the config.
|
||||
|
||||
Note that [`GRPOTrainer`] supports multiple reward functions of different types. See the parameters documentation for more details.
|
||||
|
||||
## GRPOTrainer
|
||||
|
||||
[[autodoc]] GRPOTrainer
|
||||
|
||||
## GRPOConfig
|
||||
|
||||
[[autodoc]] GRPOConfig
|
@ -1,65 +0,0 @@
|
||||
# Training FAQ
|
||||
|
||||
## What Metrics Should I Look at?
|
||||
|
||||
When performing classical supervised fine-tuning of language models, the loss (especially the validation loss) serves as a good indicator of the training progress. However, in Reinforcement Learning (RL), the loss becomes less informative about the model's performance, and its value may fluctuate while the actual performance improves.
|
||||
|
||||
To address this, we recommend focusing on two key metrics first:
|
||||
|
||||
**Mean Reward**: The primary goal is to maximize the reward achieved by the model during RL training.
|
||||
**Objective KL Divergence**: KL divergence (Kullback-Leibler divergence) measures the dissimilarity between two probability distributions. In the context of RL training, we use it to quantify the difference between the current model and a reference model. Ideally, we want to keep the KL divergence between 0 and 10 to ensure the model's generated text remains close to what the reference model produces.
|
||||
|
||||
However, there are more metrics that can be useful for debugging, checkout the [logging section](logging).
|
||||
|
||||
## Why Do We Use a Reference Model, and What's the Purpose of KL Divergence?
|
||||
|
||||
When training RL models, optimizing solely for reward may lead to unexpected behaviors, where the model exploits the environment in ways that don't align with good language generation. In the case of RLHF, we use a reward model trained to predict whether a generated text is highly ranked by humans.
|
||||
|
||||
However, the RL model being optimized against the reward model may learn patterns that yield high reward but do not represent good language. This can result in extreme cases where the model generates texts with excessive exclamation marks or emojis to maximize the reward. In some worst-case scenarios, the model may generate patterns completely unrelated to natural language yet receive high rewards, similar to adversarial attacks.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/kl-example.png">
|
||||
<p style="text-align: center;"> <b>Figure:</b> Samples without a KL penalty from <a href="https://huggingface.co/papers/1909.08593">https://huggingface.co/papers/1909.08593</a>. </p>
|
||||
</div>
|
||||
|
||||
To address this issue, we add a penalty to the reward function based on the KL divergence between the current model and the reference model. By doing this, we encourage the model to stay close to what the reference model generates.
|
||||
|
||||
## What Is the Concern with Negative KL Divergence?
|
||||
|
||||
If you generate text by purely sampling from the model distribution things work fine in general. But when you use the `generate` method there are a few caveats because it does not always purely sample depending on the settings which can cause KL-divergence to go negative. Essentially when the active model achieves `log_p_token_active < log_p_token_ref` we get negative KL-div. This can happen in a several cases:
|
||||
|
||||
- **top-k sampling**: the model can smooth out the probability distribution causing the top-k tokens having a smaller probability than those of the reference model but they still are selected
|
||||
- **min_length**: this ignores the EOS token until `min_length` is reached. thus the model can assign a very low log prob to the EOS token and very high probs to all others until min_length is reached
|
||||
|
||||
These are just a few examples. Why is negative KL an issue? The total reward `R` is computed `R = r - beta * KL` so if the model can learn how to drive KL-divergence negative it effectively gets a positive reward. In many cases it can be much easier to exploit such a bug in the generation than actually learning the reward function. In addition the KL can become arbitrarily small thus the actual reward can be very small compared to it.
|
||||
|
||||
So how should you generate text for PPO training? Let's have a look!
|
||||
|
||||
## How to generate text for training?
|
||||
|
||||
In order to avoid the KL issues described above we recommend to use the following settings:
|
||||
|
||||
```python
|
||||
generation_kwargs = {
|
||||
"min_length": -1, # don't ignore the EOS token (see above)
|
||||
"top_k": 0.0, # no top-k sampling
|
||||
"top_p": 1.0, # no nucleus sampling
|
||||
"do_sample": True, # yes, we want to sample
|
||||
"pad_token_id": tokenizer.eos_token_id, # most decoder models don't have a padding token - use EOS token instead
|
||||
"max_new_tokens": 32, # specify how many tokens you want to generate at most
|
||||
}
|
||||
```
|
||||
|
||||
With these settings we usually don't encounter any issues. You can also experiments with other settings but if you encounter issues with negative KL-divergence try to go back to these and see if they persist.
|
||||
|
||||
## How can debug your own use-case?
|
||||
|
||||
Debugging the RL pipeline can be challenging due to its complexity. Here are some tips and suggestions to make the process easier:
|
||||
|
||||
- **Start from a working example**: Begin with a working example from the trl repository and gradually modify it to fit your specific use-case. Changing everything at once can make it difficult to identify the source of potential issues. For example, you can start by replacing the model in the example and once you figure out the best hyperparameters try to switch to your dataset and reward model. If you change everything at once you won't know where a potential problem comes from.
|
||||
- **Start small, scale later**: Training large models can be very slow and take several hours or days until you see any improvement. For debugging this is not a convenient timescale so try to use small model variants during the development phase and scale up once that works. That being said you sometimes have to be careful as small models might not have the capacity to solve a complicated task either.
|
||||
- **Start simple**: Try to start with a minimal example and build complexity from there. Your use-case might require for example a complicated reward function consisting of many different rewards - try to use one signal first and see if you can optimize that and then add more complexity after that.
|
||||
- **Inspect the generations**: It's always a good idea to inspect what the model is generating. Maybe there is a bug in your post-processing or your prompt. Due to bad settings you might cut-off generations too soon. These things are very hard to see on the metrics but very obvious if you look at the generations.
|
||||
- **Inspect the reward model**: If you reward is not improving over time maybe there's an issue with the reward model. You can look at extreme cases to see if it does what it should: e.g. in the sentiment case you can check if simple positive and negative examples really get different rewards. And you can look at the distribution of your dataset. Finally, maybe the reward is dominated by the query which the model can't affect so you might need to normalize this (e.g. reward of query+response minus reward of the query).
|
||||
|
||||
These are just a few tips that we find helpful - if you have more useful tricks feel free to open a PR to add them as well!
|
@ -1,74 +0,0 @@
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png">
|
||||
</div>
|
||||
|
||||
# TRL - Transformer Reinforcement Learning
|
||||
|
||||
TRL is a full stack library where we provide a set of tools to train transformer language models with Reinforcement Learning, from the Supervised Fine-tuning step (SFT), Reward Modeling step (RM) to the Proximal Policy Optimization (PPO) step.
|
||||
The library is integrated with 🤗 [transformers](https://github.com/huggingface/transformers).
|
||||
|
||||
## Learn
|
||||
|
||||
Learn post-training with TRL and other libraries in 🤗 [smol course](https://github.com/huggingface/smol-course).
|
||||
|
||||
## API documentation
|
||||
|
||||
- [Model Classes](models): *A brief overview of what each public model class does.*
|
||||
- [`SFTTrainer`](sft_trainer): *Supervise Fine-tune your model easily with `SFTTrainer`*
|
||||
- [`RewardTrainer`](reward_trainer): *Train easily your reward model using `RewardTrainer`.*
|
||||
- [`PPOTrainer`](ppo_trainer): *Further fine-tune the supervised fine-tuned model using PPO algorithm*
|
||||
- [Best-of-N Sampling](best-of-n): *Use best of n sampling as an alternative way to sample predictions from your active model*
|
||||
- [`DPOTrainer`](dpo_trainer): *Direct Preference Optimization training using `DPOTrainer`.*
|
||||
- [`TextEnvironment`](text_environments): *Text environment to train your model using tools with RL.*
|
||||
|
||||
## Examples
|
||||
|
||||
- [Sentiment Tuning](sentiment_tuning): *Fine tune your model to generate positive movie contents*
|
||||
- [Training with PEFT](lora_tuning_peft): *Memory efficient RLHF training using adapters with PEFT*
|
||||
- [Detoxifying LLMs](detoxifying_a_lm): *Detoxify your language model through RLHF*
|
||||
- [StackLlama](using_llama_models): *End-to-end RLHF training of a Llama model on Stack exchange dataset*
|
||||
- [Learning with Tools](learning_tools): *Walkthrough of using `TextEnvironments`*
|
||||
- [Multi-Adapter Training](multi_adapter_rl): *Use a single base model and multiple adapters for memory efficient end-to-end training*
|
||||
|
||||
|
||||
## Blog posts
|
||||
|
||||
<div class="mt-10">
|
||||
<div class="w-full flex flex-col space-y-4 md:space-y-0 md:grid md:grid-cols-2 md:gap-y-4 md:gap-x-5">
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/dpo_vlm">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/dpo_vlm/thumbnail.png" alt="thumbnail" class="mt-0">
|
||||
<p class="text-gray-500 text-sm">Published on July 10, 2024</p>
|
||||
<p class="text-gray-700">Preference Optimization for Vision Language Models with TRL</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/putting_rl_back_in_rlhf_with_rloo">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/putting_rl_back_in_rlhf_with_rloo/thumbnail.png" alt="thumbnail" class="mt-0">
|
||||
<p class="text-gray-500 text-sm">Published on June 12, 2024</p>
|
||||
<p class="text-gray-700">Putting RL back in RLHF</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/trl-ddpo">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/166_trl_ddpo/thumbnail.png" alt="thumbnail" class="mt-0">
|
||||
<p class="text-gray-500 text-sm">Published on September 29, 2023</p>
|
||||
<p class="text-gray-700">Finetune Stable Diffusion Models with DDPO via TRL</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/dpo-trl">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/157_dpo_trl/dpo_thumbnail.png" alt="thumbnail" class="mt-0">
|
||||
<p class="text-gray-500 text-sm">Published on August 8, 2023</p>
|
||||
<p class="text-gray-700">Fine-tune Llama 2 with DPO</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/stackllama">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/138_stackllama/thumbnail.png" alt="thumbnail" class="mt-0">
|
||||
<p class="text-gray-500 text-sm">Published on April 5, 2023</p>
|
||||
<p class="text-gray-700">StackLLaMA: A hands-on guide to train LLaMA with RLHF</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/trl-peft">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/133_trl_peft/thumbnail.png" alt="thumbnail" class="mt-0">
|
||||
<p class="text-gray-500 text-sm">Published on March 9, 2023</p>
|
||||
<p class="text-gray-700">Fine-tuning 20B LLMs with RLHF on a 24GB consumer GPU</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/rlhf">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/120_rlhf/thumbnail.png" alt="thumbnail" class="mt-0">
|
||||
<p class="text-gray-500 text-sm">Published on December 9, 2022</p>
|
||||
<p class="text-gray-700">Illustrating Reinforcement Learning from Human Feedback</p>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
@ -1,39 +0,0 @@
|
||||
# Installation
|
||||
You can install TRL either from PyPI or from source:
|
||||
|
||||
## PyPI
|
||||
Install the library with pip or [uv](https://docs.astral.sh/uv/):
|
||||
|
||||
<hfoptions id="install">
|
||||
<hfoption id="uv">
|
||||
|
||||
uv is a fast Rust-based Python package and project manager. Refer to [Installation](https://docs.astral.sh/uv/getting-started/installation/) for installation instructions), .
|
||||
|
||||
```bash
|
||||
uv pip install trl
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="pip">
|
||||
|
||||
```bash
|
||||
pip install trl
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Source
|
||||
You can also install the latest version from source. First clone the repo and then run the installation with `pip`:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/trl.git
|
||||
cd trl/
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
If you want the development install you can replace the pip install with the following:
|
||||
|
||||
```bash
|
||||
pip install -e ".[dev]"
|
||||
```
|
@ -1,57 +0,0 @@
|
||||
# Iterative Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=iterative-sft,trl)
|
||||
|
||||
|
||||
Iterative fine-tuning is a training method that enables to perform custom actions (generation and filtering for example) between optimization steps. In TRL we provide an easy-to-use API to fine-tune your models in an iterative way in just a few lines of code.
|
||||
|
||||
## Usage
|
||||
|
||||
To get started quickly, instantiate an instance a model, and a tokenizer.
|
||||
|
||||
```python
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
trainer = IterativeSFTTrainer(
|
||||
model,
|
||||
tokenizer
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
You have the choice to either provide a list of strings or a list of tensors to the step function.
|
||||
|
||||
#### Using a list of tensors as input:
|
||||
|
||||
```python
|
||||
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask
|
||||
}
|
||||
|
||||
trainer.step(**inputs)
|
||||
|
||||
```
|
||||
|
||||
#### Using a list of strings as input:
|
||||
|
||||
```python
|
||||
|
||||
inputs = {
|
||||
"texts": texts
|
||||
}
|
||||
|
||||
trainer.step(**inputs)
|
||||
|
||||
```
|
||||
|
||||
For causal language models, labels will automatically be created from input_ids or from texts. When using sequence to sequence models you will have to provide your own labels or text_labels.
|
||||
|
||||
## IterativeTrainer
|
||||
|
||||
[[autodoc]] IterativeSFTTrainer
|
@ -1,89 +0,0 @@
|
||||
# Judges
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
TRL Judges is an experimental API which is subject to change at any time.
|
||||
|
||||
</Tip>
|
||||
|
||||
TRL provides judges to easily compare two completions.
|
||||
|
||||
Make sure to have installed the required dependencies by running:
|
||||
|
||||
```bash
|
||||
pip install trl[judges]
|
||||
```
|
||||
|
||||
## Using the provided judges
|
||||
|
||||
TRL provides several judges out of the box. For example, you can use the `HfPairwiseJudge` to compare two completions using a pre-trained model from the Hugging Face model hub:
|
||||
|
||||
```python
|
||||
from trl import HfPairwiseJudge
|
||||
|
||||
judge = HfPairwiseJudge()
|
||||
judge.judge(
|
||||
prompts=["What is the capital of France?", "What is the biggest planet in the solar system?"],
|
||||
completions=[["Paris", "Lyon"], ["Saturn", "Jupiter"]],
|
||||
) # Outputs: [0, 1]
|
||||
```
|
||||
|
||||
## Define your own judge
|
||||
|
||||
To define your own judge, we provide several base classes that you can subclass. For rank-based judges, you need to subclass [`BaseRankJudge`] and implement the [`BaseRankJudge.judge`] method. For pairwise judges, you need to subclass [`BasePairJudge`] and implement the [`BasePairJudge.judge`] method. If you want to define a judge that doesn't fit into these categories, you need to subclass [`BaseJudge`] and implement the [`BaseJudge.judge`] method.
|
||||
|
||||
As an example, let's define a pairwise judge that prefers shorter completions:
|
||||
|
||||
```python
|
||||
from trl import BasePairwiseJudge
|
||||
|
||||
class PrefersShorterJudge(BasePairwiseJudge):
|
||||
def judge(self, prompts, completions, shuffle_order=False):
|
||||
return [0 if len(completion[0]) > len(completion[1]) else 1 for completion in completions]
|
||||
```
|
||||
|
||||
You can then use this judge as follows:
|
||||
|
||||
```python
|
||||
judge = PrefersShorterJudge()
|
||||
judge.judge(
|
||||
prompts=["What is the capital of France?", "What is the biggest planet in the solar system?"],
|
||||
completions=[["Paris", "The capital of France is Paris."], ["Jupiter is the biggest planet in the solar system.", "Jupiter"]],
|
||||
) # Outputs: [0, 1]
|
||||
```
|
||||
|
||||
## Provided judges
|
||||
|
||||
### PairRMJudge
|
||||
|
||||
[[autodoc]] PairRMJudge
|
||||
|
||||
### HfPairwiseJudge
|
||||
|
||||
[[autodoc]] HfPairwiseJudge
|
||||
|
||||
### OpenAIPairwiseJudge
|
||||
|
||||
[[autodoc]] OpenAIPairwiseJudge
|
||||
|
||||
### AllTrueJudge
|
||||
|
||||
[[autodoc]] AllTrueJudge
|
||||
|
||||
## Base classes
|
||||
|
||||
### BaseJudge
|
||||
|
||||
[[autodoc]] BaseJudge
|
||||
|
||||
### BaseBinaryJudge
|
||||
|
||||
[[autodoc]] BaseBinaryJudge
|
||||
|
||||
### BaseRankJudge
|
||||
|
||||
[[autodoc]] BaseRankJudge
|
||||
|
||||
### BasePairwiseJudge
|
||||
|
||||
[[autodoc]] BasePairwiseJudge
|
@ -1,139 +0,0 @@
|
||||
# KTO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=kto,trl)
|
||||
|
||||
## Overview
|
||||
|
||||
Kahneman-Tversky Optimization (KTO) was introduced in [KTO: Model Alignment as Prospect Theoretic Optimization](https://huggingface.co/papers/2402.01306) by [Kawin Ethayarajh](https://huggingface.co/kawine), [Winnie Xu](https://huggingface.co/xwinxu), [Niklas Muennighoff](https://huggingface.co/Muennighoff), Dan Jurafsky, [Douwe Kiela](https://huggingface.co/douwekiela).
|
||||
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
> Kahneman & Tversky's prospect theory tells us that humans perceive random variables in a biased but well-defined manner; for example, humans are famously loss-averse. We show that objectives for aligning LLMs with human feedback implicitly incorporate many of these biases -- the success of these objectives (e.g., DPO) over cross-entropy minimization can partly be ascribed to them being human-aware loss functions (HALOs). However, the utility functions these methods attribute to humans still differ from those in the prospect theory literature. Using a Kahneman-Tversky model of human utility, we propose a HALO that directly maximizes the utility of generations instead of maximizing the log-likelihood of preferences, as current methods do. We call this approach Kahneman-Tversky Optimization (KTO), and it matches or exceeds the performance of preference-based methods at scales from 1B to 30B. Crucially, KTO does not need preferences -- only a binary signal of whether an output is desirable or undesirable for a given input. This makes it far easier to use in the real world, where preference data is scarce and expensive.
|
||||
|
||||
The official code can be found in [ContextualAI/HALOs](https://github.com/ContextualAI/HALOs).
|
||||
|
||||
This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif), [Younes Belkada](https://huggingface.co/ybelkada), [Lewis Tunstall](https://huggingface.co/lewtun) and Pablo Vicente.
|
||||
|
||||
## Quick start
|
||||
|
||||
This example demonstrates how to train a model using the KTO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [KTO Mix 14k](https://huggingface.co/datasets/trl-lib/kto-mix-14k). You can view the data in the dataset here:
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/trl-lib/kto-mix-14k/embed/viewer/default/train?row=0"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
Below is the script to train the model:
|
||||
|
||||
```python
|
||||
# train_kto.py
|
||||
from datasets import load_dataset
|
||||
from trl import KTOConfig, KTOTrainer
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
train_dataset = load_dataset("trl-lib/kto-mix-14k", split="train")
|
||||
|
||||
training_args = KTOConfig(output_dir="Qwen2-0.5B-KTO", logging_steps=10)
|
||||
trainer = KTOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Execute the script using the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch train_kto.py
|
||||
```
|
||||
|
||||
Distributed across 8 x H100 GPUs, the training takes approximately 30 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time.
|
||||
|
||||

|
||||
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-KTO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
|
||||
|
||||
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-KTO
|
||||
<strong><span style="color: red;"><quentin_gallouedec>:</span></strong>
|
||||
What is the best programming language?
|
||||
|
||||
<strong><span style="color: blue;"><trl-lib/Qwen2-0.5B-KTO>:</span></strong>
|
||||
The best programming language can vary depending on individual preferences, industry-specific requirements, technical skills, and familiarity with the specific use case or task. Here are some widely-used programming languages that have been noted as popular and widely used:
|
||||
|
||||
Here are some other factors to consider when choosing a programming language for a project:
|
||||
|
||||
<strong><span style="color: green;">1</span> JavaScript</strong>: JavaScript is at the heart of the web and can be used for building web applications, APIs, and interactive front-end applications like frameworks like React and Angular. It's similar to C, C++, and F# in syntax structure and is accessible and easy to learn, making it a popular choice for beginners and professionals alike.
|
||||
<strong><span style="color: green;">2</span> Java</strong>: Known for its object-oriented programming (OOP) and support for Java 8 and .NET, Java is used for developing enterprise-level software applications, high-performance games, as well as mobile apps, game development, and desktop applications.
|
||||
<strong><span style="color: green;">3</span> C++</strong>: Known for its flexibility and scalability, C++ offers comprehensive object-oriented programming and is a popular choice for high-performance computing and other technical fields. It's a powerful platform for building real-world applications and games at scale.
|
||||
<strong><span style="color: green;">4</span> Python</strong>: Developed by Guido van Rossum in 1991, Python is a high-level, interpreted, and dynamically typed language known for its simplicity, readability, and versatility.
|
||||
</code></pre>
|
||||
|
||||
## Expected dataset format
|
||||
|
||||
KTO requires an [unpaired preference dataset](dataset_formats#unpaired-preference). Alternatively, you can provide a *paired* preference dataset (also known simply as a *preference dataset*). In this case, the trainer will automatically convert it to an unpaired format by separating the chosen and rejected responses, assigning `label = True` to the chosen completions and `label = False` to the rejected ones.
|
||||
|
||||
The [`KTOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
|
||||
|
||||
In theory, the dataset should contain at least one chosen and one rejected completion. However, some users have successfully run KTO using *only* chosen or only rejected data. If using only rejected data, it is advisable to adopt a conservative learning rate.
|
||||
|
||||
## Example script
|
||||
|
||||
We provide an example script to train a model using the KTO method. The script is available in [`trl/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/kto.py)
|
||||
|
||||
To test the KTO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/kto-mix-14k), run the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch trl/scripts/kto.py \
|
||||
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
|
||||
--dataset_name trl-lib/kto-mix-14k \
|
||||
--num_train_epochs 1 \
|
||||
--logging_steps 25 \
|
||||
--output_dir Qwen2-0.5B-KTO
|
||||
```
|
||||
|
||||
## Usage tips
|
||||
|
||||
### For Mixture of Experts Models: Enabling the auxiliary loss
|
||||
|
||||
MOEs are the most efficient if the load is about equally distributed between experts.
|
||||
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
|
||||
|
||||
This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]).
|
||||
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.
|
||||
|
||||
|
||||
### Batch size recommendations
|
||||
|
||||
Use a per-step batch size that is at least 4, and an effective batch size between 16 and 128. Even if your effective batch size is large, if your per-step batch size is poor, then the KL estimate in KTO will be poor.
|
||||
|
||||
### Learning rate recommendations
|
||||
|
||||
Each choice of `beta` has a maximum learning rate it can tolerate before learning performance degrades. For the default setting of `beta = 0.1`, the learning rate should typically not exceed `1e-6` for most models. As `beta` decreases, the learning rate should also be reduced accordingly. In general, we strongly recommend keeping the learning rate between `5e-7` and `5e-6`. Even with small datasets, we advise against using a learning rate outside this range. Instead, opt for more epochs to achieve better results.
|
||||
|
||||
### Imbalanced data
|
||||
|
||||
The `desirable_weight` and `undesirable_weight` of the [`KTOConfig`] refer to the weights placed on the losses for desirable/positive and undesirable/negative examples.
|
||||
By default, they are both 1. However, if you have more of one or the other, then you should upweight the less common type such that the ratio of (`desirable_weight` \\(\times\\) number of positives) to (`undesirable_weight` \\(\times\\) number of negatives) is in the range 1:1 to 4:3.
|
||||
|
||||
## Logged metrics
|
||||
|
||||
While training and evaluating we record the following reward metrics:
|
||||
|
||||
- `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
|
||||
- `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
|
||||
- `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
|
||||
- `logps/chosen`: the mean log probabilities of the chosen completions
|
||||
- `logps/rejected`: the mean log probabilities of the rejected completions
|
||||
- `logits/chosen`: the mean logits of the chosen completions
|
||||
- `logits/rejected`: the mean logits of the rejected completions
|
||||
- `kl`: the KL divergence between the policy model and the reference model
|
||||
|
||||
## KTOTrainer
|
||||
|
||||
[[autodoc]] KTOTrainer
|
||||
|
||||
## KTOConfig
|
||||
|
||||
[[autodoc]] KTOConfig
|
@ -1,233 +0,0 @@
|
||||
# Learning Tools (Experimental 🧪)
|
||||
|
||||
Using Large Language Models (LLMs) with tools has been a popular topic recently with awesome works such as [ToolFormer](https://huggingface.co/papers/2302.04761) and [ToolBench](https://huggingface.co/papers/2305.16504). In TRL, we provide a simple example of how to teach LLM to use tools with reinforcement learning.
|
||||
|
||||
|
||||
Here's an overview of the scripts in the [trl repository](https://github.com/lvwerra/trl/tree/main/examples/research_projects/tools):
|
||||
|
||||
| File | Description |
|
||||
|---|---|
|
||||
| [`calculator.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/calculator.py) | Script to train LLM to use a calculator with reinforcement learning. |
|
||||
| [`triviaqa.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/triviaqa.py) | Script to train LLM to use a wiki tool to answer questions. |
|
||||
| [`python_interpreter.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/python_interpreter.py) | Script to train LLM to use python interpreter to solve math puzzles. |
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Note that the scripts above rely heavily on the `TextEnvironment` API which is still under active development. The API may change in the future. Please see [`TextEnvironment`](text_environment) for the related docs.
|
||||
</Tip>
|
||||
|
||||
|
||||
## Learning to Use a Calculator
|
||||
|
||||
|
||||
The rough idea is as follows:
|
||||
|
||||
1. Load a tool such as [ybelkada/simple-calculator](https://huggingface.co/spaces/ybelkada/simple-calculator) that parse a text calculation like `"14 + 34"` and return the calculated number:
|
||||
```python
|
||||
from transformers import AutoTokenizer, load_tool
|
||||
tool = load_tool("ybelkada/simple-calculator")
|
||||
tool_fn = lambda text: str(round(float(tool(text)), 2)) # rounding to 2 decimal places
|
||||
```
|
||||
1. Define a reward function that returns a positive reward if the tool returns the correct answer. In the script we create a dummy reward function like `reward_fn = lambda x: 1`, but we override the rewards directly later.
|
||||
1. Create a prompt on how to use the tools
|
||||
```python
|
||||
# system prompt
|
||||
prompt = """\
|
||||
What is 13.1-3?
|
||||
|
||||
<request><SimpleCalculatorTool>13.1-3<call>10.1<response>
|
||||
|
||||
Result=10.1<submit>
|
||||
|
||||
What is 4*3?
|
||||
|
||||
<request><SimpleCalculatorTool>4*3<call>12<response>
|
||||
|
||||
Result=12<submit>
|
||||
|
||||
What is 12.1+1?
|
||||
|
||||
<request><SimpleCalculatorTool>12.1+1<call>13.1<response>
|
||||
|
||||
Result=13.1<submit>
|
||||
|
||||
What is 12.1-20?
|
||||
|
||||
<request><SimpleCalculatorTool>12.1-20<call>-7.9<response>
|
||||
|
||||
Result=-7.9<submit>"""
|
||||
```
|
||||
3. Create a `trl.TextEnvironment` with the model
|
||||
```python
|
||||
env = TextEnvironment(
|
||||
model,
|
||||
tokenizer,
|
||||
{"SimpleCalculatorTool": tool_fn},
|
||||
reward_fn,
|
||||
prompt,
|
||||
generation_kwargs=generation_kwargs,
|
||||
)
|
||||
```
|
||||
4. Then generate some data such as `tasks = ["\n\nWhat is 13.1-3?", "\n\nWhat is 4*3?"]` and run the environment with `queries, responses, masks, rewards, histories = env.run(tasks)`. The environment will look for the `<call>` token in the prompt and append the tool output to the response; it will also return the mask associated with the response. You can further use the `histories` to visualize the interaction between the model and the tool; `histories[0].show_text()` will show the text with color-coded tool output and `histories[0].show_tokens(tokenizer)` will show visualize the tokens.
|
||||

|
||||
1. Finally, we can train the model with `train_stats = ppo_trainer.step(queries, responses, rewards, masks)`. The trainer will use the mask to ignore the tool output when computing the loss, make sure to pass that argument to `step`.
|
||||
|
||||
## Experiment results
|
||||
|
||||
We trained a model with the above script for 10 random seeds. You can reproduce the run with the following command. Feel free to remove the `--slurm-*` arguments if you don't have access to a slurm cluster.
|
||||
|
||||
```
|
||||
WANDB_TAGS="calculator_final" python benchmark/benchmark.py \
|
||||
--command "python examples/research_projects/tools/calculator.py" \
|
||||
--num-seeds 10 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 8 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
```
|
||||
|
||||
We can then use [`openrlbenchmark`](https://github.com/openrlbenchmark/openrlbenchmark) which generates the following plot.
|
||||
```
|
||||
# pip install openrlbenchmark==0.2.1a5
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=openrlbenchmark&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.tracker_project_name&cen=trl_ppo_trainer_config.value.log_with&metrics=env/reward_mean&metrics=objective/kl' \
|
||||
'wandb?tag=calculator_final&cl=calculator_mask' \
|
||||
--env-ids trl \
|
||||
--check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename static/0compare \
|
||||
--scan-history
|
||||
```
|
||||
|
||||

|
||||
|
||||
As we can see, while 1-2 experiments crashed for some reason, most of the runs obtained near perfect proficiency in the calculator task.
|
||||
|
||||
|
||||
## (Early Experiments 🧪): learning to use a wiki tool for question answering
|
||||
|
||||
In the [ToolFormer](https://huggingface.co/papers/2302.04761) paper, it shows an interesting use case that utilizes a Wikipedia Search tool to help answer questions. In this section, we attempt to perform similar experiments but uses RL instead to teach the model to use a wiki tool on the [TriviaQA](https://nlp.cs.washington.edu/triviaqa/) dataset.
|
||||
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
**Note that many settings are different so the results are not directly comparable.**
|
||||
</Tip>
|
||||
|
||||
|
||||
|
||||
|
||||
### Building a search index
|
||||
|
||||
Since [ToolFormer](https://huggingface.co/papers/2302.04761) did not open source, we needed to first replicate the search index. It is mentioned in their paper that the authors built the search index using a BM25 retriever that indexes the Wikipedia dump from [KILT](https://github.com/facebookresearch/KILT)
|
||||
|
||||
Fortunately, [`pyserini`](https://github.com/castorini/pyserini) already implements the BM25 retriever and provides a prebuilt index for the KILT Wikipedia dump. We can use the following code to search the index.
|
||||
|
||||
```python
|
||||
from pyserini.search.lucene import LuceneSearcher
|
||||
import json
|
||||
searcher = LuceneSearcher.from_prebuilt_index('wikipedia-kilt-doc')
|
||||
def search(query):
|
||||
hits = searcher.search(query, k=1)
|
||||
hit = hits[0]
|
||||
contents = json.loads(hit.raw)['contents']
|
||||
return contents
|
||||
print(search("tennis racket"))
|
||||
```
|
||||
```
|
||||
Racket (sports equipment)
|
||||
A racket or racquet is a sports implement consisting of a handled frame with an open hoop across which a network of strings or catgut is stretched tightly. It is used for striking a ball or shuttlecock in games such as squash, tennis, racquetball, and badminton. Collectively, these games are known as racket sports. Racket design and manufacturing has changed considerably over the centuries.
|
||||
|
||||
The frame of rackets for all sports was traditionally made of solid wood (later laminated wood) and the strings of animal intestine known as catgut. The traditional racket size was limited by the strength and weight of the wooden frame which had to be strong enough to hold the strings and stiff enough to hit the ball or shuttle. Manufacturers started adding non-wood laminates to wood rackets to improve stiffness. Non-wood rackets were made first of steel, then of aluminum, and then carbon fiber composites. Wood is still used for real tennis, rackets, and xare. Most rackets are now made of composite materials including carbon fiber or fiberglass, metals such as titanium alloys, or ceramics.
|
||||
...
|
||||
```
|
||||
|
||||
We then basically deployed this snippet as a Hugging Face space [here](https://huggingface.co/spaces/vwxyzjn/pyserini-wikipedia-kilt-doc), so that we can use the space as a `transformers.Tool` later.
|
||||
|
||||

|
||||
|
||||
### Experiment settings
|
||||
|
||||
We use the following settings:
|
||||
|
||||
* use the `bigcode/starcoderbase` model as the base model
|
||||
* use the `pyserini-wikipedia-kilt-doc` space as the wiki tool and only uses the first paragraphs of the search result, allowing the `TextEnvironment` to obtain at most `max_tool_reponse=400` response tokens from the tool.
|
||||
* test if the response contain the answer string, if so, give a reward of 1, otherwise, give a reward of 0.
|
||||
* notice this is a simplified evaluation criteria. In [ToolFormer](https://huggingface.co/papers/2302.04761), the authors checks if the first 20 words of the response contain the correct answer.
|
||||
* used the following prompt that demonstrates the usage of the wiki tool.
|
||||
```python
|
||||
prompt = """\
|
||||
Answer the following question:
|
||||
|
||||
Q: In which branch of the arts is Patricia Neary famous?
|
||||
A: Ballets
|
||||
A2: <request><Wiki>Patricia Neary<call>Patricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe.<response>
|
||||
Result=Ballets<submit>
|
||||
|
||||
Q: Who won Super Bowl XX?
|
||||
A: Chicago Bears
|
||||
A2: <request><Wiki>Super Bowl XX<call>Super Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 46–10, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans.<response>
|
||||
Result=Chicago Bears<submit>
|
||||
|
||||
Q: """
|
||||
```
|
||||
|
||||
|
||||
### Result and Discussion
|
||||
|
||||
|
||||
Our experiments show that the agent can learn to use the wiki tool to answer questions. The learning curves would go up mostly, but one of the experiment did crash.
|
||||
|
||||

|
||||
|
||||
Wandb report is [here](https://wandb.ai/costa-huang/cleanRL/reports/TriviaQA-Final-Experiments--Vmlldzo1MjY0ODk5) for further inspection.
|
||||
|
||||
|
||||
Note that the correct rate of the trained model is on the low end, which could be due to the following reasons:
|
||||
|
||||
* **incorrect searches:** When given the question `"What is Bruce Willis' real first name?"` if the model searches for `Bruce Willis`, our wiki tool returns "Patrick Poivey (born 18 February 1948) is a French actor. He is especially known for his voice: he is the French dub voice of Bruce Willis since 1988.` But a correct search should be `Walter Bruce Willis (born March 19, 1955) is an American former actor. He achieved fame with a leading role on the comedy-drama series Moonlighting (1985–1989) and appeared in over a hundred films, gaining recognition as an action hero after his portrayal of John McClane in the Die Hard franchise (1988–2013) and other roles.[1][2]"
|
||||
|
||||
|
||||

|
||||
|
||||
* **unnecessarily long response**: The wiki tool by default sometimes output very long sequences. E.g., when the wiki tool searches for "Brown Act"
|
||||
* Our wiki tool returns "The Ralph M. Brown Act, located at California Government Code 54950 "et seq.", is an act of the California State Legislature, authored by Assemblymember Ralph M. Brown and passed in 1953, that guarantees the public's right to attend and participate in meetings of local legislative bodies."
|
||||
* [ToolFormer](https://huggingface.co/papers/2302.04761)'s wiki tool returns "The Ralph M. Brown Act is an act of the California State Legislature that guarantees the public's right to attend and participate in meetings of local legislative bodies." which is more succinct.
|
||||
|
||||

|
||||
|
||||
|
||||
## (Early Experiments 🧪): solving math puzzles with python interpreter
|
||||
|
||||
In this section, we attempt to teach the model to use a python interpreter to solve math puzzles. The rough idea is to give the agent a prompt like the following:
|
||||
|
||||
```python
|
||||
prompt = """\
|
||||
Example of using a Python API to solve math questions.
|
||||
|
||||
Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
|
||||
|
||||
<request><PythonInterpreter>
|
||||
def solution():
|
||||
money_initial = 23
|
||||
bagels = 5
|
||||
bagel_cost = 3
|
||||
money_spent = bagels * bagel_cost
|
||||
money_left = money_initial - money_spent
|
||||
result = money_left
|
||||
return result
|
||||
print(solution())
|
||||
<call>8<response>
|
||||
|
||||
Result = 8 <submit>
|
||||
|
||||
Q: """
|
||||
```
|
||||
|
||||
|
||||
Training experiment can be found at https://wandb.ai/lvwerra/trl-gsm8k/runs/a5odv01y
|
||||
|
||||

|
@ -1,7 +0,0 @@
|
||||
# Liger Kernel Integration
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Section under construction. Feel free to contribute!
|
||||
|
||||
</Tip>
|
@ -1,74 +0,0 @@
|
||||
# Logging
|
||||
|
||||
As reinforcement learning algorithms are historically challenging to debug, it's important to pay careful attention to logging.
|
||||
By default, the TRL [`PPOTrainer`] saves a lot of relevant information to wandb or tensorboard.
|
||||
|
||||
Upon initialization, pass one of these two options to the [`PPOConfig`]:
|
||||
|
||||
```
|
||||
training_args = PPOConfig(..., report_to="wandb") # or "tensorboard"
|
||||
```
|
||||
|
||||
If you want to log with tensorboard, add the kwarg `project_kwargs={"logging_dir": PATH_TO_LOGS}` to the PPOConfig.
|
||||
|
||||
## PPO Logging
|
||||
|
||||
Here's a brief explanation for the logged metrics provided in the data:
|
||||
|
||||
Key metrics to monitor. We want to maximize the reward, maintain a low KL divergence, and maximize entropy:
|
||||
1. `env/reward_mean`: The average reward obtained from the environment. Alias `ppo/mean_scores`, which is used to specifically monitor the reward model.
|
||||
1. `env/reward_std`: The standard deviation of the reward obtained from the environment. Alias ``ppo/std_scores`, which is used to specifically monitor the reward model.
|
||||
1. `env/reward_dist`: The histogram distribution of the reward obtained from the environment.
|
||||
1. `objective/kl`: The mean Kullback-Leibler (KL) divergence between the old and new policies. It measures how much the new policy deviates from the old policy. The KL divergence is used to compute the KL penalty in the objective function.
|
||||
1. `objective/kl_dist`: The histogram distribution of the `objective/kl`.
|
||||
1. `objective/kl_coef`: The coefficient for Kullback-Leibler (KL) divergence in the objective function.
|
||||
1. `ppo/mean_non_score_reward`: The **KL penalty** calculated by `objective/kl * objective/kl_coef` as the total reward for optimization to prevent the new policy from deviating too far from the old policy.
|
||||
1. `objective/entropy`: The entropy of the model's policy, calculated by `-logprobs.sum(-1).mean()`. High entropy means the model's actions are more random, which can be beneficial for exploration.
|
||||
|
||||
Training stats:
|
||||
1. `ppo/learning_rate`: The learning rate for the PPO algorithm.
|
||||
1. `ppo/policy/entropy`: The entropy of the model's policy, calculated by `pd = torch.nn.functional.softmax(logits, dim=-1); entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)`. It measures the randomness of the policy.
|
||||
1. `ppo/policy/clipfrac`: The fraction of probability ratios (old policy / new policy) that fell outside the clipping range in the PPO objective. This can be used to monitor the optimization process.
|
||||
1. `ppo/policy/approxkl`: The approximate KL divergence between the old and new policies, measured by `0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask)`, corresponding to the `k2` estimator in http://joschu.net/blog/kl-approx.html
|
||||
1. `ppo/policy/policykl`: Similar to `ppo/policy/approxkl`, but measured by `masked_mean(old_logprobs - logprobs, mask)`, corresponding to the `k1` estimator in http://joschu.net/blog/kl-approx.html
|
||||
1. `ppo/policy/ratio`: The histogram distribution of the ratio between the new and old policies, used to compute the PPO objective.
|
||||
1. `ppo/policy/advantages_mean`: The average of the GAE (Generalized Advantage Estimation) advantage estimates. The advantage function measures how much better an action is compared to the average action at a state.
|
||||
1. `ppo/policy/advantages`: The histogram distribution of `ppo/policy/advantages_mean`.
|
||||
1. `ppo/returns/mean`: The mean of the TD(λ) returns, calculated by `returns = advantage + values`, another indicator of model performance. See https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/ for more details.
|
||||
1. `ppo/returns/var`: The variance of the TD(λ) returns, calculated by `returns = advantage + values`, another indicator of model performance.
|
||||
1. `ppo/val/mean`: The mean of the values, used to monitor the value function's performance.
|
||||
1. `ppo/val/var` : The variance of the values, used to monitor the value function's performance.
|
||||
1. `ppo/val/var_explained`: The explained variance for the value function, used to monitor the value function's performance.
|
||||
1. `ppo/val/clipfrac`: The fraction of the value function's predicted values that are clipped.
|
||||
1. `ppo/val/vpred`: The predicted values from the value function.
|
||||
1. `ppo/val/error`: The mean squared error between the `ppo/val/vpred` and returns, used to monitor the value function's performance.
|
||||
1. `ppo/loss/policy`: The policy loss for the Proximal Policy Optimization (PPO) algorithm.
|
||||
1. `ppo/loss/value`: The loss for the value function in the PPO algorithm. This value quantifies how well the function estimates the expected future rewards.
|
||||
1. `ppo/loss/total`: The total loss for the PPO algorithm. It is the sum of the policy loss and the value function loss.
|
||||
|
||||
|
||||
Stats on queries, responses, and logprobs:
|
||||
1. `tokens/queries_len_mean`: The average length of the queries tokens.
|
||||
1. `tokens/queries_len_std`: The standard deviation of the length of the queries tokens.
|
||||
1. `tokens/queries_dist`: The histogram distribution of the length of the queries tokens.
|
||||
1. `tokens/responses_len_mean`: The average length of the responses tokens.
|
||||
1. `tokens/responses_len_std`: The standard deviation of the length of the responses tokens.
|
||||
1. `tokens/responses_dist`: The histogram distribution of the length of the responses tokens. (Costa: inconsistent naming, should be `tokens/responses_len_dist`)
|
||||
1. `objective/logprobs`: The histogram distribution of the log probabilities of the actions taken by the model.
|
||||
1. `objective/ref_logprobs`: The histogram distribution of the log probabilities of the actions taken by the reference model.
|
||||
|
||||
|
||||
|
||||
### Crucial values
|
||||
During training, many values are logged, here are the most important ones:
|
||||
|
||||
1. `env/reward_mean`,`env/reward_std`, `env/reward_dist`: the properties of the reward distribution from the "environment" / reward model
|
||||
1. `ppo/mean_non_score_reward`: The mean negated KL penalty during training (shows the delta between the reference model and the new policy over the batch in the step)
|
||||
|
||||
Here are some parameters that are useful to monitor for stability (when these diverge or collapse to 0, try tuning variables):
|
||||
|
||||
1. `ppo/loss/value`: it will spike / NaN when not going well.
|
||||
1. `ppo/policy/ratio`: `ratio` being 1 is a baseline value, meaning that the probability of sampling a token is the same under the new and old policy. If the ratio is too high like 200, it means the probability of sampling a token is 200 times higher under the new policy than the old policy. This is a sign that the new policy is too different from the old policy, which will likely cause overoptimization and collapse training later on.
|
||||
1. `ppo/policy/clipfrac` and `ppo/policy/approxkl`: if `ratio` is too high, the `ratio` is going to get clipped, resulting in high `clipfrac` and high `approxkl` as well.
|
||||
1. `objective/kl`: it should stay positive so that the policy is not too far away from the reference policy.
|
||||
1. `objective/kl_coef`: The target coefficient with [`AdaptiveKLController`]. Often increases before numerical instabilities.
|
@ -1,28 +0,0 @@
|
||||
# Models
|
||||
|
||||
With the `AutoModelForCausalLMWithValueHead` class TRL supports all decoder model architectures in transformers such as GPT-2, OPT, and GPT-Neo. In addition, with `AutoModelForSeq2SeqLMWithValueHead` you can use encoder-decoder architectures such as T5. TRL also requires reference models which are frozen copies of the model that is trained. With `create_reference_model` you can easily create a frozen copy and also share layers between the two models to save memory.
|
||||
|
||||
## PreTrainedModelWrapper
|
||||
|
||||
[[autodoc]] PreTrainedModelWrapper
|
||||
|
||||
## AutoModelForCausalLMWithValueHead
|
||||
|
||||
|
||||
[[autodoc]] AutoModelForCausalLMWithValueHead
|
||||
- __init__
|
||||
- forward
|
||||
- generate
|
||||
- _init_weights
|
||||
|
||||
## AutoModelForSeq2SeqLMWithValueHead
|
||||
|
||||
[[autodoc]] AutoModelForSeq2SeqLMWithValueHead
|
||||
- __init__
|
||||
- forward
|
||||
- generate
|
||||
- _init_weights
|
||||
|
||||
## create_reference_model
|
||||
|
||||
[[autodoc]] create_reference_model
|
@ -1,100 +0,0 @@
|
||||
# Multi Adapter RL (MARL) - a single base model for everything
|
||||
|
||||
Here we present an approach that uses a single base model for the entire PPO algorithm - which includes retrieving the reference logits, computing the active logits and the rewards. This feature is experimental as we did not test the convergence of the approach. We encourage the community to let us know if they potentially face issues.
|
||||
|
||||
## Requirements
|
||||
|
||||
You just need to install `peft` and optionally install `bitsandbytes` as well if you want to go for 8bit base models, for more memory efficient finetuning.
|
||||
|
||||
## Summary
|
||||
|
||||
You need to address this approach in three stages that we summarize as follows:
|
||||
|
||||
1- Train a base model on the target domain (e.g. [IMDB dataset](https://huggingface.co/datasets/stanfordnlp/imdb)) - this is the Supervised Fine Tuning stage - it can leverage the `SFTTrainer` from TRL.
|
||||
2- Train a reward model using `peft`. This is required in order to re-use the adapter during the RL optimisation process (step 3 below). We show an example of leveraging the `RewardTrainer` from TRL in [this example](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_modeling.py)
|
||||
3- Fine tune new adapters on the base model using PPO and the reward adapter. ("0 abstraction RL")
|
||||
|
||||
Make sure to use the same model (i.e. same architecture and same weights) for the stages 2 & 3.
|
||||
|
||||
## Quickstart
|
||||
|
||||
Let us assume you have trained your reward adapter on `llama-7b` model using `RewardTrainer` and pushed the weights on the hub under `trl-lib/llama-7b-hh-rm-adapter`.
|
||||
When doing PPO, before passing the model to `PPOTrainer` create your model as follows:
|
||||
|
||||
```python
|
||||
model_name = "huggyllama/llama-7b"
|
||||
rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter"
|
||||
|
||||
# PPO adapter
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
model_name,
|
||||
peft_config=lora_config,
|
||||
reward_adapter=rm_adapter_id,
|
||||
)
|
||||
|
||||
...
|
||||
trainer = PPOTrainer(
|
||||
model=model,
|
||||
...
|
||||
)
|
||||
|
||||
...
|
||||
```
|
||||
Then inside your PPO training loop, call the `compute_reward_score` method by accessing the `model` attribute from `PPOTrainer`.
|
||||
|
||||
```python
|
||||
rewards = trainer.model.compute_reward_score(**inputs)
|
||||
```
|
||||
|
||||
## Advanced usage
|
||||
|
||||
### Control on the adapter name
|
||||
|
||||
If you are familiar with the `peft` library, you know that you can use multiple adapters inside the same model. What you can do is train multiple adapters on the same base model to fine-tune on different policies.
|
||||
In this case, you want to be able to control the adapter name you want to activate back, after retrieving the reward. For that, simply pass the appropriate `adapter_name` to `ppo_adapter_name` argument when calling `compute_reward_score`.
|
||||
|
||||
```python
|
||||
adapter_name_policy_1 = "policy_1"
|
||||
rewards = trainer.model.compute_reward_score(**inputs, ppo_adapter_name=adapter_name_policy_1)
|
||||
...
|
||||
```
|
||||
|
||||
### Using 4-bit and 8-bit base models
|
||||
|
||||
For more memory efficient fine-tuning, you can load your base model in 8-bit or 4-bit while keeping the adapters in the default precision (float32).
|
||||
Just pass the appropriate arguments (i.e. `load_in_8bit=True` or `load_in_4bit=True`) to `AutoModelForCausalLMWithValueHead.from_pretrained` as follows (assuming you have installed `bitsandbytes`):
|
||||
```python
|
||||
model_name = "llama-7b"
|
||||
rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter"
|
||||
|
||||
# PPO adapter
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
model_name,
|
||||
peft_config=lora_config,
|
||||
reward_adapter=rm_adapter_id,
|
||||
load_in_8bit=True,
|
||||
)
|
||||
|
||||
...
|
||||
trainer = PPOTrainer(
|
||||
model=model,
|
||||
...
|
||||
)
|
||||
...
|
||||
```
|
@ -1,159 +0,0 @@
|
||||
# Nash-MD Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=nash-md,trl)
|
||||
|
||||
## Overview
|
||||
|
||||
Nash-MD was proposed in the paper [Nash Learning from Human Feedback](https://huggingface.co/papers/2312.00886) by Rémi Munos, [Michal Valko](https://huggingface.co/misovalko), Daniele Calandriello, Mohammad Gheshlaghi Azar, Mark Rowland, Daniel Guo, Yunhao Tang, Matthieu Geist, Thomas Mésnard, and Andrea Michi.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
> Reinforcement learning from human feedback (RLHF) has emerged as the main paradigm for aligning large language models (LLMs) with human preferences. Typically, RLHF involves the initial step of learning a reward model from human feedback, often expressed as preferences between pairs of text generations produced by a pre-trained LLM. Subsequently, the LLM's policy is fine-tuned by optimizing it to maximize the reward model through a reinforcement learning algorithm. However, an inherent limitation of current reward models is their inability to fully represent the richness of human preferences and their dependency on the sampling distribution. In this study, we introduce an alternative pipeline for the fine-tuning of LLMs using pairwise human feedback. Our approach entails the initial learning of a preference model, which is conditioned on two inputs given a prompt, followed by the pursuit of a policy that consistently generates responses preferred over those generated by any competing policy, thus defining the Nash equilibrium of this preference model. We term this approach Nash learning from human feedback (NLHF). In the context of a tabular policy representation, we present a novel algorithmic solution, Nash-MD, founded on the principles of mirror descent. This algorithm produces a sequence of policies, with the last iteration converging to the regularized Nash equilibrium. Additionally, we explore parametric representations of policies and introduce gradient descent algorithms for deep-learning architectures. To demonstrate the effectiveness of our approach, we present experimental results involving the fine-tuning of a LLM for a text summarization task. We believe NLHF offers a compelling avenue for preference learning and policy optimization with the potential of advancing the field of aligning LLMs with human preferences.
|
||||
|
||||
This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif) and [Daniil Tiapkin](https://huggingface.co/dtiapkin), [Pierre Ménard](https://huggingface.co/menardprr), Daniele Calandriello and [Quentin Gallouédec](https://huggingface.co/qgallouedec).
|
||||
|
||||
## Quick start
|
||||
|
||||
This example demonstrates how to train a model using the Nash-MD method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model and [`PairRMJudge`] as a judge. We use the prompts from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the prompts in the dataset here:
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/trl-lib/ultrafeedback-prompt/embed/viewer/default/train?row=0"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
Below is the script to train the model:
|
||||
|
||||
```python
|
||||
# train_nash_md.py
|
||||
from datasets import load_dataset
|
||||
from trl import NashMDConfig, NashMDTrainer, PairRMJudge
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
judge = PairRMJudge()
|
||||
train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")
|
||||
|
||||
training_args = NashMDConfig(output_dir="Qwen2-0.5B-NashMD", logging_steps=10)
|
||||
trainer = NashMDTrainer(
|
||||
model=model, judge=judge, args=training_args, processing_class=tokenizer, train_dataset=train_dataset
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Execute the script using the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch train_nash_md.py
|
||||
```
|
||||
|
||||
Distributed across 8 GPUs, the training takes approximately 3 hours.
|
||||
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-NashMD) performs, you can use the [TRL Chat CLI](clis#chat-interface).
|
||||
|
||||
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-NashMD
|
||||
<strong><span style="color: red;"><quentin_gallouedec>:</span></strong>
|
||||
What is the best programming language?
|
||||
|
||||
<strong><span style="color: blue;"><trl-lib/Qwen2-0.5B-NashMD>:</span></strong>
|
||||
The best programming language depends on personal preference, the complexity of the project, and the specific requirements of the task. Some programming languages that are often recommended include Python, Java, and JavaScript, and there are many other languages to choose from depending on individual needs.
|
||||
</code></pre>
|
||||
|
||||
## Expected dataset type
|
||||
|
||||
Nash-MD requires a [prompt-only dataset](dataset_formats#prompt-only). The [`NashMDTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
|
||||
|
||||
## Usage tips
|
||||
|
||||
### Use a reward model
|
||||
|
||||
Instead of a judge, you can chose to use a reward model -- see [Reward Bench](https://huggingface.co/spaces/allenai/reward-bench) for a leaderboard of public models you can use. Below is a code example showing how to replace a judge with the [trl-lib/Qwen2-0.5B-Reward](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward) model:
|
||||
|
||||
```diff
|
||||
- from trl import PairRMJudge
|
||||
+ from transformers import AutoModelForSequenceClassification
|
||||
|
||||
- judge = PairRMJudge()
|
||||
+ reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1)
|
||||
|
||||
trainer = NashMDTrainer(
|
||||
...
|
||||
- judge=judge,
|
||||
+ reward_model=reward_model,
|
||||
)
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Make sure that the SFT model and reward model use the _same_ chat template and the same tokenizer. Otherwise, you may find the model completions are scored incorrectly during training.
|
||||
|
||||
</Tip>
|
||||
|
||||
### Encourage EOS token generation
|
||||
|
||||
We may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`NashMDConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`NashMDConfig`]:
|
||||
|
||||
```python
|
||||
training_args = NashMDConfig(..., max_new_tokens=128, missing_eos_penalty=1.0)
|
||||
```
|
||||
|
||||
### Logging Completions
|
||||
|
||||
To better understand your model’s behavior during training, you can log sample completions periodically using the [`LogCompletionsCallback`].
|
||||
|
||||
```python
|
||||
trainer = NashMDTrainer(..., eval_dataset=eval_dataset)
|
||||
completions_callback = LogCompletionsCallback(trainer, num_prompts=8)
|
||||
trainer.add_callback(completions_callback)
|
||||
```
|
||||
|
||||
This callback logs the model's generated completions directly to Weights & Biases.
|
||||
|
||||

|
||||
|
||||
## Example script
|
||||
|
||||
We provide an example script to train a model using the Nash-MD method. The script is available in [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py)
|
||||
|
||||
To test the online DPO script with the [Qwen2.5 0.5B model](https://huggingface.co/trl-lib/Qwen/Qwen2.5-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback), run the following command:
|
||||
|
||||
```bash
|
||||
python examples/scripts/nash_md.py \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--judge pair_rm \
|
||||
--dataset_name trl-lib/ultrafeedback-prompt \
|
||||
--learning_rate 5.0e-7 \
|
||||
--logging_steps 25 \
|
||||
--output_dir Qwen2.5-0.5B-NashMD-PairRM \
|
||||
--warmup_ratio 0.1 \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
## Logged metrics
|
||||
|
||||
The logged metrics are as follows:
|
||||
|
||||
* `loss/kl`: The mean KL divergence between the model and reference data.
|
||||
* `objective/entropy`: The mean entropy of the model and reference data.
|
||||
* `loss/score`: The mean reinforce score loss.
|
||||
* `rewards/chosen`: The mean scores (according to the reward model) of the model completions.
|
||||
* `rewards/rejected`: The mean scores (according to the reward model) of the mixture completions.
|
||||
* `rewards/probabilities`: The mean probability (according to the reward model or judge) of the model completions chosen vs the mixture completion.
|
||||
* `rewards/accuracies`: The accuracies of the Nash-MD's implicit reward model.
|
||||
* `rewards/margins`: The mean reward margin (according to reward model) between the chosen and mixture completions.
|
||||
* `logps/chosen`: The mean log probabilities of the chosen completions.
|
||||
* `logps/rejected`: The mean log probabilities of the reference completions.
|
||||
* `val/model_contain_eos_token`: The amount of times the model's output contains the eos token.
|
||||
* `val/ref_contain_eos_token`: The amount of times the mixture's output contains the eos token.
|
||||
* `beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`NashMDConfig`].
|
||||
* `mixture_coef`: Logit mixture coefficient for the model and reference model. Typically fixed, but can be made dynamic by passing a list to [`NashMDConfig`].
|
||||
|
||||
## NashMDTrainer
|
||||
|
||||
[[autodoc]] NashMDTrainer
|
||||
|
||||
## NashMDConfig
|
||||
|
||||
[[autodoc]] NashMDConfig
|
@ -1,278 +0,0 @@
|
||||
# Online DPO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=online-dpo,trl)
|
||||
|
||||
## Overview
|
||||
|
||||
Online DPO was proposed in [Direct Language Model Alignment from Online AI Feedback](https://huggingface.co/papers/2402.04792) by Shangmin Guo, Biao Zhang, Tianlin Liu, Tianqi Liu, Misha Khalman, Felipe Llinares, Alexandre Rame, Thomas Mesnard, Yao Zhao, Bilal Piot, Johan Ferret, and Mathieu Blondel.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
> Direct alignment from preferences (DAP) methods, such as DPO, have recently emerged as efficient alternatives to reinforcement learning from human feedback (RLHF), that do not require a separate reward model. However, the preference datasets used in DAP methods are usually collected ahead of training and never updated, thus the feedback is purely offline. Moreover, responses in these datasets are often sampled from a language model distinct from the one being aligned, and since the model evolves over training, the alignment phase is inevitably off-policy. In this study, we posit that online feedback is key and improves DAP methods. Our method, online AI feedback (OAIF), uses an LLM as annotator: on each training iteration, we sample two responses from the current model and prompt the LLM annotator to choose which one is preferred, thus providing online feedback. Despite its simplicity, we demonstrate via human evaluation in several tasks that OAIF outperforms both offline DAP and RLHF methods. We further show that the feedback leveraged in OAIF is easily controllable, via instruction prompts to the LLM annotator.
|
||||
|
||||
This post-training method was contributed by [Michael Noukhovitch](https://huggingface.co/mnoukhov), [Shengyi Costa Huang](https://huggingface.co/vwxyzjn), [Quentin Gallouédec](https://huggingface.co/qgallouedec), and [Edward Beeching](https://huggingface.co/edbeeching).
|
||||
|
||||
## Quick start
|
||||
|
||||
This example demonstrates how to train a model using the online DPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model and [`PairRMJudge`] as a judge. We use the prompts from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the prompts in the dataset here:
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/trl-lib/ultrafeedback-prompt/embed/viewer/default/train?row=0"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
Below is the script to train the model:
|
||||
|
||||
```python
|
||||
# train_online_dpo.py
|
||||
from datasets import load_dataset
|
||||
from trl import OnlineDPOConfig, OnlineDPOTrainer, PairRMJudge
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
judge = PairRMJudge()
|
||||
train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")
|
||||
|
||||
training_args = OnlineDPOConfig(output_dir="Qwen2-0.5B-OnlineDPO", logging_steps=10)
|
||||
trainer = OnlineDPOTrainer(
|
||||
model=model, judge=judge, args=training_args, processing_class=tokenizer, train_dataset=train_dataset
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Execute the script using the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch train_online_dpo.py
|
||||
```
|
||||
|
||||
Distributed across 8 GPUs, the training takes approximately 1 hour. You can verify the training progress by checking the reward graph. An increasing trend in both the reward for rejected and chosen completions indicates that the model is improving and generating better responses over time.
|
||||
|
||||

|
||||
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-OnlineDPO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
|
||||
|
||||
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-OnlineDPO
|
||||
<strong><span style="color: red;"><quentin_gallouedec>:</span></strong>
|
||||
What is the best programming language?
|
||||
|
||||
<strong><span style="color: blue;"><trl-lib/Qwen2-0.5B-OnlineDPO>:</span></strong>
|
||||
The best programming language depends on your specific needs and priorities. Some people prefer imperative programming languages (like Haskell or Lisp), while others prefer functional programming languages (like Scala or Python). It's important to consider your work style, programming environment, and project requirements when choosing a programming language.
|
||||
</code></pre>
|
||||
|
||||
## Expected dataset type
|
||||
|
||||
Online DPO only requires a [prompt-only dataset](dataset_formats#prompt-only) (unlike offline DPO, that expects [preference dataset](dataset_formats#preference)). The [`OnlineDPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
|
||||
|
||||
## Usage tips
|
||||
|
||||
### Use a reward model
|
||||
|
||||
Instead of a judge, you can chose to use a reward model -- see [Reward Bench](https://huggingface.co/spaces/allenai/reward-bench) for a leaderboard of public models you can use. Below is a code example showing how to replace a judge with the [trl-lib/Qwen2-0.5B-Reward](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward) model:
|
||||
|
||||
```diff
|
||||
- from trl import PairRMJudge
|
||||
+ from transformers import AutoModelForSequenceClassification
|
||||
|
||||
- judge = PairRMJudge()
|
||||
+ reward_model = AutoModelForSequenceClassification.from_pretrained("trl-lib/Qwen2-0.5B-Reward", num_labels=1)
|
||||
+ reward_tokenizer = AutoTokenizer.from_pretrained("trl-lib/Qwen2-0.5B-Reward")
|
||||
|
||||
trainer = OnlineDPOTrainer(
|
||||
...
|
||||
- judge=judge,
|
||||
+ reward_model=reward_model,
|
||||
+ reward_processing_class=reward_tokenizer,
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
### Encourage EOS token generation
|
||||
|
||||
When using a reward model, we may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`OnlineDPOConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`OnlineDPOConfig`]:
|
||||
|
||||
```python
|
||||
training_args = OnlineDPOConfig(..., max_new_tokens=128, missing_eos_penalty=1.0)
|
||||
```
|
||||
|
||||
### Logging Completions
|
||||
|
||||
To better understand your model’s behavior during training, you can log sample completions periodically using the [`LogCompletionsCallback`].
|
||||
|
||||
```python
|
||||
trainer = OnlineDPOTrainer(..., eval_dataset=eval_dataset)
|
||||
completions_callback = LogCompletionsCallback(trainer, num_prompts=8)
|
||||
trainer.add_callback(completions_callback)
|
||||
```
|
||||
|
||||
This callback logs the model's generated completions directly to Weights & Biases.
|
||||
|
||||

|
||||
|
||||
|
||||
## Example script
|
||||
|
||||
We provide an example script to train a model using the online DPO method. The script is available in [`examples/scripts/dpo_online.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_online.py)
|
||||
|
||||
To test the online DPO script with the [Qwen2.5 0.5B model](https://huggingface.co/trl-lib/Qwen/Qwen2.5-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback), run the following command:
|
||||
|
||||
```bash
|
||||
python examples/scripts/dpo_online.py \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--judge pair_rm \
|
||||
--dataset_name trl-lib/ultrafeedback-prompt \
|
||||
--learning_rate 5.0e-7 \
|
||||
--logging_steps 25 \
|
||||
--output_dir Qwen2.5-0.5B-Online-DPO-PairRM \
|
||||
--warmup_ratio 0.1 \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
## Logged metrics
|
||||
|
||||
The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/w4apmsi9)
|
||||
|
||||
* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current model and reference model.
|
||||
* `objective/entropy`: The mean entropy of the model, indicating the randomness of the actions chosen by the model.
|
||||
* `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence.
|
||||
* `objective/rlhf_reward`: The mean RLHF reward, which is `scores - non_score_reward`. The `rlhf_reward` is the ultimate objective of online DPO training. If training works as intended, this metric should keep going up.
|
||||
* `objective/scores`: The mean scores returned by the reward model.
|
||||
* `objective/scores_margin`: The mean score margin (according to the external reward model) between the chosen and rejected completions.
|
||||
* `rewards/chosen`: The mean reward (according to online DPO's implicit reward model)of the chosen completions.
|
||||
* `rewards/rejected`: The mean reward (according to online DPO's implicit reward model) of the rejected completions.
|
||||
* `rewards/accuracies`: The accuracies of the online DPO's implicit reward model.
|
||||
* `rewards/margins`: The mean reward margin (according to online DPO's implicit reward model) between the chosen and rejected completions.
|
||||
* `logps/chosen`: The mean log probabilities of the chosen completions.
|
||||
* `logps/rejected`: The mean log probabilities of the rejected completions.
|
||||
* `val/contain_eos_token`: The fraction of completions which contain an EOS token.
|
||||
* `beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`OnlineDPOConfig`].
|
||||
|
||||
## Benchmark experiments
|
||||
|
||||
To validate the online DPO implementation works, we ran experiments with the Pythia 1B, 2.8B, and 6.9B models on a single node of 8 x H100s. Here are the commands we used to run the experiments. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
|
||||
|
||||
|
||||
```
|
||||
# 1B Online DPO experiment
|
||||
accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml \
|
||||
examples/scripts/dpo_online.py \
|
||||
--model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \
|
||||
--reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \
|
||||
--dataset_name trl-lib/tldr \
|
||||
--learning_rate 5.0e-7 \
|
||||
--output_dir pythia-1b-deduped-tldr-online-dpo \
|
||||
--beta 0.1 \
|
||||
--per_device_train_batch_size 8 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--num_train_epochs 3 \
|
||||
--max_new_tokens 53 \
|
||||
--warmup_ratio 0.1 \
|
||||
--missing_eos_penalty 1.0 \
|
||||
--logging_steps 20 \
|
||||
--save_steps 0.1 \
|
||||
--push_to_hub
|
||||
|
||||
# 2.8B Online DPO experiment
|
||||
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
|
||||
examples/scripts/dpo_online.py \
|
||||
--model_name_or_path trl-lib/pythia-2.8b-deduped-tldr-sft \
|
||||
--reward_model_path trl-lib/pythia-2.8b-deduped-tldr-rm \
|
||||
--dataset_name trl-lib/tldr \
|
||||
--learning_rate 5.0e-7 \
|
||||
--output_dir pythia-2.8b-deduped-tldr-online-dpo \
|
||||
--beta 0.1 \
|
||||
--per_device_train_batch_size 8 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--num_train_epochs 3 \
|
||||
--max_new_tokens 53 \
|
||||
--warmup_ratio 0.1 \
|
||||
--missing_eos_penalty 1.0 \
|
||||
--bf16 \
|
||||
--logging_steps 20 \
|
||||
--save_steps 0.1 \
|
||||
--push_to_hub
|
||||
|
||||
# 6.9B Online DPO experiment
|
||||
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
|
||||
examples/scripts/dpo_online.py \
|
||||
--model_name_or_path trl-lib/pythia-6.9b-deduped-tldr-sft \
|
||||
--reward_model_path trl-lib/pythia-6.9b-deduped-tldr-rm \
|
||||
--dataset_name trl-lib/tldr \
|
||||
--learning_rate 5.0e-7 \
|
||||
--output_dir pythia-6.9b-deduped-tldr-online-dpo \
|
||||
--beta 0.1 \
|
||||
--per_device_train_batch_size 4 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
--num_train_epochs 3 \
|
||||
--max_new_tokens 53 \
|
||||
--warmup_ratio 0.1 \
|
||||
--missing_eos_penalty 1.0 \
|
||||
--bf16 \
|
||||
--gradient_checkpointing \
|
||||
--logging_steps 20 \
|
||||
--save_steps 0.1 \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
Checkpoints and experiment tracking are available at:
|
||||
|
||||
- [🤗 Model checkpoints](https://huggingface.co/collections/trl-lib/online-dpo-66acd3fa38a331a9cd457b07)
|
||||
- [🐝 Tracked experiment](https://wandb.ai/huggingface/trl/reports/Online-DPO-experiments-for-TL-DR-summarisation--Vmlldzo5MTczMDU0)
|
||||
|
||||
|
||||
To evaluate, we use [vLLM](https://github.com/vllm-project/vllm) to load the checkpoints and GPT-4o mini as a judge model to evaluate the generated TL;DR against the reference TL;DR.
|
||||
For more information on how to use judges, see [Judges](judges).
|
||||
|
||||
```bash
|
||||
$ python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft --judge_model gpt-4o-mini --num_examples 1000
|
||||
Model win rate: 33.00%
|
||||
python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-6.9b-deduped-tldr-sft --judge_model gpt-4o-mini --num_examples 1000
|
||||
Model win rate: 41.50%
|
||||
python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-1b-deduped-tldr-online-dpo --judge_model gpt-4o-mini --num_examples 1000
|
||||
Model win rate: 62.60%
|
||||
python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-6.9b-deduped-tldr-online-dpo --judge_model gpt-4o-mini --num_examples 1000
|
||||
Model win rate: 74.20%
|
||||
```
|
||||
|
||||
We can then plot the RLHF scaling chart.
|
||||
|
||||
```python
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
results = {
|
||||
"SFT": {1.0e9: 0.21, 2.8e9: 0.27, 6.9e9: 0.316},
|
||||
"online-dpo": {1.0e9: 0.542, 2.8e9: 0.746, 6.9e9: 0.796},
|
||||
"offline-dpo": {1.0e9: 0.422, 2.8e9: 0.517, 6.9e9: 0.701},
|
||||
}
|
||||
|
||||
|
||||
plt.plot(results["SFT"].keys(), results["SFT"].values(), label="SFT", marker="o")
|
||||
plt.plot(results["online-dpo"].keys(), results["online-dpo"].values(), label="Online-dpo with RM judge", marker="o")
|
||||
plt.plot(results["offline-dpo"].keys(), results["offline-dpo"].values(), label="Offline-dpo", marker="o")
|
||||
plt.axhline(y=0.5, color="black", linestyle="-.", label="Human reference summary")
|
||||
plt.xscale("log")
|
||||
plt.xlabel("Model size")
|
||||
plt.ylabel("Win rate against reference summaries\n(according to GPT-4-0613)")
|
||||
plt.title("DPO scaling by model size")
|
||||
plt.legend()
|
||||
plt.xlim(5e8, 1.2e10)
|
||||
plt.xticks([1e9, 3e9, 1e10], ["1B", "3B", "10B"])
|
||||
plt.grid(True, which="both", ls="--", c="0.7")
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
```
|
||||
|
||||

|
||||
|
||||
The online DPO checkpoint gets increasingly more win rate as we scale up the model sizes. This is a good sign that the online DPO implementation is working as intended.
|
||||
|
||||
## OnlineDPOTrainer
|
||||
|
||||
[[autodoc]] OnlineDPOTrainer
|
||||
|
||||
## OnlineDPOConfig
|
||||
|
||||
[[autodoc]] OnlineDPOConfig
|
@ -1,129 +0,0 @@
|
||||
# ORPO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=orpo,trl) [](https://github.com/huggingface/smol-course/tree/main/2_preference_alignment)
|
||||
|
||||
## Overview
|
||||
|
||||
Odds Ratio Preference Optimization (ORPO) was introduced in [ORPO: Monolithic Preference Optimization without Reference Model](https://huggingface.co/papers/2403.07691) by [Jiwoo Hong](https://huggingface.co/JW17), [Noah Lee](https://huggingface.co/nlee-208), and [James Thorne](https://huggingface.co/j6mes).
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
> While recent preference alignment algorithms for language models have demonstrated promising results, supervised fine-tuning (SFT) remains imperative for achieving successful convergence. In this paper, we study the crucial role of SFT within the context of preference alignment, emphasizing that a minor penalty for the disfavored generation style is sufficient for preference-aligned SFT. Building on this foundation, we introduce a straightforward and innovative reference model-free monolithic odds ratio preference optimization algorithm, ORPO, eliminating the necessity for an additional preference alignment phase. We demonstrate, both empirically and theoretically, that the odds ratio is a sensible choice for contrasting favored and disfavored styles during SFT across the diverse sizes from 125M to 7B. Specifically, fine-tuning Phi-2 (2.7B), Llama-2 (7B), and Mistral (7B) with ORPO on the UltraFeedback alone surpasses the performance of state-of-the-art language models with more than 7B and 13B parameters: achieving up to 12.20% on AlpacaEval_{2.0} (Figure 1), 66.19% on IFEval (instruction-level loose, Table 6), and 7.32 in MT-Bench (Figure 12). We release code and model checkpoints for Mistral-ORPO-alpha (7B) and Mistral-ORPO-beta (7B).
|
||||
|
||||
It studies the crucial role of SFT within the context of preference alignment. Using preference data the method posits that a minor penalty for the disfavored generation together with a strong adaption signal to the chosen response via a simple log odds ratio term appended to the NLL loss is sufficient for preference-aligned SFT.
|
||||
|
||||
Thus ORPO is a reference model-free preference optimization algorithm eliminating the necessity for an additional preference alignment phase thus saving compute and memory.
|
||||
|
||||
The official code can be found in [xfactlab/orpo](https://github.com/xfactlab/orpo).
|
||||
|
||||
This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif), [Lewis Tunstall](https://huggingface.co/lewtun) and [Alvaro Bartolome](https://huggingface.co/alvarobartt).
|
||||
|
||||
## Quick start
|
||||
|
||||
This example demonstrates how to train a model using the ORPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the data in the dataset here:
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized/embed/viewer/default/train?row=0"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
Below is the script to train the model:
|
||||
|
||||
```python
|
||||
# train_orpo.py
|
||||
from datasets import load_dataset
|
||||
from trl import ORPOConfig, ORPOTrainer
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
||||
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
|
||||
training_args = ORPOConfig(output_dir="Qwen2-0.5B-ORPO", logging_steps=10)
|
||||
trainer = ORPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Execute the script using the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch train_orpo.py
|
||||
```
|
||||
|
||||
Distributed across 8 GPUs, the training takes approximately 30 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time.
|
||||
|
||||

|
||||
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-ORPO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
|
||||
|
||||
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-ORPO
|
||||
<strong><span style="color: red;"><quentin_gallouedec>:</span></strong>
|
||||
What is the best programming language?
|
||||
|
||||
<strong><span style="color: blue;"><trl-lib/Qwen2-0.5B-ORPO>:</span></strong>
|
||||
It's challenging to determine the best programming language as no one language is perfect, as the complexity of a task and the type of project are significant factors. Some popular languages include Java, Python, JavaScript, and
|
||||
C++. If you have specific needs or requirements for a specific project, it's important to choose the language that best suits those needs.
|
||||
|
||||
Here are some other factors to consider when choosing a programming language for a project:
|
||||
|
||||
<strong><span style="color: green;">• Language proficiency:</span></strong> A good programming language is more likely to be easy to understand and use, and will allow developers to collaborate on projects more efficiently.
|
||||
<strong><span style="color: green;">• Ease of use:</span></strong> There are tools and libraries available to make programming more accessible, so developers should choose a language that can help them get started easier.
|
||||
<strong><span style="color: green;">• Code readability:</span></strong> A clear and concise codebase should be easy to read and understand, especially when working with large projects.
|
||||
<strong><span style="color: green;">• Tool and framework support:</span></strong> There are numerous libraries available for Python, Java, and JavaScript, along with tools like IDEs and static code analysis tools.
|
||||
<strong><span style="color: green;">• Accessibility:</span></strong> Some languages and tools have features that make them more accessible to developers with disabilities, such as support for screen readers.
|
||||
<strong><span style="color: green;">• Version control:</span></strong> As your projects grow and complexity increases, version control tools can be beneficial for tracking changes.
|
||||
|
||||
</code></pre>
|
||||
|
||||
## Expected dataset type
|
||||
|
||||
ORPO requires a [preference dataset](dataset_formats#preference). The [`ORPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
|
||||
|
||||
Although the [`ORPOTrainer`] supports both explicit and implicit prompts, we recommend using explicit prompts. If provided with an implicit prompt dataset, the trainer will automatically extract the prompt from the `"chosen"` and `"rejected"` columns. For more information, refer to the [preference style](dataset_formats#preference) section.
|
||||
|
||||
## Example script
|
||||
|
||||
We provide an example script to train a model using the ORPO method. The script is available in [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py)
|
||||
|
||||
To test the ORPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch examples/scripts/orpo.py \
|
||||
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
|
||||
--dataset_name trl-lib/ultrafeedback_binarized \
|
||||
--num_train_epochs 1 \
|
||||
--logging_steps 25 \
|
||||
--output_dir Qwen2-0.5B-ORPO
|
||||
```
|
||||
|
||||
## Usage tips
|
||||
|
||||
### For Mixture of Experts Models: Enabling the auxiliary loss
|
||||
|
||||
MOEs are the most efficient if the load is about equally distributed between experts.
|
||||
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
|
||||
|
||||
This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]).
|
||||
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.
|
||||
|
||||
## Logged metrics
|
||||
|
||||
While training and evaluating we record the following reward metrics:
|
||||
|
||||
- `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
|
||||
- `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
|
||||
- `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
|
||||
- `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
|
||||
- `log_odds_chosen`: the mean log odds ratio of the chosen responses over the rejected responses
|
||||
- `log_odds_ratio`: the mean of the `log(sigmoid(log_odds_chosen))`
|
||||
- `nll_loss`: the mean negative log likelihood loss from the SFT part of the loss over chosen responses
|
||||
|
||||
## ORPOTrainer
|
||||
|
||||
[[autodoc]] ORPOTrainer
|
||||
|
||||
## ORPOConfig
|
||||
|
||||
[[autodoc]] ORPOConfig
|
@ -1,144 +0,0 @@
|
||||
# Examples of using peft with trl to finetune 8-bit models with Low Rank Adaption (LoRA)
|
||||
|
||||
The notebooks and scripts in this examples show how to use Low Rank Adaptation (LoRA) to fine-tune models in a memory efficient manner. Most of PEFT methods supported in peft library but note that some PEFT methods such as Prompt tuning are not supported.
|
||||
For more information on LoRA, see the [original paper](https://huggingface.co/papers/2106.09685).
|
||||
|
||||
Here's an overview of the `peft`-enabled notebooks and scripts in the [trl repository](https://github.com/huggingface/trl/tree/main/examples):
|
||||
|
||||
| File | Task | Description | Colab link |
|
||||
|---|---| --- |
|
||||
| [`stack_llama/rl_training.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py) | RLHF | Distributed fine-tuning of the 7b parameter LLaMA models with a learned reward model and `peft`. | |
|
||||
| [`stack_llama/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/reward_modeling.py) | Reward Modeling | Distributed training of the 7b parameter LLaMA reward model with `peft`. | |
|
||||
| [`stack_llama/supervised_finetuning.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/supervised_finetuning.py) | SFT | Distributed instruction/supervised fine-tuning of the 7b parameter LLaMA model with `peft`. | |
|
||||
|
||||
## Installation
|
||||
Note: peft is in active development, so we install directly from their Github page.
|
||||
Peft also relies on the latest version of transformers.
|
||||
|
||||
```bash
|
||||
pip install trl[peft]
|
||||
pip install bitsandbytes loralib
|
||||
pip install git+https://github.com/huggingface/transformers.git@main
|
||||
#optional: wandb
|
||||
pip install wandb
|
||||
```
|
||||
|
||||
Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scripts/notebooks. You can also replace it with your favourite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking).
|
||||
|
||||
## How to use it?
|
||||
|
||||
Simply declare a `PeftConfig` object in your script and pass it through `.from_pretrained` to load the TRL+PEFT model.
|
||||
|
||||
```python
|
||||
from peft import LoraConfig
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
model_id = "edbeeching/gpt-neo-125M-imdb"
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
model_id,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
```
|
||||
And if you want to load your model in 8bit precision:
|
||||
```python
|
||||
pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
config.model_name,
|
||||
load_in_8bit=True,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
```
|
||||
... or in 4bit precision:
|
||||
```python
|
||||
pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
config.model_name,
|
||||
peft_config=lora_config,
|
||||
load_in_4bit=True,
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
## Launch scripts
|
||||
|
||||
The `trl` library is powered by `accelerate`. As such it is best to configure and launch trainings with the following commands:
|
||||
|
||||
```bash
|
||||
accelerate config # will prompt you to define the training configuration
|
||||
accelerate launch examples/scripts/ppo.py --use_peft # launch`es training
|
||||
```
|
||||
|
||||
## Using `trl` + `peft` and Data Parallelism
|
||||
|
||||
You can scale up to as many GPUs as you want, as long as you are able to fit the training process in a single device. The only tweak you need to apply is to load the model as follows:
|
||||
```python
|
||||
from peft import LoraConfig
|
||||
...
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
config.model_name,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
```
|
||||
And if you want to load your model in 8bit precision:
|
||||
```python
|
||||
pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
config.model_name,
|
||||
peft_config=lora_config,
|
||||
load_in_8bit=True,
|
||||
)
|
||||
```
|
||||
... or in 4bit precision:
|
||||
```python
|
||||
pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
config.model_name,
|
||||
peft_config=lora_config,
|
||||
load_in_4bit=True,
|
||||
)
|
||||
```
|
||||
Finally, make sure that the rewards are computed on correct device as well, for that you can use `ppo_trainer.model.current_device`.
|
||||
|
||||
## Naive pipeline parallelism (NPP) for large models (>60B models)
|
||||
|
||||
The `trl` library also supports naive pipeline parallelism (NPP) for large models (>60B models). This is a simple way to parallelize the model across multiple GPUs.
|
||||
This paradigm, termed as "Naive Pipeline Parallelism" (NPP) is a simple way to parallelize the model across multiple GPUs. We load the model and the adapters across multiple GPUs and the activations and gradients will be naively communicated across the GPUs. This supports `int8` models as well as other `dtype` models.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-npp.png">
|
||||
</div>
|
||||
|
||||
### How to use NPP?
|
||||
|
||||
Simply load your model with a custom `device_map` argument on the `from_pretrained` to split your model across multiple devices. Check out this [nice tutorial](https://github.com/huggingface/blog/blob/main/accelerate-large-models.md) on how to properly create a `device_map` for your model.
|
||||
|
||||
Also make sure to have the `lm_head` module on the first GPU device as it may throw an error if it is not on the first device. As this time of writing, you need to install the `main` branch of `accelerate`: `pip install git+https://github.com/huggingface/accelerate.git@main` and `peft`: `pip install git+https://github.com/huggingface/peft.git@main`.
|
||||
|
||||
### Launch scripts
|
||||
|
||||
Although `trl` library is powered by `accelerate`, you should run your training script in a single process. Note that we do not support Data Parallelism together with NPP yet.
|
||||
|
||||
```bash
|
||||
python PATH_TO_SCRIPT
|
||||
```
|
||||
|
||||
## Fine-tuning Llama-2 model
|
||||
|
||||
You can easily fine-tune Llama2 model using `SFTTrainer` and the official script! For example to fine-tune llama2-7b on the Guanaco dataset, run (tested on a single NVIDIA T4-16GB):
|
||||
|
||||
```bash
|
||||
python trl/scripts/sft.py --output_dir sft_openassistant-guanaco --model_name meta-llama/Llama-2-7b-hf --dataset_name timdettmers/openassistant-guanaco --load_in_4bit --use_peft --per_device_train_batch_size 4 --gradient_accumulation_steps 2
|
||||
```
|
@ -1,239 +0,0 @@
|
||||
# PPO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=ppo,trl)
|
||||
|
||||
TRL supports training LLMs with [Proximal Policy Optimization (PPO)](https://huggingface.co/papers/1707.06347).
|
||||
|
||||
References:
|
||||
- [Fine-Tuning Language Models from Human Preferences](https://github.com/openai/lm-human-preferences)
|
||||
- [Learning to Summarize from Human Feedback](https://github.com/openai/summarize-from-feedback)
|
||||
- [The N Implementation Details of RLHF with PPO](https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo)
|
||||
- [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031)
|
||||
|
||||
## Get started
|
||||
|
||||
To just run a PPO script to make sure the trainer can run, you can run the following command to train a PPO model with a dummy reward model.
|
||||
|
||||
```bash
|
||||
python examples/scripts/ppo/ppo.py \
|
||||
--dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
|
||||
--dataset_train_split descriptiveness \
|
||||
--learning_rate 3e-6 \
|
||||
--num_ppo_epochs 1 \
|
||||
--num_mini_batches 1 \
|
||||
--output_dir models/minimal/ppo \
|
||||
--per_device_train_batch_size 64 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--total_episodes 10000 \
|
||||
--model_name_or_path EleutherAI/pythia-1b-deduped \
|
||||
--sft_model_path EleutherAI/pythia-1b-deduped \
|
||||
--reward_model_path EleutherAI/pythia-1b-deduped \
|
||||
--missing_eos_penalty 1.0
|
||||
```
|
||||
|
||||
|
||||
## Explanation of the logged metrics
|
||||
|
||||
The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35)
|
||||
|
||||
* `eps`: Tracks the number of episodes per second.
|
||||
* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy.
|
||||
* `objective/entropy`: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy.
|
||||
* `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence.
|
||||
* `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`.
|
||||
* `objective/scores`: The mean scores returned by the reward model / environment.
|
||||
* `policy/approxkl_avg`: The average approximate KL divergence between consecutive PPO policies. Note that this is not the same as `objective/kl`.
|
||||
* `policy/clipfrac_avg`: The average fraction of policy updates that are clipped, indicating how often the policy updates are constrained to prevent large changes.
|
||||
* `loss/policy_avg`: The average policy loss, indicating how well the policy is performing.
|
||||
* `loss/value_avg`: The average value loss, indicating the difference between the predicted value and the actual reward.
|
||||
* `val/clipfrac_avg`: The average fraction of value function updates that are clipped, similar to policy/clipfrac_avg but for the value function.
|
||||
* `policy/entropy_avg`: The average entropy of the policy during training, indicating how diverse the policy's actions are.
|
||||
* `val/ratio`: The mean ratio of the current policy probability to the old policy probability, providing a measure of how much the policy has changed.
|
||||
* `val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes.
|
||||
* `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses.
|
||||
* `lr`: lr: The current learning rate used by the optimizer.
|
||||
* `episode`: episode: The current global step or episode count in the training process.
|
||||
|
||||
|
||||
## Cookbook
|
||||
|
||||
* Debugging TIP: `objective/rlhf_reward`: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up.
|
||||
* Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try understand why this is happening and try to fix it.
|
||||
* Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint.
|
||||
* Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`.
|
||||
* Usage TIP: We recommend to use the "EOS trick" via `--missing_eos_penalty`, which subtracts a static scalar penalty from the score of completions that do not end with an EOS token. This can help the model learn to generate more coherent completions.
|
||||
|
||||
|
||||
## What is my model doing exactly?
|
||||
|
||||
To help you understand what your model is doing, we periodically log some sample completions from the model. Here is an example of a completion. In an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35), it looks like the following, allowing you to see the model's response at different stages of training. By default we generate `--num_sample_generations 10` during training, but you can customize the number of generations.
|
||||
|
||||

|
||||
|
||||
|
||||
In the logs the sampled generations look like
|
||||
|
||||
```
|
||||
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓
|
||||
┃ query ┃ model response ┃ score ┃
|
||||
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩
|
||||
│ SUBREDDIT: r/AskReddit │ I'm in love with a friend, and │ 3.921875 │
|
||||
│ │ I don't know how to get rid of │ │
|
||||
│ TITLE: How do you get someone │ those feelings. I'm │ │
|
||||
│ out of your head? │ desperate.<|endoftext|>[PAD][P… │ │
|
||||
│ │ │ │
|
||||
│ POST: Hi, │ │ │
|
||||
│ I'm 22, and I have been with my │ │ │
|
||||
│ girlfriend for 5 years now. We │ │ │
|
||||
│ recently moved together. We've │ │ │
|
||||
│ always loved each other │ │ │
|
||||
│ intensely. │ │ │
|
||||
│ │ │ │
|
||||
│ Problem, I recently started to │ │ │
|
||||
│ have feelings for an other │ │ │
|
||||
│ person (a friend). This person │ │ │
|
||||
│ has had a boyfriend for now 3 │ │ │
|
||||
│ years, and has absolutely no │ │ │
|
||||
│ ideas. Those feelings were so │ │ │
|
||||
│ strong, it was hard to hide │ │ │
|
||||
│ them. After 2 months of me │ │ │
|
||||
│ being distant and really sad, │ │ │
|
||||
│ my girlfriend forced me to say │ │ │
|
||||
│ what was bothering me. I'm not │ │ │
|
||||
│ a good liar, and now she knows. │ │ │
|
||||
│ │ │ │
|
||||
│ We decided to give us a week │ │ │
|
||||
│ alone, I went to my parents. │ │ │
|
||||
│ │ │ │
|
||||
│ Now, I'm completely lost. I │ │ │
|
||||
│ keep on thinking about this │ │ │
|
||||
│ person, and I hate that. I │ │ │
|
||||
│ would like for those feelings │ │ │
|
||||
│ to go away, to leave me alone. │ │ │
|
||||
│ But I can't. │ │ │
|
||||
│ │ │ │
|
||||
│ What do I do? It's been 3 │ │ │
|
||||
│ months now, and I'm just │ │ │
|
||||
│ desperate. │ │ │
|
||||
│ │ │ │
|
||||
│ TL;DR: │ │ │
|
||||
├─────────────────────────────────┼─────────────────────────────────┼──────────┤
|
||||
│ SUBREDDIT: r/pettyrevenge │ My mom woke me up with a loud │ 6.84375 │
|
||||
│ │ TV. I blasted Gangnam Style on │ │
|
||||
│ TITLE: So, my mom woke me up │ repeat, with the bass cranked │ │
|
||||
│ with a loud TV. │ up as high as it could │ │
|
||||
│ │ go.<|endoftext|>[PAD][PAD][PAD… │ │
|
||||
│ POST: She was in her living │ │ │
|
||||
│ room, watching TV. This was at │ │ │
|
||||
│ about 8:30 in the morning, and │ │ │
|
||||
│ she was exercising. She turned │ │ │
|
||||
│ the TV up extra loud to hear it │ │ │
|
||||
│ over her excercycle, and woke │ │ │
|
||||
│ me up. I went in there asking │ │ │
|
||||
│ for her to turn it down. She │ │ │
|
||||
│ said she didn't have to; I │ │ │
|
||||
│ explained that I always used │ │ │
|
||||
│ headphones so she didn't have │ │ │
|
||||
│ to deal with my noise and that │ │ │
|
||||
│ she should give me a little │ │ │
|
||||
│ more respect, given that I paid │ │ │
|
||||
│ rent at the time. │ │ │
|
||||
│ │ │ │
|
||||
│ She disagreed. I went back to │ │ │
|
||||
│ my room, rather pissed off at │ │ │
|
||||
│ the lack of equality. I had no │ │ │
|
||||
│ lock on my door; but I had a │ │ │
|
||||
│ dresser right next to it, so I │ │ │
|
||||
│ pulled one of the drawers out │ │ │
|
||||
│ enough so that it caused the │ │ │
|
||||
│ door to not be openable. Then, │ │ │
|
||||
│ I turned my speakers up really │ │ │
|
||||
│ loud and blasted Gangnam Style │ │ │
|
||||
│ on repeat, with the bass │ │ │
|
||||
│ cranked up as high as it could │ │ │
|
||||
│ go. │ │ │
|
||||
│ │ │ │
|
||||
│ If you hate Gangnam Style for │ │ │
|
||||
│ being overplayed, you will see │ │ │
|
||||
│ why I chose that particular │ │ │
|
||||
│ song. I personally don't mind │ │ │
|
||||
│ it. But here's the thing about │ │ │
|
||||
│ my bass; it vibrates the walls, │ │ │
|
||||
│ making one hell of a lot of │ │ │
|
||||
│ noise. Needless to say, my mom │ │ │
|
||||
│ was not pleased and shut off │ │ │
|
||||
│ the internet. But it was oh so │ │ │
|
||||
│ worth it. │ │ │
|
||||
│ │ │ │
|
||||
│ TL;DR: │ │ │
|
||||
└─────────────────────────────────┴─────────────────────────────────┴──────────┘
|
||||
```
|
||||
|
||||
## Implementation details
|
||||
|
||||
This PPO implementation is based on the [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
|
||||
|
||||
## Benchmark experiments
|
||||
|
||||
To validate the PPO implementation works, we ran experiment on the 1B model. Here are the command we used to run the experiment. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
|
||||
|
||||
```
|
||||
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
|
||||
examples/scripts/ppo/ppo_tldr.py \
|
||||
--output_dir models/minimal/ppo_tldr \
|
||||
--learning_rate 3e-6 \
|
||||
--per_device_train_batch_size 16 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
--total_episodes 1000000 \
|
||||
--model_name_or_path EleutherAI/pythia-1b-deduped \
|
||||
--sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \
|
||||
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
|
||||
--local_rollout_forward_batch_size 16 \
|
||||
--missing_eos_penalty 1.0 \
|
||||
--stop_token eos
|
||||
```
|
||||
|
||||
Checkpoints and experiment tracking are available at:
|
||||
|
||||
- [🤗 Model checkpoint](https://huggingface.co/vwxyzjn/ppo_tldr)
|
||||
- [🐝 Tracked experiment](https://wandb.ai/huggingface/trl/runs/dd2o3g35)
|
||||
|
||||
To evaluate, we use [vLLM](https://github.com/vllm-project/vllm) to load the checkpoints and GPT-4o mini as a judge model to evaluate the generated TL;DR against the reference TL;DR.
|
||||
For more information on how to use judges, see [Judges](judges).
|
||||
|
||||
```bash
|
||||
$ python examples/scripts/evals/judge_tldr.py --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr --judge_model gpt-4o-mini --num_examples 1000
|
||||
Model win rate: 33.00%
|
||||
$ python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/ppo_tldr --judge_model gpt-4o-mini --num_examples 1000
|
||||
Model win rate: 64.70%
|
||||
```
|
||||
|
||||
The PPO checkpoint gets a 64.7% preferred rate vs the 33.0% preference rate of the SFT checkpoint. This is a good sign that the PPO training is working as intended.
|
||||
|
||||
Metrics:
|
||||
|
||||

|
||||
|
||||
|
||||
```bash
|
||||
# pip install openrlbenchmark==0.2.1a5
|
||||
# see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation
|
||||
# to use it, change `?we=huggingface&wpn=trl` to your own project and `?tag=pr-1540` to your own tag
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=huggingface&wpn=trl&xaxis=train/episode&ceik=output_dir&cen=sft_model_path&metrics=train/objective/rlhf_reward&metrics=train/objective/scores&metrics=train/objective/kl&metrics=train/objective/non_score_reward&metrics=train/objective/entropy&metrics=train/policy/approxkl_avg&metrics=train/policy/clipfrac_avg&metrics=train/loss/policy_avg&metrics=train/loss/value_avg&metrics=train/val/clipfrac_avg&metrics=train/policy/entropy_avg&metrics=train/val/ratio&metrics=train/val/ratio_var&metrics=train/val/num_eos_tokens&metrics=train/lr&metrics=train/eps' \
|
||||
"cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr?tag=pr-1540" \
|
||||
--env-ids models/minimal/ppo_tldr \
|
||||
--pc.ncols 4 \
|
||||
--pc.ncols-legend 1 \
|
||||
--pc.xlabel "Episode" \
|
||||
--output-filename benchmark/trl/pr-1540/ppo \
|
||||
--scan-history
|
||||
```
|
||||
|
||||
## PPOTrainer
|
||||
|
||||
[[autodoc]] PPOTrainer
|
||||
|
||||
## PPOConfig
|
||||
|
||||
[[autodoc]] PPOConfig
|
@ -1,125 +0,0 @@
|
||||
# PRM Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=prm,trl)
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
PRM Trainer is an experimental API which is subject to change at any time.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Overview
|
||||
|
||||
Process-supervised Reward Models (PRM) were proposed in [Solving math word problems with process- and outcome-based feedback](https://huggingface.co/papers/2211.14275) by Jonathan Uesato, Nate Kushman, Ramana Kumar, Francis Song, Noah Siegel, Lisa Wang, Antonia Creswell, Geoffrey Irving, and Irina Higgins.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
> Recent work has shown that asking language models to generate reasoning steps improves performance on many reasoning tasks. When moving beyond prompting, this raises the question of how we should supervise such models: outcome-based approaches which supervise the final result, or process-based approaches which supervise the reasoning process itself? Differences between these approaches might naturally be expected not just in final-answer errors but also in reasoning errors, which can be difficult to detect and are problematic in many real-world domains such as education. We run the first comprehensive comparison between process- and outcome-based approaches trained on a natural language task, GSM8K. We find that pure outcome-based supervision produces similar final-answer error rates with less label supervision. However, for correct reasoning steps we find it necessary to use processbased supervision or supervision from learned reward models that emulate process-based feedback. In total, we improve the previous best results from 16.8% → 12.7% final-answer error and 14.0% → 3.4% reasoning error among final-answer-correct solutions.
|
||||
|
||||
This post-training method was contributed by [Gaetan Lopez](https://github.com/gaetanlop), [Lewis Tunstall](https://huggingface.co/lewtun), [Quentin Gallouédec](https://huggingface.co/qgallouedec) and [Agustín Piqueres](https://huggingface.co/plaguss).
|
||||
|
||||
|
||||
## Quick start
|
||||
|
||||
This example demonstrates how to train a model using the PRM method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B) as the base model. We use the stepwise supervision data from the [Math Shepherd dataset](https://huggingface.co/datasets/trl-lib/math_shepherd). You can view the data in the dataset here:
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/trl-lib/math_shepherd/embed/viewer/default/train?row=0"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
Below is the script to train the model:
|
||||
|
||||
```python
|
||||
# train_prm.py
|
||||
from datasets import load_dataset
|
||||
from trl import PRMConfig, PRMTrainer
|
||||
from transformers import AutoModelForTokenClassification, AutoTokenizer
|
||||
|
||||
model = AutoModelForTokenClassification.from_pretrained("Qwen/Qwen2-0.5B", num_labels=2)
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
|
||||
train_dataset = load_dataset("trl-lib/math_shepherd", split="train[:10%]")
|
||||
|
||||
training_args = PRMConfig(output_dir="Qwen2-0.5B-Reward-Math-Sheperd", logging_steps=10)
|
||||
trainer = PRMTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Execute the script using the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch train_prm.py
|
||||
```
|
||||
|
||||
Distributed across 8 GPUs, the training takes approximately 1 hour.
|
||||
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward-Math-Sheperd) performs, you can use the following script.
|
||||
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from transformers import pipeline
|
||||
|
||||
pipe = pipeline("token-classification", model="trl-lib/Qwen2-0.5B-Reward-Math-Sheperd")
|
||||
dataset = load_dataset("trl-lib/math_shepherd")
|
||||
example = {
|
||||
"prompt": "Musa is the class teacher of a class of 45 students. He wants to split them into three groups by age. If a third of the class is under 11 years, and two-fifths are above 11 but under 13, how many students will be in the third group (13 years and above)?",
|
||||
"completions": [
|
||||
"Step 1: A third of the class is under 11 years because 11 - 1/3 = <<11-1/3=7>>7.",
|
||||
"Step 2: Two-fifths of the class are above 11 but under 13 because 2/5 * 11 = <<2/5*11=8>>8.",
|
||||
"Step 3: There are 45 students, so the third group will have 45 - 7 - 8 = <<45-7-8=20>>20 students. The answer is: 20",
|
||||
],
|
||||
"labels": [True, False, False],
|
||||
}
|
||||
|
||||
|
||||
separator = "\n" # It's important to use the same separator as the one used during training
|
||||
|
||||
for idx in range(1, len(example["completions"]) + 1):
|
||||
steps = example["completions"][0:idx]
|
||||
text = separator.join((example["prompt"], *steps)) + separator # Add a separator between the prompt and each steps
|
||||
pred_entity = pipe(text)[-1]["entity"]
|
||||
pred = {"LABEL_0": False, "LABEL_1": True}[pred_entity]
|
||||
label = example["labels"][idx - 1]
|
||||
print(f"Step {idx}\tPredicted: {pred} \tLabel: {label}")
|
||||
```
|
||||
|
||||
```text
|
||||
Step 1 Predicted: True Label: True
|
||||
Step 2 Predicted: False Label: False
|
||||
Step 3 Predicted: False Label: False
|
||||
```
|
||||
|
||||
It's a win!
|
||||
|
||||
## Expected dataset type
|
||||
|
||||
PRM requires a [stepwise supervision](dataset_formats#stepwise-supervision).
|
||||
The dataset should contain the following columns: `prompt`, `completions` and `labels`, where `completions` contains a list of reasoning steps and `labels` a list of booleans or floats indicating the correctness of each step.
|
||||
|
||||
The [`PRMTrainer`] only supports [standard](dataset_formats#standard) dataset format.
|
||||
|
||||
## Example script
|
||||
|
||||
We provide an example script to train a model using the PRM method. The script is available in [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py)
|
||||
|
||||
To use the PRM script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B) on the [Math Shepherd dataset](https://huggingface.co/datasets/trl-lib/math_shepherd), run the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch examples/scripts/prm.py \
|
||||
--model_name_or_path Qwen/Qwen2-0.5B \
|
||||
--dataset_name trl-lib/math_shepherd \
|
||||
--num_train_epochs 1 \
|
||||
--logging_steps 25 \
|
||||
--output_dir Qwen2-0.5B-Reward-Math-Sheperd
|
||||
```
|
||||
|
||||
## PRMTrainer
|
||||
|
||||
[[autodoc]] PRMTrainer
|
||||
|
||||
## PRMConfig
|
||||
|
||||
[[autodoc]] PRMConfig
|
@ -1,88 +0,0 @@
|
||||
# Quickstart
|
||||
|
||||
## How does it work?
|
||||
|
||||
Fine-tuning a language model via PPO consists of roughly three steps:
|
||||
|
||||
1. **Rollout**: The language model generates a response or continuation based on a query which could be the start of a sentence.
|
||||
2. **Evaluation**: The query and response are evaluated with a function, model, human feedback, or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair. The optimization will aim at maximizing this value.
|
||||
3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate too far from the reference language model. The active language model is then trained with PPO.
|
||||
|
||||
The full process is illustrated in the following figure:
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_overview.png"/>
|
||||
|
||||
## Minimal example
|
||||
|
||||
The following code illustrates the steps above.
|
||||
|
||||
```python
|
||||
# 0. imports
|
||||
import torch
|
||||
from transformers import GPT2Tokenizer
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
|
||||
|
||||
|
||||
# 1. load a pretrained model
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
|
||||
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# 2. initialize trainer
|
||||
ppo_config = {"mini_batch_size": 1, "batch_size": 1}
|
||||
config = PPOConfig(**ppo_config)
|
||||
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer)
|
||||
|
||||
# 3. encode a query
|
||||
query_txt = "This morning I went to the "
|
||||
query_tensor = tokenizer.encode(query_txt, return_tensors="pt").to(model.pretrained_model.device)
|
||||
|
||||
# 4. generate model response
|
||||
generation_kwargs = {
|
||||
"min_length": -1,
|
||||
"top_k": 0.0,
|
||||
"top_p": 1.0,
|
||||
"do_sample": True,
|
||||
"pad_token_id": tokenizer.eos_token_id,
|
||||
"max_new_tokens": 20,
|
||||
}
|
||||
response_tensor = ppo_trainer.generate([item for item in query_tensor], return_prompt=False, **generation_kwargs)
|
||||
response_txt = tokenizer.decode(response_tensor[0])
|
||||
|
||||
# 5. define a reward for response
|
||||
# (this could be any reward such as human feedback or output from another model)
|
||||
reward = [torch.tensor(1.0, device=model.pretrained_model.device)]
|
||||
|
||||
# 6. train model with ppo
|
||||
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)
|
||||
```
|
||||
|
||||
In general, you would run steps 3-6 in a for-loop and run it on many diverse queries. You can find more realistic examples in the examples section.
|
||||
|
||||
## How to use a trained model
|
||||
|
||||
After training a `AutoModelForCausalLMWithValueHead`, you can directly use the model in `transformers`.
|
||||
```python
|
||||
|
||||
# .. Let's assume we have a trained model using `PPOTrainer` and `AutoModelForCausalLMWithValueHead`
|
||||
|
||||
# push the model on the Hub
|
||||
model.push_to_hub("my-fine-tuned-model-ppo")
|
||||
|
||||
# or save it locally
|
||||
model.save_pretrained("my-fine-tuned-model-ppo")
|
||||
|
||||
# load the model from the Hub
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("my-fine-tuned-model-ppo")
|
||||
```
|
||||
|
||||
You can also load your model with `AutoModelForCausalLMWithValueHead` if you want to use the value head, for example to continue training.
|
||||
|
||||
```python
|
||||
from trl.model import AutoModelForCausalLMWithValueHead
|
||||
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained("my-fine-tuned-model-ppo")
|
||||
```
|
@ -1,133 +0,0 @@
|
||||
# Reducing Memory Usage
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Section under construction. Feel free to contribute!
|
||||
|
||||
</Tip>
|
||||
|
||||
## Truncation
|
||||
|
||||
Sequence lengths in the dataset can vary widely. When data is batched, sequences are padded to match the longest one in the batch, which can cause high memory usage, even if most sequences are relatively short.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/why_you_should_truncate.png" alt="Truncation prompt completion" width="600"/>
|
||||
</div>
|
||||
|
||||
To reduce memory usage, it’s important to truncate sequences to a reasonable length. While TRL trainers truncate sequences by default, you may want to adjust the default truncation length to better align with your specific use case.
|
||||
|
||||
<hfoptions id="dpo">
|
||||
<hfoption id="DPO">
|
||||
|
||||
DPO truncation is applied first to the prompt and to the completion via the `max_prompt_length` and `max_completion_length` parameters. The `max_length` parameter is then used to truncate the resulting sequence.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/truncation_prompt_completion.png" alt="Truncation prompt completion" width="600"/>
|
||||
</div>
|
||||
|
||||
To set the truncation parameters, use the following code snippet:
|
||||
|
||||
```python
|
||||
from trl import DPOConfig
|
||||
|
||||
training_args = DPOConfig(..., max_prompt_length=..., max_length=...)
|
||||
```
|
||||
|
||||
You can also use the `max_completion_length` parameter to truncate the completion, though this is less common since the goal is typically to preserve the completion's full length whenever possible.
|
||||
|
||||
```python
|
||||
from trl import DPOConfig
|
||||
|
||||
training_args = DPOConfig(..., max_completion_length=...)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="SFT">
|
||||
|
||||
SFT truncation is applied to the input sequence via the `max_length` parameter.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/truncation_input_ids.png" alt="Truncation input ids" width="600"/>
|
||||
</div>
|
||||
|
||||
To set the truncation parameter, use the following code snippet:
|
||||
|
||||
```python
|
||||
from trl import SFTConfig
|
||||
|
||||
training_args = SFTConfig(..., max_length=...)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Packing
|
||||
|
||||
<Tip>
|
||||
|
||||
This technique applies only to SFT.
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
[Truncation](#truncation) has several drawbacks:
|
||||
1. **Loss of information**: Key data at the end of a sequence may be discarded.
|
||||
2. **Choosing truncation length**: Too short loses data; too long undermines efficiency.
|
||||
|
||||
Packing, introduced in [Raffel et al., 2020](https://huggingface.co/papers/1910.10683), addresses these issues by grouping sequences instead of truncating. It concatenates and splits dataset sequences into the desired lengths.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/packing.png" alt="Packing" width="600"/>
|
||||
</div>
|
||||
|
||||
Packing eliminates padding, preserves all sequence information, and allows for flexible sequence lengths, making it a more efficient alternative to truncation. To enable packing, use `packing=True` in the [`SFTConfig`]:
|
||||
|
||||
```python
|
||||
from trl import SFTConfig
|
||||
|
||||
training_args = SFTConfig(..., packing=True, max_length=512)
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Packing may cause batch contamination, where adjacent sequences influence one another. This can be problematic for some applications. For more details, see [#1230](https://github.com/huggingface/trl/issues/1230).
|
||||
|
||||
</Tip>
|
||||
|
||||
## Disabling model gathering for generation in online methods
|
||||
|
||||
When using DeepSpeed ZeRO-3, model weights are sharded across multiple GPUs. Online methods involve generating completions from the model as part of the training process. During this step, the model weights are temporarily gathered on a single GPU for generation. For very large models, this gathering can lead to out-of-memory (OOM) errors, as described in this issue: [#2250](https://github.com/huggingface/trl/issues/2250#issue-2598304204).
|
||||
|
||||
If you encounter this issue, you can disable the gathering of model weights for generation by setting the following parameter:
|
||||
|
||||
<hfoptions id="ds3_gather_for_generation">
|
||||
<hfoption id="Online DPO">
|
||||
|
||||
```python
|
||||
from trl import OnlineDPOConfig
|
||||
|
||||
training_args = OnlineDPOConfig(..., ds3_gather_for_generation=False)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="PPO">
|
||||
|
||||
```python
|
||||
from trl import PPOConfig
|
||||
|
||||
training_args = PPOConfig(..., ds3_gather_for_generation=False)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="RLOO">
|
||||
|
||||
```python
|
||||
from trl import RLOOConfig
|
||||
|
||||
training_args = RLOOConfig(..., ds3_gather_for_generation=False)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
This adjustment prevents model weights from being gathered, avoiding OOM errors, but it may result in slower generation speeds.
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user