Merge branch 'main' into multi-turn

This commit is contained in:
Quentin Gallouédec
2025-09-25 21:17:51 -06:00
committed by GitHub
24 changed files with 1641 additions and 1075 deletions

View File

@ -101,10 +101,44 @@ To leverage GSPO-token, the user will need to provide the per-token advantage \
</Tip>
## Usage
### GRPO With Replay Buffer
This experimental trainer, trains a model with GRPO but replaces groups (and corresponding completions) that have 0 standard deviation with groups with high rewards and standard deviation that've been used to train a model in prior batches.
#### Usage
```python
from trl.experimental.new_trainer import NewTrainer
from trl.experimental.grpo_with_replay_buffer import GRPOWithReplayBufferTrainer
from datasets import load_dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
# Guarantee that some rewards have 0 std
def custom_reward_func(completions, **kwargs):
if torch.rand(1).item() < 0.25:
return [0] * len(completions) # simulate some None rewards
else:
return torch.rand(len(completions)).tolist()
training_args = GRPOWithReplayBufferConfig(
output_dir=self.tmp_dir,
learning_rate=1e-4,
per_device_train_batch_size=4,
num_generations=4,
max_completion_length=8,
replay_buffer_size=8,
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=[custom_reward_func],
args=training_args,
train_dataset=dataset,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
```
To silence the runtime notice:

View File

@ -199,7 +199,9 @@ pip install trl[vllm]
We support two ways of using vLLM during training: **server mode** and **colocate mode**.
<Tip>
By default, Truncated Importance Sampling is activated for vLLM generation to address the generation-training mismatch that occurs when using different frameworks. This can be turned off by setting `vllm_importance_sampling_correction=False`. For more information, see [Truncated Importance Sampling](paper_index#truncated-importance-sampling)
</Tip>
#### 🔌 Option 1: Server mode

View File

@ -259,8 +259,6 @@ plt.tight_layout()
plt.show()
```
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/online_dpo_scaling.png)
The online DPO checkpoint gets increasingly more win rate as we scale up the model sizes. This is a good sign that the online DPO implementation is working as intended.
## OnlineDPOTrainer

View File

@ -181,7 +181,9 @@ To train on completion only, use a [prompt-completion](dataset_formats#prompt-co
![train_on_completion](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/train_on_completion.png)
<Tip>
Training on completion only is compatible with training on assistant messages only. In this case, use a [conversational](dataset_formats#conversational) [prompt-completion](dataset_formats#prompt-completion) dataset and set `assistant_only_loss=True` in the [`SFTConfig`].
</Tip>
### Train adapters with PEFT

View File

@ -29,6 +29,11 @@ from transformers.testing_utils import require_liger_kernel, require_peft, requi
from transformers.utils import is_peft_available
from trl import GRPOConfig, GRPOTrainer
from trl.experimental.grpo_with_replay_buffer.grpo_with_replay_buffer_config import GRPOWithReplayBufferConfig
from trl.experimental.grpo_with_replay_buffer.grpo_with_replay_buffer_trainer import (
GRPOWithReplayBufferTrainer,
ReplayBuffer,
)
from trl.experimental.gspo_token import GRPOTrainer as GSPOTokenTrainer
from .testing_utils import TrlTestCase, require_vllm
@ -1709,6 +1714,273 @@ class GRPOTrainerTester(TrlTestCase):
@pytest.mark.low_priority
class TestReplayBuffer(unittest.TestCase):
def setUp(self):
self.replay_buffer = ReplayBuffer(max_size=5)
def test_add(self):
# Add elements to the replay buffer
scores = [0.5, 0.8, 0.3, 0.9, 0.7]
data = [
{"id": 1},
{"id": 2},
{"id": 3},
{"id": 4},
{"id": 5},
]
self.replay_buffer.add(scores, data)
# Check if the buffer contains the correct number of elements
self.assertEqual(len(self.replay_buffer.heap), 5)
# Check if the buffer maintains the min-heap property
heap_scores = [item[0] for item in self.replay_buffer.heap]
self.assertEqual(heap_scores[0], min(heap_scores))
self.assertEqual(heap_scores[0], 0.3)
def test_add_more_than_maxlen(self):
# Add elements to the replay buffer
scores = [0.5, 0.8, 0.3, 0.9, 0.7, 0.6, 0.4]
data = [
{"id": 1},
{"id": 2},
{"id": 3},
{"id": 4},
{"id": 5},
{"id": 6},
{"id": 7},
]
self.replay_buffer.add(scores, data)
# Check if the buffer contains the correct number of elements
self.assertEqual(len(self.replay_buffer.heap), 5)
# Check if the buffer maintains the min-heap property
heap_scores = [item[0] for item in self.replay_buffer.heap]
self.assertEqual(heap_scores[0], min(heap_scores))
self.assertEqual(heap_scores[0], 0.5) # 0.3 and 0.4 should be removed
def test_sample(self):
# Add elements to the replay buffer
scores = [0.5, 0.8, 0.3, 0.9, 0.7]
data = [
{"id": 1},
{"id": 2},
{"id": 3},
{"id": 4},
{"id": 5},
]
self.replay_buffer.add(scores, data)
# Sample elements from the buffer
sampled = self.replay_buffer.sample(num_samples=3)
# Check if the sampled elements are from the buffer
self.assertEqual(len(sampled), 3)
for item in sampled:
self.assertIn(item, [entry[1] for entry in self.replay_buffer.heap])
@pytest.mark.low_priority
class TestUpdateWithReplayBuffer(unittest.TestCase):
def setUp(self):
config = GRPOWithReplayBufferConfig(
replay_buffer_size=5,
)
self.trainer = GRPOWithReplayBufferTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=config,
train_dataset=None,
)
self.trainer.replay_buffer = ReplayBuffer(max_size=5)
self.trainer.num_generations = 2
def _prepopulate_buffer(self, with_pixels=False, with_logprobs=False):
scores = [0.1, 0.9]
data = [
{
"prompt_ids": torch.tensor([[100, 101], [102, 103]]),
"prompt_mask": torch.ones(2, 2, dtype=torch.long),
"completion_ids": torch.tensor([[5, 6], [7, 8]]),
"completion_mask": torch.ones(2, 2, dtype=torch.long),
"advantages": torch.tensor([[0.5, 0.6]]),
**({"pixel_values": torch.randn(2, 3, 224, 224)} if with_pixels else {}),
**({"old_per_token_logps": torch.randn(2, 2)} if with_logprobs else {}),
},
{
"prompt_ids": torch.tensor([[104, 105], [106, 107]]),
"prompt_mask": torch.ones(2, 2, dtype=torch.long),
"completion_ids": torch.tensor([[13, 14], [15, 16]]),
"completion_mask": torch.ones(2, 2, dtype=torch.long),
"advantages": torch.tensor([[0.8, 0.85]]),
**({"pixel_values": torch.randn(2, 3, 224, 224)} if with_pixels else {}),
**({"old_per_token_logps": torch.randn(2, 2)} if with_logprobs else {}),
},
]
self.trainer.replay_buffer.add(scores, data)
def _make_inputs(self, group_advantages, with_pixels=False, with_logprobs=False):
inputs = {
"group_advantages": group_advantages,
"prompt_ids": torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]]),
"prompt_mask": torch.ones(4, 2, dtype=torch.long),
"completion_ids": torch.tensor([[9, 10], [11, 12], [13, 14], [15, 16]]),
"completion_mask": torch.ones(4, 2, dtype=torch.long),
"prompt_inputs": {"pixel_values": torch.randn(4, 3, 224, 224)} if with_pixels else {},
"old_per_token_logps": torch.randn(4, 2) if with_logprobs else None,
}
inputs["group_std_rewards"] = group_advantages.std(dim=1).expand_as(group_advantages)
return inputs
def test_update_with_replay_buffer_no_variance(self):
self._prepopulate_buffer(with_pixels=True, with_logprobs=True)
group_advantages = torch.tensor([[0.5, 0.5], [0.8, 0.8]]) # no variance
inputs = self._make_inputs(group_advantages, with_pixels=True, with_logprobs=True)
original_prompt_ids = inputs["prompt_ids"].clone()
outputs = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4)
self.assertIsNotNone(outputs)
self.assertIn("pixel_values", outputs)
self.assertIn("old_per_token_logps", outputs)
self.assertEqual(len(self.trainer.replay_buffer.heap), 2)
for pid in outputs["prompt_ids"]:
self.assertNotIn(pid.tolist(), original_prompt_ids.tolist())
def test_update_with_replay_buffer_with_variance(self):
self._prepopulate_buffer()
group_advantages = torch.tensor([[0.6, 0.4], [0.7, 1.2]]) # has variance
inputs = self._make_inputs(group_advantages)
sampled = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4)
self.assertEqual(len(self.trainer.replay_buffer.heap), 4) # grew
self.assertIsNone(sampled)
def test_update_with_mixed_variance(self):
self._prepopulate_buffer()
group_advantages = torch.tensor([[0.6, 0.6], [0.3, 0.45]]) # one no-variance, one variance
inputs = self._make_inputs(group_advantages)
original_prompt_ids = inputs["prompt_ids"].clone().view(-1, self.trainer.num_generations, 2).tolist()
outputs = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4)
self.assertEqual(len(self.trainer.replay_buffer.heap), 3) # grew by 1
output_prompt_ids = outputs["prompt_ids"].view(-1, self.trainer.num_generations, 2).tolist()
buffer_ids = [item[1]["prompt_ids"].tolist() for item in self.trainer.replay_buffer.heap]
found_from_buffer = any(pid in buffer_ids for pid in output_prompt_ids)
found_from_original = any(pid in original_prompt_ids for pid in output_prompt_ids)
self.assertTrue(found_from_buffer)
self.assertTrue(found_from_original)
self.assertNotIn([[1, 2], [3, 4]], output_prompt_ids) # excluded no-variance group
def test_update_with_inputs_different_seq_len(self):
"""
Test with inputs where the sequence lengths are different from the prepopulated buffer.
"""
self._prepopulate_buffer()
pad_token_id = self.trainer.tokenizer.pad_token_id
group_advantages = torch.tensor([[0.6, 0.6], [0.3, 0.45]]) # one no-variance, one variance
inputs = {
"group_advantages": group_advantages,
"prompt_ids": torch.tensor(
[
[1, 2, pad_token_id],
[1, 2, pad_token_id],
[3, 4, 5],
[3, 4, 5],
]
),
"prompt_mask": torch.tensor([[1, 1, 0], [1, 1, 0], [1, 1, 1], [1, 1, 1]], dtype=torch.long),
"completion_ids": torch.tensor(
[
[1009, 1010, pad_token_id],
[1011, 1012, 1013],
[1013, 1014, pad_token_id],
[1015, 1016, 1017],
]
),
"completion_mask": torch.tensor([[1, 1, 0], [1, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.long),
"prompt_inputs": {},
}
inputs["group_std_rewards"] = group_advantages.std(dim=1).expand_as(group_advantages)
outputs_after_sampling = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4)
# Seq length of current batch should be preserved
self.assertEqual(outputs_after_sampling["prompt_ids"].shape[-1], 3)
self.assertEqual(len(self.trainer.replay_buffer.heap), 3)
output_prompt_ids = outputs_after_sampling["prompt_ids"].view(-1, self.trainer.num_generations, 3).tolist()
buffered_prompt_completion_ids = [
(item[1]["prompt_ids"].tolist(), item[1]["completion_ids"].tolist())
for item in self.trainer.replay_buffer.heap
]
buffered_prompt_ids, buffered_completion_ids = zip(*buffered_prompt_completion_ids)
# Check for new entry with seq len 3 in buffer
self.assertIn([[3, 4, 5], [3, 4, 5]], buffered_prompt_ids) # excluded no-variance group
self.assertIn(
[[1013, 1014, pad_token_id], [1015, 1016, 1017]], buffered_completion_ids
) # excluded no-variance group
# Check that sampled outputs contain one group with prompt_ids starting with a pad token
self.assertTrue(
[
[pad_token_id, 101, 102],
[pad_token_id, 102, 103],
]
in output_prompt_ids
or [
[pad_token_id, 104, 105],
[pad_token_id, 106, 107],
]
in output_prompt_ids
)
@pytest.mark.low_priority
class TestGRPOWithReplayBufferTrainer(TrlTestCase):
def test_training_with_replay_buffer(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
# Guarantee that some rewards have 0 std
def custom_reward_func(completions, **kwargs):
if torch.rand(1).item() < 0.25:
return [0] * len(completions) # simulate some None rewards
else:
return torch.rand(len(completions)).tolist()
training_args = GRPOWithReplayBufferConfig(
output_dir=self.tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=4, # reduce the batch size to reduce memory usage
num_generations=4, # reduce the number of generations to reduce memory usage
max_completion_length=8, # reduce the completion length to reduce memory usage
replay_buffer_size=8,
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=[custom_reward_func],
args=training_args,
train_dataset=dataset,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
class GSPOTokenTrainerTester(TrlTestCase):
def test_training(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

View File

@ -0,0 +1,16 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .grpo_with_replay_buffer_config import GRPOWithReplayBufferConfig
from .grpo_with_replay_buffer_trainer import GRPOWithReplayBufferTrainer, ReplayBuffer

View File

@ -0,0 +1,34 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from trl.trainer.grpo_config import GRPOConfig
@dataclass
class GRPOWithReplayBufferConfig(GRPOConfig):
"""
New Parameters:
replay_buffer_size (`int`, *optional*, defaults to `0`):
A cache that stores the rollouts with the highest advantage scores and variance per group. If a new
group has 0 variance, it is replaced with a group sampled from the replay buffer.
"""
replay_buffer_size: int = field(
default=64,
metadata={
"help": "A cache that stores the rollouts with the highest advantage scores and variance per group. If a new group has 0 variance, it is replaced with a group sampled from the replay buffer."
},
)

View File

@ -0,0 +1,988 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import heapq
import re
from contextlib import nullcontext
from typing import Any, Optional, Union
import torch
from accelerate.utils import broadcast_object_list, gather_object
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers.utils import is_flash_attn_2_available
from trl.data_utils import is_conversational, maybe_apply_chat_template, prepare_multimodal_messages
from trl.extras.profiling import profiling_context
from trl.import_utils import is_vllm_available
from trl.models import unwrap_model_for_generation
from trl.trainer.grpo_trainer import GRPOTrainer
from trl.trainer.utils import (
nanmax,
nanmin,
nanstd,
pad,
truncate_with_protected_tokens,
)
from .grpo_with_replay_buffer_config import GRPOWithReplayBufferConfig
if is_vllm_available():
from vllm import SamplingParams
from vllm.sampling_params import GuidedDecodingParams
class ReplayBuffer:
"""
A simple replay buffer to store and sample previously seen rollouts.
"""
def __init__(self, max_size: int):
self.max_size = max_size
self.heap = [] # Min-heap of (score, data) tuples
def add(self, scores: list[float], data: list[dict]):
for score, datum in zip(scores, data):
if len(self.heap) < self.max_size:
heapq.heappush(self.heap, (score, datum))
else:
# Only add if score is better than worst (minimum) item
if score > self.heap[0][0]:
heapq.heapreplace(self.heap, (score, datum))
def sample(self, num_samples: int) -> list[dict[str, torch.Tensor]]:
if not self.heap:
return None
# Sample by normalized scores
scores = torch.tensor([item[0] for item in self.heap], dtype=torch.float32)
probabilities = scores / scores.sum()
replacement = False
if num_samples > len(self.heap):
replacement = True
chosen_indices = torch.multinomial(probabilities, num_samples, replacement=replacement).tolist()
return [self.heap[i][1] for i in chosen_indices]
class GRPOWithReplayBufferTrainer(GRPOTrainer):
def __init__(self, args: Optional[GRPOWithReplayBufferConfig] = None, **kwargs):
super().__init__(args=args, **kwargs)
self.replay_buffer = ReplayBuffer(args.replay_buffer_size) if args.replay_buffer_size > 0 else None
def _generate_and_score_completions(
self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
) -> dict[str, Union[torch.Tensor, Any]]:
device = self.accelerator.device
mode = "train" if self.model.training else "eval"
prompts = [x["prompt"] for x in inputs]
# We don't yet support visual reward models/function, so we keep a copy of the original text-only prompts for
# later use in the reward computation. If images are present, we insert {"type": "image"} as required by the
# VLM chat template.
original_prompts = copy.deepcopy(prompts)
# If the prompts are conversational and the inputs contain images, we need to convert the prompts from
# [{"role": "user", "content": "What color is the sky?"}] to
# [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}]
kwargs = {}
has_images = "image" in inputs[0]
image_split_sizes = None
if has_images:
images = [example.get("image") for example in inputs]
kwargs = {"images": [[img] for img in images]}
for prompt in prompts:
if isinstance(prompt, list): # i.e., when using conversational data
prepare_multimodal_messages(prompt, num_images=1)
if hasattr(self.processing_class, "_get_num_multimodal_tokens"):
image_sizes = [(image.height, image.width) for image in images]
multimodal_extra_data = self.processing_class._get_num_multimodal_tokens(image_sizes)
image_split_sizes = multimodal_extra_data.num_image_patches
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
prompt_inputs = self.processing_class(
text=prompts_text,
return_tensors="pt",
padding=True,
padding_side="left",
add_special_tokens=False,
**kwargs,
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
if "image_grid_thw" in prompt_inputs and image_split_sizes is None:
# Fallback for VLMs that require image_grid_thw but don't provide _get_num_multimodal_tokens
image_split_sizes = prompt_inputs["image_grid_thw"].prod(dim=1).tolist()
if self.max_prompt_length is not None:
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
# Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text,
# because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation).
protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id]
protected = [token for token in protected if token is not None]
prompt_ids, prompt_mask = truncate_with_protected_tokens(
prompt_ids, prompt_mask, self.max_prompt_length, protected
)
prompts_text = self.processing_class.batch_decode(
prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)
prompts_text = [re.sub(rf"^({re.escape(self.pad_token)})+", "", text) for text in prompts_text]
# The chat template sometimes inserts a single image token into the prompt text. However, when this text is
# later tokenized, the single image token string is expanded into multiple image token IDs, depending on the
# image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We
# collapse them back into a single token string to match the original chat template in case it originally
# applies it. Otherwise, it assumes that the chat template uses only vision_start_token_id to indicate images
# (e.g. Gemma 3) and removes all image_token instances and vision_end_token_id as well, leaving only
# the vision_start_token_id (e.g. <start_of_image>).
if self.image_token is not None:
escaped_img_token = re.escape(self.image_token)
# Search for the image token in the chat template
if re.search(escaped_img_token, self.processing_class.chat_template):
prompts_text = [
re.sub(rf"({escaped_img_token})+", self.image_token, text) for text in prompts_text
]
else:
# If the chat template doesn't use the image token, we remove all instances of it + vision_end_token_id
if self.vision_end_token_id is not None:
escaped_eoi_token = re.escape(
self.processing_class.tokenizer.decode([self.vision_end_token_id])
)
prompts_text = [
re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text
]
else:
# If vision_end_token_id is None, just remove the image tokens
prompts_text = [re.sub(rf"({escaped_img_token})+", "", text) for text in prompts_text]
# Generate completions using either vLLM or regular generation
if self.use_vllm:
if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode:
# wake up colocated vLLM instances if needed
torch.cuda.empty_cache() # required to avoid OOM in some cases
self.llm.wake_up()
# First, update the vLLM weights if needed
if self.state.global_step != self._last_loaded_step:
self._move_model_to_vllm()
self._last_loaded_step = self.state.global_step
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
if self.vllm_mode == "server":
all_prompts_text = gather_object(prompts_text)
if has_images:
all_images = gather_object(images)
if self.accelerator.is_main_process:
# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
# prompt individually.
ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
if has_images:
ordered_set_of_images = all_images[:: self.num_generations]
else:
ordered_set_of_images = None
with profiling_context(self, "vLLM.generate"):
output = self.vllm_client.generate(
prompts=ordered_set_of_prompts,
images=ordered_set_of_images,
n=self.num_generations,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
top_k=-1 if self.top_k is None else self.top_k,
min_p=0.0 if self.min_p is None else self.min_p,
max_tokens=self.max_completion_length,
guided_decoding_regex=self.guided_decoding_regex,
generation_kwargs=self.args.generation_kwargs,
)
payload = (output["completion_ids"], output["logprobs"])
else:
payload = None
# Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice.
obj_list = [payload]
broadcast_object_list(obj_list, from_process=0)
completion_ids, all_logprobs = obj_list[0]
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
completion_ids = completion_ids[process_slice]
all_logprobs = all_logprobs[process_slice]
# Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts
elif self.vllm_mode == "colocate":
if self.guided_decoding_regex:
guided_decoding = GuidedDecodingParams(regex=self.guided_decoding_regex)
else:
guided_decoding = None
generation_kwargs = {
"n": 1, # vLLM on each GPU generates only 1 in colocate mode
"repetition_penalty": self.repetition_penalty,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": -1 if self.top_k is None else self.top_k,
"min_p": 0.0 if self.min_p is None else self.min_p,
"max_tokens": self.max_completion_length,
"guided_decoding": guided_decoding,
"logprobs": 0, # only return the logprob of the generated token
}
if self.args.generation_kwargs is not None:
generation_kwargs.update(self.args.generation_kwargs)
sampling_params = SamplingParams(**generation_kwargs)
if self.vllm_tensor_parallel_size > 1:
# Gather prompts from all ranks in the TP group and flatten.
# Each rank starts with its own prompts; after gathering, all ranks see the full group set.
orig_size = len(prompts_text)
gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)]
torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group)
all_prompts_text = [p for sublist in gathered_prompts for p in sublist]
if has_images:
gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)]
torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group)
all_images = [img for sublist in gathered_images for img in sublist]
else:
all_images = None
else:
all_prompts_text = prompts_text
all_images = images if has_images else None
if has_images and all_images:
vllm_inputs = []
for prompt, image in zip(all_prompts_text, all_images):
if image is not None:
vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}})
else:
vllm_inputs.append(prompt)
else:
vllm_inputs = all_prompts_text
with profiling_context(self, "vLLM.generate"):
all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False)
completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]
all_logprobs = [
[next(iter(lp.values())).logprob for lp in output.logprobs]
for outputs in all_outputs
for output in outputs.outputs
]
if self.vllm_tensor_parallel_size > 1:
# Slice completions for this rank within its TP group.
# Each rank generates all outputs — we keep only our share.
local_rank_in_group = torch.distributed.get_rank(group=self.tp_group)
tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
completion_ids = completion_ids[tp_slice]
all_logprobs = all_logprobs[tp_slice]
if self.args.vllm_enable_sleep_mode:
self.llm.sleep(level=1)
# Pad the completions, and concatenate them with the prompts
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=self.pad_token_id)
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
sampling_per_token_logps = [
torch.tensor(logprobs, device=device, dtype=torch.float32) for logprobs in all_logprobs
]
sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0)
elif self.use_transformers_paged:
# Re-process inputs for paged generation if needed
# Note: images are already validated and preprocessed above
paged_prompt_inputs = self.processing_class(text=prompts_text, **kwargs)
previous_attn = self.model_wrapped.config._attn_implementation
if is_flash_attn_2_available():
self.model_wrapped.config._attn_implementation = "paged_attention"
else:
self.model_wrapped.config._attn_implementation = "sdpa_paged"
with (
profiling_context(self, "transformers.generate_batch"),
unwrap_model_for_generation(
self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
) as unwrapped_model,
torch.no_grad(),
FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
):
# Cast to the appropriate dtype based on training configuration
if self.args.bf16:
unwrapped_model.to(torch.bfloat16)
elif self.args.fp16:
unwrapped_model.to(torch.float16)
with torch.inference_mode():
all_outputs = unwrapped_model.generate_batch(
paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False
)
completion_ids = [output.generated_tokens for output in all_outputs.values()]
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right")
prompt_ids = [torch.tensor(ids, device=device) for ids in paged_prompt_inputs.input_ids]
prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left")
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
# Restore the original attention implementation, training mode
self.model_wrapped.config._attn_implementation = previous_attn
else:
# Regular generation path
with (
profiling_context(self, "transformers.generate"),
unwrap_model_for_generation(
self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
) as unwrapped_model,
torch.no_grad(),
FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
):
prompt_inputs["input_ids"], prompt_inputs["attention_mask"] = prompt_ids, prompt_mask
prompt_completion_ids = unwrapped_model.generate(
**prompt_inputs, generation_config=self.generation_config, disable_compile=True
)
# Compute prompt length and extract completion ids
prompt_length = prompt_ids.size(1)
prompt_ids = prompt_completion_ids[:, :prompt_length]
completion_ids = prompt_completion_ids[:, prompt_length:]
# Mask everything after the first EOS token
is_eos = completion_ids == self.eos_token_id
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
# Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need
# to re-tokenize completions if the reward is computed from tokens.
completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())]
# Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging
completion_lengths = completion_mask.sum(1)
agg_completion_lengths = self.accelerator.gather(completion_lengths)
num_items_in_batch = agg_completion_lengths.sum() # this is required for the DAPO loss
# If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
if self.mask_truncated_completions:
truncated_completions = ~is_eos.any(dim=1)
completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int()
# Concatenate prompt_mask with completion_mask for logit computation
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
with torch.no_grad():
# If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
# a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the
# samples may come from an earlier version of the model. In that case, we need to track old_per_token_logps
# for importance sampling. If the steps are aligned, importance sampling isn't necessary and we set
# old_per_token_logps to None.
# When using vLLM, we always compute old_per_token_logps for importance sampling, it was shown that the
# distribution mismatch between vLLM and the training model can be large and harm the training.
generate_every = self.args.steps_per_generation * self.num_iterations # generation frequency
if self.args.gradient_accumulation_steps % generate_every != 0 or (
self.use_vllm and self.vllm_importance_sampling_correction
):
old_per_token_logps, _ = self._get_per_token_logps_and_entropies(
self.model,
prompt_completion_ids,
attention_mask,
logits_to_keep,
batch_size,
pixel_values=prompt_inputs.get("pixel_values"),
image_grid_thw=prompt_inputs.get("image_grid_thw"),
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
image_split_sizes=image_split_sizes,
)
else:
old_per_token_logps = None
# Compute the importance sampling ratio when using vLLM, to correct for potential distribution mismatch
if self.use_vllm and self.vllm_importance_sampling_correction:
importance_sampling_ratio = torch.exp(old_per_token_logps - sampling_per_token_logps)
importance_sampling_ratio = torch.clamp(
importance_sampling_ratio, max=self.vllm_importance_sampling_cap
)
# Compute the per-token log probabilities for the reference model
if self.beta != 0.0:
if self.ref_model is not None:
ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(
self.ref_model,
prompt_completion_ids,
attention_mask,
logits_to_keep,
batch_size=batch_size,
pixel_values=prompt_inputs.get("pixel_values"),
image_grid_thw=prompt_inputs.get("image_grid_thw"),
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
image_split_sizes=image_split_sizes,
)
else:
with self.accelerator.unwrap_model(self.model).disable_adapter():
ref_per_token_logps, _ = self._get_per_token_logps_and_entropies(
self.model,
prompt_completion_ids,
attention_mask,
logits_to_keep,
batch_size=batch_size,
pixel_values=prompt_inputs.get("pixel_values"),
image_grid_thw=prompt_inputs.get("image_grid_thw"),
pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"),
image_sizes=prompt_inputs.get("image_sizes"),
image_split_sizes=image_split_sizes,
)
else:
ref_per_token_logps = None
# Decode the generated completions
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
if is_conversational(inputs[0]):
completions = []
for prompt, completion in zip(prompts, completions_text):
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
completions.append([{"role": "assistant", "content": bootstrap + completion}])
else:
completions = completions_text
# Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is
# important because rewards will be normalized per group, and completions are distributed. We will later slice
# rewards_per_func to extract each process's subset.
rewards_per_func = self._calculate_rewards(inputs, original_prompts, completions, completion_ids_list)
# Apply weights to each reward function's output and sum
rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
# Compute grouped-wise rewards
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
# Normalize the rewards to compute the advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = rewards - mean_grouped_rewards
std_rewards = None
if self.scale_rewards in ["group", "none"]:
# If self.scale_rewards = "none", we'll still log group level std
std_rewards = rewards.view(-1, self.num_generations).std(dim=1)
std_rewards = std_rewards.repeat_interleave(self.num_generations, dim=0)
elif self.scale_rewards == "batch":
# Compute global std
std_rewards = rewards.std().expand_as(rewards)
else:
raise ValueError(
f"Invalid value for scale_rewards: {self.scale_rewards}. Must be one of 'batch', 'group', or 'none'."
)
is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards))
if self.scale_rewards != "none":
advantages = advantages / (std_rewards + 1e-4)
# Slice to keep only the local part of the data
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
)
all_process_advantages = advantages.clone() # keep the aggregated advantages for logging
advantages = advantages[process_slice]
if std_rewards is None:
std_rewards = rewards.view(-1, self.num_generations).std(dim=1)
std_rewards = std_rewards.repeat_interleave(self.num_generations, dim=0)
std_rewards = std_rewards[process_slice] if std_rewards is not None else None
# Log the metrics
if mode == "train":
self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item()
self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen]
# Log completion lengths, mean, min, max
self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item())
self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item())
self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item())
# Identify sequences that terminated with EOS and log their lengths
agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1))
term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos]
clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths)
self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio)
if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found
term_completion_lengths = torch.zeros(1, device=device)
self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item())
self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item())
self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item())
# Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
for i, reward_func_name in enumerate(self.reward_func_names):
mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
self._metrics[mode][f"rewards/{reward_func_name}/mean"].append(mean_rewards)
std_func_rewards = nanstd(rewards_per_func[:, i]).item()
self._metrics[mode][f"rewards/{reward_func_name}/std"].append(std_func_rewards)
self._metrics[mode]["reward"].append(mean_grouped_rewards.mean().item())
self._metrics[mode]["reward_std"].append(std_rewards.mean().item())
self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item())
# Log prompt and completion texts
self._logs["prompt"].extend(gather_object(prompts_text))
self._logs["completion"].extend(gather_object(completions_text))
for i, name in enumerate(self.reward_func_names):
self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist())
self._logs["advantages"].extend(all_process_advantages.tolist())
if has_images:
self._logs["image"].extend(gather_object(images))
if self.use_vllm and self.vllm_importance_sampling_correction:
delta = torch.abs(old_per_token_logps - sampling_per_token_logps)
delta = delta[completion_mask.bool()]
mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device)
max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device)
self._metrics[mode]["sampling/sampling_logp_difference/mean"].append(
self.accelerator.gather(mean_delta).mean().item()
)
self._metrics[mode]["sampling/sampling_logp_difference/max"].append(
self.accelerator.gather(max_delta).max().item()
)
flat_is_ratio = importance_sampling_ratio[completion_mask.bool()]
min_importance_sampling_ratio = (
torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device)
)
mean_importance_sampling_ratio = (
torch.mean(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device)
)
max_importance_sampling_ratio = (
torch.max(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device)
)
self._metrics[mode]["sampling/importance_sampling_ratio/min"].append(
nanmin(self.accelerator.gather(min_importance_sampling_ratio)).item()
)
self._metrics[mode]["sampling/importance_sampling_ratio/mean"].append(
self.accelerator.gather(mean_importance_sampling_ratio).nanmean().item()
)
self._metrics[mode]["sampling/importance_sampling_ratio/max"].append(
nanmax(self.accelerator.gather(max_importance_sampling_ratio)).item()
)
outputs_after_sampling_buffer = self.update_with_replay_buffer(
advantages,
std_rewards,
prompt_ids,
prompt_mask,
completion_ids,
completion_mask,
prompt_inputs,
num_items_in_batch,
old_per_token_logps,
ref_per_token_logps,
importance_sampling_ratio if self.use_vllm and self.vllm_importance_sampling_correction else None,
)
if outputs_after_sampling_buffer is not None:
return outputs_after_sampling_buffer
else:
output = {
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask,
"completion_ids": completion_ids,
"completion_mask": completion_mask,
"advantages": advantages,
"num_items_in_batch": num_items_in_batch,
}
if old_per_token_logps is not None:
output["old_per_token_logps"] = old_per_token_logps
if self.use_vllm and self.vllm_importance_sampling_correction:
output["importance_sampling_ratio"] = importance_sampling_ratio
if ref_per_token_logps is not None:
output["ref_per_token_logps"] = ref_per_token_logps
optional_vision_fields = [
"pixel_values",
"image_grid_thw",
"pixel_attention_mask",
"image_sizes",
]
for field in optional_vision_fields:
if field in prompt_inputs:
output[field] = prompt_inputs[field]
return output
def slice_group_data(
self, data: torch.Tensor, mask: torch.Tensor, group_idx: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Slices the input data and mask tensors for a specific group index. Also trims the sequence length to the
maximum length in the group based on the mask.
Args:
data: Tensor of shape (num_groups * num_generations, seq_length)
mask: Tensor of shape (num_groups * num_generations, seq_length)
group_idx: Index of the group to slice
Returns:
Tuple of (sliced_data, sliced_mask) for the specified group, with sequence length trimmed to the maximum
length in the group.
"""
start_idx = group_idx * self.num_generations
end_idx = (group_idx + 1) * self.num_generations
group_data = data[start_idx:end_idx]
group_mask = mask[start_idx:end_idx]
group_max_len = group_mask.sum(dim=1).max().item()
return group_data[:, :group_max_len], group_mask[:, :group_max_len]
def update_replay_buffer(
self,
groups_with_variance: torch.Tensor,
group_advantages: torch.Tensor,
group_std_rewards: torch.Tensor,
prompt_ids: torch.Tensor,
prompt_mask: torch.Tensor,
completion_ids: torch.Tensor,
completion_mask: torch.Tensor,
prompt_inputs: dict,
optional_vision_fields: list[str] = None,
old_per_token_logps: Optional[torch.Tensor] = None,
ref_per_token_logps: Optional[torch.Tensor] = None,
importance_sampling_ratio: Optional[float] = None,
) -> None:
"""
Update the replay buffer with groups that have reward variance (std > 0).
Args:
groups_with_variance: Boolean tensor indicating which groups have reward variance
group_advantages: Tensor of shape (num_groups, num_generations) containing advantage values
std_rewards: Tensor of shape (num_groups, num_generations) containing std of rewards per group
prompt_ids: Tensor containing prompt token IDs
prompt_mask: Tensor containing prompt attention masks
completion_ids: Tensor containing completion token IDs
completion_mask: Tensor containing completion attention masks
prompt_inputs: Dictionary containing additional prompt inputs (vision data, etc.)
optional_vision_fields: List of optional vision-related fields to include if present in prompt_inputs
old_per_token_logps: Optional tensor of old per-token log probabilities
ref_per_token_logps: Optional tensor of reference per-token log probabilities
importance_sampling_ratio: Optional importance sampling correction ratio
"""
# Prepare buffered outputs for groups with variance
buffered_outputs = []
for _, group_idx in enumerate(groups_with_variance.nonzero(as_tuple=True)[0].unique().tolist()):
group_prompt_ids, group_prompt_mask = self.slice_group_data(prompt_ids, prompt_mask, group_idx)
group_completion_ids, group_completion_mask = self.slice_group_data(
completion_ids, completion_mask, group_idx
)
# Store unpadded data in the buffer
buffered_output = {
"prompt_ids": group_prompt_ids,
"completion_ids": group_completion_ids,
"advantages": group_advantages[group_idx].tolist(),
"prompt_mask": group_prompt_mask,
"completion_mask": group_completion_mask,
}
# Add optional fields if they exist
optional_fields = {
"old_per_token_logps": old_per_token_logps if old_per_token_logps is not None else None,
"ref_per_token_logps": ref_per_token_logps if ref_per_token_logps is not None else None,
}
for field_name, field_data in optional_fields.items():
if field_data is not None:
buffered_output[field_name] = self.slice_group_data(field_data, completion_mask, group_idx)[0]
# Add importance sampling if needed
if self.use_vllm and self.vllm_importance_sampling_correction:
buffered_output["importance_sampling_ratio"] = importance_sampling_ratio
if optional_vision_fields:
# Add vision-related fields if they exist
for field_name in optional_vision_fields:
if field_name in prompt_inputs:
buffered_output[field_name] = self.slice_group_data(
prompt_inputs[field_name], prompt_mask, group_idx
)[0]
buffered_outputs.append(buffered_output)
if groups_with_variance.any():
# Calculate replay buffer scores for groups with variance
replay_buffer_scores = (group_advantages.abs() * group_std_rewards).sum(dim=-1)[groups_with_variance]
# Add all groups to replay buffer at once (batch operation)
self.replay_buffer.add(replay_buffer_scores.tolist(), buffered_outputs)
def sample_from_replay_buffer(
self, num_samples: int, optional_vision_fields: list[str] = None, optional_tensor_fields: list[str] = None
) -> list[dict]:
"""
Sample groups from the replay buffer.
Args:
num_samples: Number of samples to draw from the replay buffer
optional_vision_fields: List of optional vision-related fields to include if present in sampled data
optional_tensor_fields: List of optional tensor fields to include if present in sampled data
Returns:
List of sampled data dictionaries from the replay buffer
"""
sampled = self.replay_buffer.sample(num_samples=num_samples)
# Extract and concatenate sampled data
sampled_data = {
"prompt_ids": [],
"prompt_mask": [],
"completion_ids": [],
"completion_mask": [],
"advantages": [],
}
all_optional_fields = (optional_tensor_fields or []) + (optional_vision_fields or [])
# Initialize containers for optional fields if they exist in sampled data
for field in all_optional_fields:
if sampled and field in sampled[0]:
sampled_data[field] = []
# Extract data from each sampled item
for item in sampled:
# Handle core fields
for key in ["prompt_ids", "prompt_mask", "completion_ids", "completion_mask"]:
sampled_data[key].append(item[key])
# Handle advantages (list, not tensor)
sampled_data["advantages"].append(item["advantages"])
# Handle optional fields
for field in all_optional_fields:
if field in item:
sampled_data[field].append(item[field])
return sampled_data
def update_with_replay_buffer(
self,
group_advantages: torch.Tensor,
group_std_rewards: torch.Tensor,
prompt_ids: torch.Tensor,
prompt_mask: torch.Tensor,
completion_ids: torch.Tensor,
completion_mask: torch.Tensor,
prompt_inputs: dict,
num_items_in_batch: int,
old_per_token_logps: Optional[torch.Tensor] = None,
ref_per_token_logps: Optional[torch.Tensor] = None,
importance_sampling_ratio: Optional[float] = None,
) -> None:
"""
Update current batch data with samples from replay buffer.
Groups with reward variance (std > 0) are added to the replay buffer and then replaced with samples from the
buffer to improve training stability.
Args:
group_advantages: Tensor of shape (num_groups, num_generations) containing advantage values
std_rewards: Tensor of shape (num_groups, num_generations) containing std of rewards per group
prompt_ids: Tensor containing prompt token IDs
prompt_mask: Tensor containing prompt attention masks
completion_ids: Tensor containing completion token IDs
completion_mask: Tensor containing completion attention masks
prompt_inputs: Dictionary containing additional prompt inputs (vision data, etc.)
num_items_in_batch: Number of items in the current batch
old_per_token_logps: Optional tensor of old per-token log probabilities
ref_per_token_logps: Optional tensor of reference per-token log probabilities
importance_sampling_ratio: Optional importance sampling correction ratio
"""
if self.replay_buffer.max_size <= 0:
return
# Groups to consider for adding to the replay buffer
groups_with_variance = group_std_rewards.max(dim=0).values > 0
# Groups to replace from the replay buffer
groups_without_variance = ~groups_with_variance
# Track which optional fields are present in sampled data
optional_tensor_fields = ["old_per_token_logps", "ref_per_token_logps"]
vision_fields = ["pixel_values", "image_grid_thw", "pixel_attention_mask", "image_sizes"]
self.update_replay_buffer(
groups_with_variance,
group_advantages,
group_std_rewards,
prompt_ids,
prompt_mask,
completion_ids,
completion_mask,
prompt_inputs,
vision_fields,
old_per_token_logps,
ref_per_token_logps,
importance_sampling_ratio,
)
# Sample from replay buffer to replace groups with variance
num_groups_to_replace = groups_without_variance.sum().item()
if not num_groups_to_replace:
return
sampled_data = self.sample_from_replay_buffer(
num_samples=num_groups_to_replace,
optional_vision_fields=vision_fields,
optional_tensor_fields=optional_tensor_fields,
)
# Pad sampled data if they are shorter than the current batch sequences
# Or pad the current batch if sampled are longer
current_batch_prompt_seq_len = prompt_ids.size(1)
current_batch_completion_seq_len = completion_ids.size(1)
groups_to_replace_idxs = groups_with_variance.logical_not().nonzero(as_tuple=True)[0].unique().tolist()
# Determine target (max) sequence lengths once
sampled_prompt_lengths = [t.size(1) for t in sampled_data["prompt_ids"]]
sampled_completion_lengths = [t.size(1) for t in sampled_data["completion_ids"]]
target_prompt_len = max([current_batch_prompt_seq_len] + sampled_prompt_lengths)
target_completion_len = max([current_batch_completion_seq_len] + sampled_completion_lengths)
# If any sampled prompt is longer, pad the whole batch prompt tensors once (left padding)
if target_prompt_len > current_batch_prompt_seq_len:
prompt_ids = pad(
list(prompt_ids.unbind(0)),
padding_value=self.pad_token_id,
pad_to_multiple_of=target_prompt_len,
padding_side="left",
)
prompt_mask = pad(
list(prompt_mask.unbind(0)), padding_value=0, pad_to_multiple_of=target_prompt_len, padding_side="left"
)
# If any sampled completion is longer, pad the whole batch completion tensors once (right padding)
if target_completion_len > current_batch_completion_seq_len:
completion_ids = pad(
list(completion_ids.unbind(0)),
padding_value=self.pad_token_id,
pad_to_multiple_of=target_completion_len,
padding_side="right",
)
completion_mask = pad(
list(completion_mask.unbind(0)),
padding_value=0,
pad_to_multiple_of=target_completion_len,
padding_side="right",
)
if old_per_token_logps is not None:
old_per_token_logps = pad(
list(old_per_token_logps.unbind(0)),
padding_value=0.0,
pad_to_multiple_of=target_completion_len,
padding_side="right",
)
if ref_per_token_logps is not None:
ref_per_token_logps = pad(
list(ref_per_token_logps.unbind(0)),
padding_value=0.0,
pad_to_multiple_of=target_completion_len,
padding_side="right",
)
# Replace per-group data, padding only sampled groups that are shorter than the target
for i, group_idx in enumerate(groups_to_replace_idxs):
start_idx = group_idx * self.num_generations
end_idx = (group_idx + 1) * self.num_generations
idx_range = slice(start_idx, end_idx)
# Pad sampled prompt to target length if needed
if sampled_data["prompt_ids"][i].size(1) < target_prompt_len:
sampled_data["prompt_ids"][i] = pad(
sampled_data["prompt_ids"][i],
padding_value=self.pad_token_id,
pad_to_multiple_of=target_prompt_len,
padding_side="left",
)
sampled_data["prompt_mask"][i] = pad(
sampled_data["prompt_mask"][i],
padding_value=0,
pad_to_multiple_of=target_prompt_len,
padding_side="left",
)
# Pad sampled completion to target length if needed
if sampled_data["completion_ids"][i].size(1) < target_completion_len:
sampled_data["completion_ids"][i] = pad(
sampled_data["completion_ids"][i],
padding_value=self.pad_token_id,
pad_to_multiple_of=target_completion_len,
padding_side="right",
)
sampled_data["completion_mask"][i] = pad(
sampled_data["completion_mask"][i],
padding_value=0,
pad_to_multiple_of=target_completion_len,
padding_side="right",
)
if "old_per_token_logps" in sampled_data:
sampled_data["old_per_token_logps"][i] = pad(
sampled_data["old_per_token_logps"][i],
padding_value=0.0,
pad_to_multiple_of=target_completion_len,
padding_side="right",
)
if "ref_per_token_logps" in sampled_data:
sampled_data["ref_per_token_logps"][i] = pad(
sampled_data["ref_per_token_logps"][i],
padding_value=0.0,
pad_to_multiple_of=target_completion_len,
padding_side="right",
)
# Assign (replace) group slice
prompt_ids[idx_range] = sampled_data["prompt_ids"][i]
prompt_mask[idx_range] = sampled_data["prompt_mask"][i]
completion_ids[idx_range] = sampled_data["completion_ids"][i]
completion_mask[idx_range] = sampled_data["completion_mask"][i]
group_advantages[group_idx] = sampled_data["advantages"][i]
if "old_per_token_logps" in sampled_data:
old_per_token_logps[idx_range] = sampled_data["old_per_token_logps"][i]
if "ref_per_token_logps" in sampled_data:
ref_per_token_logps[idx_range] = sampled_data["ref_per_token_logps"][i]
for field in vision_fields:
if field in sampled_data and field in prompt_inputs:
prompt_inputs[field][idx_range] = sampled_data[field][i]
# Prepare final outputs after sampling and replacement
outputs_after_sampling_buffer = {
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask,
"completion_ids": completion_ids,
"completion_mask": completion_mask,
"advantages": group_advantages,
}
# Replace optional tensor fields if they exist
for field in optional_tensor_fields:
if field in sampled_data:
outputs_after_sampling_buffer[field] = (
old_per_token_logps if field == "old_per_token_logps" else ref_per_token_logps
)
# Replace vision fields if they exist
for field in vision_fields:
if field in sampled_data and field in prompt_inputs:
outputs_after_sampling_buffer[field] = prompt_inputs[field]
outputs_after_sampling_buffer["num_items_in_batch"] = num_items_in_batch
if self.use_vllm and self.vllm_importance_sampling_correction:
outputs_after_sampling_buffer["importance_sampling_ratio"] = importance_sampling_ratio
return outputs_after_sampling_buffer

View File

@ -0,0 +1,84 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Optional, Union
from transformers import Trainer, is_wandb_available
from .utils import generate_model_card, get_comet_experiment_url
if is_wandb_available():
import wandb
class BaseTrainer(Trainer):
_tag_names = []
_name = "Base"
_paper = {}
def create_model_card(
self,
model_name: Optional[str] = None,
dataset_name: Optional[str] = None,
tags: Optional[Union[str, list[str]]] = None,
):
"""
Creates a draft of a model card using the information available to the `Trainer`.
Args:
model_name (`str`, *optional*):
Name of the model.
dataset_name (`str`, *optional*):
Name of the dataset used for training.
tags (`str`, `list[str]`, *optional*):
Tags to be associated with the model card.
"""
if not self.is_world_process_zero():
return
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
base_model = self.model.config._name_or_path
else:
base_model = None
# Normalize tags
if tags is None:
tags = set()
elif isinstance(tags, str):
tags = {tags}
else:
tags = set(tags)
if hasattr(self.model.config, "unsloth_version"):
tags.add("unsloth")
if "JOB_ID" in os.environ:
tags.add("hf_jobs")
tags.update(self._tag_names)
tags = list(tags)
model_card = generate_model_card(
base_model=base_model,
model_name=model_name,
hub_model_id=self.hub_model_id,
dataset_name=dataset_name,
tags=tags,
wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name=self._name,
trainer_citation=self._paper.get("citation"),
paper_title=self._paper.get("title"),
paper_id=self._paper.get("id"),
)
model_card.save(os.path.join(self.args.output_dir, "README.md"))

View File

@ -40,7 +40,6 @@ from transformers import (
PreTrainedModel,
PreTrainedTokenizerBase,
ProcessorMixin,
Trainer,
TrainingArguments,
is_comet_available,
is_sklearn_available,
@ -53,13 +52,12 @@ from transformers.utils import is_peft_available
from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset
from ..import_utils import is_joblib_available
from ..models import create_reference_model, prepare_deepspeed
from .base_trainer import BaseTrainer
from .bco_config import BCOConfig
from .utils import (
DPODataCollatorWithPadding,
RunningMoments,
disable_dropout_in_model,
generate_model_card,
get_comet_experiment_url,
log_table_to_comet_experiment,
pad_to_length,
peft_module_casting_to_bf16,
@ -279,7 +277,7 @@ def _process_tokens(example: dict[str, Any], model: "PreTrainedModel" = None, **
return batch
class BCOTrainer(Trainer):
class BCOTrainer(BaseTrainer):
r"""
Initialize BCOTrainer from [BCO](https://huggingface.co/papers/2404.04656) paper.
@ -326,6 +324,19 @@ class BCOTrainer(Trainer):
"""
_tag_names = ["trl", "bco"]
_name = "BCO"
_paper = {
"title": "Binary Classifier Optimization for Large Language Model Alignment",
"id": "2404.04656",
# docstyle-ignore
"citation": textwrap.dedent("""\
@article{jung2024binary,
title = {{Binary Classifier Optimization for Large Language Model Alignment}},
author = {Seungjae Jung and Gunsoo Han and Daniel Wontae Nam and Kyoung{-}Woon On},
year = 2024,
eprint = {arXiv:2404.04656}
}"""),
}
def __init__(
self,
@ -1497,69 +1508,3 @@ class BCOTrainer(Trainer):
model_name = self.args.hub_model_id.split("/")[-1]
self.create_model_card(model_name=model_name)
super()._save_checkpoint(model, trial)
def create_model_card(
self,
model_name: Optional[str] = None,
dataset_name: Optional[str] = None,
tags: Union[str, list[str], None] = None,
):
"""
Creates a draft of a model card using the information available to the `Trainer`.
Args:
model_name (`str`, *optional*):
Name of the model.
dataset_name (`str`, *optional*):
Name of the dataset used for training.
tags (`str`, `list[str]`, *optional*):
Tags to be associated with the model card.
"""
if not self.is_world_process_zero():
return
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
base_model = self.model.config._name_or_path
else:
base_model = None
# normalize `tags` to a mutable set
if tags is None:
tags = set()
elif isinstance(tags, str):
tags = {tags}
else:
tags = set(tags)
if hasattr(self.model.config, "unsloth_version"):
tags.add("unsloth")
if "JOB_ID" in os.environ:
tags.add("hf_jobs")
tags.update(self._tag_names)
# docstyle-ignore
citation = textwrap.dedent("""\
@article{jung2024binary,
title = {{Binary Classifier Optimization for Large Language Model Alignment}},
author = {Seungjae Jung and Gunsoo Han and Daniel Wontae Nam and Kyoung{-}Woon On},
year = 2024,
eprint = {arXiv:2404.04656}
}""")
model_card = generate_model_card(
base_model=base_model,
model_name=model_name,
hub_model_id=self.hub_model_id,
dataset_name=dataset_name,
tags=list(tags),
wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="BCO",
trainer_citation=citation,
paper_title="Binary Classifier Optimization for Large Language Model Alignment",
paper_id="2404.04656",
)
model_card.save(os.path.join(self.args.output_dir, "README.md"))

View File

@ -13,7 +13,6 @@
# limitations under the License.
import inspect
import os
import random
import textwrap
from collections import defaultdict
@ -38,7 +37,6 @@ from transformers import (
PreTrainedModel,
PreTrainedTokenizerBase,
ProcessorMixin,
Trainer,
is_comet_available,
is_wandb_available,
)
@ -47,14 +45,13 @@ from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_peft_available, is_torch_fx_proxy
from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
from .base_trainer import BaseTrainer
from .cpo_config import CPOConfig
from .utils import (
DPODataCollatorWithPadding,
add_bos_token_if_needed,
add_eos_token_if_needed,
disable_dropout_in_model,
generate_model_card,
get_comet_experiment_url,
log_table_to_comet_experiment,
pad_to_length,
peft_module_casting_to_bf16,
@ -73,7 +70,7 @@ if is_wandb_available():
logger = logging.get_logger(__name__)
class CPOTrainer(Trainer):
class CPOTrainer(BaseTrainer):
r"""
Initialize CPOTrainer.
@ -112,6 +109,21 @@ class CPOTrainer(Trainer):
"""
_tag_names = ["trl", "cpo"]
_name = "CPO"
_paper = {
"title": "Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation",
"id": "2401.08417",
# docstyle-ignore
"citation": textwrap.dedent("""\
@inproceedings{xu2024contrastive,
title = {{Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation}},
author = {Haoran Xu and Amr Sharaf and Yunmo Chen and Weiting Tan and Lingfeng Shen and Benjamin Van Durme and Kenton Murray and Young Jin Kim},
year = 2024,
booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
publisher = {OpenReview.net},
url = {https://openreview.net/forum?id=51iwkioZpn}
}"""),
}
def __init__(
self,
@ -1069,70 +1081,3 @@ class CPOTrainer(Trainer):
model_name = self.args.hub_model_id.split("/")[-1]
self.create_model_card(model_name=model_name)
super()._save_checkpoint(model, trial)
def create_model_card(
self,
model_name: Optional[str] = None,
dataset_name: Optional[str] = None,
tags: Union[str, list[str], None] = None,
):
"""
Creates a draft of a model card using the information available to the `Trainer`.
Args:
model_name (`str`, *optional*):
Name of the model.
dataset_name (`str`, *optional*):
Name of the dataset used for training.
tags (`str`, `list[str]`, *optional*):
Tags to be associated with the model card.
"""
if not self.is_world_process_zero():
return
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
base_model = self.model.config._name_or_path
else:
base_model = None
# normalize `tags` to a mutable set
if tags is None:
tags = set()
elif isinstance(tags, str):
tags = {tags}
else:
tags = set(tags)
if hasattr(self.model.config, "unsloth_version"):
tags.add("unsloth")
if "JOB_ID" in os.environ:
tags.add("hf_jobs")
tags.update(self._tag_names)
# docstyle-ignore
citation = textwrap.dedent("""\
@inproceedings{xu2024contrastive,
title = {{Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation}},
author = {Haoran Xu and Amr Sharaf and Yunmo Chen and Weiting Tan and Lingfeng Shen and Benjamin Van Durme and Kenton Murray and Young Jin Kim},
year = 2024,
booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
publisher = {OpenReview.net},
url = {https://openreview.net/forum?id=51iwkioZpn}
}""")
model_card = generate_model_card(
base_model=base_model,
model_name=model_name,
hub_model_id=self.hub_model_id,
dataset_name=dataset_name,
tags=list(tags),
wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="CPO",
trainer_citation=citation,
paper_title="Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation",
paper_id="2401.08417",
)
model_card.save(os.path.join(self.args.output_dir, "README.md"))

View File

@ -13,7 +13,6 @@
# limitations under the License.
import inspect
import os
import random
import textwrap
import warnings
@ -40,7 +39,6 @@ from transformers import (
PreTrainedModel,
PreTrainedTokenizerBase,
ProcessorMixin,
Trainer,
)
from transformers.data.data_collator import DataCollatorMixin
from transformers.integrations import (
@ -56,6 +54,7 @@ from transformers.utils import is_liger_kernel_available, is_peft_available
from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
from ..models import create_reference_model, prepare_deepspeed
from ..models.utils import prepare_fsdp
from .base_trainer import BaseTrainer
from .callbacks import SyncRefModelCallback
from .dpo_config import DPOConfig, FDivergenceConstants, FDivergenceType
from .utils import (
@ -66,8 +65,6 @@ from .utils import (
empty_cache,
flush_left,
flush_right,
generate_model_card,
get_comet_experiment_url,
log_table_to_comet_experiment,
pad,
pad_to_length,
@ -184,7 +181,7 @@ class DataCollatorForPreference(DataCollatorMixin):
return output
class DPOTrainer(Trainer):
class DPOTrainer(BaseTrainer):
"""
Trainer for Direct Preference Optimization (DPO) method.
@ -250,6 +247,21 @@ class DPOTrainer(Trainer):
"""
_tag_names = ["trl", "dpo"]
_name = "DPO"
_paper = {
"title": "Direct Preference Optimization: Your Language Model is Secretly a Reward Model",
"id": "2305.18290",
# docstyle-ignore
"citation": textwrap.dedent("""\
@inproceedings{rafailov2023direct,
title = {{Direct Preference Optimization: Your Language Model is Secretly a Reward Model}},
author = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Christopher D. Manning and Stefano Ermon and Chelsea Finn},
year = 2023,
booktitle = {Advances in Neural Information Processing Systems 36: Annual Conference on Neural Information Processing Systems 2023, NeurIPS 2023, New Orleans, LA, USA, December 10 - 16, 2023},
url = {http://papers.nips.cc/paper_files/paper/2023/hash/a85b405ed65c6477a4fe8302b5e06ce7-Abstract-Conference.html},
editor = {Alice Oh and Tristan Naumann and Amir Globerson and Kate Saenko and Moritz Hardt and Sergey Levine},
}"""),
}
def __init__(
self,
@ -1955,73 +1967,3 @@ class DPOTrainer(Trainer):
model_name = self.args.hub_model_id.split("/")[-1]
self.create_model_card(model_name=model_name)
super()._save_checkpoint(model, trial)
def create_model_card(
self,
model_name: Optional[str] = None,
dataset_name: Optional[str] = None,
tags: Union[str, list[str], None] = None,
):
"""
Creates a draft of a model card using the information available to the `Trainer`.
Args:
model_name (`str`, *optional*):
Name of the model.
dataset_name (`str`, *optional*):
Name of the dataset used for training.
tags (`str`, `list[str]`, *optional*):
Tags to be associated with the model card.
"""
if not self.is_world_process_zero():
return
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
base_model = self.model.config._name_or_path
else:
base_model = None
# normalize `tags` to a mutable set
if tags is None:
tags = set()
elif isinstance(tags, str):
tags = {tags}
else:
tags = set(tags)
if hasattr(self.model.config, "unsloth_version"):
tags.add("unsloth")
if "JOB_ID" in os.environ:
tags.add("hf_jobs")
tags.update(self._tag_names)
# docstyle-ignore
citation = textwrap.dedent(
"""\
@inproceedings{rafailov2023direct,
title = {{Direct Preference Optimization: Your Language Model is Secretly a Reward Model}},
author = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Christopher D. Manning and Stefano Ermon and Chelsea Finn},
year = 2023,
booktitle = {Advances in Neural Information Processing Systems 36: Annual Conference on Neural Information Processing Systems 2023, NeurIPS 2023, New Orleans, LA, USA, December 10 - 16, 2023},
url = {http://papers.nips.cc/paper_files/paper/2023/hash/a85b405ed65c6477a4fe8302b5e06ce7-Abstract-Conference.html},
editor = {Alice Oh and Tristan Naumann and Amir Globerson and Kate Saenko and Moritz Hardt and Sergey Levine},
}"""
)
model_card = generate_model_card(
base_model=base_model,
model_name=model_name,
hub_model_id=self.hub_model_id,
dataset_name=dataset_name,
tags=list(tags),
wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="DPO",
trainer_citation=citation,
paper_title="Direct Preference Optimization: Your Language Model is Secretly a Reward Model",
paper_id="2305.18290",
)
model_card.save(os.path.join(self.args.output_dir, "README.md"))

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import textwrap
from typing import Any, Callable, Optional, Union
@ -30,7 +29,6 @@ from transformers import (
PreTrainedModel,
PreTrainedTokenizerBase,
ProcessorMixin,
is_wandb_available,
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction
@ -44,17 +42,12 @@ from .utils import (
DataCollatorForChatML,
disable_dropout_in_model,
empty_cache,
generate_model_card,
get_comet_experiment_url,
)
if is_peft_available():
from peft import PeftConfig
if is_wandb_available():
import wandb
if is_liger_kernel_available():
from liger_kernel.chunked_loss import LigerFusedLinearJSDLoss
@ -100,6 +93,21 @@ class GKDTrainer(SFTTrainer):
"""
_tag_names = ["trl", "gkd"]
_name = "GKD"
_paper = {
"title": "On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes",
"id": "2306.13649",
# docstyle-ignore
"citation": textwrap.dedent("""\
@inproceedings{agarwal2024on-policy,
title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}},
author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem},
year = 2024,
booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
publisher = {OpenReview.net},
url = {https://openreview.net/forum?id=3zKtaqxLhW},
}"""),
}
def __init__(
self,
@ -424,71 +432,3 @@ class GKDTrainer(SFTTrainer):
loss = super().training_step(model, inputs, num_items_in_batch)
return loss
def create_model_card(
self,
model_name: Optional[str] = None,
dataset_name: Optional[str] = None,
tags: Union[str, list[str], None] = None,
):
"""
Creates a draft of a model card using the information available to the `Trainer`.
Args:
model_name (`str`, *optional*):
Name of the model.
dataset_name (`str`, *optional*):
Name of the dataset used for training.
tags (`str`, `list[str]`, *optional*):
Tags to be associated with the model card.
"""
if not self.is_world_process_zero():
return
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
base_model = self.model.config._name_or_path
else:
base_model = None
# normalize `tags` to a mutable set
if tags is None:
tags = set()
elif isinstance(tags, str):
tags = {tags}
else:
tags = set(tags)
if hasattr(self.model.config, "unsloth_version"):
tags.add("unsloth")
if "JOB_ID" in os.environ:
tags.add("hf_jobs")
tags.update(self._tag_names)
# docstyle-ignore
citation = textwrap.dedent("""\
@inproceedings{agarwal2024on-policy,
title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}},
author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem},
year = 2024,
booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
publisher = {OpenReview.net},
url = {https://openreview.net/forum?id=3zKtaqxLhW},
}""")
model_card = generate_model_card(
base_model=base_model,
model_name=model_name,
hub_model_id=self.hub_model_id,
dataset_name=dataset_name,
tags=list(tags),
wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="GKD",
trainer_citation=citation,
paper_title="On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes",
paper_id="2306.13649",
)
model_card.save(os.path.join(self.args.output_dir, "README.md"))

View File

@ -42,7 +42,6 @@ from transformers import (
PreTrainedModel,
PreTrainedTokenizerBase,
ProcessorMixin,
Trainer,
TrainerCallback,
is_wandb_available,
)
@ -55,6 +54,7 @@ from ..extras.vllm_client import VLLMClient
from ..import_utils import is_liger_kernel_available, is_vllm_available
from ..models import prepare_deepspeed, prepare_fsdp, prepare_peft_model, unwrap_model_for_generation
from ..models.utils import _ForwardRedirection
from .base_trainer import BaseTrainer
from .callbacks import SyncRefModelCallback
from .grpo_config import GRPOConfig
from .utils import (
@ -219,6 +219,20 @@ class GRPOTrainer(Trainer):
"""
_tag_names = ["trl", "grpo"]
_name = "GRPO"
_paper = {
"title": "DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
"id": "2402.03300",
# docstyle-ignore
"citation": textwrap.dedent("""\
@article{shao2024deepseekmath,
title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
year = 2024,
eprint = {arXiv:2402.03300},
}
"""),
}
def __init__(
self,
@ -1912,72 +1926,3 @@ class GRPOTrainer(Trainer):
model_name = self.args.hub_model_id.split("/")[-1]
self.create_model_card(model_name=model_name)
super()._save_checkpoint(model, trial)
def create_model_card(
self,
model_name: Optional[str] = None,
dataset_name: Optional[str] = None,
tags: Union[str, list[str], None] = None,
):
"""
Creates a draft of a model card using the information available to the `Trainer`.
Args:
model_name (`str`, *optional*):
Name of the model.
dataset_name (`str`, *optional*):
Name of the dataset used for training.
tags (`str`, `list[str]`, *optional*):
Tags to be associated with the model card.
"""
if not self.is_world_process_zero():
return
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
base_model = self.model.config._name_or_path
else:
base_model = None
# normalize `tags` to a mutable set
if tags is None:
tags = set()
elif isinstance(tags, str):
tags = {tags}
else:
tags = set(tags)
if hasattr(self.model.config, "unsloth_version"):
tags.add("unsloth")
if "JOB_ID" in os.environ:
tags.add("hf_jobs")
tags.update(self._tag_names)
# docstyle-ignore
citation = textwrap.dedent(
"""\
@article{shao2024deepseekmath,
title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
year = 2024,
eprint = {arXiv:2402.03300},
}
"""
)
model_card = generate_model_card(
base_model=base_model,
model_name=model_name,
hub_model_id=self.hub_model_id,
dataset_name=dataset_name,
tags=list(tags),
wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="GRPO",
trainer_citation=citation,
paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
paper_id="2402.03300",
)
model_card.save(os.path.join(self.args.output_dir, "README.md"))

View File

@ -13,7 +13,6 @@
# limitations under the License.
import inspect
import os
import random
import textwrap
from collections import defaultdict
@ -40,7 +39,6 @@ from transformers import (
PreTrainedModel,
PreTrainedTokenizerBase,
ProcessorMixin,
Trainer,
TrainerCallback,
TrainingArguments,
is_comet_available,
@ -52,12 +50,11 @@ from transformers.utils import is_peft_available
from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset
from ..import_utils import is_liger_kernel_available
from ..models import create_reference_model, prepare_deepspeed
from .base_trainer import BaseTrainer
from .kto_config import KTOConfig
from .utils import (
DPODataCollatorWithPadding,
disable_dropout_in_model,
generate_model_card,
get_comet_experiment_url,
log_table_to_comet_experiment,
pad_to_length,
peft_module_casting_to_bf16,
@ -275,7 +272,7 @@ def _process_tokens(example: dict[str, Any], model: "PreTrainedModel" = None, **
return batch
class KTOTrainer(Trainer):
class KTOTrainer(BaseTrainer):
r"""
Initialize KTOTrainer.
@ -322,6 +319,19 @@ class KTOTrainer(Trainer):
"""
_tag_names = ["trl", "kto"]
_name = "KTO"
_paper = {
"title": "KTO: Model Alignment as Prospect Theoretic Optimization",
"id": "2402.01306",
# docstyle-ignore
"citation": textwrap.dedent("""\
@article{ethayarajh2024kto,
title = {{KTO: Model Alignment as Prospect Theoretic Optimization}},
author = {Kawin Ethayarajh and Winnie Xu and Niklas Muennighoff and Dan Jurafsky and Douwe Kiela},
year = 2024,
eprint = {arXiv:2402.01306},
}"""),
}
def __init__(
self,
@ -1677,69 +1687,3 @@ class KTOTrainer(Trainer):
model_name = self.args.hub_model_id.split("/")[-1]
self.create_model_card(model_name=model_name)
super()._save_checkpoint(model, trial)
def create_model_card(
self,
model_name: Optional[str] = None,
dataset_name: Optional[str] = None,
tags: Union[str, list[str], None] = None,
):
"""
Creates a draft of a model card using the information available to the `Trainer`.
Args:
model_name (`str`, *optional*):
Name of the model.
dataset_name (`str`, *optional*):
Name of the dataset used for training.
tags (`str`, `list[str]`, *optional*):
Tags to be associated with the model card.
"""
if not self.is_world_process_zero():
return
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
base_model = self.model.config._name_or_path
else:
base_model = None
# normalize `tags` to a mutable set
if tags is None:
tags = set()
elif isinstance(tags, str):
tags = {tags}
else:
tags = set(tags)
if hasattr(self.model.config, "unsloth_version"):
tags.add("unsloth")
if "JOB_ID" in os.environ:
tags.add("hf_jobs")
tags.update(self._tag_names)
# docstyle-ignore
citation = textwrap.dedent("""\
@article{ethayarajh2024kto,
title = {{KTO: Model Alignment as Prospect Theoretic Optimization}},
author = {Kawin Ethayarajh and Winnie Xu and Niklas Muennighoff and Dan Jurafsky and Douwe Kiela},
year = 2024,
eprint = {arXiv:2402.01306},
}""")
model_card = generate_model_card(
base_model=base_model,
model_name=model_name,
hub_model_id=self.hub_model_id,
dataset_name=dataset_name,
tags=list(tags),
wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="KTO",
trainer_citation=citation,
paper_title="KTO: Model Alignment as Prospect Theoretic Optimization",
paper_id="2402.01306",
)
model_card.save(os.path.join(self.args.output_dir, "README.md"))

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import textwrap
from typing import Any, Callable, Optional, Union
@ -28,7 +27,6 @@ from transformers import (
PreTrainedTokenizerBase,
ProcessorMixin,
TrainerCallback,
is_wandb_available,
)
from transformers.trainer_utils import EvalPrediction
from transformers.training_args import OptimizerNames
@ -43,8 +41,6 @@ from .online_dpo_trainer import OnlineDPOTrainer
from .utils import (
SIMPLE_CHAT_TEMPLATE,
empty_cache,
generate_model_card,
get_comet_experiment_url,
get_reward,
selective_log_softmax,
truncate_right,
@ -55,10 +51,6 @@ if is_apex_available():
from apex import amp
if is_wandb_available():
import wandb
if is_peft_available():
from peft import PeftModel
@ -111,6 +103,21 @@ class NashMDTrainer(OnlineDPOTrainer):
"""
_tag_names = ["trl", "nash-md"]
_name = "Nash-MD"
_paper = {
"title": "Nash Learning from Human Feedback",
"id": "2312.00886",
# docstyle-ignore
"citation": textwrap.dedent("""\
@inproceedings{munos2024nash,
title = {{Nash Learning from Human Feedback}},
author = {R{\'{e}}mi Munos and Michal Valko and Daniele Calandriello and Mohammad Gheshlaghi Azar and Mark Rowland and Zhaohan Daniel Guo and Yunhao Tang and Matthieu Geist and Thomas Mesnard and C{\\^{o}}me Fiegel and Andrea Michi and Marco Selvi and Sertan Girgin and Nikola Momchev and Olivier Bachem and Daniel J. Mankowitz and Doina Precup and Bilal Piot},
year = 2024,
booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
publisher = {OpenReview.net},
url = {https://openreview.net/forum?id=Y5AmNYiyCQ}
}"""),
}
def __init__(
self,
@ -496,71 +503,3 @@ class NashMDTrainer(OnlineDPOTrainer):
self.accelerator.backward(loss, **kwargs)
return loss.detach() / self.args.gradient_accumulation_steps
def create_model_card(
self,
model_name: Optional[str] = None,
dataset_name: Optional[str] = None,
tags: Union[str, list[str], None] = None,
):
"""
Creates a draft of a model card using the information available to the `Trainer`.
Args:
model_name (`str`, *optional*):
Name of the model.
dataset_name (`str`, *optional*):
Name of the dataset used for training.
tags (`str`, `list[str]`, *optional*):
Tags to be associated with the model card.
"""
if not self.is_world_process_zero():
return
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
base_model = self.model.config._name_or_path
else:
base_model = None
# normalize `tags` to a mutable set
if tags is None:
tags = set()
elif isinstance(tags, str):
tags = {tags}
else:
tags = set(tags)
if hasattr(self.model.config, "unsloth_version"):
tags.add("unsloth")
if "JOB_ID" in os.environ:
tags.add("hf_jobs")
tags.update(self._tag_names)
# docstyle-ignore
citation = textwrap.dedent("""\
@inproceedings{munos2024nash,
title = {{Nash Learning from Human Feedback}},
author = {R{\'{e}}mi Munos and Michal Valko and Daniele Calandriello and Mohammad Gheshlaghi Azar and Mark Rowland and Zhaohan Daniel Guo and Yunhao Tang and Matthieu Geist and Thomas Mesnard and C{\\^{o}}me Fiegel and Andrea Michi and Marco Selvi and Sertan Girgin and Nikola Momchev and Olivier Bachem and Daniel J. Mankowitz and Doina Precup and Bilal Piot},
year = 2024,
booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
publisher = {OpenReview.net},
url = {https://openreview.net/forum?id=Y5AmNYiyCQ}
}""")
model_card = generate_model_card(
base_model=base_model,
model_name=model_name,
hub_model_id=self.hub_model_id,
dataset_name=dataset_name,
tags=list(tags),
wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="Nash-MD",
trainer_citation=citation,
paper_title="Nash Learning from Human Feedback",
paper_id="2312.00886",
)
model_card.save(os.path.join(self.args.output_dir, "README.md"))

View File

@ -44,7 +44,6 @@ from transformers import (
Trainer,
TrainerCallback,
is_apex_available,
is_wandb_available,
)
from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
from transformers.trainer_utils import EvalPrediction, seed_worker
@ -61,6 +60,7 @@ from ..extras.vllm_client import VLLMClient
from ..import_utils import is_vllm_available
from ..models import create_reference_model, prepare_peft_model
from ..models.utils import unwrap_model_for_generation
from .base_trainer import BaseTrainer
from .judges import BasePairwiseJudge
from .online_dpo_config import OnlineDPOConfig
from .utils import (
@ -69,8 +69,6 @@ from .utils import (
disable_dropout_in_model,
empty_cache,
ensure_master_addr_port,
generate_model_card,
get_comet_experiment_url,
pad,
prepare_deepspeed,
truncate_right,
@ -97,8 +95,6 @@ if is_vllm_available():
from vllm import LLM, SamplingParams
from vllm.sampling_params import GuidedDecodingParams
if is_wandb_available():
import wandb
logger = logging.get_logger(__name__)
@ -107,7 +103,7 @@ logger = logging.get_logger(__name__)
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
class OnlineDPOTrainer(Trainer):
class OnlineDPOTrainer(BaseTrainer):
r"""
Initialize OnlineDPOTrainer.
@ -179,6 +175,19 @@ class OnlineDPOTrainer(Trainer):
"""
_tag_names = ["trl", "online-dpo"]
_name = "Online DPO"
_paper = {
"title": "Direct Language Model Alignment from Online AI Feedback",
"id": "2402.04792",
# docstyle-ignore
"citation": textwrap.dedent("""\
@article{guo2024direct,
title = {{Direct Language Model Alignment from Online AI Feedback}},
author = {Shangmin Guo and Biao Zhang and Tianlin Liu and Tianqi Liu and Misha Khalman and Felipe Llinares and Alexandre Ram{\'{e}} and Thomas Mesnard and Yao Zhao and Bilal Piot and Johan Ferret and Mathieu Blondel},
year = 2024,
eprint = {arXiv:2402.04792}
}"""),
}
def __init__(
self,
@ -1507,68 +1516,3 @@ class OnlineDPOTrainer(Trainer):
model_name = self.args.hub_model_id.split("/")[-1]
self.create_model_card(model_name=model_name)
super()._save_checkpoint(model, trial)
def create_model_card(
self,
model_name: Optional[str] = None,
dataset_name: Optional[str] = None,
tags: Union[str, list[str], None] = None,
):
"""
Creates a draft of a model card using the information available to the `Trainer`.
Args:
model_name (`str`, *optional*):
Name of the model.
dataset_name (`str`, *optional*):
Name of the dataset used for training.
tags (`str`, `list[str]`, *optional*):
Tags to be associated with the model card.
"""
if not self.is_world_process_zero():
return
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
base_model = self.model.config._name_or_path
else:
base_model = None
# normalize `tags` to a mutable set
if tags is None:
tags = set()
elif isinstance(tags, str):
tags = {tags}
else:
tags = set(tags)
if hasattr(self.model.config, "unsloth_version"):
tags.add("unsloth")
if "JOB_ID" in os.environ:
tags.add("hf_jobs")
tags.update(self._tag_names)
# docstyle-ignore
citation = textwrap.dedent("""\
@article{guo2024direct,
title = {{Direct Language Model Alignment from Online AI Feedback}},
author = {Shangmin Guo and Biao Zhang and Tianlin Liu and Tianqi Liu and Misha Khalman and Felipe Llinares and Alexandre Ram{\'{e}} and Thomas Mesnard and Yao Zhao and Bilal Piot and Johan Ferret and Mathieu Blondel},
year = 2024,
eprint = {arXiv:2402.04792}
}""")
model_card = generate_model_card(
base_model=base_model,
model_name=model_name,
hub_model_id=self.hub_model_id,
dataset_name=dataset_name,
tags=list(tags),
wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="Online DPO",
trainer_citation=citation,
paper_title="Direct Language Model Alignment from Online AI Feedback",
paper_id="2402.04792",
)
model_card.save(os.path.join(self.args.output_dir, "README.md"))

View File

@ -13,7 +13,6 @@
# limitations under the License.
import inspect
import os
import random
import textwrap
from collections import defaultdict
@ -38,7 +37,6 @@ from transformers import (
PreTrainedModel,
PreTrainedTokenizerBase,
ProcessorMixin,
Trainer,
is_comet_available,
is_torch_xla_available,
is_wandb_available,
@ -48,14 +46,13 @@ from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_peft_available, is_torch_fx_proxy
from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
from .base_trainer import BaseTrainer
from .orpo_config import ORPOConfig
from .utils import (
DPODataCollatorWithPadding,
add_bos_token_if_needed,
add_eos_token_if_needed,
disable_dropout_in_model,
generate_model_card,
get_comet_experiment_url,
log_table_to_comet_experiment,
pad_to_length,
peft_module_casting_to_bf16,
@ -77,7 +74,7 @@ if is_torch_xla_available():
logger = logging.get_logger(__name__)
class ORPOTrainer(Trainer):
class ORPOTrainer(BaseTrainer):
r"""
Initialize ORPOTrainer.
@ -116,6 +113,19 @@ class ORPOTrainer(Trainer):
"""
_tag_names = ["trl", "orpo"]
_name = "ORPO"
_paper = {
"title": "ORPO: Monolithic Preference Optimization without Reference Model",
"id": "2403.07691",
# docstyle-ignore
"citation": textwrap.dedent("""\
@article{hong2024orpo,
title = {{ORPO: Monolithic Preference Optimization without Reference Model}},
author = {Jiwoo Hong and Noah Lee and James Thorne},
year = 2024,
eprint = {arXiv:2403.07691}
}"""),
}
def __init__(
self,
@ -1031,69 +1041,3 @@ class ORPOTrainer(Trainer):
model_name = self.args.hub_model_id.split("/")[-1]
self.create_model_card(model_name=model_name)
super()._save_checkpoint(model, trial)
def create_model_card(
self,
model_name: Optional[str] = None,
dataset_name: Optional[str] = None,
tags: Union[str, list[str], None] = None,
):
"""
Creates a draft of a model card using the information available to the `Trainer`.
Args:
model_name (`str`, *optional*):
Name of the model.
dataset_name (`str`, *optional*):
Name of the dataset used for training.
tags (`str`, `list[str]`, *optional*):
Tags to be associated with the model card.
"""
if not self.is_world_process_zero():
return
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
base_model = self.model.config._name_or_path
else:
base_model = None
# normalize `tags` to a mutable set
if tags is None:
tags = set()
elif isinstance(tags, str):
tags = {tags}
else:
tags = set(tags)
if hasattr(self.model.config, "unsloth_version"):
tags.add("unsloth")
if "JOB_ID" in os.environ:
tags.add("hf_jobs")
tags.update(self._tag_names)
# docstyle-ignore
citation = textwrap.dedent("""\
@article{hong2024orpo,
title = {{ORPO: Monolithic Preference Optimization without Reference Model}},
author = {Jiwoo Hong and Noah Lee and James Thorne},
year = 2024,
eprint = {arXiv:2403.07691}
}""")
model_card = generate_model_card(
base_model=base_model,
model_name=model_name,
hub_model_id=self.hub_model_id,
dataset_name=dataset_name,
tags=list(tags),
wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="ORPO",
trainer_citation=citation,
paper_title="ORPO: Monolithic Preference Optimization without Reference Model",
paper_id="2403.07691",
)
model_card.save(os.path.join(self.args.output_dir, "README.md"))

View File

@ -37,10 +37,8 @@ from transformers import (
GenerationConfig,
PreTrainedTokenizerBase,
ProcessorMixin,
Trainer,
TrainerCallback,
TrainerControl,
is_wandb_available,
)
from transformers.integrations import get_reporting_integration_callbacks
from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK
@ -50,6 +48,7 @@ from transformers.utils import is_peft_available, is_rich_available
from ..core import masked_mean, masked_whiten
from ..models import create_reference_model
from ..models.utils import unwrap_model_for_generation
from .base_trainer import BaseTrainer
from .ppo_config import PPOConfig
from .utils import (
OnlineTrainerState,
@ -59,8 +58,6 @@ from .utils import (
exact_div,
first_true_indices,
forward,
generate_model_card,
get_comet_experiment_url,
get_reward,
log_table_to_comet_experiment,
peft_module_casting_to_bf16,
@ -74,9 +71,6 @@ from .utils import (
if is_peft_available():
from peft import PeftConfig, PeftModel, get_peft_model
if is_wandb_available():
import wandb
INVALID_LOGPROB = 1.0
@ -97,7 +91,7 @@ class PolicyAndValueWrapper(nn.Module):
return self.policy(**kwargs), logits
class PPOTrainer(Trainer):
class PPOTrainer(BaseTrainer):
"""Trainer for Proximal Policy Optimization (PPO).
For details on PPO, see the paper: [Proximal Policy Optimization
@ -135,6 +129,19 @@ class PPOTrainer(Trainer):
"""
_tag_names = ["trl", "ppo"]
_name = "PPO"
_paper = {
"title": "Fine-Tuning Language Models from Human Preferences",
"id": "1909.08593",
# docstyle-ignore
"citation": textwrap.dedent("""\
@article{mziegler2019fine-tuning,
title = {{Fine-Tuning Language Models from Human Preferences}},
author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving},
year = 2019,
eprint = {arXiv:1909.08593}
}"""),
}
def __init__(
self,
@ -793,69 +800,3 @@ class PPOTrainer(Trainer):
model_name = self.args.hub_model_id.split("/")[-1]
self.create_model_card(model_name=model_name)
super()._save_checkpoint(model, trial)
def create_model_card(
self,
model_name: Optional[str] = None,
dataset_name: Optional[str] = None,
tags: Union[str, list[str], None] = None,
):
"""
Creates a draft of a model card using the information available to the `Trainer`.
Args:
model_name (`str`, *optional*):
Name of the model.
dataset_name (`str`, *optional*):
Name of the dataset used for training.
tags (`str`, `list[str]`, *optional*):
Tags to be associated with the model card.
"""
if not self.is_world_process_zero():
return
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
base_model = self.model.config._name_or_path
else:
base_model = None
# normalize `tags` to a mutable set
if tags is None:
tags = set()
elif isinstance(tags, str):
tags = {tags}
else:
tags = set(tags)
if hasattr(self.model.config, "unsloth_version"):
tags.add("unsloth")
if "JOB_ID" in os.environ:
tags.add("hf_jobs")
tags.update(self._tag_names)
# docstyle-ignore
citation = textwrap.dedent("""\
@article{mziegler2019fine-tuning,
title = {{Fine-Tuning Language Models from Human Preferences}},
author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving},
year = 2019,
eprint = {arXiv:1909.08593}
}""")
model_card = generate_model_card(
base_model=base_model,
model_name=model_name,
hub_model_id=self.hub_model_id,
dataset_name=dataset_name,
tags=list(tags),
wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="PPO",
trainer_citation=citation,
paper_title="Fine-Tuning Language Models from Human Preferences",
paper_id="1909.08593",
)
model_card.save(os.path.join(self.args.output_dir, "README.md"))

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import textwrap
from itertools import chain
from pathlib import Path
@ -30,26 +29,22 @@ from transformers import (
PreTrainedModel,
PreTrainedTokenizerBase,
ProcessorMixin,
Trainer,
is_wandb_available,
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from transformers.utils import is_peft_available
from ..models import prepare_peft_model
from .base_trainer import BaseTrainer
from .prm_config import PRMConfig
from .utils import compute_accuracy, disable_dropout_in_model, generate_model_card
from .utils import compute_accuracy, disable_dropout_in_model
if is_peft_available():
from peft import PeftModel
if is_wandb_available():
import wandb
class PRMTrainer(Trainer):
class PRMTrainer(BaseTrainer):
"""
Initialize PRMTrainer.
@ -88,6 +83,19 @@ class PRMTrainer(Trainer):
"""
_tag_names = ["trl", "prm"]
_name = "PRM"
_paper = {
"title": "Solving math word problems with process-and outcome-based feedback",
"id": "2211.14275",
# docstyle-ignore
"citation": textwrap.dedent("""\
@article{uesato2022solving,
title = {{Solving Math Word Problems With Process- and Outcome-Based Feedback}},
author = {Uesato, Jonathan and Kushman, Nate and Kumar, Ramana and Song, Francis and Siegel, Noah and Wang, Lisa and Creswell, Antonia and Irving, Geoffrey and Higgins, Irina},
year = 2022,
journal = {arXiv preprint arXiv:2211.14275}
}"""),
}
def __init__(
self,
@ -288,67 +296,3 @@ class PRMTrainer(Trainer):
model_name = self.args.hub_model_id.split("/")[-1]
self.create_model_card(model_name=model_name)
super()._save_checkpoint(model, trial)
def create_model_card(
self,
model_name: Optional[str] = None,
dataset_name: Optional[str] = None,
tags: Union[str, list[str], None] = None,
):
"""
Creates a draft of a model card using the information available to the `Trainer`.
Args:
model_name (`str`, *optional*):
Name of the model.
dataset_name (`str`, *optional*):
Name of the dataset used for training.
tags (`str`, `list[str]`, *optional*):
Tags to be associated with the model card.
"""
if not self.is_world_process_zero():
return
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
base_model = self.model.config._name_or_path
else:
base_model = None
# normalize `tags` to a mutable set
if tags is None:
tags = set()
elif isinstance(tags, str):
tags = {tags}
else:
tags = set(tags)
if hasattr(self.model.config, "unsloth_version"):
tags.add("unsloth")
if "JOB_ID" in os.environ:
tags.add("hf_jobs")
tags.update(self._tag_names)
# docstyle-ignore
citation = textwrap.dedent("""\
@article{uesato2022solving,
title = {{Solving Math Word Problems With Process- and Outcome-Based Feedback}},
author = {Uesato, Jonathan and Kushman, Nate and Kumar, Ramana and Song, Francis and Siegel, Noah and Wang, Lisa and Creswell, Antonia and Irving, Geoffrey and Higgins, Irina},
year = 2022,
journal = {arXiv preprint arXiv:2211.14275}
}""")
model_card = generate_model_card(
base_model=base_model,
model_name=model_name,
hub_model_id=self.hub_model_id,
dataset_name=dataset_name,
tags=list(tags),
wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
trainer_name="PRM",
trainer_citation=citation,
paper_title="Solving math word problems with process-and outcome-based feedback",
)
model_card.save(os.path.join(self.args.output_dir, "README.md"))

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import defaultdict
from dataclasses import FrozenInstanceError, replace
from pathlib import Path
@ -31,8 +30,6 @@ from transformers import (
PreTrainedModel,
PreTrainedTokenizerBase,
ProcessorMixin,
Trainer,
is_wandb_available,
)
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_pt_utils import nested_detach
@ -41,14 +38,13 @@ from transformers.utils import is_peft_available, is_rich_available
from ..data_utils import maybe_apply_chat_template
from ..models import prepare_peft_model
from .base_trainer import BaseTrainer
from .reward_config import RewardConfig
from .utils import (
RewardDataCollatorWithPadding,
compute_accuracy,
decode_and_strip_padding,
disable_dropout_in_model,
generate_model_card,
get_comet_experiment_url,
log_table_to_comet_experiment,
print_rich_table,
)
@ -57,9 +53,6 @@ from .utils import (
if is_peft_available():
from peft import PeftModel
if is_wandb_available():
import wandb
logger = logging.get_logger(__name__)
@ -83,7 +76,7 @@ def _tokenize(batch: dict[str, list[Any]], tokenizer: "PreTrainedTokenizerBase")
return new_examples
class RewardTrainer(Trainer):
class RewardTrainer(BaseTrainer):
"""
Trainer for custom reward.
@ -123,6 +116,7 @@ class RewardTrainer(Trainer):
"""
_tag_names = ["trl", "reward-trainer"]
_name = "Reward"
def __init__(
self,
@ -357,57 +351,3 @@ class RewardTrainer(Trainer):
model_name = self.args.hub_model_id.split("/")[-1]
self.create_model_card(model_name=model_name)
super()._save_checkpoint(model, trial)
def create_model_card(
self,
model_name: Optional[str] = None,
dataset_name: Optional[str] = None,
tags: Union[str, list[str], None] = None,
):
"""
Creates a draft of a model card using the information available to the `Trainer`.
Args:
model_name (`str`, *optional*):
Name of the model.
dataset_name (`str`, *optional*):
Name of the dataset used for training.
tags (`str`, `list[str]`, *optional*):
Tags to be associated with the model card.
"""
if not self.is_world_process_zero():
return
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
base_model = self.model.config._name_or_path
else:
base_model = None
# normalize `tags` to a mutable set
if tags is None:
tags = set()
elif isinstance(tags, str):
tags = {tags}
else:
tags = set(tags)
if hasattr(self.model.config, "unsloth_version"):
tags.add("unsloth")
if "JOB_ID" in os.environ:
tags.add("hf_jobs")
tags.update(self._tag_names)
model_card = generate_model_card(
base_model=base_model,
model_name=model_name,
hub_model_id=self.hub_model_id,
dataset_name=dataset_name,
tags=list(tags),
wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="Reward",
)
model_card.save(os.path.join(self.args.output_dir, "README.md"))

View File

@ -42,7 +42,6 @@ from transformers import (
PreTrainedModel,
PreTrainedTokenizerBase,
ProcessorMixin,
Trainer,
TrainerCallback,
is_wandb_available,
)
@ -54,6 +53,7 @@ from ..extras.profiling import profiling_context, profiling_decorator
from ..extras.vllm_client import VLLMClient
from ..import_utils import is_vllm_available
from ..models import prepare_deepspeed, prepare_fsdp, prepare_peft_model, unwrap_model_for_generation
from .base_trainer import BaseTrainer
from .callbacks import SyncRefModelCallback
from .rloo_config import RLOOConfig
from .utils import (
@ -61,8 +61,6 @@ from .utils import (
disable_dropout_in_model,
ensure_master_addr_port,
entropy_from_logits,
generate_model_card,
get_comet_experiment_url,
identity,
nanmax,
nanmin,
@ -96,7 +94,7 @@ logger = logging.get_logger(__name__)
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
class RLOOTrainer(Trainer):
class RLOOTrainer(BaseTrainer):
"""
Trainer for the Reinforce Leave One Out (RLOO) method. This algorithm was initially proposed in the paper [Back to
Basics: Revisiting REINFORCE Style Optimization for Learning from Human Feedback in LLMs]
@ -240,6 +238,22 @@ class RLOOTrainer(Trainer):
"""
_tag_names = ["trl", "rloo"]
_name = "RLOO"
_paper = {
"title": "Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs",
"id": "2402.14740",
# docstyle-ignore
"citation": textwrap.dedent("""\
@inproceedings{ahmadian2024back,
title = {{Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs}},
author = {Arash Ahmadian and Chris Cremer and Matthias Gall{\'{e}} and Marzieh Fadaee and Julia Kreutzer and Olivier Pietquin and Ahmet {\"{U}}st{\"{u}}n and Sara Hooker},
year = 2024,
booktitle = {Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), {ACL} 2024, Bangkok, Thailand, August 11-16, 2024},
pages = {12248--12267},
publisher = {Association for Computational Linguistics},
editor = {Lun{-}Wei Ku and Andre Martins and Vivek Srikumar},
}"""),
}
def __init__(
self,
@ -1635,75 +1649,3 @@ class RLOOTrainer(Trainer):
model_name = self.args.hub_model_id.split("/")[-1]
self.create_model_card(model_name=model_name)
super()._save_checkpoint(model, trial)
def create_model_card(
self,
model_name: Optional[str] = None,
dataset_name: Optional[str] = None,
tags: Union[str, list[str], None] = None,
):
"""
Creates a draft of a model card using the information available to the `Trainer`.
Args:
model_name (`str`, *optional*):
Name of the model.
dataset_name (`str`, *optional*):
Name of the dataset used for training.
tags (`str`, `list[str]`, *optional*):
Tags to be associated with the model card.
"""
if not self.is_world_process_zero():
return
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
base_model = self.model.config._name_or_path
else:
base_model = None
# normalize `tags` to a mutable set
if tags is None:
tags = set()
elif isinstance(tags, str):
tags = {tags}
else:
tags = set(tags)
if hasattr(self.model.config, "unsloth_version"):
tags.add("unsloth")
if "JOB_ID" in os.environ:
tags.add("hf_jobs")
tags.update(self._tag_names)
# docstyle-ignore
citation = textwrap.dedent(
"""\
@inproceedings{ahmadian2024back,
title = {{Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs}},
author = {Arash Ahmadian and Chris Cremer and Matthias Gall{\'{e}} and Marzieh Fadaee and Julia Kreutzer and Olivier Pietquin and Ahmet {\"{U}}st{\"{u}}n and Sara Hooker},
year = 2024,
booktitle = {Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), {ACL} 2024, Bangkok, Thailand, August 11-16, 2024},
pages = {12248--12267},
publisher = {Association for Computational Linguistics},
editor = {Lun{-}Wei Ku and Andre Martins and Vivek Srikumar},
}
"""
)
model_card = generate_model_card(
base_model=base_model,
model_name=model_name,
hub_model_id=self.hub_model_id,
dataset_name=dataset_name,
tags=list(tags),
wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="RLOO",
trainer_citation=citation,
paper_title="Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs",
paper_id="2402.14740",
)
model_card.save(os.path.join(self.args.output_dir, "README.md"))

View File

@ -32,9 +32,7 @@ from transformers import (
PreTrainedModel,
PreTrainedTokenizerBase,
ProcessorMixin,
Trainer,
TrainingArguments,
is_wandb_available,
)
from transformers.data.data_collator import DataCollatorMixin
from transformers.trainer_callback import TrainerCallback
@ -51,13 +49,12 @@ from ..data_utils import (
truncate_dataset,
)
from ..models import clone_chat_template, get_act_offloading_ctx_manager, prepare_peft_model
from .base_trainer import BaseTrainer
from .sft_config import SFTConfig
from .utils import (
create_model_from_path,
entropy_from_logits,
flush_left,
generate_model_card,
get_comet_experiment_url,
pad,
selective_log_softmax,
)
@ -66,8 +63,6 @@ from .utils import (
if is_peft_available():
from peft import PeftConfig, PeftModel
if is_wandb_available():
import wandb
logger = logging.get_logger(__name__)
@ -498,7 +493,7 @@ def dft_loss(outputs, labels, num_items_in_batch=None):
return loss
class SFTTrainer(Trainer):
class SFTTrainer(BaseTrainer):
"""
Trainer for Supervised Fine-Tuning (SFT) method.
@ -590,6 +585,7 @@ class SFTTrainer(Trainer):
"""
_tag_names = ["trl", "sft"]
_name = "SFT"
def __init__(
self,
@ -1222,57 +1218,3 @@ class SFTTrainer(Trainer):
model_name = self.args.hub_model_id.split("/")[-1]
self.create_model_card(model_name=model_name)
super()._save_checkpoint(model, trial)
def create_model_card(
self,
model_name: Optional[str] = None,
dataset_name: Optional[str] = None,
tags: Union[str, list[str], None] = None,
):
"""
Creates a draft of a model card using the information available to the `Trainer`.
Args:
model_name (`str`, *optional*):
Name of the model.
dataset_name (`str`, *optional*):
Name of the dataset used for training.
tags (`str`, `list[str]`, *optional*):
Tags to be associated with the model card.
"""
if not self.is_world_process_zero():
return
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
base_model = self.model.config._name_or_path
else:
base_model = None
# normalize `tags` to a mutable set
if tags is None:
tags = set()
elif isinstance(tags, str):
tags = {tags}
else:
tags = set(tags)
if hasattr(self.model.config, "unsloth_version"):
tags.add("unsloth")
if "JOB_ID" in os.environ:
tags.add("hf_jobs")
tags.update(self._tag_names)
model_card = generate_model_card(
base_model=base_model,
model_name=model_name,
hub_model_id=self.hub_model_id,
dataset_name=dataset_name,
tags=list(tags),
wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="SFT",
)
model_card.save(os.path.join(self.args.output_dir, "README.md"))

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import textwrap
from typing import Any, Callable, Optional, Union
@ -29,7 +28,6 @@ from transformers import (
ProcessorMixin,
TrainerCallback,
is_apex_available,
is_wandb_available,
)
from transformers.trainer_utils import EvalPrediction
from transformers.training_args import OptimizerNames
@ -42,8 +40,6 @@ from .online_dpo_trainer import OnlineDPOTrainer
from .utils import (
SIMPLE_CHAT_TEMPLATE,
empty_cache,
generate_model_card,
get_comet_experiment_url,
get_reward,
selective_log_softmax,
truncate_right,
@ -55,10 +51,6 @@ if is_apex_available():
from apex import amp
if is_wandb_available():
import wandb
if is_peft_available():
from peft import PeftModel
@ -113,6 +105,19 @@ class XPOTrainer(OnlineDPOTrainer):
"""
_tag_names = ["trl", "xpo"]
_name = "XPO"
_paper = {
"title": "Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF",
"id": "2405.21046",
# docstyle-ignore
"citation": textwrap.dedent("""\
@article{jung2024binary,
title = {{Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF}},
author = {Tengyang Xie and Dylan J. Foster and Akshay Krishnamurthy and Corby Rosset and Ahmed Awadallah and Alexander Rakhlin},
year = 2024,
eprint = {arXiv:2405.21046}
}"""),
}
def __init__(
self,
@ -544,69 +549,3 @@ class XPOTrainer(OnlineDPOTrainer):
self.accelerator.backward(loss, **kwargs)
return loss.detach() / self.args.gradient_accumulation_steps
def create_model_card(
self,
model_name: Optional[str] = None,
dataset_name: Optional[str] = None,
tags: Union[str, list[str], None] = None,
):
"""
Creates a draft of a model card using the information available to the `Trainer`.
Args:
model_name (`str`, *optional*):
Name of the model.
dataset_name (`str`, *optional*):
Name of the dataset used for training.
tags (`str`, `list[str]`, *optional*):
Tags to be associated with the model card.
"""
if not self.is_world_process_zero():
return
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
base_model = self.model.config._name_or_path
else:
base_model = None
# normalize `tags` to a mutable set
if tags is None:
tags = set()
elif isinstance(tags, str):
tags = {tags}
else:
tags = set(tags)
if hasattr(self.model.config, "unsloth_version"):
tags.add("unsloth")
if "JOB_ID" in os.environ:
tags.add("hf_jobs")
tags.update(self._tag_names)
# docstyle-ignore
citation = textwrap.dedent("""\
@article{jung2024binary,
title = {{Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF}},
author = {Tengyang Xie and Dylan J. Foster and Akshay Krishnamurthy and Corby Rosset and Ahmed Awadallah and Alexander Rakhlin},
year = 2024,
eprint = {arXiv:2405.21046}
}""")
model_card = generate_model_card(
base_model=base_model,
model_name=model_name,
hub_model_id=self.hub_model_id,
dataset_name=dataset_name,
tags=list(tags),
wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
comet_url=get_comet_experiment_url(),
trainer_name="XPO",
trainer_citation=citation,
paper_title="Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF",
paper_id="2405.21046",
)
model_card.save(os.path.join(self.args.output_dir, "README.md"))