mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Compare commits
21 Commits
aa25c2697c
...
refactor-d
Author | SHA1 | Date | |
---|---|---|---|
ea664ba0df | |||
4b92f6473d | |||
d2aa8f3019 | |||
aefc01bb95 | |||
f023e65151 | |||
36513d0a70 | |||
fcf62d1640 | |||
a68ea0f668 | |||
d2f5227a16 | |||
6da159efac | |||
a7aab5aa73 | |||
a6941ea06a | |||
ea21e98b03 | |||
ddd8022b91 | |||
48b242f173 | |||
4d37a04108 | |||
2b911b6f02 | |||
c589fdaf8b | |||
0e3c3e456b | |||
a528fc17fb | |||
e1c9477659 |
@ -955,6 +955,20 @@ class TestTruncateExamples(TrlTestCase):
|
||||
dataset = truncate_dataset(dataset, max_length)
|
||||
assert dataset.to_dict() == expected_output
|
||||
|
||||
def test_with_specified_columns(self):
|
||||
examples = {
|
||||
"prompt_ids": [[1, 2, 3], [6, 7], [12]],
|
||||
"completion_ids": [[4, 5], [8, 9, 10, 11], [13, 14]],
|
||||
}
|
||||
dataset = Dataset.from_dict(examples)
|
||||
max_length = 2
|
||||
expected_output = {
|
||||
"prompt_ids": [[1, 2], [6, 7], [12]],
|
||||
"completion_ids": [[4, 5], [8, 9, 10, 11], [13, 14]],
|
||||
}
|
||||
dataset = truncate_dataset(dataset, max_length, columns=["prompt_ids"])
|
||||
assert dataset.to_dict() == expected_output
|
||||
|
||||
|
||||
class TestMaybeConvertToChatML(TrlTestCase):
|
||||
def test_with_conversations_key(self):
|
||||
|
@ -714,7 +714,10 @@ def pack_dataset(
|
||||
|
||||
|
||||
def truncate_dataset(
|
||||
dataset: DatasetType, max_length: int, map_kwargs: Optional[dict[str, Any]] = None
|
||||
dataset: DatasetType,
|
||||
max_length: int,
|
||||
columns: Union[str, list[str]] = "all",
|
||||
map_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> DatasetType:
|
||||
r"""
|
||||
Truncate sequences in a dataset to a specified `max_length`.
|
||||
@ -724,6 +727,8 @@ def truncate_dataset(
|
||||
Dataset to truncate.
|
||||
max_length (`int`):
|
||||
Maximum sequence length to truncate to.
|
||||
columns (`str` or `list[str]`, *optional*, defaults to `"all"`):
|
||||
Which columns to truncate. If `"all"` (default), all columns are truncated.
|
||||
map_kwargs (`dict`, *optional*):
|
||||
Additional keyword arguments to pass to the dataset's map method when truncating examples.
|
||||
|
||||
@ -749,32 +754,30 @@ def truncate_dataset(
|
||||
map_kwargs = {}
|
||||
if isinstance(dataset, Dataset):
|
||||
# Fast truncation with pyarrow
|
||||
def truncate(examples):
|
||||
def truncate(examples, columns):
|
||||
truncated_columns = []
|
||||
for column in examples.columns:
|
||||
if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type):
|
||||
column = pc.list_slice(column, 0, max_length)
|
||||
if columns == "all" or column._name in columns:
|
||||
if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type):
|
||||
column = pc.list_slice(column, 0, max_length)
|
||||
truncated_columns.append(column)
|
||||
return pa.Table.from_arrays(truncated_columns, names=examples.column_names)
|
||||
|
||||
dataset = dataset.with_format("arrow")
|
||||
dataset = dataset.map(truncate, batched=True, **map_kwargs)
|
||||
dataset = dataset.map(truncate, batched=True, **map_kwargs, fn_kwargs={"columns": columns})
|
||||
dataset = dataset.with_format(None)
|
||||
else:
|
||||
|
||||
def truncate(examples):
|
||||
def truncate(examples, columns):
|
||||
truncated_examples = {}
|
||||
for key, column in examples.items():
|
||||
if column and isinstance(column[0], list):
|
||||
column = [val[:max_length] for val in column]
|
||||
if columns == "all" or key in columns:
|
||||
if column and isinstance(column[0], list):
|
||||
column = [val[:max_length] for val in column]
|
||||
truncated_examples[key] = column
|
||||
return truncated_examples
|
||||
|
||||
dataset = dataset.map(
|
||||
truncate,
|
||||
batched=True,
|
||||
**map_kwargs,
|
||||
)
|
||||
dataset = dataset.map(truncate, batched=True, **map_kwargs, fn_kwargs={"columns": columns})
|
||||
return dataset
|
||||
|
||||
|
||||
|
16
trl/experimental/dpo/__init__.py
Normal file
16
trl/experimental/dpo/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .dpo_config import DPOConfig
|
||||
from .dpo_trainer import DPOTrainer
|
212
trl/experimental/dpo/dpo_config.py
Normal file
212
trl/experimental/dpo/dpo_config.py
Normal file
@ -0,0 +1,212 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
from transformers import TrainingArguments
|
||||
|
||||
|
||||
@dataclass
|
||||
class DPOConfig(TrainingArguments):
|
||||
r"""
|
||||
Configuration class for the [`DPOTrainer`].
|
||||
|
||||
This class includes only the parameters that are specific to DPO training. For a full list of training arguments,
|
||||
please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this class may
|
||||
differ from those in [`~transformers.TrainingArguments`].
|
||||
|
||||
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
||||
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
||||
command line.
|
||||
|
||||
Parameters:
|
||||
> Parameters that control the model and reference model
|
||||
|
||||
model_init_kwargs (`dict[str, Any]`, *optional*):
|
||||
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
|
||||
argument of the [`DPOTrainer`] is provided as a string.
|
||||
disable_dropout (`bool`, *optional*, defaults to `True`):
|
||||
Whether to disable dropout in the model and reference model.
|
||||
|
||||
> Parameters that control the data preprocessing
|
||||
|
||||
dataset_num_proc (`int`, *optional*):
|
||||
Number of processes to use for processing the dataset.
|
||||
pad_token (`str`, *optional*):
|
||||
Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`,
|
||||
it falls back to `processing_class.eos_token`.
|
||||
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
||||
Maximum length of the prompt part of the sequence. If `None`, no truncation is applied.
|
||||
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
|
||||
Maximum length of the completion part of the sequence. If `None`, no truncation is applied.
|
||||
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
||||
Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right.
|
||||
If `None`, no truncation is applied.
|
||||
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
|
||||
Truncation mode to use when the sequence exceeds `max_length`. Possible values are `"keep_end"` and
|
||||
`"keep_start"`.
|
||||
padding_free (`bool`, *optional*, defaults to `False`):
|
||||
Whether to perform forward passes without padding by flattening all sequences in the batch into a single
|
||||
continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only
|
||||
supported with the FlashAttention 2 or 3, which can efficiently handle the flattened batch structure.
|
||||
pad_to_multiple_of (`int`, *optional*):
|
||||
If set, the sequences will be padded to a multiple of this value.
|
||||
precompute_ref_log_probs (`bool`, *optional*, defaults to `True`):
|
||||
Whether to precompute the reference model log probabilities for the entire training dataset before
|
||||
training. This allows to save memory during training, as the reference model does not need to be kept in
|
||||
memory.
|
||||
|
||||
> Parameters that control the training
|
||||
|
||||
loss_type (`str` or `list[str]`, *optional*, defaults to `"sigmoid"`):
|
||||
Type of loss to use. Possible values are:
|
||||
|
||||
- `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
|
||||
- `"hinge"`: hinge loss on the normalized likelihood from the
|
||||
[SLiC](https://huggingface.co/papers/2305.10425) paper.
|
||||
beta (`float`, *optional*, defaults to `0.1`):
|
||||
Parameter controlling the deviation from the reference model. Higher β means less deviation from the
|
||||
reference model.
|
||||
activation_offloading (`bool`, *optional*, defaults to `False`):
|
||||
Whether to offload the activations to the CPU.
|
||||
"""
|
||||
|
||||
_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"]
|
||||
|
||||
# Parameters whose default values are overridden from TrainingArguments
|
||||
learning_rate: float = field(
|
||||
default=1e-6,
|
||||
metadata={"help": "The initial learning rate for AdamW."},
|
||||
)
|
||||
logging_steps: float = field(
|
||||
default=10,
|
||||
metadata={
|
||||
"help": "Log every X updates steps. Should be an integer or a float in range `[0,1)`. If smaller than 1, "
|
||||
"will be interpreted as ratio of total training steps."
|
||||
},
|
||||
)
|
||||
gradient_checkpointing: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
|
||||
},
|
||||
)
|
||||
bf16: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA "
|
||||
"architecture or Intel XPU or using CPU (use_cpu) or Ascend NPU. If not set, it defaults to `True` if "
|
||||
"`fp16` is not set."
|
||||
},
|
||||
)
|
||||
|
||||
# Parameters that control the model
|
||||
model_init_kwargs: Optional[dict[str, Any]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of "
|
||||
"the `DPOTrainer` is provided as a string."
|
||||
},
|
||||
)
|
||||
disable_dropout: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to disable dropout in the model and reference model."},
|
||||
)
|
||||
|
||||
# Parameters that control the data preprocessing
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of processes to use for processing the dataset."},
|
||||
)
|
||||
pad_token: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that "
|
||||
"is also `None`, it falls back to `processing_class.eos_token`."
|
||||
},
|
||||
)
|
||||
max_prompt_length: Optional[int] = field(
|
||||
default=512,
|
||||
metadata={"help": "Maximum length of the prompt part of the sequence. If `None`, no truncation is applied."},
|
||||
)
|
||||
max_completion_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Maximum length of the completion part of the sequence. If `None`, no truncation is applied."
|
||||
},
|
||||
)
|
||||
max_length: Optional[int] = field(
|
||||
default=1024,
|
||||
metadata={
|
||||
"help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from "
|
||||
"the right. If `None`, no truncation is applied."
|
||||
},
|
||||
)
|
||||
truncation_mode: str = field(
|
||||
default="keep_end",
|
||||
metadata={
|
||||
"help": "Truncation mode to use when the sequence exceeds `max_length`. Possible values are `'keep_end'` "
|
||||
"and `'keep_start'`.",
|
||||
"choices": ["keep_end", "keep_start"],
|
||||
},
|
||||
)
|
||||
padding_free: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to perform forward passes without padding by flattening all sequences in the batch into "
|
||||
"a single continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this "
|
||||
"is only supported with the FlashAttention 2 or 3, which can efficiently handle the flattened batch "
|
||||
"structure."
|
||||
},
|
||||
)
|
||||
pad_to_multiple_of: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "If set, the sequences will be padded to a multiple of this value."},
|
||||
)
|
||||
precompute_ref_log_probs: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether to precompute the reference model log probabilities for the entire training dataset "
|
||||
"before training. This allows to save memory during training, as the reference model does not need to be "
|
||||
"kept in memory."
|
||||
},
|
||||
)
|
||||
|
||||
# Parameters that control the training
|
||||
loss_type: list[str] = field(
|
||||
default_factory=lambda: ["sigmoid"],
|
||||
metadata={
|
||||
"help": "Type of loss to use. Possible values are: `'sigmoid'`, `'hinge'`.",
|
||||
},
|
||||
)
|
||||
beta: float = field(
|
||||
default=0.1,
|
||||
metadata={
|
||||
"help": "Parameter controlling the deviation from the reference model. Higher β means less deviation from "
|
||||
"the reference model."
|
||||
},
|
||||
)
|
||||
activation_offloading: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to offload the activations to the CPU."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16
|
||||
|
||||
# Normalize loss_type to string format for internal use
|
||||
if hasattr(self.loss_type, "__len__") and len(self.loss_type) == 1:
|
||||
self.loss_type = self.loss_type[0]
|
||||
super().__post_init__()
|
871
trl/experimental/dpo/dpo_trainer.py
Normal file
871
trl/experimental/dpo/dpo_trainer.py
Normal file
@ -0,0 +1,871 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import textwrap
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from accelerate import PartialState, logging
|
||||
from accelerate.utils import is_peft_model
|
||||
from datasets import Dataset, IterableDataset
|
||||
from datasets.fingerprint import Hasher
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoProcessor,
|
||||
BaseImageProcessor,
|
||||
DataCollator,
|
||||
FeatureExtractionMixin,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
ProcessorMixin,
|
||||
TrainerCallback,
|
||||
)
|
||||
from transformers.data.data_collator import DataCollatorMixin
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_utils import EvalPrediction
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from ...data_utils import extract_prompt, is_conversational, prepare_multimodal_messages, truncate_dataset
|
||||
from ...models import get_act_offloading_ctx_manager, prepare_deepspeed, prepare_fsdp, prepare_peft_model
|
||||
from ...trainer.base_trainer import BaseTrainer
|
||||
from ...trainer.utils import (
|
||||
disable_dropout_in_model,
|
||||
entropy_from_logits,
|
||||
flush_left,
|
||||
flush_right,
|
||||
hash_module,
|
||||
pad,
|
||||
remove_none_values,
|
||||
selective_log_softmax,
|
||||
)
|
||||
from .dpo_config import DPOConfig
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import PeftConfig, PeftModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
FLASH_ATTENTION_VARIANTS = {
|
||||
"flash_attention_2",
|
||||
"flash_attention_3",
|
||||
"kernels-community/flash-attn",
|
||||
"kernels-community/vllm-flash-attn3",
|
||||
"kernels-community/flash-attn3",
|
||||
}
|
||||
|
||||
|
||||
def get_dataset_column_names(dataset: Union[Dataset, IterableDataset]) -> list[str]:
|
||||
return list(next(iter(dataset)).keys()) if dataset.column_names is None else dataset.column_names
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForPreference(DataCollatorMixin):
|
||||
"""
|
||||
Data collator used for preference data. Inputs are dynamically padded to the maximum length of a batch.
|
||||
|
||||
This collator expects each example in the input list to be a dictionary containing the keys `"prompt_ids"`,
|
||||
`"chosen_ids"` and `"rejected_input_ids"`. The collator returns a dictionary containing the following keys:
|
||||
- `"input_ids"`: Tensor of input IDs, padded to the maximum length of the batch. The first half of the batch
|
||||
corresponds to the `"chosen_input_ids"` and the second half to the `"rejected_input_ids"`.
|
||||
- `"attention_mask"`: Tensor of attention mask, padded to the maximum length of the batch.
|
||||
- `"completion_mask"`: Tensor indicating the positions of the completion tokens, padded to the maximum length of
|
||||
the batch.
|
||||
|
||||
Optionally, the examples can contain a `"ref_chosen_logps"` and `"ref_rejected_logps"` keys, in which case the
|
||||
returned dictionary will also contain these keys with the corresponding tensors.
|
||||
|
||||
Args:
|
||||
pad_token_id (`int`):
|
||||
Token ID to use for padding.
|
||||
pad_to_multiple_of (`int`, *optional*):
|
||||
If set, the sequences will be padded to a multiple of this value.
|
||||
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
||||
Type of Tensor to return. Only `"pt"` is currently supported.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
>>> from trl.trainer.dpo_trainer import DataCollatorForPreference
|
||||
|
||||
>>> collator = DataCollatorForPreference(pad_token_id=0)
|
||||
>>> examples = [{"prompt_ids": [1, 2, 3], {"chosen_ids": [4, 5], "rejected_ids": [6]}]
|
||||
>>> collator(examples)
|
||||
{'input_ids': tensor([[ 1, 2, 3, 4, 5],
|
||||
[ 1, 2, 3, 6, 0]]),
|
||||
'attention_mask': tensor([[1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 0]]),
|
||||
'completion_mask': tensor([[0, 0, 0, 1, 1],
|
||||
[0, 0, 0, 1, 0]])}
|
||||
```
|
||||
"""
|
||||
|
||||
pad_token_id: int
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
return_tensors: str = "pt"
|
||||
|
||||
def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
prompt_chosen_ids = [example["prompt_ids"] + example["chosen_ids"] for example in examples]
|
||||
prompt_rejected_ids = [example["prompt_ids"] + example["rejected_ids"] for example in examples]
|
||||
chosen_attention_mask = [[1] * len(example["prompt_ids"] + example["chosen_ids"]) for example in examples]
|
||||
rejected_attention_mask = [[1] * len(example["prompt_ids"] + example["rejected_ids"]) for example in examples]
|
||||
chosen_mask = [[0] * len(example["prompt_ids"]) + [1] * len(example["chosen_ids"]) for example in examples]
|
||||
rejected_mask = [[0] * len(example["prompt_ids"]) + [1] * len(example["rejected_ids"]) for example in examples]
|
||||
input_ids = prompt_chosen_ids + prompt_rejected_ids
|
||||
attention_mask = chosen_attention_mask + rejected_attention_mask
|
||||
completion_mask = chosen_mask + rejected_mask
|
||||
|
||||
# Convert to tensor
|
||||
input_ids = [torch.tensor(ids) for ids in input_ids]
|
||||
attention_mask = [torch.tensor(m, dtype=torch.long) for m in attention_mask]
|
||||
completion_mask = [torch.tensor(m, dtype=torch.long) for m in completion_mask]
|
||||
if "ref_chosen_logps" in examples[0]:
|
||||
ref_chosen_logps = torch.tensor([example["ref_chosen_logps"] for example in examples])
|
||||
if "ref_rejected_logps" in examples[0]:
|
||||
ref_rejected_logps = torch.tensor([example["ref_rejected_logps"] for example in examples])
|
||||
|
||||
# Pad
|
||||
output = {}
|
||||
output["input_ids"] = pad(
|
||||
input_ids,
|
||||
padding_value=self.pad_token_id,
|
||||
padding_side="right",
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
)
|
||||
output["attention_mask"] = pad(
|
||||
attention_mask,
|
||||
padding_value=0,
|
||||
padding_side="right",
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
)
|
||||
output["completion_mask"] = pad(
|
||||
completion_mask,
|
||||
padding_value=0,
|
||||
padding_side="right",
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
)
|
||||
if "ref_chosen_logps" in examples[0]:
|
||||
output["ref_chosen_logps"] = ref_chosen_logps
|
||||
if "ref_rejected_logps" in examples[0]:
|
||||
output["ref_rejected_logps"] = ref_rejected_logps
|
||||
return output
|
||||
|
||||
|
||||
class DPOTrainer(BaseTrainer):
|
||||
"""
|
||||
Trainer for Direct Preference Optimization (DPO) method.
|
||||
|
||||
This class is a wrapper around the [`~transformers.Trainer`] class and inherits all of its attributes and methods.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import DPOTrainer
|
||||
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
|
||||
trainer = DPOTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Args:
|
||||
model (`Union[str, PreTrainedModel]`):
|
||||
Model to be trained. Can be either:
|
||||
|
||||
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
|
||||
path to a *directory* containing model weights saved using
|
||||
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
|
||||
using `<ModelArchitecture>.from_pretrained` (where `<ModelArchitecture>` is derived from the model
|
||||
config) with the keyword arguments in `args.model_init_kwargs`.
|
||||
- A [`~transformers.PreTrainedModel`] object.
|
||||
If you're training a model with an MoE architecture and want to include the load balancing/auxilliary loss
|
||||
as a part of the final loss, remember to set the `output_router_logits` config of the model to `True`.
|
||||
args ([`DPOConfig`], *optional*):
|
||||
Configuration for this trainer. If `None`, a default configuration is used.
|
||||
data_collator ([`~transformers.DataCollator`], *optional*):
|
||||
Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`.
|
||||
Will default to [`~trainer.dpo_trainer.DataCollatorForPreference`] if the model is a language model and
|
||||
[`~trainer.dpo_trainer.DataCollatorForVisionLanguageModeling`] if the model is a vision-language model.
|
||||
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
|
||||
Dataset to use for training. DPO supports both [language modeling](#language-modeling) type and
|
||||
[prompt-completion](#prompt-completion) type. The format of the samples can be either:
|
||||
|
||||
- [Standard](dataset_formats#standard): Each sample contains plain text.
|
||||
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
|
||||
and content).
|
||||
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
|
||||
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
|
||||
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`], *optional*):
|
||||
Processing class used to process the data. If `None`, the processing class is loaded from the model's name
|
||||
with [`~transformers.AutoProcessor.from_pretrained`]. A padding token, `tokenizer.pad_token`, must be set.
|
||||
If the processing class has not set a padding token, `tokenizer.eos_token` will be used as the default.
|
||||
compute_loss_func (`Callable`, *optional*):
|
||||
A function that accepts the raw model outputs, labels, and the number of items in the entire accumulated
|
||||
batch (batch_size * gradient_accumulation_steps) and returns the loss. For example, see the default [loss
|
||||
function](https://github.com/huggingface/transformers/blob/052e652d6d53c2b26ffde87e039b723949a53493/src/transformers/trainer.py#L3618)
|
||||
used by [`Trainer`].
|
||||
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
||||
The function that will be used to compute metrics at evaluation. Must take a
|
||||
[`~transformers.EvalPrediction`] and return a dictionary string to metric values. When passing
|
||||
[`DPOConfig`] with `batch_eval_metrics` set to `True`, your `compute_metrics` function must take a boolean
|
||||
`compute_result` argument. This will be triggered after the last eval batch to signal that the function
|
||||
needs to calculate and return the global summary statistics rather than accumulating the batch-level
|
||||
statistics.
|
||||
callbacks (list of [`~transformers.TrainerCallback`], *optional*):
|
||||
List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed
|
||||
in [here](https://huggingface.co/docs/transformers/main_classes/callback).
|
||||
|
||||
If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
|
||||
method.
|
||||
optimizers (`tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]]`, *optional*, defaults to `(None, None)`):
|
||||
A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your
|
||||
model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`.
|
||||
optimizer_cls_and_kwargs (`tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*):
|
||||
A tuple containing the optimizer class and keyword arguments to use. Overrides `optim` and `optim_args` in
|
||||
`args`. Incompatible with the `optimizers` argument.
|
||||
|
||||
Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before
|
||||
initializing the Trainer.
|
||||
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
|
||||
A function that preprocess the logits right before caching them at each evaluation step. Must take two
|
||||
tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
|
||||
by this function will be reflected in the predictions received by `compute_metrics`.
|
||||
|
||||
Note that the labels (second parameter) will be `None` if the dataset does not have them.
|
||||
peft_config ([`~peft.PeftConfig`], *optional*):
|
||||
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
|
||||
"""
|
||||
|
||||
_tag_names = ["trl", "dpo"]
|
||||
_name = "DPO"
|
||||
_paper = {
|
||||
"title": "Direct Preference Optimization: Your Language Model is Secretly a Reward Model",
|
||||
"id": "2305.18290",
|
||||
# docstyle-ignore
|
||||
"citation": textwrap.dedent("""\
|
||||
@inproceedings{rafailov2023direct,
|
||||
title = {{Direct Preference Optimization: Your Language Model is Secretly a Reward Model}},
|
||||
author = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Christopher D. Manning and Stefano Ermon and Chelsea Finn},
|
||||
year = 2023,
|
||||
booktitle = {Advances in Neural Information Processing Systems 36: Annual Conference on Neural Information Processing Systems 2023, NeurIPS 2023, New Orleans, LA, USA, December 10 - 16, 2023},
|
||||
url = {http://papers.nips.cc/paper_files/paper/2023/hash/a85b405ed65c6477a4fe8302b5e06ce7-Abstract-Conference.html},
|
||||
editor = {Alice Oh and Tristan Naumann and Amir Globerson and Kate Saenko and Moritz Hardt and Sergey Levine},
|
||||
}"""),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[str, PreTrainedModel],
|
||||
args: Optional[DPOConfig] = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
||||
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
||||
processing_class: Optional[Union[PreTrainedTokenizerBase, ProcessorMixin]] = None,
|
||||
compute_loss_func: Optional[Callable] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
||||
callbacks: Optional[list[TrainerCallback]] = None,
|
||||
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
|
||||
optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None,
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
peft_config: Optional["PeftConfig"] = None,
|
||||
):
|
||||
# Args
|
||||
if args is None:
|
||||
model_name = model if isinstance(model, str) else model.config._name_or_path
|
||||
model_name = model_name.split("/")[-1]
|
||||
args = DPOConfig(f"{model_name}-DPO")
|
||||
|
||||
# Models
|
||||
# Trained model
|
||||
model_init_kwargs = args.model_init_kwargs or {}
|
||||
if isinstance(model, str):
|
||||
model_id = model
|
||||
dtype = model_init_kwargs.get("dtype")
|
||||
if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None:
|
||||
pass # dtype is already a torch.dtype or "auto" or None
|
||||
elif isinstance(dtype, str): # it's a str, but not "auto"
|
||||
dtype = getattr(torch, dtype)
|
||||
model_init_kwargs["dtype"] = dtype
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid `dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
|
||||
f"a `torch.dtype` (e.g., 'float32'), but got {dtype}."
|
||||
)
|
||||
# Disable caching if gradient checkpointing is enabled (not supported)
|
||||
config = AutoConfig.from_pretrained(model_id)
|
||||
architecture = getattr(transformers, config.architectures[0])
|
||||
model = architecture.from_pretrained(model_id, **model_init_kwargs)
|
||||
else:
|
||||
model_id = model.config._name_or_path
|
||||
if args.model_init_kwargs is not None:
|
||||
logger.warning(
|
||||
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
|
||||
"The `model_init_kwargs` will be ignored."
|
||||
)
|
||||
|
||||
if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)):
|
||||
model = prepare_peft_model(model, peft_config, args)
|
||||
|
||||
# Disable dropout in the model
|
||||
if args.disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
|
||||
# Processing class
|
||||
if processing_class is None:
|
||||
processing_class = AutoProcessor.from_pretrained(model.config._name_or_path)
|
||||
|
||||
# Handle pad token for processors or tokenizers
|
||||
if isinstance(processing_class, ProcessorMixin):
|
||||
tokenizer = processing_class.tokenizer
|
||||
self._is_vlm = True
|
||||
elif isinstance(processing_class, PreTrainedTokenizerBase):
|
||||
tokenizer = processing_class
|
||||
self._is_vlm = False
|
||||
else:
|
||||
raise TypeError("The `processing_class` must be either a `PreTrainedTokenizerBase` or a `ProcessorMixin`")
|
||||
|
||||
if self._is_vlm and args.padding_free:
|
||||
raise ValueError(
|
||||
"Padding-free training is yet not supported for vision-language models. Please set "
|
||||
"`padding_free=False` in the `DPOConfig`."
|
||||
)
|
||||
|
||||
# Data collator
|
||||
self.padding_free = args.padding_free
|
||||
use_flash_attention = model.config._attn_implementation in FLASH_ATTENTION_VARIANTS
|
||||
if self.padding_free:
|
||||
raise NotImplementedError("Padding-free training is not yet implemented.")
|
||||
if data_collator is not None:
|
||||
raise ValueError("Passing a custom data collator is not supported when using padding-free.")
|
||||
if not use_flash_attention:
|
||||
logger.warning(
|
||||
"Padding-free training is enabled, but the attention implementation is not set to a supported "
|
||||
"flash attention variant. Padding-free training flattens batches into a single sequence, and only "
|
||||
"the following implementations are known to reliably support this: "
|
||||
f"{', '.join(sorted(FLASH_ATTENTION_VARIANTS))}. Using other implementations may lead to "
|
||||
"unexpected behavior. To ensure compatibility, set `attn_implementation` in the model "
|
||||
"configuration to one of these supported options or verify that your attention mechanism can "
|
||||
"handle flattened sequences."
|
||||
)
|
||||
|
||||
if args.per_device_train_batch_size == 1:
|
||||
logger.warning(
|
||||
"You are using a per_device_train_batch_size of 1 with padding-free training. Using a batch size "
|
||||
"of 1 anihilate the benefits of padding-free training. Please consider increasing the batch size "
|
||||
"to at least 2."
|
||||
)
|
||||
|
||||
dataset_sample = next(iter(train_dataset))
|
||||
self._is_vision_dataset = "image" in dataset_sample or "images" in dataset_sample
|
||||
if self._is_vision_dataset and not self._is_vlm:
|
||||
raise ValueError(
|
||||
"The dataset appears to be vision-related (contains 'image' or 'images' keys), but the provided "
|
||||
"model does not seem to be a vision-language model. Please check your model and dataset."
|
||||
)
|
||||
|
||||
if data_collator is None and not self._is_vision_dataset:
|
||||
# Get the pad token: if not provided, use the one from the processing class or the eos token
|
||||
# if the processing class does not have a pad token.
|
||||
pad_token = args.pad_token or tokenizer.pad_token or tokenizer.eos_token
|
||||
pad_token_id = tokenizer.convert_tokens_to_ids(pad_token)
|
||||
if pad_token_id is None:
|
||||
raise ValueError(
|
||||
f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given "
|
||||
f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists "
|
||||
"in the vocabulary before using it as a padding token."
|
||||
)
|
||||
data_collator = DataCollatorForPreference(
|
||||
pad_token_id=pad_token_id,
|
||||
pad_to_multiple_of=args.pad_to_multiple_of,
|
||||
)
|
||||
elif data_collator is None and self._is_vision_dataset:
|
||||
raise NotImplementedError("VLM training is not yet implemented.")
|
||||
|
||||
# Training arguments
|
||||
self.loss_type = args.loss_type if isinstance(args.loss_type, list) else [args.loss_type]
|
||||
self.beta = args.beta
|
||||
|
||||
# Dataset
|
||||
# Skip dataset preparation if it's a VLM, where preprocessing (e.g., image-to-pixel conversion) is too costly
|
||||
# and done on the fly instead.
|
||||
skip_prepare_dataset = self._is_vision_dataset
|
||||
if not skip_prepare_dataset:
|
||||
train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train")
|
||||
if eval_dataset is not None:
|
||||
if isinstance(eval_dataset, dict):
|
||||
eval_dataset = {
|
||||
key: self._prepare_dataset(dataset, processing_class, args, key)
|
||||
for key, dataset in eval_dataset.items()
|
||||
}
|
||||
else:
|
||||
eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval")
|
||||
|
||||
# Initialize the metrics
|
||||
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
|
||||
self._total_train_tokens = 0
|
||||
|
||||
# Initialize the Trainer. Parent class will handle:
|
||||
# - DeepSpeed configuration (through create_accelerator_and_postprocess)
|
||||
# - FSDP setup
|
||||
# - Distributed training setup
|
||||
# - Optimizer and scheduler creation
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
args=args,
|
||||
data_collator=data_collator,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=processing_class,
|
||||
compute_loss_func=compute_loss_func,
|
||||
compute_metrics=compute_metrics,
|
||||
callbacks=callbacks,
|
||||
optimizers=optimizers,
|
||||
optimizer_cls_and_kwargs=optimizer_cls_and_kwargs,
|
||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
||||
)
|
||||
|
||||
# Reference model
|
||||
self.beta = args.beta
|
||||
if self.beta == 0.0:
|
||||
# If beta is 0.0, the reference model is not needed
|
||||
self.ref_model = None
|
||||
elif is_peft_model(model):
|
||||
# If PEFT is used, the reference model is not needed since the adapter can be disabled
|
||||
# to revert to the initial model.
|
||||
self.ref_model = None
|
||||
else:
|
||||
# For deepspeed, fsdp or non-distributed models, create a reference model from scratch
|
||||
config = AutoConfig.from_pretrained(model_id)
|
||||
architecture = getattr(transformers, config.architectures[0])
|
||||
self.ref_model = architecture.from_pretrained(model_id, **model_init_kwargs)
|
||||
|
||||
# Disable dropout in the models
|
||||
if args.disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
if self.ref_model is not None:
|
||||
disable_dropout_in_model(self.ref_model)
|
||||
|
||||
# Disable dropout in the models
|
||||
if args.disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
if self.ref_model is not None:
|
||||
disable_dropout_in_model(self.ref_model)
|
||||
|
||||
# Initialize the metrics
|
||||
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
|
||||
self._total_train_tokens = 0
|
||||
|
||||
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
||||
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
||||
# self.model_accepts_loss_kwargs to False to enable scaling.
|
||||
self.model_accepts_loss_kwargs = False
|
||||
|
||||
# Add tags for model
|
||||
self.model.add_model_tags(self._tag_names)
|
||||
|
||||
if self.ref_model is not None:
|
||||
if self.is_deepspeed_enabled:
|
||||
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
||||
elif self.is_fsdp_enabled:
|
||||
self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
|
||||
# Initialize activation offloading context
|
||||
if self.args.activation_offloading:
|
||||
self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model)
|
||||
else:
|
||||
self.maybe_activation_offload_context = contextlib.nullcontext()
|
||||
|
||||
def _prepare_dataset(
|
||||
self,
|
||||
dataset: Union[Dataset, IterableDataset],
|
||||
processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin],
|
||||
args: DPOConfig,
|
||||
dataset_name: str,
|
||||
) -> Union[Dataset, IterableDataset]:
|
||||
# Tabular backends like Arrow/Parquet insert `None` for mismatched keys in nested structures. Clean them from
|
||||
# sampled data.
|
||||
if isinstance(dataset, Dataset): # IterableDataset does not support `with_transform`
|
||||
dataset = dataset.with_transform(remove_none_values)
|
||||
|
||||
# Build the kwargs for the `map` function
|
||||
map_kwargs = {}
|
||||
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
|
||||
map_kwargs["num_proc"] = args.dataset_num_proc
|
||||
|
||||
with PartialState().main_process_first():
|
||||
# Extract the prompt if needed
|
||||
first_example = next(iter(dataset))
|
||||
if "prompt" not in first_example:
|
||||
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
||||
map_kwargs["desc"] = f"Extracting prompt from {dataset_name} dataset"
|
||||
dataset = dataset.map(extract_prompt, **map_kwargs)
|
||||
|
||||
# Apply the chat template if needed
|
||||
first_example = next(iter(dataset))
|
||||
if not is_conversational(first_example):
|
||||
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
||||
map_kwargs["desc"] = f"Adding EOS to {dataset_name} dataset"
|
||||
|
||||
def add_eos(example, eos_token):
|
||||
if not example["chosen"].endswith(eos_token):
|
||||
example["chosen"] = example["chosen"] + eos_token
|
||||
if not example["rejected"].endswith(eos_token):
|
||||
example["rejected"] = example["rejected"] + eos_token
|
||||
return example
|
||||
|
||||
dataset = dataset.map(add_eos, fn_kwargs={"eos_token": processing_class.eos_token}, **map_kwargs)
|
||||
|
||||
# Tokenize the dataset
|
||||
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
||||
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
|
||||
|
||||
def tokenize_fn(example, processing_class):
|
||||
output = {}
|
||||
if is_conversational(example):
|
||||
if self._is_vlm:
|
||||
prepare_multimodal_messages(example["prompt"], num_images=0)
|
||||
prepare_multimodal_messages(example["completion"], num_images=0)
|
||||
prompt_ids = processing_class.apply_chat_template(
|
||||
example["prompt"],
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
tools=example.get("tools"),
|
||||
**example.get("chat_template_kwargs", {}),
|
||||
)
|
||||
prompt_chosen_processed = processing_class.apply_chat_template(
|
||||
example["prompt"] + example["chosen"],
|
||||
return_dict=True,
|
||||
tokenize=True,
|
||||
tools=example.get("tools"),
|
||||
**example.get("chat_template_kwargs", {}),
|
||||
)
|
||||
prompt_rejected_processed = processing_class.apply_chat_template(
|
||||
example["prompt"] + example["rejected"],
|
||||
return_dict=True,
|
||||
tokenize=True,
|
||||
tools=example.get("tools"),
|
||||
**example.get("chat_template_kwargs", {}),
|
||||
)
|
||||
# Fix transformers inconsistency: for VLMs, apply_chat_template returns lists of lists
|
||||
# even for single examples, while for LLMs it returns lists of ints.
|
||||
prompt_ids = prompt_ids[0] if isinstance(prompt_ids[0], list) else prompt_ids
|
||||
prompt_chosen_processed = {
|
||||
k: v[0] if isinstance(v[0], list) else v for k, v in prompt_chosen_processed.items()
|
||||
}
|
||||
prompt_rejected_processed = {
|
||||
k: v[0] if isinstance(v[0], list) else v for k, v in prompt_rejected_processed.items()
|
||||
}
|
||||
prompt_chosen_ids = prompt_chosen_processed["input_ids"]
|
||||
prompt_rejected_ids = prompt_rejected_processed["input_ids"]
|
||||
else:
|
||||
prompt_ids = processing_class(text=example["prompt"])["input_ids"]
|
||||
prompt_chosen_ids = processing_class(text=example["prompt"] + example["chosen"])["input_ids"]
|
||||
prompt_rejected_ids = processing_class(text=example["prompt"] + example["rejected"])["input_ids"]
|
||||
|
||||
# Check if the tokenized prompt starts with the tokenized prompt+completion
|
||||
if not prompt_chosen_ids[: len(prompt_ids)] == prompt_ids:
|
||||
logger.warning(
|
||||
"Mismatch between tokenized prompt and the start of tokenized prompt+chosen. "
|
||||
"This may be due to unexpected tokenizer behavior, whitespace issues, or special "
|
||||
"token handling. Verify that the tokenizer is processing text consistently."
|
||||
)
|
||||
if not prompt_rejected_ids[: len(prompt_ids)] == prompt_ids:
|
||||
logger.warning(
|
||||
"Mismatch between tokenized prompt and the start of tokenized prompt+rejected. "
|
||||
"This may be due to unexpected tokenizer behavior, whitespace issues, or special "
|
||||
"token handling. Verify that the tokenizer is processing text consistently."
|
||||
)
|
||||
|
||||
output["prompt_ids"] = prompt_ids
|
||||
output["chosen_ids"] = prompt_chosen_ids[len(prompt_ids) :]
|
||||
output["rejected_ids"] = prompt_rejected_ids[len(prompt_ids) :]
|
||||
return output
|
||||
|
||||
dataset = dataset.map(tokenize_fn, fn_kwargs={"processing_class": processing_class}, **map_kwargs)
|
||||
|
||||
# Truncate
|
||||
if args.max_prompt_length is not None:
|
||||
raise NotImplementedError("Prompt truncation is not yet implemented.")
|
||||
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
||||
map_kwargs["desc"] = f"Truncating prompt in {dataset_name} dataset"
|
||||
dataset = truncate_dataset(
|
||||
dataset, args.max_prompt_length, columns=["prompt_ids"], map_kwargs=map_kwargs
|
||||
)
|
||||
if args.max_completion_length is not None:
|
||||
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
||||
map_kwargs["desc"] = f"Truncating completions in {dataset_name} dataset"
|
||||
dataset = truncate_dataset(
|
||||
dataset, args.max_completion_length, columns=["chosen_ids", "rejected_ids"], map_kwargs=map_kwargs
|
||||
)
|
||||
# For Liger kernel, ensure only the essential columns
|
||||
if args.use_liger_kernel:
|
||||
collator_expected_keys = {"input_ids", "completion_mask"}
|
||||
column_names = get_dataset_column_names(dataset)
|
||||
dataset = dataset.select_columns(collator_expected_keys.intersection(column_names))
|
||||
|
||||
return dataset
|
||||
|
||||
def _set_signature_columns_if_needed(self):
|
||||
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
|
||||
# By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids"
|
||||
# and "attention_mask").
|
||||
if self._signature_columns is None:
|
||||
if self._is_vision_dataset:
|
||||
self._signature_columns = ["prompt", "chosen", "rejected"]
|
||||
else:
|
||||
self._signature_columns = [
|
||||
"prompt_ids",
|
||||
"chosen_ids",
|
||||
"rejected_ids",
|
||||
"ref_chosen_logps",
|
||||
"ref_rejected_logps",
|
||||
]
|
||||
|
||||
def train(self, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None, **kwargs):
|
||||
if self.args.precompute_ref_log_probs:
|
||||
self.train_dataset = self._precompute_ref_logps(
|
||||
self.train_dataset, self.args.per_device_train_batch_size, "train"
|
||||
)
|
||||
if self.eval_dataset is not None:
|
||||
if isinstance(self.eval_dataset, dict):
|
||||
self.eval_dataset = {
|
||||
key: self._precompute_ref_logps(dataset, self.args.per_device_eval_batch_size, key)
|
||||
for key, dataset in self.eval_dataset.items()
|
||||
}
|
||||
else:
|
||||
self.eval_dataset = self._precompute_ref_logps(
|
||||
self.eval_dataset, self.args.per_device_eval_batch_size, "eval"
|
||||
)
|
||||
return super().train()
|
||||
|
||||
def _precompute_ref_logps(
|
||||
self, dataset: Union[Dataset, IterableDataset], batch_size: int, dataset_name: str
|
||||
) -> None:
|
||||
def compute_ref_logps(examples, collator, max_length, truncation_mode):
|
||||
examples = [dict(zip(examples.keys(), v)) for v in zip(*examples.values())] # dict[list] to list[dict]
|
||||
inputs = collator(examples)
|
||||
input_ids = inputs["input_ids"].to(self.model.device)
|
||||
attention_mask = inputs["attention_mask"].to(self.model.device)
|
||||
completion_mask = inputs["completion_mask"].to(self.model.device)
|
||||
|
||||
# Truncate inputs
|
||||
if max_length is not None:
|
||||
if truncation_mode == "keep_start":
|
||||
input_ids = input_ids[:, :max_length]
|
||||
attention_mask = attention_mask[:, :max_length]
|
||||
completion_mask = completion_mask[:, :max_length]
|
||||
elif truncation_mode == "keep_end":
|
||||
attention_mask, input_ids, completion_mask = flush_right(
|
||||
attention_mask, input_ids, completion_mask
|
||||
)
|
||||
input_ids = input_ids[:, -max_length:]
|
||||
attention_mask = attention_mask[:, -max_length:]
|
||||
completion_mask = completion_mask[:, -max_length:]
|
||||
attention_mask, input_ids, completion_mask = flush_left(attention_mask, input_ids, completion_mask)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported truncation mode: {truncation_mode}, expected 'keep_start' or 'keep_end'"
|
||||
)
|
||||
|
||||
outputs = self.model(input_ids, attention_mak=attention_mask, use_cache=False)
|
||||
shift_logits = outputs.logits[..., :-1, :].contiguous()
|
||||
shift_labels = input_ids[..., 1:].contiguous()
|
||||
shift_completion_mask = completion_mask[..., 1:].contiguous()
|
||||
per_token_logps = selective_log_softmax(shift_logits, shift_labels)
|
||||
per_token_logps[shift_completion_mask == 0] = 0.0 # mask out non-completion tokens
|
||||
logps = per_token_logps.sum(dim=1) # sum over sequence length
|
||||
chosen_logps, rejected_logps = logps.chunk(2, dim=0) # batch is [chosen, rejected]
|
||||
return {"ref_chosen_logps": chosen_logps.tolist(), "ref_rejected_logps": rejected_logps.tolist()}
|
||||
|
||||
# Normally, `map` creates a fingerprint based on the transform function and its arguments. However, the model’s
|
||||
# produces a different fingerprint on each run, which prevents the cache from being used. To fix this, we
|
||||
# manually compute a stable fingerprint for the model instead.
|
||||
fn_kwargs = {
|
||||
"collator": self.data_collator,
|
||||
"max_length": self.args.max_length,
|
||||
"truncation_mode": self.args.truncation_mode,
|
||||
}
|
||||
model_hash = hash_module(self.model)
|
||||
dataset = dataset.map(
|
||||
compute_ref_logps,
|
||||
batched=True,
|
||||
batch_size=batch_size,
|
||||
fn_kwargs=fn_kwargs,
|
||||
desc=f"Computing reference logps for {dataset_name} dataset",
|
||||
new_fingerprint=Hasher.hash((dataset._fingerprint, fn_kwargs, model_hash)),
|
||||
)
|
||||
return dataset
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: dict[str, Union[torch.Tensor, Any]],
|
||||
return_outputs: bool = False,
|
||||
num_items_in_batch: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
Compute training loss and additionally compute token accuracies
|
||||
"""
|
||||
mode = "train" if self.model.training else "eval"
|
||||
|
||||
input_ids = inputs["input_ids"]
|
||||
attention_mask = inputs["attention_mask"]
|
||||
completion_mask = inputs["completion_mask"]
|
||||
|
||||
# Truncate inputs
|
||||
if self.args.max_length is not None:
|
||||
if self.args.truncation_mode == "keep_start":
|
||||
input_ids = input_ids[:, : self.args.max_length]
|
||||
attention_mask = attention_mask[:, : self.args.max_length]
|
||||
completion_mask = completion_mask[:, : self.args.max_length]
|
||||
elif self.args.truncation_mode == "keep_end":
|
||||
attention_mask, input_ids, completion_mask = flush_right(attention_mask, input_ids, completion_mask)
|
||||
input_ids = input_ids[:, -self.args.max_length :]
|
||||
attention_mask = attention_mask[:, -self.args.max_length :]
|
||||
completion_mask = completion_mask[:, -self.args.max_length :]
|
||||
attention_mask, input_ids, completion_mask = flush_left(attention_mask, input_ids, completion_mask)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported truncation mode: {self.args.truncation_mode}, expected 'keep_start' or 'keep_end'"
|
||||
)
|
||||
|
||||
outputs = model(input_ids, attention_mak=attention_mask, use_cache=False)
|
||||
shift_logits = outputs.logits[..., :-1, :].contiguous()
|
||||
shift_labels = input_ids[..., 1:].contiguous()
|
||||
shift_completion_mask = completion_mask[..., 1:].contiguous()
|
||||
per_token_logps = selective_log_softmax(shift_logits, shift_labels)
|
||||
per_token_logps[shift_completion_mask == 0] = 0.0 # mask out non-completion tokens
|
||||
logps = per_token_logps.sum(dim=1) # sum over sequence length
|
||||
chosen_logps, rejected_logps = logps.chunk(2, dim=0) # batch is [chosen, rejected]
|
||||
ref_chosen_logps, ref_rejected_logps = inputs["ref_chosen_logps"], inputs["ref_rejected_logps"]
|
||||
|
||||
# Get the log ratios for the chosen and rejected responses
|
||||
chosen_logratios = chosen_logps - ref_chosen_logps
|
||||
rejected_logratios = rejected_logps - ref_rejected_logps
|
||||
|
||||
loss = 0
|
||||
|
||||
for loss_type in self.loss_type:
|
||||
if loss_type == "sigmoid":
|
||||
per_sequence_loss = -F.logsigmoid(self.beta * chosen_logratios - self.beta * rejected_logratios)
|
||||
|
||||
elif loss_type == "hinge":
|
||||
per_sequence_loss = torch.relu(1 - (self.beta * chosen_logratios - self.beta * rejected_logratios))
|
||||
|
||||
loss += per_sequence_loss.mean()
|
||||
|
||||
# Log the metrics
|
||||
# Entropy
|
||||
per_token_entropy = entropy_from_logits(shift_logits.detach())
|
||||
entropy = per_token_entropy[shift_completion_mask.bool()].mean()
|
||||
entropy = self.accelerator.gather_for_metrics(entropy).mean().item()
|
||||
self._metrics[mode]["entropy"].append(entropy)
|
||||
|
||||
# Number of tokens
|
||||
if mode == "train":
|
||||
num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item()
|
||||
self._total_train_tokens += num_tokens_in_batch
|
||||
self._metrics[mode]["num_tokens"] = [self._total_train_tokens]
|
||||
|
||||
# Average logits for chosen and rejected completions
|
||||
chosen_logits, rejected_logits = shift_logits.detach().chunk(2, dim=0)
|
||||
chosen_mask, rejected_mask = shift_completion_mask.chunk(2, dim=0)
|
||||
total_chosen_logits = chosen_logits[chosen_mask.bool()].mean(-1)
|
||||
total_chosen_tokens = chosen_mask.sum()
|
||||
total_rejected_logits = rejected_logits[rejected_mask.bool()].mean(-1)
|
||||
total_rejected_tokens = rejected_mask.sum()
|
||||
total_chosen_logits = self.accelerator.gather_for_metrics(total_chosen_logits).sum().item()
|
||||
total_chosen_tokens = self.accelerator.gather_for_metrics(total_chosen_tokens).sum().item()
|
||||
total_rejected_logits = self.accelerator.gather_for_metrics(total_rejected_logits).sum().item()
|
||||
total_rejected_tokens = self.accelerator.gather_for_metrics(total_rejected_tokens).sum().item()
|
||||
avg_chosen_logits = total_chosen_logits / total_chosen_tokens if total_chosen_tokens > 0 else 0.0
|
||||
avg_rejected_logits = total_rejected_logits / total_rejected_tokens if total_rejected_tokens > 0 else 0.0
|
||||
self._metrics[mode]["logits/chosen"].append(avg_chosen_logits)
|
||||
self._metrics[mode]["logits/rejected"].append(avg_rejected_logits)
|
||||
|
||||
# Token accuracy for the chosen completions
|
||||
predictions = chosen_logits.argmax(dim=-1)
|
||||
chosen_mask = shift_completion_mask[: len(shift_completion_mask) // 2].bool()
|
||||
chosen_labels = shift_labels[: len(shift_labels) // 2]
|
||||
correct_predictions = (predictions == chosen_labels) & chosen_mask
|
||||
total_tokens = chosen_mask.sum()
|
||||
correct_tokens = correct_predictions.sum()
|
||||
correct_tokens = self.accelerator.gather_for_metrics(correct_tokens)
|
||||
total_tokens = self.accelerator.gather_for_metrics(total_tokens)
|
||||
total_sum = total_tokens.sum()
|
||||
accuracy = (correct_tokens.sum() / total_sum).item() if total_sum > 0 else 0.0
|
||||
self._metrics[mode]["mean_token_accuracy"].append(accuracy)
|
||||
|
||||
# Rewards for chosen and rejected completions
|
||||
chosen_rewards = self.beta * (chosen_logps.detach() - ref_chosen_logps)
|
||||
rejected_rewards = self.beta * (rejected_logps.detach() - ref_rejected_logps)
|
||||
agg_chosen_rewards = self.accelerator.gather(chosen_rewards)
|
||||
agg_rejected_rewards = self.accelerator.gather(rejected_rewards)
|
||||
self._metrics[mode]["rewards/chosen"].append(agg_chosen_rewards.mean().item())
|
||||
self._metrics[mode]["rewards/rejected"].append(agg_rejected_rewards.mean().item())
|
||||
|
||||
# Reward accuracy
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
agg_reward_accuracies = self.accelerator.gather(reward_accuracies)
|
||||
self._metrics[mode]["rewards/accuracies"].append(agg_reward_accuracies.mean().item())
|
||||
|
||||
# Reward margins
|
||||
margins = chosen_rewards - rejected_rewards
|
||||
agg_margins = self.accelerator.gather(margins)
|
||||
self._metrics[mode]["rewards/margins"].append(agg_margins.mean().item())
|
||||
|
||||
# Average log probabilities for chosen and rejected completions
|
||||
self._metrics[mode]["logps/chosen"].append(self.accelerator.gather(chosen_logps).mean().item())
|
||||
self._metrics[mode]["logps/rejected"].append(self.accelerator.gather(rejected_logps).mean().item())
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
# Override training step to add activation offloading context.
|
||||
def training_step(self, *args, **kwargs):
|
||||
with self.maybe_activation_offload_context:
|
||||
return super().training_step(*args, **kwargs)
|
||||
|
||||
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
||||
mode = "train" if self.model.training else "eval"
|
||||
metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics
|
||||
|
||||
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
|
||||
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
|
||||
if mode == "eval":
|
||||
metrics = {f"eval_{key}": val for key, val in metrics.items()}
|
||||
|
||||
logs.update(metrics)
|
||||
super().log(logs, start_time)
|
||||
self._metrics[mode].clear()
|
||||
|
||||
# Ensure the model card is saved along with the checkpoint
|
||||
def _save_checkpoint(self, model, trial):
|
||||
if self.args.hub_model_id is None:
|
||||
model_name = Path(self.args.output_dir).name
|
||||
else:
|
||||
model_name = self.args.hub_model_id.split("/")[-1]
|
||||
self.create_model_card(model_name=model_name)
|
||||
super()._save_checkpoint(model, trial)
|
@ -51,7 +51,7 @@ from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_utils import EvalLoopOutput
|
||||
from transformers.utils import is_liger_kernel_available, is_peft_available
|
||||
|
||||
from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
|
||||
from ..data_utils import is_conversational, maybe_apply_chat_template, maybe_extract_prompt
|
||||
from ..models import create_reference_model, prepare_deepspeed
|
||||
from ..models.utils import prepare_fsdp
|
||||
from .base_trainer import BaseTrainer
|
||||
@ -649,6 +649,8 @@ class DPOTrainer(BaseTrainer):
|
||||
map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset"
|
||||
dataset = dataset.map(maybe_extract_prompt, **map_kwargs)
|
||||
|
||||
is_chat = is_conversational(next(iter(dataset)))
|
||||
|
||||
# Apply the chat template if needed
|
||||
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
||||
map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
|
||||
@ -669,6 +671,7 @@ class DPOTrainer(BaseTrainer):
|
||||
"max_completion_length": args.max_completion_length,
|
||||
# for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token])
|
||||
"add_special_tokens": False,
|
||||
"is_chat": is_chat,
|
||||
},
|
||||
**map_kwargs,
|
||||
)
|
||||
@ -682,6 +685,7 @@ class DPOTrainer(BaseTrainer):
|
||||
max_prompt_length: Optional[int] = None,
|
||||
max_completion_length: Optional[int] = None,
|
||||
add_special_tokens: bool = True,
|
||||
is_chat: bool = False,
|
||||
) -> dict[str, list[int]]:
|
||||
"""
|
||||
Tokenize a row of the dataset.
|
||||
@ -728,8 +732,9 @@ class DPOTrainer(BaseTrainer):
|
||||
prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids
|
||||
if tokenizer.eos_token_id is not None:
|
||||
prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id]
|
||||
chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id]
|
||||
rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id]
|
||||
if not is_chat:
|
||||
chosen_input_ids = chosen_input_ids + [tokenizer.eos_token_id]
|
||||
rejected_input_ids = rejected_input_ids + [tokenizer.eos_token_id]
|
||||
|
||||
# Truncate prompt and completion sequences
|
||||
if max_prompt_length is not None:
|
||||
|
@ -1057,7 +1057,7 @@ class SFTTrainer(BaseTrainer):
|
||||
elif args.max_length is not None:
|
||||
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
||||
map_kwargs["desc"] = f"Truncating {dataset_name} dataset"
|
||||
dataset = truncate_dataset(dataset, args.max_length, map_kwargs)
|
||||
dataset = truncate_dataset(dataset, args.max_length, map_kwargs=map_kwargs)
|
||||
# For Liger kernel, ensure only the essential columns
|
||||
if args.use_liger_kernel:
|
||||
collator_expected_keys = {"input_ids", "seq_lengths", "completion_mask", "assistant_masks"}
|
||||
|
@ -19,6 +19,7 @@ import os
|
||||
import random
|
||||
import socket
|
||||
import warnings
|
||||
import zlib
|
||||
from collections.abc import Mapping, Sequence, Sized
|
||||
from dataclasses import dataclass, field
|
||||
from importlib.metadata import version
|
||||
@ -29,7 +30,6 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data
|
||||
import transformers
|
||||
from accelerate import Accelerator, PartialState, logging
|
||||
from accelerate.state import AcceleratorState
|
||||
@ -1990,3 +1990,14 @@ def create_model_from_path(model_id: str, **kwargs) -> PreTrainedModel:
|
||||
architecture = getattr(transformers, config.architectures[0])
|
||||
model = architecture.from_pretrained(model_id, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def hash_module(module: torch.nn.Module) -> str:
|
||||
h = zlib.adler32(b"")
|
||||
for _, tensor in sorted(module.state_dict().items()):
|
||||
tensor = tensor.cpu()
|
||||
h = zlib.adler32(str(tensor.dtype).encode(), h)
|
||||
if tensor.dtype in (torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2):
|
||||
tensor = tensor.to(torch.float32)
|
||||
h = zlib.adler32(tensor.numpy().tobytes(), h)
|
||||
return f"{h:08x}"
|
||||
|
Reference in New Issue
Block a user