mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Compare commits
94 Commits
aa25c2697c
...
prm-traine
Author | SHA1 | Date | |
---|---|---|---|
201bdf22fb | |||
46b6bd6a44 | |||
95a4a46251 | |||
ebc8fb1756 | |||
fd204d75f5 | |||
f4ba54f046 | |||
5a8d0a2114 | |||
5b10e3829c | |||
072794acb9 | |||
a93138f405 | |||
4c83f41507 | |||
e8c782da67 | |||
be6e84396f | |||
69adb5c47b | |||
8dce558e63 | |||
3a034d020a | |||
b057cf7878 | |||
59f1e9f780 | |||
e310b0edf0 | |||
147c375d8c | |||
16e4ef8cb8 | |||
84c28fe857 | |||
9ae131a2d2 | |||
1c76266325 | |||
fb1569114c | |||
e445bad025 | |||
5a6970dae7 | |||
364d7d837d | |||
916f87e97a | |||
a7bac4e16a | |||
754ba44ed4 | |||
97ef925fbc | |||
91a3de876c | |||
2c9d2f3b4e | |||
4fd282e126 | |||
24d2f1a7e6 | |||
d53ad3560f | |||
a65e30cd57 | |||
fc702beb21 | |||
dfe7e04b33 | |||
faf1051b93 | |||
6128a7f979 | |||
d205064c58 | |||
468502b6e8 | |||
fe440de277 | |||
086ea8f79b | |||
9b1693de4d | |||
b47eea5bde | |||
66baadaf5e | |||
2030a83b5e | |||
6b2bd97920 | |||
6bb467b6e9 | |||
701241b06d | |||
2059c51f32 | |||
e8e93f145b | |||
6c62c69eb3 | |||
e77eee238c | |||
a03aed86bf | |||
3a488e08e3 | |||
898f621fc8 | |||
c3eb08ea31 | |||
35de0ee0de | |||
e0c0648fc6 | |||
6947aef799 | |||
f4e6d4e708 | |||
436dfd7cf6 | |||
3ac323f7c6 | |||
f02056abe9 | |||
d5f780a84a | |||
b00e32b46a | |||
424af34c97 | |||
c5824640b4 | |||
614fb4ea73 | |||
c60bc40d25 | |||
8e4e159067 | |||
8b3fa5224f | |||
8c4ac31af1 | |||
1461a61557 | |||
3ec4ebe7a5 | |||
93e665292e | |||
161f5de46d | |||
b96ef4d73c | |||
613d838f01 | |||
2dd752d2cd | |||
afa9e0a4f5 | |||
b777d1c87c | |||
8818b6adc0 | |||
5034083e71 | |||
c2720d771e | |||
0163dcc676 | |||
106bc0eead | |||
641e899784 | |||
841f7a1197 | |||
357a8c62ec |
@ -44,6 +44,8 @@
|
||||
title: PPO
|
||||
- local: reward_trainer
|
||||
title: Reward
|
||||
- local: stepwise_reward_trainer
|
||||
title: Stepwise Reward
|
||||
- local: rloo_trainer
|
||||
title: RLOO
|
||||
- local: sft_trainer
|
||||
|
@ -254,21 +254,22 @@ stepwise_example = {
|
||||
|
||||
Choosing the right dataset type depends on the task you are working on and the specific requirements of the TRL trainer you are using. Below is a brief overview of the dataset types supported by each TRL trainer.
|
||||
|
||||
| Trainer | Expected dataset type |
|
||||
| ----------------------- | ------------------------------------------------------------------------------------------------------ |
|
||||
| [`BCOTrainer`] | [Unpaired preference](#unpaired-preference) |
|
||||
| [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`GKDTrainer`] | [Prompt-completion](#prompt-completion) |
|
||||
| [`IterativeSFTTrainer`] | [Unpaired preference](#unpaired-preference) |
|
||||
| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`NashMDTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`PPOTrainer`] | Tokenized language modeling |
|
||||
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
|
||||
| [`SFTTrainer`] | [Language modeling](#language-modeling) |
|
||||
| [`XPOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| Trainer | Expected dataset type |
|
||||
| ------------------------- | ------------------------------------------------------------------------------------------------------ |
|
||||
| [`BCOTrainer`] | [Unpaired preference](#unpaired-preference) |
|
||||
| [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`GKDTrainer`] | [Prompt-completion](#prompt-completion) |
|
||||
| [`IterativeSFTTrainer`] | [Unpaired preference](#unpaired-preference) |
|
||||
| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`NashMDTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`PPOTrainer`] | Tokenized language modeling |
|
||||
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
|
||||
| [`SFTTrainer`] | [Language modeling](#language-modeling) |
|
||||
| [`StepwiseRewardTrainer`] | [Stepwise supervision](#stepwise-supervision) |
|
||||
| [`XPOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
|
||||
<Tip>
|
||||
|
||||
|
85
docs/source/stepwise_reward_trainer.mdx
Normal file
85
docs/source/stepwise_reward_trainer.mdx
Normal file
@ -0,0 +1,85 @@
|
||||
# Stepwise Reward Modeling
|
||||
|
||||
## Overview
|
||||
|
||||
Stepwise or process reward models were proposed in [Solving math word problems with processand outcome-based feedback](https://arxiv.org/pdf/2211.14275) by Jonathan Uesato, Nate Kushman, Ramana Kumar, Francis Song, Noah Siegel, Lisa Wang, Antonia Creswell, Geoffrey Irving and Irina Higgins.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
> Recent work has shown that asking language models to generate reasoning steps improves performance on many reasoning tasks. When moving beyond prompting, this raises the question of how we should supervise such models: outcome-based approaches which supervise the final result, or process-based approaches which supervise the reasoning process itself? Differences between these approaches might naturally be expected not just in final-answer errors but also in reasoning errors, which can be difficult to detect and are problematic in many real-world domains such as education. We run the first comprehensive comparison between process- and outcome-based approaches trained on a natural language task, GSM8K. We find that pure outcome-based supervision produces similar final-answer error rates with less label supervision. However, for correct reasoning steps we find it necessary to use processbased supervision or supervision from learned reward models that emulate process-based feedback. In total, we improve the previous best results from 16.8% → 12.7% final-answer error and 14.0% → 3.4% reasoning error among final-answer-correct solutions.
|
||||
|
||||
This post-training method was contributed by [Gaetan Lopez](https://github.com/gaetanlop), [Lewis Tunstall](https://huggingface.co/lewtun) and [Quentin Gallouédec](https://huggingface.co/qgallouedec)
|
||||
|
||||
## Usage tips
|
||||
|
||||
The [`StepwiseRewardTrainer`] is a wrapper around the [`Trainer`] class. It needs two parameters to be set via the [`StepwiseRewardConfig`], namely:
|
||||
* `max_length`: controls the maximum length of the sequences, where a sequence is composed of the prompt and the concatenation of each completion step.
|
||||
* `step_separator`: indicates the separator used to separate each step of the reasoning process. By default, it is set to `"\n"`.
|
||||
|
||||
The basic API is as follows:
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
from transformers import AutoModelForTokenClassification, AutoTokenizer
|
||||
from trl import StepwiseRewardTrainer, StepwiseRewardConfig
|
||||
|
||||
|
||||
NUM_DUMMY_SAMPLES = 100
|
||||
|
||||
model = AutoModelForTokenClassification.from_pretrained("Qwen/Qwen2-0.5B-Instruct", num_labels=2)
|
||||
|
||||
train_dataset = Dataset.from_dict(
|
||||
{
|
||||
"prompt": [
|
||||
"Which number has a larger absolute value, -13.1 or 7.0?",
|
||||
],
|
||||
"completions": [
|
||||
["The absolute value of -13.1 is 13.1.","The absolute value of 7.0 is 7.0", "7.0 is larger than 13.1.", "Hence, in absolute value, 7.0 is larger than -13.1."]
|
||||
]
|
||||
"labels": [
|
||||
[True, True, False, False]
|
||||
]
|
||||
* NUM_DUMMY_SAMPLES
|
||||
}
|
||||
)
|
||||
eval_dataset = Dataset.from_dict(
|
||||
{
|
||||
"prompt": [
|
||||
"Is 19 divisible by 6?",
|
||||
],
|
||||
"completion": [
|
||||
["Dividing 19 by 6 gives a remainder of 1.", "A number is divisible by another number if the division results in no remainder.", "Hence, 19 is not divisible by 6."]
|
||||
]
|
||||
"labels": [
|
||||
[True, True, True]
|
||||
]
|
||||
* NUM_DUMMY_SAMPLES
|
||||
}
|
||||
)
|
||||
|
||||
config = StepwiseRewardConfig(output_dir="stepwise-reward-model", per_device_train_batch_size=1, max_length=512, step_separator="\n")
|
||||
trainer = StepwiseRewardTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Expected dataset format
|
||||
|
||||
The dataset should be formatted as a [Stepwise Supervision](dataset_formats#stepwise-supervision) dataset, which implies that it should contain the following columns: `prompt`, `completions` and `labels`, where `completions` contains a list of reasoning steps and `labels` a list of booleans or floats indicating the correctness of each step.
|
||||
|
||||
The [`StepwiseRewardTrainer`] only supports [standard](dataset_formats#standard) dataset format.
|
||||
|
||||
You can also use a pretokenized dataset, in which case the dataset should contain the following columns: `input_ids`, `attention_mask` and `labels`.
|
||||
|
||||
## StepwiseRewardTrainer
|
||||
|
||||
[[autodoc]] StepwiseRewardTrainer
|
||||
|
||||
## StepwiseRewardConfig
|
||||
|
||||
[[autodoc]] StepwiseRewardConfig
|
131
examples/scripts/stepwise_reward_modeling.py
Normal file
131
examples/scripts/stepwise_reward_modeling.py
Normal file
@ -0,0 +1,131 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Full training:
|
||||
python examples/scripts/stepwise_reward_modeling.py \
|
||||
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
|
||||
--dataset_name trl-lib/prm800k \
|
||||
--output_dir Qwen2-0.5B-Reward \
|
||||
--per_device_train_batch_size 8 \
|
||||
--num_train_epochs 1 \
|
||||
--gradient_checkpointing True \
|
||||
--learning_rate 1.0e-5 \
|
||||
--logging_steps 25 \
|
||||
--eval_strategy steps \
|
||||
--eval_steps 50
|
||||
|
||||
LoRA:
|
||||
python examples/scripts/stepwise_reward_modeling.py \
|
||||
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
|
||||
--dataset_name trl-lib/prm800k \
|
||||
--output_dir Qwen2-0.5B-Reward-LoRA \
|
||||
--per_device_train_batch_size 8 \
|
||||
--num_train_epochs 1 \
|
||||
--gradient_checkpointing True \
|
||||
--learning_rate 1.0e-4 \
|
||||
--logging_steps 25 \
|
||||
--eval_strategy steps \
|
||||
--eval_steps 50
|
||||
--use_peft \
|
||||
--lora_r 32 \
|
||||
--lora_alpha 16
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForTokenClassification, AutoTokenizer, HfArgumentParser
|
||||
|
||||
from trl import (
|
||||
ModelConfig,
|
||||
ScriptArguments,
|
||||
StepwiseRewardConfig,
|
||||
StepwiseRewardTrainer,
|
||||
get_kbit_device_map,
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
setup_chat_format,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser((ScriptArguments, StepwiseRewardConfig, ModelConfig))
|
||||
script_args, training_args, model_config = parser.parse_args_into_dataclasses()
|
||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||
|
||||
################
|
||||
# Model & Tokenizer
|
||||
################
|
||||
torch_dtype = (
|
||||
model_config.torch_dtype
|
||||
if model_config.torch_dtype in ["auto", None]
|
||||
else getattr(torch, model_config.torch_dtype)
|
||||
)
|
||||
quantization_config = get_quantization_config(model_config)
|
||||
model_kwargs = dict(
|
||||
revision=model_config.model_revision,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True
|
||||
)
|
||||
model = AutoModelForTokenClassification.from_pretrained(
|
||||
model_config.model_name_or_path, num_labels=2, trust_remote_code=model_config.trust_remote_code, **model_kwargs
|
||||
)
|
||||
# Align padding tokens between tokenizer and model
|
||||
model.config.pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
# If post-training a base model, use ChatML as the default template
|
||||
if tokenizer.chat_template is None:
|
||||
model, tokenizer = setup_chat_format(model, tokenizer)
|
||||
|
||||
if model_config.use_peft and model_config.lora_task_type != "TOKEN_CLS":
|
||||
warnings.warn(
|
||||
"You are using a `task_type` that is different than `TOKEN_CLS` for PEFT. This will lead to silent bugs"
|
||||
" Make sure to pass --lora_task_type TOKEN_CLS when using this script with PEFT."
|
||||
)
|
||||
|
||||
##############
|
||||
# Load dataset
|
||||
##############
|
||||
dataset = load_dataset(script_args.dataset_name)
|
||||
|
||||
##########
|
||||
# Training
|
||||
##########
|
||||
trainer = StepwiseRewardTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
args=training_args,
|
||||
train_dataset=dataset[script_args.dataset_train_split],
|
||||
eval_dataset=dataset[script_args.dataset_test_split],
|
||||
peft_config=get_peft_config(model_config),
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
############################
|
||||
# Save model and push to Hub
|
||||
############################
|
||||
trainer.save_model(training_args.output_dir)
|
||||
metrics = trainer.evaluate()
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
# Save and push to hub
|
||||
trainer.save_model(training_args.output_dir)
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
189
tests/test_stepwise_reward_trainer.py
Normal file
189
tests/test_stepwise_reward_trainer.py
Normal file
@ -0,0 +1,189 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoModelForTokenClassification, AutoTokenizer, EvalPrediction
|
||||
from transformers.testing_utils import require_peft
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from trl import StepwiseRewardConfig, StepwiseRewardTrainer
|
||||
from trl.trainer import compute_accuracy
|
||||
from trl.trainer.stepwise_reward_trainer import _tokenize_fn
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import LoraConfig, TaskType
|
||||
|
||||
|
||||
class StepwiseRewardTrainerTester(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
|
||||
self.model = AutoModelForTokenClassification.from_pretrained(self.model_id)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
def test_token_level_accuracy(self):
|
||||
dummy_eval_predictions = EvalPrediction(
|
||||
torch.FloatTensor([[[0.1, 0.9], [0.1, 0.9]], [[0.1, 0.9], [0.9, 0.1]]]),
|
||||
torch.LongTensor([[-100, 1], [-100, 1]]),
|
||||
)
|
||||
accuracy = compute_accuracy(dummy_eval_predictions)
|
||||
self.assertEqual(accuracy["accuracy"], 0.5)
|
||||
|
||||
@parameterized.expand([True, False])
|
||||
def test_preprocessing(self, train_on_last_step):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_stepwise_supervision", split="train")
|
||||
training_args = StepwiseRewardConfig(
|
||||
output_dir=tmp_dir, report_to="none", max_length=512, train_on_last_step=train_on_last_step
|
||||
)
|
||||
trainer = StepwiseRewardTrainer(
|
||||
model=self.model,
|
||||
args=training_args,
|
||||
processing_class=self.tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
)
|
||||
dummy_dataset = dummy_dataset.map(
|
||||
_tokenize_fn,
|
||||
batched=True,
|
||||
fn_kwargs={
|
||||
"tokenizer": self.tokenizer,
|
||||
"max_length": 512,
|
||||
"step_separator": "\n",
|
||||
"train_on_last_step": train_on_last_step,
|
||||
},
|
||||
remove_columns=dummy_dataset.features,
|
||||
)
|
||||
self.assertDictEqual(trainer.train_dataset[:], dummy_dataset[:])
|
||||
|
||||
@parameterized.expand([True, False])
|
||||
def test_train_full(self, train_on_last_step):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_stepwise_supervision", split="train")
|
||||
training_args = StepwiseRewardConfig(
|
||||
output_dir=tmp_dir,
|
||||
max_steps=3,
|
||||
report_to="none",
|
||||
max_length=512,
|
||||
train_on_last_step=train_on_last_step,
|
||||
)
|
||||
trainer = StepwiseRewardTrainer(
|
||||
model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset
|
||||
)
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
# check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
# check the params have changed - ignore 0 biases
|
||||
if param.sum() != 0:
|
||||
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
|
||||
|
||||
@parameterized.expand([True, False])
|
||||
def test_train_full_pretokenized(self, train_on_last_step):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_stepwise_supervision", split="train")
|
||||
dummy_dataset = dummy_dataset.map(
|
||||
_tokenize_fn,
|
||||
batched=True,
|
||||
fn_kwargs={
|
||||
"tokenizer": self.tokenizer,
|
||||
"max_length": 512,
|
||||
"step_separator": "\n",
|
||||
"train_on_last_step": train_on_last_step,
|
||||
},
|
||||
remove_columns=dummy_dataset.features,
|
||||
)
|
||||
|
||||
training_args = StepwiseRewardConfig(
|
||||
output_dir=tmp_dir,
|
||||
max_steps=3,
|
||||
report_to="none",
|
||||
max_length=512,
|
||||
train_on_last_step=train_on_last_step,
|
||||
)
|
||||
trainer = StepwiseRewardTrainer(
|
||||
model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset
|
||||
)
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
# check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
# check the params have changed - ignore 0 biases
|
||||
if param.sum() != 0:
|
||||
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
|
||||
|
||||
@require_peft
|
||||
def test_train_lora(self):
|
||||
peft_config = LoraConfig(
|
||||
task_type=TaskType.TOKEN_CLS,
|
||||
inference_mode=False,
|
||||
r=8,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.1,
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_stepwise_supervision", split="train")
|
||||
training_args = StepwiseRewardConfig(output_dir=tmp_dir, max_steps=3, report_to="none", max_length=512)
|
||||
trainer = StepwiseRewardTrainer(
|
||||
model=self.model,
|
||||
args=training_args,
|
||||
processing_class=self.tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
peft_config=peft_config,
|
||||
)
|
||||
previous_trainable_params = {}
|
||||
previous_non_trainable_params = {}
|
||||
|
||||
# due to a change in the way the modules to save are dealt in PEFT.
|
||||
trainable_params_name = ["lora", "modules_to_save"]
|
||||
|
||||
# check gradients are not None
|
||||
for n, param in trainer.model.named_parameters():
|
||||
if any(t in n for t in trainable_params_name):
|
||||
previous_trainable_params[n] = param.clone()
|
||||
else:
|
||||
previous_non_trainable_params[n] = param.clone()
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
|
||||
|
||||
# check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
self.assertFalse(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12))
|
||||
|
||||
# check the non trainable params have not changed
|
||||
for n, param in previous_non_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
self.assertTrue(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12))
|
||||
|
||||
def test_tags(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_stepwise_supervision", split="train")
|
||||
training_args = StepwiseRewardConfig(output_dir=tmp_dir, report_to="none", max_length=512)
|
||||
trainer = StepwiseRewardTrainer(
|
||||
model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset
|
||||
)
|
||||
self.assertEqual(trainer.model.model_tags, trainer._tag_names)
|
@ -88,6 +88,8 @@ _import_structure = {
|
||||
"RLOOTrainer",
|
||||
"SFTConfig",
|
||||
"SFTTrainer",
|
||||
"StepwiseRewardConfig",
|
||||
"StepwiseRewardTrainer",
|
||||
"WinRateCallback",
|
||||
"XPOConfig",
|
||||
"XPOTrainer",
|
||||
@ -178,6 +180,8 @@ if TYPE_CHECKING:
|
||||
RLOOTrainer,
|
||||
SFTConfig,
|
||||
SFTTrainer,
|
||||
StepwiseRewardConfig,
|
||||
StepwiseRewardTrainer,
|
||||
WinRateCallback,
|
||||
XPOConfig,
|
||||
XPOTrainer,
|
||||
|
@ -68,6 +68,8 @@ _import_structure = {
|
||||
"rloo_trainer": ["RLOOTrainer"],
|
||||
"sft_config": ["SFTConfig"],
|
||||
"sft_trainer": ["SFTTrainer"],
|
||||
"stepwise_reward_config": ["StepwiseRewardConfig"],
|
||||
"stepwise_reward_trainer": ["StepwiseRewardTrainer"],
|
||||
"utils": [
|
||||
"AdaptiveKLController",
|
||||
"ConstantLengthDataset",
|
||||
@ -136,6 +138,8 @@ if TYPE_CHECKING:
|
||||
from .rloo_trainer import RLOOTrainer
|
||||
from .sft_config import SFTConfig
|
||||
from .sft_trainer import SFTTrainer
|
||||
from .stepwise_reward_config import StepwiseRewardConfig
|
||||
from .stepwise_reward_trainer import StepwiseRewardTrainer
|
||||
from .utils import (
|
||||
AdaptiveKLController,
|
||||
ConstantLengthDataset,
|
||||
|
47
trl/trainer/stepwise_reward_config.py
Normal file
47
trl/trainer/stepwise_reward_config.py
Normal file
@ -0,0 +1,47 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from transformers import TrainingArguments
|
||||
|
||||
|
||||
@dataclass
|
||||
class StepwiseRewardConfig(TrainingArguments):
|
||||
r"""
|
||||
Configuration class for the [`StepwiseRewardTrainer`].
|
||||
|
||||
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
||||
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
||||
command line.
|
||||
|
||||
Parameters:
|
||||
max_length (`Optional[int]`, *optional*, defaults to `None`):
|
||||
Maximum length of the sequences (prompt + completion) used for truncation.
|
||||
max_completion_length (`Optional[int]`, *optional*, defaults to `None`):
|
||||
Maximum length of the completion used for truncation. The completion is the concatenation of the steps.
|
||||
step_separator (`str`, *optional*, defaults to `"\n"`):
|
||||
Separator used to separate each step of the reasoning process.
|
||||
train_on_last_step_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to train only on the last step.
|
||||
dataset_num_proc (`int`, *optional*, defaults to `None`):
|
||||
Number of processes to use for processing the dataset.
|
||||
"""
|
||||
|
||||
max_length: Optional[int] = None
|
||||
max_completion_length: Optional[int] = None
|
||||
step_separator: str = "\n"
|
||||
train_on_last_step_only: bool = False
|
||||
dataset_num_proc: Optional[int] = None
|
296
trl/trainer/stepwise_reward_trainer.py
Normal file
296
trl/trainer/stepwise_reward_trainer.py
Normal file
@ -0,0 +1,296 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import textwrap
|
||||
import warnings
|
||||
from itertools import chain
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from accelerate import PartialState
|
||||
from datasets import Dataset
|
||||
from transformers import (
|
||||
BaseImageProcessor,
|
||||
DataCollator,
|
||||
DataCollatorForTokenClassification,
|
||||
FeatureExtractionMixin,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
ProcessorMixin,
|
||||
Trainer,
|
||||
is_wandb_available,
|
||||
)
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_utils import EvalPrediction
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from .stepwise_reward_config import StepwiseRewardConfig
|
||||
from .utils import compute_accuracy, generate_model_card
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
|
||||
class StepwiseRewardTrainer(Trainer):
|
||||
_tag_names = ["trl", "stepwise-reward-trainer"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
|
||||
args: Optional[StepwiseRewardConfig] = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
||||
processing_class: Optional[
|
||||
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
||||
] = None,
|
||||
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
||||
callbacks: Optional[list[TrainerCallback]] = None,
|
||||
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
|
||||
None,
|
||||
None,
|
||||
),
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
peft_config: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Initialize StepwiseRewardTrainer.
|
||||
|
||||
Args:
|
||||
model (`transformers.PreTrainedModel`):
|
||||
The model to train, preferably an `AutoModelForTokenClassification`.
|
||||
args (`StepwiseRewardConfig`):
|
||||
The arguments to use for training.
|
||||
data_collator (`transformers.DataCollator`):
|
||||
The data collator to use for training. If None is specified, the default data collator (`DataCollatorForTokenClassification`) will be used
|
||||
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
||||
train_dataset (`datasets.Dataset`):
|
||||
The dataset to use for training.
|
||||
eval_dataset (`datasets.Dataset`):
|
||||
The dataset to use for evaluation.
|
||||
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
||||
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
||||
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
||||
reuse the fine-tuned model.
|
||||
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
||||
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
||||
compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
|
||||
The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.
|
||||
callbacks (`list[transformers.TrainerCallback]`):
|
||||
The callbacks to use for training.
|
||||
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
||||
The optimizer and scheduler to use for training.
|
||||
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
||||
The function to use to preprocess the logits before computing the metrics.
|
||||
peft_config (`dict`, defaults to `None`):
|
||||
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
||||
"""
|
||||
if not is_peft_available() and peft_config is not None:
|
||||
raise ValueError(
|
||||
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
||||
)
|
||||
elif is_peft_available() and peft_config is not None:
|
||||
if not isinstance(model, PeftModel):
|
||||
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
|
||||
_supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
|
||||
inspect.signature(prepare_model_for_kbit_training).parameters
|
||||
)
|
||||
|
||||
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
||||
|
||||
if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
||||
warnings.warn(
|
||||
"You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
|
||||
"please update to the latest version of peft to use `gradient_checkpointing_kwargs`."
|
||||
)
|
||||
elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
||||
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
||||
|
||||
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
||||
|
||||
model = get_peft_model(model, peft_config)
|
||||
|
||||
if compute_metrics is None:
|
||||
compute_metrics = compute_accuracy
|
||||
|
||||
if data_collator is None:
|
||||
if processing_class is None:
|
||||
raise ValueError(
|
||||
"A processing_class must be specified when using the default DataCollatorForTokenClassification"
|
||||
)
|
||||
data_collator = DataCollatorForTokenClassification(processing_class, max_length=args.max_length)
|
||||
|
||||
if "input_ids" not in train_dataset.column_names:
|
||||
with PartialState().local_main_process_first():
|
||||
fn_kwargs = {
|
||||
"tokenizer": processing_class,
|
||||
"step_separator": args.step_separator,
|
||||
"max_completion_length": args.max_completion_length,
|
||||
"train_on_last_step_only": args.train_on_last_step_only,
|
||||
}
|
||||
train_dataset = train_dataset.map(
|
||||
self.tokenize_row,
|
||||
fn_kwargs=fn_kwargs,
|
||||
num_proc=args.dataset_num_proc,
|
||||
remove_columns=train_dataset.features,
|
||||
desc="Tokenizing train dataset",
|
||||
)
|
||||
|
||||
if eval_dataset is not None:
|
||||
eval_dataset = eval_dataset.map(
|
||||
self.tokenize_row,
|
||||
fn_kwargs=fn_kwargs,
|
||||
num_proc=args.dataset_num_proc,
|
||||
remove_columns=eval_dataset.features,
|
||||
desc="Tokenizing eval dataset",
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
args=args,
|
||||
data_collator=data_collator,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=processing_class,
|
||||
model_init=model_init,
|
||||
compute_metrics=compute_metrics,
|
||||
callbacks=callbacks,
|
||||
optimizers=optimizers,
|
||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
||||
)
|
||||
|
||||
# Add tags for models that have been loaded with the correct transformers version
|
||||
if hasattr(self.model, "add_model_tags"):
|
||||
self.model.add_model_tags(self._tag_names)
|
||||
|
||||
@staticmethod
|
||||
def tokenize_row(features, tokenizer, step_separator, max_completion_length, train_on_last_step_only):
|
||||
"""
|
||||
Tokenize a row of the dataset.
|
||||
|
||||
Args:
|
||||
features (`dict[str, str]`):
|
||||
Row of the dataset, should contain the keys `"prompt"`, `"completions"`, and `"labels"`.
|
||||
tokenizer (`PreTrainedTokenizerBase`):
|
||||
Tokenizer used to process the data.
|
||||
step_separator (`str`):
|
||||
Separator between steps in the completion.
|
||||
max_completion_length (`int` or `None`):
|
||||
Maximum length of the completion sequences. If `None`, the completion sequences are not truncated.
|
||||
train_on_last_step_only (`bool`):
|
||||
Whether to train only on the last step. If `True`, the labels are `-100` for all tokens except the last
|
||||
token of the completion.
|
||||
|
||||
Returns:
|
||||
`dict[str, list[int]]`:
|
||||
Tokenized sequences with the keys `"input_ids"`, and `"labels".
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
|
||||
>>> features = {"prompt": "Which number is larger, 9.8 or 9.11?",
|
||||
... "completions": ["11 is greater than 8.",
|
||||
... "Hence, 9.11 > 9.8."],
|
||||
... "labels": [True, False]}
|
||||
>>> StepwiseRewardTrainer.tokenize_row(features, tokenizer, "\n", max_completion_length=None, train_on_last_step_only=False)
|
||||
{'input_ids': [23085, 1372, 374, 8131, 11, 220, 24, 13, 23, 476, 220, 24, 13, 16, 16, 30, 16, 16, 374, 7046, 1091, 220, 23, 13, 198, 39, 763, 11, 220, 24, 13, 16, 16, 861, 220, 24, 13, 23, 13, 198],
|
||||
'labels': [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0]}
|
||||
```
|
||||
"""
|
||||
# Tokenize the prompt and completions
|
||||
prompt_ids = tokenizer(features["prompt"])["input_ids"]
|
||||
completions_ids = [tokenizer(completion)["input_ids"] for completion in features["completions"]]
|
||||
if train_on_last_step_only:
|
||||
labels = [-100] * (len(features["labels"]) - 1) + [int(features["labels"][-1])]
|
||||
else:
|
||||
labels = [int(label) for label in features["labels"]]
|
||||
|
||||
# Get the ID of the separator token and add it to the completions
|
||||
separator_ids = tokenizer.encode(step_separator)
|
||||
completions_ids = [completion + separator_ids for completion in completions_ids]
|
||||
|
||||
# Create the label
|
||||
labels = [[-100] * (len(completion) - 1) + [label] for completion, label in zip(completions_ids, labels)]
|
||||
|
||||
# Join the completions and labels steps
|
||||
completion_ids = list(chain(*completions_ids))
|
||||
labels = list(chain(*labels))
|
||||
|
||||
if max_completion_length is not None:
|
||||
completion_ids = completion_ids[:max_completion_length]
|
||||
labels = labels[:max_completion_length]
|
||||
|
||||
return {"input_ids": prompt_ids + completion_ids, "labels": labels}
|
||||
|
||||
def create_model_card(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
tags: Union[str, list[str], None] = None,
|
||||
):
|
||||
"""
|
||||
Creates a draft of a model card using the information available to the `Trainer`.
|
||||
Args:
|
||||
model_name (`str`, *optional*, defaults to `None`):
|
||||
The name of the model.
|
||||
dataset_name (`str`, *optional*, defaults to `None`):
|
||||
The name of the dataset used for training.
|
||||
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
||||
Tags to be associated with the model card.
|
||||
"""
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
||||
base_model = self.model.config._name_or_path
|
||||
else:
|
||||
base_model = None
|
||||
|
||||
tags = tags or []
|
||||
if isinstance(tags, str):
|
||||
tags = [tags]
|
||||
|
||||
if hasattr(self.model.config, "unsloth_version"):
|
||||
tags.append("unsloth")
|
||||
|
||||
citation = textwrap.dedent("""\
|
||||
@article{uesato2022solving,
|
||||
title = {Solving Math Word Problems With Process- and Outcome-Based Feedback},
|
||||
author = {Uesato, Jonathan and Kushman, Nate and Kumar, Ramana and Song, Francis and Siegel, Noah and Wang, Lisa and Creswell, Antonia and Irving, Geoffrey and Higgins, Irina},
|
||||
year = 2022,
|
||||
journal = {arXiv preprint arXiv:2211.14275}
|
||||
}""")
|
||||
|
||||
model_card = generate_model_card(
|
||||
base_model=base_model,
|
||||
model_name=model_name,
|
||||
hub_model_id=self.hub_model_id,
|
||||
dataset_name=dataset_name,
|
||||
tags=tags,
|
||||
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
||||
trainer_name="Stepwise Reward",
|
||||
trainer_citation=citation,
|
||||
paper_title="Solving math word problems with process-and outcome-based feedback",
|
||||
)
|
||||
|
||||
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
@ -769,13 +769,23 @@ def get_global_statistics(
|
||||
|
||||
def compute_accuracy(eval_pred) -> Dict[str, float]:
|
||||
predictions, labels = eval_pred
|
||||
# Here, predictions is rewards_chosen and rewards_rejected.
|
||||
# We want to see how much of the time rewards_chosen > rewards_rejected.
|
||||
if np.array(predictions[:, 0] == predictions[:, 1], dtype=float).sum() > 0:
|
||||
warnings.warn(
|
||||
f"There are {np.array(predictions[:, 0] == predictions[:, 1]).sum()} out of {len(predictions[:, 0])} instances where the predictions for both options are equal. As a consequence the accuracy can be misleading."
|
||||
if predictions.ndim == 3:
|
||||
# Token classification task.
|
||||
# Used to compute the accuracy in the stepwise_reward_trainer.
|
||||
predictions = np.argmax(predictions, axis=2)
|
||||
|
||||
predictions = np.array(
|
||||
[p for prediction, label in zip(predictions, labels) for (p, lbl) in zip(prediction, label) if lbl != -100]
|
||||
)
|
||||
predictions = np.argmax(predictions, axis=1)
|
||||
labels = np.array([lbl for label in labels for lbl in label if lbl != -100])
|
||||
else:
|
||||
# Here, predictions is rewards_chosen and rewards_rejected.
|
||||
# We want to see how much of the time rewards_chosen > rewards_rejected.
|
||||
if np.array(predictions[:, 0] == predictions[:, 1], dtype=float).sum() > 0:
|
||||
warnings.warn(
|
||||
f"There are {np.array(predictions[:, 0] == predictions[:, 1]).sum()} out of {len(predictions[:, 0])} instances where the predictions for both options are equal. As a consequence the accuracy can be misleading."
|
||||
)
|
||||
predictions = np.argmax(predictions, axis=1)
|
||||
|
||||
accuracy = np.array(predictions == labels, dtype=float).mean().item()
|
||||
return {"accuracy": accuracy}
|
||||
|
Reference in New Issue
Block a user