Compare commits

...

21 Commits

Author SHA1 Message Date
ea664ba0df fix hash, disable dropout and other changes 2025-10-15 15:30:44 +00:00
4b92f6473d progress 2025-10-11 22:43:56 +00:00
d2aa8f3019 move to exp 2025-10-11 03:00:41 +00:00
aefc01bb95 disable dropout 2025-10-11 02:22:55 +00:00
f023e65151 fix precompute 2025-10-11 01:20:51 +00:00
36513d0a70 fix default 2025-10-11 00:03:38 +00:00
fcf62d1640 hinge loss 2025-10-10 19:13:06 +00:00
a68ea0f668 Merge branch 'main' into refactor-dpo 2025-10-10 11:16:42 -05:00
d2f5227a16 precompute 2025-10-09 15:43:56 +00:00
6da159efac ref log p 2025-10-08 03:42:05 +00:00
a7aab5aa73 progress 2025-10-08 02:28:38 +00:00
a6941ea06a Merge branch 'main' into refactor-dpo 2025-10-07 16:27:10 -06:00
ea21e98b03 wip 2025-08-17 03:57:11 +00:00
ddd8022b91 drop is_processed 2025-08-17 00:57:53 +00:00
48b242f173 remove formatting_func 2025-08-17 00:56:02 +00:00
4d37a04108 fix 2025-08-15 22:24:09 +00:00
2b911b6f02 remove functionalities from sft 2025-08-15 22:21:31 +00:00
c589fdaf8b Revert "remove functionalities from sft"
This reverts commit 0e3c3e456bafcec3855026df62e26a42990a2545.
2025-08-15 22:12:48 +00:00
0e3c3e456b remove functionalities from sft 2025-08-15 22:02:42 +00:00
a528fc17fb update DPO config 2025-08-15 21:48:01 +00:00
e1c9477659 copy sft 2025-08-15 21:40:14 +00:00
8 changed files with 1150 additions and 18 deletions

View File

@ -955,6 +955,20 @@ class TestTruncateExamples(TrlTestCase):
dataset = truncate_dataset(dataset, max_length)
assert dataset.to_dict() == expected_output
def test_with_specified_columns(self):
examples = {
"prompt_ids": [[1, 2, 3], [6, 7], [12]],
"completion_ids": [[4, 5], [8, 9, 10, 11], [13, 14]],
}
dataset = Dataset.from_dict(examples)
max_length = 2
expected_output = {
"prompt_ids": [[1, 2], [6, 7], [12]],
"completion_ids": [[4, 5], [8, 9, 10, 11], [13, 14]],
}
dataset = truncate_dataset(dataset, max_length, columns=["prompt_ids"])
assert dataset.to_dict() == expected_output
class TestMaybeConvertToChatML(TrlTestCase):
def test_with_conversations_key(self):

View File

@ -714,7 +714,10 @@ def pack_dataset(
def truncate_dataset(
dataset: DatasetType, max_length: int, map_kwargs: Optional[dict[str, Any]] = None
dataset: DatasetType,
max_length: int,
columns: Union[str, list[str]] = "all",
map_kwargs: Optional[dict[str, Any]] = None,
) -> DatasetType:
r"""
Truncate sequences in a dataset to a specified `max_length`.
@ -724,6 +727,8 @@ def truncate_dataset(
Dataset to truncate.
max_length (`int`):
Maximum sequence length to truncate to.
columns (`str` or `list[str]`, *optional*, defaults to `"all"`):
Which columns to truncate. If `"all"` (default), all columns are truncated.
map_kwargs (`dict`, *optional*):
Additional keyword arguments to pass to the dataset's map method when truncating examples.
@ -749,32 +754,30 @@ def truncate_dataset(
map_kwargs = {}
if isinstance(dataset, Dataset):
# Fast truncation with pyarrow
def truncate(examples):
def truncate(examples, columns):
truncated_columns = []
for column in examples.columns:
if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type):
column = pc.list_slice(column, 0, max_length)
if columns == "all" or column._name in columns:
if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type):
column = pc.list_slice(column, 0, max_length)
truncated_columns.append(column)
return pa.Table.from_arrays(truncated_columns, names=examples.column_names)
dataset = dataset.with_format("arrow")
dataset = dataset.map(truncate, batched=True, **map_kwargs)
dataset = dataset.map(truncate, batched=True, **map_kwargs, fn_kwargs={"columns": columns})
dataset = dataset.with_format(None)
else:
def truncate(examples):
def truncate(examples, columns):
truncated_examples = {}
for key, column in examples.items():
if column and isinstance(column[0], list):
column = [val[:max_length] for val in column]
if columns == "all" or key in columns:
if column and isinstance(column[0], list):
column = [val[:max_length] for val in column]
truncated_examples[key] = column
return truncated_examples
dataset = dataset.map(
truncate,
batched=True,
**map_kwargs,
)
dataset = dataset.map(truncate, batched=True, **map_kwargs, fn_kwargs={"columns": columns})
return dataset

View File

@ -0,0 +1,16 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .dpo_config import DPOConfig
from .dpo_trainer import DPOTrainer

View File

@ -0,0 +1,212 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Any, Optional
from transformers import TrainingArguments
@dataclass
class DPOConfig(TrainingArguments):
r"""
Configuration class for the [`DPOTrainer`].
This class includes only the parameters that are specific to DPO training. For a full list of training arguments,
please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
differ from those in [`~transformers.TrainingArguments`].
Using [`~transformers.HfArgumentParser`] we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
command line.
Parameters:
> Parameters that control the model and reference model
model_init_kwargs (`dict[str, Any]`, *optional*):
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
argument of the [`DPOTrainer`] is provided as a string.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether to disable dropout in the model and reference model.
> Parameters that control the data preprocessing
dataset_num_proc (`int`, *optional*):
Number of processes to use for processing the dataset.
pad_token (`str`, *optional*):
Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`,
it falls back to `processing_class.eos_token`.
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
Maximum length of the prompt part of the sequence. If `None`, no truncation is applied.
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
Maximum length of the completion part of the sequence. If `None`, no truncation is applied.
max_length (`int` or `None`, *optional*, defaults to `1024`):
Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right.
If `None`, no truncation is applied.
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
Truncation mode to use when the sequence exceeds `max_length`. Possible values are `"keep_end"` and
`"keep_start"`.
padding_free (`bool`, *optional*, defaults to `False`):
Whether to perform forward passes without padding by flattening all sequences in the batch into a single
continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only
supported with the FlashAttention 2 or 3, which can efficiently handle the flattened batch structure.
pad_to_multiple_of (`int`, *optional*):
If set, the sequences will be padded to a multiple of this value.
precompute_ref_log_probs (`bool`, *optional*, defaults to `True`):
Whether to precompute the reference model log probabilities for the entire training dataset before
training. This allows to save memory during training, as the reference model does not need to be kept in
memory.
> Parameters that control the training
loss_type (`str` or `list[str]`, *optional*, defaults to `"sigmoid"`):
Type of loss to use. Possible values are:
- `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
- `"hinge"`: hinge loss on the normalized likelihood from the
[SLiC](https://huggingface.co/papers/2305.10425) paper.
beta (`float`, *optional*, defaults to `0.1`):
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
reference model.
activation_offloading (`bool`, *optional*, defaults to `False`):
Whether to offload the activations to the CPU.
"""
_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"]
# Parameters whose default values are overridden from TrainingArguments
learning_rate: float = field(
default=1e-6,
metadata={"help": "The initial learning rate for AdamW."},
)
logging_steps: float = field(
default=10,
metadata={
"help": "Log every X updates steps. Should be an integer or a float in range `[0,1)`. If smaller than 1, "
"will be interpreted as ratio of total training steps."
},
)
gradient_checkpointing: bool = field(
default=True,
metadata={
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)
bf16: Optional[bool] = field(
default=None,
metadata={
"help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA "
"architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU. If not set, it defaults to `True` if "
"`fp16` is not set."
},
)
# Parameters that control the model
model_init_kwargs: Optional[dict[str, Any]] = field(
default=None,
metadata={
"help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of "
"the `DPOTrainer` is provided as a string."
},
)
disable_dropout: bool = field(
default=True,
metadata={"help": "Whether to disable dropout in the model and reference model."},
)
# Parameters that control the data preprocessing
dataset_num_proc: Optional[int] = field(
default=None,
metadata={"help": "Number of processes to use for processing the dataset."},
)
pad_token: Optional[str] = field(
default=None,
metadata={
"help": "Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that "
"is also `None`, it falls back to `processing_class.eos_token`."
},
)
max_prompt_length: Optional[int] = field(
default=512,
metadata={"help": "Maximum length of the prompt part of the sequence. If `None`, no truncation is applied."},
)
max_completion_length: Optional[int] = field(
default=None,
metadata={
"help": "Maximum length of the completion part of the sequence. If `None`, no truncation is applied."
},
)
max_length: Optional[int] = field(
default=1024,
metadata={
"help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from "
"the right. If `None`, no truncation is applied."
},
)
truncation_mode: str = field(
default="keep_end",
metadata={
"help": "Truncation mode to use when the sequence exceeds `max_length`. Possible values are `'keep_end'` "
"and `'keep_start'`.",
"choices": ["keep_end", "keep_start"],
},
)
padding_free: bool = field(
default=False,
metadata={
"help": "Whether to perform forward passes without padding by flattening all sequences in the batch into "
"a single continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this "
"is only supported with the FlashAttention 2 or 3, which can efficiently handle the flattened batch "
"structure."
},
)
pad_to_multiple_of: Optional[int] = field(
default=None,
metadata={"help": "If set, the sequences will be padded to a multiple of this value."},
)
precompute_ref_log_probs: bool = field(
default=True,
metadata={
"help": "Whether to precompute the reference model log probabilities for the entire training dataset "
"before training. This allows to save memory during training, as the reference model does not need to be "
"kept in memory."
},
)
# Parameters that control the training
loss_type: list[str] = field(
default_factory=lambda: ["sigmoid"],
metadata={
"help": "Type of loss to use. Possible values are: `'sigmoid'`, `'hinge'`.",
},
)
beta: float = field(
default=0.1,
metadata={
"help": "Parameter controlling the deviation from the reference model. Higher β means less deviation from "
"the reference model."
},
)
activation_offloading: bool = field(
default=False,
metadata={"help": "Whether to offload the activations to the CPU."},
)
def __post_init__(self):
self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16
# Normalize loss_type to string format for internal use
if hasattr(self.loss_type, "__len__") and len(self.loss_type) == 1:
self.loss_type = self.loss_type[0]
super().__post_init__()

View File

@ -0,0 +1,871 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import textwrap
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
from accelerate import PartialState, logging
from accelerate.utils import is_peft_model
from datasets import Dataset, IterableDataset
from datasets.fingerprint import Hasher
from transformers import (
AutoConfig,
AutoProcessor,
BaseImageProcessor,
DataCollator,
FeatureExtractionMixin,
PreTrainedModel,
PreTrainedTokenizerBase,
ProcessorMixin,
TrainerCallback,
)
from transformers.data.data_collator import DataCollatorMixin
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from transformers.utils import is_peft_available
from ...data_utils import extract_prompt, is_conversational, prepare_multimodal_messages, truncate_dataset
from ...models import get_act_offloading_ctx_manager, prepare_deepspeed, prepare_fsdp, prepare_peft_model
from ...trainer.base_trainer import BaseTrainer
from ...trainer.utils import (
disable_dropout_in_model,
entropy_from_logits,
flush_left,
flush_right,
hash_module,
pad,
remove_none_values,
selective_log_softmax,
)
from .dpo_config import DPOConfig
if is_peft_available():
from peft import PeftConfig, PeftModel
logger = logging.get_logger(__name__)
FLASH_ATTENTION_VARIANTS = {
"flash_attention_2",
"flash_attention_3",
"kernels-community/flash-attn",
"kernels-community/vllm-flash-attn3",
"kernels-community/flash-attn3",
}
def get_dataset_column_names(dataset: Union[Dataset, IterableDataset]) -> list[str]:
return list(next(iter(dataset)).keys()) if dataset.column_names is None else dataset.column_names
@dataclass
class DataCollatorForPreference(DataCollatorMixin):
"""
Data collator used for preference data. Inputs are dynamically padded to the maximum length of a batch.
This collator expects each example in the input list to be a dictionary containing the keys `"prompt_ids"`,
`"chosen_ids"` and `"rejected_input_ids"`. The collator returns a dictionary containing the following keys:
- `"input_ids"`: Tensor of input IDs, padded to the maximum length of the batch. The first half of the batch
corresponds to the `"chosen_input_ids"` and the second half to the `"rejected_input_ids"`.
- `"attention_mask"`: Tensor of attention mask, padded to the maximum length of the batch.
- `"completion_mask"`: Tensor indicating the positions of the completion tokens, padded to the maximum length of
the batch.
Optionally, the examples can contain a `"ref_chosen_logps"` and `"ref_rejected_logps"` keys, in which case the
returned dictionary will also contain these keys with the corresponding tensors.
Args:
pad_token_id (`int`):
Token ID to use for padding.
pad_to_multiple_of (`int`, *optional*):
If set, the sequences will be padded to a multiple of this value.
return_tensors (`str`, *optional*, defaults to `"pt"`):
Type of Tensor to return. Only `"pt"` is currently supported.
Examples:
```python
>>> from trl.trainer.dpo_trainer import DataCollatorForPreference
>>> collator = DataCollatorForPreference(pad_token_id=0)
>>> examples = [{"prompt_ids": [1, 2, 3], {"chosen_ids": [4, 5], "rejected_ids": [6]}]
>>> collator(examples)
{'input_ids': tensor([[ 1, 2, 3, 4, 5],
[ 1, 2, 3, 6, 0]]),
'attention_mask': tensor([[1, 1, 1, 1, 1],
[1, 1, 1, 1, 0]]),
'completion_mask': tensor([[0, 0, 0, 1, 1],
[0, 0, 0, 1, 0]])}
```
"""
pad_token_id: int
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pt"
def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
prompt_chosen_ids = [example["prompt_ids"] + example["chosen_ids"] for example in examples]
prompt_rejected_ids = [example["prompt_ids"] + example["rejected_ids"] for example in examples]
chosen_attention_mask = [[1] * len(example["prompt_ids"] + example["chosen_ids"]) for example in examples]
rejected_attention_mask = [[1] * len(example["prompt_ids"] + example["rejected_ids"]) for example in examples]
chosen_mask = [[0] * len(example["prompt_ids"]) + [1] * len(example["chosen_ids"]) for example in examples]
rejected_mask = [[0] * len(example["prompt_ids"]) + [1] * len(example["rejected_ids"]) for example in examples]
input_ids = prompt_chosen_ids + prompt_rejected_ids
attention_mask = chosen_attention_mask + rejected_attention_mask
completion_mask = chosen_mask + rejected_mask
# Convert to tensor
input_ids = [torch.tensor(ids) for ids in input_ids]
attention_mask = [torch.tensor(m, dtype=torch.long) for m in attention_mask]
completion_mask = [torch.tensor(m, dtype=torch.long) for m in completion_mask]
if "ref_chosen_logps" in examples[0]:
ref_chosen_logps = torch.tensor([example["ref_chosen_logps"] for example in examples])
if "ref_rejected_logps" in examples[0]:
ref_rejected_logps = torch.tensor([example["ref_rejected_logps"] for example in examples])
# Pad
output = {}
output["input_ids"] = pad(
input_ids,
padding_value=self.pad_token_id,
padding_side="right",
pad_to_multiple_of=self.pad_to_multiple_of,
)
output["attention_mask"] = pad(
attention_mask,
padding_value=0,
padding_side="right",
pad_to_multiple_of=self.pad_to_multiple_of,
)
output["completion_mask"] = pad(
completion_mask,
padding_value=0,
padding_side="right",
pad_to_multiple_of=self.pad_to_multiple_of,
)
if "ref_chosen_logps" in examples[0]:
output["ref_chosen_logps"] = ref_chosen_logps
if "ref_rejected_logps" in examples[0]:
output["ref_rejected_logps"] = ref_rejected_logps
return output
class DPOTrainer(BaseTrainer):
"""
Trainer for Direct Preference Optimization (DPO) method.
This class is a wrapper around the [`~transformers.Trainer`] class and inherits all of its attributes and methods.
Example:
```python
from datasets import load_dataset
from trl import DPOTrainer
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
trainer = DPOTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset)
trainer.train()
```
Args:
model (`Union[str, PreTrainedModel]`):
Model to be trained. Can be either:
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
path to a *directory* containing model weights saved using
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
using `<ModelArchitecture>.from_pretrained` (where `<ModelArchitecture>` is derived from the model
config) with the keyword arguments in `args.model_init_kwargs`.
- A [`~transformers.PreTrainedModel`] object.
If you're training a model with an MoE architecture and want to include the load balancing/auxilliary loss
as a part of the final loss, remember to set the `output_router_logits` config of the model to `True`.
args ([`DPOConfig`], *optional*):
Configuration for this trainer. If `None`, a default configuration is used.
data_collator ([`~transformers.DataCollator`], *optional*):
Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`.
Will default to [`~trainer.dpo_trainer.DataCollatorForPreference`] if the model is a language model and
[`~trainer.dpo_trainer.DataCollatorForVisionLanguageModeling`] if the model is a vision-language model.
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
Dataset to use for training. DPO supports both [language modeling](#language-modeling) type and
[prompt-completion](#prompt-completion) type. The format of the samples can be either:
- [Standard](dataset_formats#standard): Each sample contains plain text.
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
and content).
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`], *optional*):
Processing class used to process the data. If `None`, the processing class is loaded from the model's name
with [`~transformers.AutoProcessor.from_pretrained`]. A padding token, `tokenizer.pad_token`, must be set.
If the processing class has not set a padding token, `tokenizer.eos_token` will be used as the default.
compute_loss_func (`Callable`, *optional*):
A function that accepts the raw model outputs, labels, and the number of items in the entire accumulated
batch (batch_size * gradient_accumulation_steps) and returns the loss. For example, see the default [loss
function](https://github.com/huggingface/transformers/blob/052e652d6d53c2b26ffde87e039b723949a53493/src/transformers/trainer.py#L3618)
used by [`Trainer`].
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
The function that will be used to compute metrics at evaluation. Must take a
[`~transformers.EvalPrediction`] and return a dictionary string to metric values. When passing
[`DPOConfig`] with `batch_eval_metrics` set to `True`, your `compute_metrics` function must take a boolean
`compute_result` argument. This will be triggered after the last eval batch to signal that the function
needs to calculate and return the global summary statistics rather than accumulating the batch-level
statistics.
callbacks (list of [`~transformers.TrainerCallback`], *optional*):
List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed
in [here](https://huggingface.co/docs/transformers/main_classes/callback).
If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
method.
optimizers (`tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]]`, *optional*, defaults to `(None, None)`):
A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your
model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`.
optimizer_cls_and_kwargs (`tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*):
A tuple containing the optimizer class and keyword arguments to use. Overrides `optim` and `optim_args` in
`args`. Incompatible with the `optimizers` argument.
Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before
initializing the Trainer.
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
A function that preprocess the logits right before caching them at each evaluation step. Must take two
tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
by this function will be reflected in the predictions received by `compute_metrics`.
Note that the labels (second parameter) will be `None` if the dataset does not have them.
peft_config ([`~peft.PeftConfig`], *optional*):
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
"""
_tag_names = ["trl", "dpo"]
_name = "DPO"
_paper = {
"title": "Direct Preference Optimization: Your Language Model is Secretly a Reward Model",
"id": "2305.18290",
# docstyle-ignore
"citation": textwrap.dedent("""\
@inproceedings{rafailov2023direct,
title = {{Direct Preference Optimization: Your Language Model is Secretly a Reward Model}},
author = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Christopher D. Manning and Stefano Ermon and Chelsea Finn},
year = 2023,
booktitle = {Advances in Neural Information Processing Systems 36: Annual Conference on Neural Information Processing Systems 2023, NeurIPS 2023, New Orleans, LA, USA, December 10 - 16, 2023},
url = {http://papers.nips.cc/paper_files/paper/2023/hash/a85b405ed65c6477a4fe8302b5e06ce7-Abstract-Conference.html},
editor = {Alice Oh and Tristan Naumann and Amir Globerson and Kate Saenko and Moritz Hardt and Sergey Levine},
}"""),
}
def __init__(
self,
model: Union[str, PreTrainedModel],
args: Optional[DPOConfig] = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
processing_class: Optional[Union[PreTrainedTokenizerBase, ProcessorMixin]] = None,
compute_loss_func: Optional[Callable] = None,
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
callbacks: Optional[list[TrainerCallback]] = None,
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None,
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
peft_config: Optional["PeftConfig"] = None,
):
# Args
if args is None:
model_name = model if isinstance(model, str) else model.config._name_or_path
model_name = model_name.split("/")[-1]
args = DPOConfig(f"{model_name}-DPO")
# Models
# Trained model
model_init_kwargs = args.model_init_kwargs or {}
if isinstance(model, str):
model_id = model
dtype = model_init_kwargs.get("dtype")
if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None:
pass # dtype is already a torch.dtype or "auto" or None
elif isinstance(dtype, str): # it's a str, but not "auto"
dtype = getattr(torch, dtype)
model_init_kwargs["dtype"] = dtype
else:
raise ValueError(
"Invalid `dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
f"a `torch.dtype` (e.g., 'float32'), but got {dtype}."
)
# Disable caching if gradient checkpointing is enabled (not supported)
config = AutoConfig.from_pretrained(model_id)
architecture = getattr(transformers, config.architectures[0])
model = architecture.from_pretrained(model_id, **model_init_kwargs)
else:
model_id = model.config._name_or_path
if args.model_init_kwargs is not None:
logger.warning(
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
"The `model_init_kwargs` will be ignored."
)
if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)):
model = prepare_peft_model(model, peft_config, args)
# Disable dropout in the model
if args.disable_dropout:
disable_dropout_in_model(model)
# Processing class
if processing_class is None:
processing_class = AutoProcessor.from_pretrained(model.config._name_or_path)
# Handle pad token for processors or tokenizers
if isinstance(processing_class, ProcessorMixin):
tokenizer = processing_class.tokenizer
self._is_vlm = True
elif isinstance(processing_class, PreTrainedTokenizerBase):
tokenizer = processing_class
self._is_vlm = False
else:
raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`")
if self._is_vlm and args.padding_free:
raise ValueError(
"Padding-free training is yet not supported for vision-language models. Please set "
"`padding_free=False` in the `DPOConfig`."
)
# Data collator
self.padding_free = args.padding_free
use_flash_attention = model.config._attn_implementation in FLASH_ATTENTION_VARIANTS
if self.padding_free:
raise NotImplementedError("Padding-free training is not yet implemented.")
if data_collator is not None:
raise ValueError("Passing a custom data collator is not supported when using padding-free.")
if not use_flash_attention:
logger.warning(
"Padding-free training is enabled, but the attention implementation is not set to a supported "
"flash attention variant. Padding-free training flattens batches into a single sequence, and only "
"the following implementations are known to reliably support this: "
f"{', '.join(sorted(FLASH_ATTENTION_VARIANTS))}. Using other implementations may lead to "
"unexpected behavior. To ensure compatibility, set `attn_implementation` in the model "
"configuration to one of these supported options or verify that your attention mechanism can "
"handle flattened sequences."
)
if args.per_device_train_batch_size == 1:
logger.warning(
"You are using a per_device_train_batch_size of 1 with padding-free training. Using a batch size "
"of 1 anihilate the benefits of padding-free training. Please consider increasing the batch size "
"to at least 2."
)
dataset_sample = next(iter(train_dataset))
self._is_vision_dataset = "image" in dataset_sample or "images" in dataset_sample
if self._is_vision_dataset and not self._is_vlm:
raise ValueError(
"The dataset appears to be vision-related (contains 'image' or 'images' keys), but the provided "
"model does not seem to be a vision-language model. Please check your model and dataset."
)
if data_collator is None and not self._is_vision_dataset:
# Get the pad token: if not provided, use the one from the processing class or the eos token
# if the processing class does not have a pad token.
pad_token = args.pad_token or tokenizer.pad_token or tokenizer.eos_token
pad_token_id = tokenizer.convert_tokens_to_ids(pad_token)
if pad_token_id is None:
raise ValueError(
f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given "
f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists "
"in the vocabulary before using it as a padding token."
)
data_collator = DataCollatorForPreference(
pad_token_id=pad_token_id,
pad_to_multiple_of=args.pad_to_multiple_of,
)
elif data_collator is None and self._is_vision_dataset:
raise NotImplementedError("VLM training is not yet implemented.")
# Training arguments
self.loss_type = args.loss_type if isinstance(args.loss_type, list) else [args.loss_type]
self.beta = args.beta
# Dataset
# Skip dataset preparation if it's a VLM, where preprocessing (e.g., image-to-pixel conversion) is too costly
# and done on the fly instead.
skip_prepare_dataset = self._is_vision_dataset
if not skip_prepare_dataset:
train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train")
if eval_dataset is not None:
if isinstance(eval_dataset, dict):
eval_dataset = {
key: self._prepare_dataset(dataset, processing_class, args, key)
for key, dataset in eval_dataset.items()
}
else:
eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval")
# Initialize the metrics
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
self._total_train_tokens = 0
# Initialize the Trainer. Parent class will handle:
# - DeepSpeed configuration (through create_accelerator_and_postprocess)
# - FSDP setup
# - Distributed training setup
# - Optimizer and scheduler creation
super().__init__(
model=model,
args=args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=processing_class,
compute_loss_func=compute_loss_func,
compute_metrics=compute_metrics,
callbacks=callbacks,
optimizers=optimizers,
optimizer_cls_and_kwargs=optimizer_cls_and_kwargs,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
# Reference model
self.beta = args.beta
if self.beta == 0.0:
# If beta is 0.0, the reference model is not needed
self.ref_model = None
elif is_peft_model(model):
# If PEFT is used, the reference model is not needed since the adapter can be disabled
# to revert to the initial model.
self.ref_model = None
else:
# For deepspeed, fsdp or non-distributed models, create a reference model from scratch
config = AutoConfig.from_pretrained(model_id)
architecture = getattr(transformers, config.architectures[0])
self.ref_model = architecture.from_pretrained(model_id, **model_init_kwargs)
# Disable dropout in the models
if args.disable_dropout:
disable_dropout_in_model(model)
if self.ref_model is not None:
disable_dropout_in_model(self.ref_model)
# Disable dropout in the models
if args.disable_dropout:
disable_dropout_in_model(model)
if self.ref_model is not None:
disable_dropout_in_model(self.ref_model)
# Initialize the metrics
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
self._total_train_tokens = 0
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
# self.model_accepts_loss_kwargs to False to enable scaling.
self.model_accepts_loss_kwargs = False
# Add tags for model
self.model.add_model_tags(self._tag_names)
if self.ref_model is not None:
if self.is_deepspeed_enabled:
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
elif self.is_fsdp_enabled:
self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
# Initialize activation offloading context
if self.args.activation_offloading:
self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model)
else:
self.maybe_activation_offload_context = contextlib.nullcontext()
def _prepare_dataset(
self,
dataset: Union[Dataset, IterableDataset],
processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin],
args: DPOConfig,
dataset_name: str,
) -> Union[Dataset, IterableDataset]:
# Tabular backends like Arrow/Parquet insert `None` for mismatched keys in nested structures. Clean them from
# sampled data.
if isinstance(dataset, Dataset): # IterableDataset does not support `with_transform`
dataset = dataset.with_transform(remove_none_values)
# Build the kwargs for the `map` function
map_kwargs = {}
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
map_kwargs["num_proc"] = args.dataset_num_proc
with PartialState().main_process_first():
# Extract the prompt if needed
first_example = next(iter(dataset))
if "prompt" not in first_example:
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Extracting prompt from {dataset_name} dataset"
dataset = dataset.map(extract_prompt, **map_kwargs)
# Apply the chat template if needed
first_example = next(iter(dataset))
if not is_conversational(first_example):
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Adding EOS to {dataset_name} dataset"
def add_eos(example, eos_token):
if not example["chosen"].endswith(eos_token):
example["chosen"] = example["chosen"] + eos_token
if not example["rejected"].endswith(eos_token):
example["rejected"] = example["rejected"] + eos_token
return example
dataset = dataset.map(add_eos, fn_kwargs={"eos_token": processing_class.eos_token}, **map_kwargs)
# Tokenize the dataset
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
def tokenize_fn(example, processing_class):
output = {}
if is_conversational(example):
if self._is_vlm:
prepare_multimodal_messages(example["prompt"], num_images=0)
prepare_multimodal_messages(example["completion"], num_images=0)
prompt_ids = processing_class.apply_chat_template(
example["prompt"],
tokenize=True,
add_generation_prompt=True,
tools=example.get("tools"),
**example.get("chat_template_kwargs", {}),
)
prompt_chosen_processed = processing_class.apply_chat_template(
example["prompt"] + example["chosen"],
return_dict=True,
tokenize=True,
tools=example.get("tools"),
**example.get("chat_template_kwargs", {}),
)
prompt_rejected_processed = processing_class.apply_chat_template(
example["prompt"] + example["rejected"],
return_dict=True,
tokenize=True,
tools=example.get("tools"),
**example.get("chat_template_kwargs", {}),
)
# Fix transformers inconsistency: for VLMs, apply_chat_template returns lists of lists
# even for single examples, while for LLMs it returns lists of ints.
prompt_ids = prompt_ids[0] if isinstance(prompt_ids[0], list) else prompt_ids
prompt_chosen_processed = {
k: v[0] if isinstance(v[0], list) else v for k, v in prompt_chosen_processed.items()
}
prompt_rejected_processed = {
k: v[0] if isinstance(v[0], list) else v for k, v in prompt_rejected_processed.items()
}
prompt_chosen_ids = prompt_chosen_processed["input_ids"]
prompt_rejected_ids = prompt_rejected_processed["input_ids"]
else:
prompt_ids = processing_class(text=example["prompt"])["input_ids"]
prompt_chosen_ids = processing_class(text=example["prompt"] + example["chosen"])["input_ids"]
prompt_rejected_ids = processing_class(text=example["prompt"] + example["rejected"])["input_ids"]
# Check if the tokenized prompt starts with the tokenized prompt+completion
if not prompt_chosen_ids[: len(prompt_ids)] == prompt_ids:
logger.warning(
"Mismatch between tokenized prompt and the start of tokenized prompt+chosen. "
"This may be due to unexpected tokenizer behavior, whitespace issues, or special "
"token handling. Verify that the tokenizer is processing text consistently."
)
if not prompt_rejected_ids[: len(prompt_ids)] == prompt_ids:
logger.warning(
"Mismatch between tokenized prompt and the start of tokenized prompt+rejected. "
"This may be due to unexpected tokenizer behavior, whitespace issues, or special "
"token handling. Verify that the tokenizer is processing text consistently."
)
output["prompt_ids"] = prompt_ids
output["chosen_ids"] = prompt_chosen_ids[len(prompt_ids) :]
output["rejected_ids"] = prompt_rejected_ids[len(prompt_ids) :]
return output
dataset = dataset.map(tokenize_fn, fn_kwargs={"processing_class": processing_class}, **map_kwargs)
# Truncate
if args.max_prompt_length is not None:
raise NotImplementedError("Prompt truncation is not yet implemented.")
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Truncating prompt in {dataset_name} dataset"
dataset = truncate_dataset(
dataset, args.max_prompt_length, columns=["prompt_ids"], map_kwargs=map_kwargs
)
if args.max_completion_length is not None:
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Truncating completions in {dataset_name} dataset"
dataset = truncate_dataset(
dataset, args.max_completion_length, columns=["chosen_ids", "rejected_ids"], map_kwargs=map_kwargs
)
# For Liger kernel, ensure only the essential columns
if args.use_liger_kernel:
collator_expected_keys = {"input_ids", "completion_mask"}
column_names = get_dataset_column_names(dataset)
dataset = dataset.select_columns(collator_expected_keys.intersection(column_names))
return dataset
def _set_signature_columns_if_needed(self):
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
# By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids"
# and "attention_mask").
if self._signature_columns is None:
if self._is_vision_dataset:
self._signature_columns = ["prompt", "chosen", "rejected"]
else:
self._signature_columns = [
"prompt_ids",
"chosen_ids",
"rejected_ids",
"ref_chosen_logps",
"ref_rejected_logps",
]
def train(self, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None, **kwargs):
if self.args.precompute_ref_log_probs:
self.train_dataset = self._precompute_ref_logps(
self.train_dataset, self.args.per_device_train_batch_size, "train"
)
if self.eval_dataset is not None:
if isinstance(self.eval_dataset, dict):
self.eval_dataset = {
key: self._precompute_ref_logps(dataset, self.args.per_device_eval_batch_size, key)
for key, dataset in self.eval_dataset.items()
}
else:
self.eval_dataset = self._precompute_ref_logps(
self.eval_dataset, self.args.per_device_eval_batch_size, "eval"
)
return super().train()
def _precompute_ref_logps(
self, dataset: Union[Dataset, IterableDataset], batch_size: int, dataset_name: str
) -> None:
def compute_ref_logps(examples, collator, max_length, truncation_mode):
examples = [dict(zip(examples.keys(), v)) for v in zip(*examples.values())] # dict[list] to list[dict]
inputs = collator(examples)
input_ids = inputs["input_ids"].to(self.model.device)
attention_mask = inputs["attention_mask"].to(self.model.device)
completion_mask = inputs["completion_mask"].to(self.model.device)
# Truncate inputs
if max_length is not None:
if truncation_mode == "keep_start":
input_ids = input_ids[:, :max_length]
attention_mask = attention_mask[:, :max_length]
completion_mask = completion_mask[:, :max_length]
elif truncation_mode == "keep_end":
attention_mask, input_ids, completion_mask = flush_right(
attention_mask, input_ids, completion_mask
)
input_ids = input_ids[:, -max_length:]
attention_mask = attention_mask[:, -max_length:]
completion_mask = completion_mask[:, -max_length:]
attention_mask, input_ids, completion_mask = flush_left(attention_mask, input_ids, completion_mask)
else:
raise ValueError(
f"Unsupported truncation mode: {truncation_mode}, expected 'keep_start' or 'keep_end'"
)
outputs = self.model(input_ids, attention_mak=attention_mask, use_cache=False)
shift_logits = outputs.logits[..., :-1, :].contiguous()
shift_labels = input_ids[..., 1:].contiguous()
shift_completion_mask = completion_mask[..., 1:].contiguous()
per_token_logps = selective_log_softmax(shift_logits, shift_labels)
per_token_logps[shift_completion_mask == 0] = 0.0 # mask out non-completion tokens
logps = per_token_logps.sum(dim=1) # sum over sequence length
chosen_logps, rejected_logps = logps.chunk(2, dim=0) # batch is [chosen, rejected]
return {"ref_chosen_logps": chosen_logps.tolist(), "ref_rejected_logps": rejected_logps.tolist()}
# Normally, `map` creates a fingerprint based on the transform function and its arguments. However, the models
# produces a different fingerprint on each run, which prevents the cache from being used. To fix this, we
# manually compute a stable fingerprint for the model instead.
fn_kwargs = {
"collator": self.data_collator,
"max_length": self.args.max_length,
"truncation_mode": self.args.truncation_mode,
}
model_hash = hash_module(self.model)
dataset = dataset.map(
compute_ref_logps,
batched=True,
batch_size=batch_size,
fn_kwargs=fn_kwargs,
desc=f"Computing reference logps for {dataset_name} dataset",
new_fingerprint=Hasher.hash((dataset._fingerprint, fn_kwargs, model_hash)),
)
return dataset
def compute_loss(
self,
model: nn.Module,
inputs: dict[str, Union[torch.Tensor, Any]],
return_outputs: bool = False,
num_items_in_batch: Optional[torch.Tensor] = None,
):
"""
Compute training loss and additionally compute token accuracies
"""
mode = "train" if self.model.training else "eval"
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
completion_mask = inputs["completion_mask"]
# Truncate inputs
if self.args.max_length is not None:
if self.args.truncation_mode == "keep_start":
input_ids = input_ids[:, : self.args.max_length]
attention_mask = attention_mask[:, : self.args.max_length]
completion_mask = completion_mask[:, : self.args.max_length]
elif self.args.truncation_mode == "keep_end":
attention_mask, input_ids, completion_mask = flush_right(attention_mask, input_ids, completion_mask)
input_ids = input_ids[:, -self.args.max_length :]
attention_mask = attention_mask[:, -self.args.max_length :]
completion_mask = completion_mask[:, -self.args.max_length :]
attention_mask, input_ids, completion_mask = flush_left(attention_mask, input_ids, completion_mask)
else:
raise ValueError(
f"Unsupported truncation mode: {self.args.truncation_mode}, expected 'keep_start' or 'keep_end'"
)
outputs = model(input_ids, attention_mak=attention_mask, use_cache=False)
shift_logits = outputs.logits[..., :-1, :].contiguous()
shift_labels = input_ids[..., 1:].contiguous()
shift_completion_mask = completion_mask[..., 1:].contiguous()
per_token_logps = selective_log_softmax(shift_logits, shift_labels)
per_token_logps[shift_completion_mask == 0] = 0.0 # mask out non-completion tokens
logps = per_token_logps.sum(dim=1) # sum over sequence length
chosen_logps, rejected_logps = logps.chunk(2, dim=0) # batch is [chosen, rejected]
ref_chosen_logps, ref_rejected_logps = inputs["ref_chosen_logps"], inputs["ref_rejected_logps"]
# Get the log ratios for the chosen and rejected responses
chosen_logratios = chosen_logps - ref_chosen_logps
rejected_logratios = rejected_logps - ref_rejected_logps
loss = 0
for loss_type in self.loss_type:
if loss_type == "sigmoid":
per_sequence_loss = -F.logsigmoid(self.beta * chosen_logratios - self.beta * rejected_logratios)
elif loss_type == "hinge":
per_sequence_loss = torch.relu(1 - (self.beta * chosen_logratios - self.beta * rejected_logratios))
loss += per_sequence_loss.mean()
# Log the metrics
# Entropy
per_token_entropy = entropy_from_logits(shift_logits.detach())
entropy = per_token_entropy[shift_completion_mask.bool()].mean()
entropy = self.accelerator.gather_for_metrics(entropy).mean().item()
self._metrics[mode]["entropy"].append(entropy)
# Number of tokens
if mode == "train":
num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item()
self._total_train_tokens += num_tokens_in_batch
self._metrics[mode]["num_tokens"] = [self._total_train_tokens]
# Average logits for chosen and rejected completions
chosen_logits, rejected_logits = shift_logits.detach().chunk(2, dim=0)
chosen_mask, rejected_mask = shift_completion_mask.chunk(2, dim=0)
total_chosen_logits = chosen_logits[chosen_mask.bool()].mean(-1)
total_chosen_tokens = chosen_mask.sum()
total_rejected_logits = rejected_logits[rejected_mask.bool()].mean(-1)
total_rejected_tokens = rejected_mask.sum()
total_chosen_logits = self.accelerator.gather_for_metrics(total_chosen_logits).sum().item()
total_chosen_tokens = self.accelerator.gather_for_metrics(total_chosen_tokens).sum().item()
total_rejected_logits = self.accelerator.gather_for_metrics(total_rejected_logits).sum().item()
total_rejected_tokens = self.accelerator.gather_for_metrics(total_rejected_tokens).sum().item()
avg_chosen_logits = total_chosen_logits / total_chosen_tokens if total_chosen_tokens > 0 else 0.0
avg_rejected_logits = total_rejected_logits / total_rejected_tokens if total_rejected_tokens > 0 else 0.0
self._metrics[mode]["logits/chosen"].append(avg_chosen_logits)
self._metrics[mode]["logits/rejected"].append(avg_rejected_logits)
# Token accuracy for the chosen completions
predictions = chosen_logits.argmax(dim=-1)
chosen_mask = shift_completion_mask[: len(shift_completion_mask) // 2].bool()
chosen_labels = shift_labels[: len(shift_labels) // 2]
correct_predictions = (predictions == chosen_labels) & chosen_mask
total_tokens = chosen_mask.sum()
correct_tokens = correct_predictions.sum()
correct_tokens = self.accelerator.gather_for_metrics(correct_tokens)
total_tokens = self.accelerator.gather_for_metrics(total_tokens)
total_sum = total_tokens.sum()
accuracy = (correct_tokens.sum() / total_sum).item() if total_sum > 0 else 0.0
self._metrics[mode]["mean_token_accuracy"].append(accuracy)
# Rewards for chosen and rejected completions
chosen_rewards = self.beta * (chosen_logps.detach() - ref_chosen_logps)
rejected_rewards = self.beta * (rejected_logps.detach() - ref_rejected_logps)
agg_chosen_rewards = self.accelerator.gather(chosen_rewards)
agg_rejected_rewards = self.accelerator.gather(rejected_rewards)
self._metrics[mode]["rewards/chosen"].append(agg_chosen_rewards.mean().item())
self._metrics[mode]["rewards/rejected"].append(agg_rejected_rewards.mean().item())
# Reward accuracy
reward_accuracies = (chosen_rewards > rejected_rewards).float()
agg_reward_accuracies = self.accelerator.gather(reward_accuracies)
self._metrics[mode]["rewards/accuracies"].append(agg_reward_accuracies.mean().item())
# Reward margins
margins = chosen_rewards - rejected_rewards
agg_margins = self.accelerator.gather(margins)
self._metrics[mode]["rewards/margins"].append(agg_margins.mean().item())
# Average log probabilities for chosen and rejected completions
self._metrics[mode]["logps/chosen"].append(self.accelerator.gather(chosen_logps).mean().item())
self._metrics[mode]["logps/rejected"].append(self.accelerator.gather(rejected_logps).mean().item())
return (loss, outputs) if return_outputs else loss
# Override training step to add activation offloading context.
def training_step(self, *args, **kwargs):
with self.maybe_activation_offload_context:
return super().training_step(*args, **kwargs)
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
mode = "train" if self.model.training else "eval"
metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
if mode == "eval":
metrics = {f"eval_{key}": val for key, val in metrics.items()}
logs.update(metrics)
super().log(logs, start_time)
self._metrics[mode].clear()
# Ensure the model card is saved along with the checkpoint
def _save_checkpoint(self, model, trial):
if self.args.hub_model_id is None:
model_name = Path(self.args.output_dir).name
else:
model_name = self.args.hub_model_id.split("/")[-1]
self.create_model_card(model_name=model_name)
super()._save_checkpoint(model, trial)

View File

@ -51,7 +51,7 @@ from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_liger_kernel_available, is_peft_available
from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
from ..data_utils import is_conversational, maybe_apply_chat_template, maybe_extract_prompt
from ..models import create_reference_model, prepare_deepspeed
from ..models.utils import prepare_fsdp
from .base_trainer import BaseTrainer
@ -649,6 +649,8 @@ class DPOTrainer(BaseTrainer):
map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset"
dataset = dataset.map(maybe_extract_prompt, **map_kwargs)
is_chat = is_conversational(next(iter(dataset)))
# Apply the chat template if needed
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
@ -669,6 +671,7 @@ class DPOTrainer(BaseTrainer):
"max_completion_length": args.max_completion_length,
# for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token])
"add_special_tokens": False,
"is_chat": is_chat,
},
**map_kwargs,
)
@ -682,6 +685,7 @@ class DPOTrainer(BaseTrainer):
max_prompt_length: Optional[int] = None,
max_completion_length: Optional[int] = None,
add_special_tokens: bool = True,
is_chat: bool = False,
) -> dict[str, list[int]]:
"""
Tokenize a row of the dataset.
@ -728,8 +732,9 @@ class DPOTrainer(BaseTrainer):
prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids
if tokenizer.eos_token_id is not None:
prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id]
chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id]
rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id]
if not is_chat:
chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id]
rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id]
# Truncate prompt and completion sequences
if max_prompt_length is not None:

View File

@ -1057,7 +1057,7 @@ class SFTTrainer(BaseTrainer):
elif args.max_length is not None:
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Truncating {dataset_name} dataset"
dataset = truncate_dataset(dataset, args.max_length, map_kwargs)
dataset = truncate_dataset(dataset, args.max_length, map_kwargs=map_kwargs)
# For Liger kernel, ensure only the essential columns
if args.use_liger_kernel:
collator_expected_keys = {"input_ids", "seq_lengths", "completion_mask", "assistant_masks"}

View File

@ -19,6 +19,7 @@ import os
import random
import socket
import warnings
import zlib
from collections.abc import Mapping, Sequence, Sized
from dataclasses import dataclass, field
from importlib.metadata import version
@ -29,7 +30,6 @@ import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torch.utils.data
import transformers
from accelerate import Accelerator, PartialState, logging
from accelerate.state import AcceleratorState
@ -1990,3 +1990,14 @@ def create_model_from_path(model_id: str, **kwargs) -> PreTrainedModel:
architecture = getattr(transformers, config.architectures[0])
model = architecture.from_pretrained(model_id, **kwargs)
return model
def hash_module(module: torch.nn.Module) -> str:
h = zlib.adler32(b"")
for _, tensor in sorted(module.state_dict().items()):
tensor = tensor.cpu()
h = zlib.adler32(str(tensor.dtype).encode(), h)
if tensor.dtype in (torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2):
tensor = tensor.to(torch.float32)
h = zlib.adler32(tensor.numpy().tobytes(), h)
return f"{h:08x}"