Compare commits

...

9 Commits

Author SHA1 Message Date
1661bc295e [GKD] interpolate in prob. space (#2204)
* interpolate in prob. space

* better var names

* use logsumexp

* set beta dtype

* beta tensor
2024-10-10 12:41:34 +00:00
fdbcaaea3d Version 0.11.2 -> 0.11.3 2024-10-10 12:30:46 +00:00
22567cd1d4 Update incorrect data processing in DataCollatorForChatML (#2172)
* Update incorrect data processing in DataCollatorForChatML

Fix the extra BOS token and the absence of an EOS token in the returned input_ids, and potentially the absence of a target string in the returned labels.

* Update trl/trainer/utils.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* style

* move comment

* add test for DataCollatorForChatML

* update comment with more details

* update assert reports and comments, and adds verification that the last token of input_ids should be EOS token

* new line at the end of file for code quality

* Update tests/test_utils.py

* Update tests/test_utils.py

* Update tests/test_utils.py

* update tests

* fix test

* Update tests/test_utils.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update tests/test_utils.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* formatting

* fix typo

* simplify

* Revert "simplify"

This reverts commit 7e4006c87265665183032932ca05dffef567e38b.

* tokenize full messages

* dont add eos

* eos is in the last token

* simplify DataCollatorForChatML

* Update tests/test_utils.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-10-10 12:28:19 +00:00
00b537eefd Drop decoder_input_ids in DPOTrainer (#2208) 2024-10-10 12:24:50 +00:00
01142bb7c4 Version 0.11.1 -> 0.11.2 2024-10-07 15:59:17 +00:00
d3fb4860e2 Fix RLOO checkpointing (#2114)
* Fix RLOO checkpointing for transformers>=4.45.0

* Add missing import

* Fix pre-commit issues

* Added test for RLOO checkpointing

* Ensure that tokenizer matches SFT and Reward model

* Pre-commit formatting

* processing class

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-10-07 15:47:38 +00:00
86ad7a7e85 Version 0.11.0-> 0.11.1 2024-09-24 15:47:35 +00:00
3aa1996986 [online-dpo] allow parse-args as list of floats (#2108)
* use a seperate argument for list of floats

* do super first

* fix docstrings

* typos

* use list of floats only

* check if it has len

* fix docstring

* fix suggestion

* fix default

* Update trl/trainer/online_dpo_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/xpo_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/nash_md_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/nash_md_config.py

* additional tests

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-09-24 15:43:39 +00:00
e4935e13fc Release v0.11.0 2024-09-19 09:49:54 +02:00
12 changed files with 296 additions and 98 deletions

View File

@ -73,7 +73,7 @@ import os
from setuptools import find_packages, setup
__version__ = "0.11.0.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
__version__ = "0.11.3" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
REQUIRED_PKGS = [
"torch>=1.4.0",

View File

@ -13,8 +13,14 @@
# limitations under the License.
import platform
import subprocess
import tempfile
import unittest
import torch
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from trl import RLOOConfig, RLOOTrainer
def test():
@ -26,6 +32,8 @@ python examples/scripts/rloo/rloo.py \
--gradient_accumulation_steps 1 \
--total_episodes 10 \
--model_name_or_path EleutherAI/pythia-14m \
--sft_model_path EleutherAI/pythia-14m \
--reward_model_path EleutherAI/pythia-14m \
--missing_eos_penalty 1.0 \
--save_strategy no \
--stop_token eos
@ -71,3 +79,42 @@ def test_rloo_reward():
baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1)
vec_advantages = rlhf_reward - baseline
torch.testing.assert_close(vec_advantages.flatten(), advantages)
class RLOOTrainerTester(unittest.TestCase):
def setUp(self):
self.sft_model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
self.reward_model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
self.policy_model = AutoModelForCausalLM.from_pretrained(self.sft_model_id)
self.reward_model = AutoModelForSequenceClassification.from_pretrained(self.reward_model_id)
self.policy_ref_model = AutoModelForCausalLM.from_pretrained(self.sft_model_id)
self.tokenizer = AutoTokenizer.from_pretrained(self.sft_model_id, padding_side="left")
self.tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}"
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
def test_rloo_checkpoint(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = RLOOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
total_episodes=1,
report_to="none",
)
dummy_text = {"content": "Hello World!", "role": "user"}
dummy_data = self.tokenizer.apply_chat_template(dummy_text)
dummy_dataset = Dataset.from_dict({"input_ids": dummy_data})
trainer = RLOOTrainer(
config=training_args,
policy=self.policy_model,
reward_model=self.reward_model,
ref_policy=self.policy_ref_model,
processing_class=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
)
trainer._save_checkpoint(trainer.model, trial=None)

View File

@ -16,6 +16,7 @@ import tempfile
import unittest
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from trl import (
@ -27,12 +28,18 @@ from trl import (
DPOTrainer,
KTOConfig,
KTOTrainer,
NashMDConfig,
NashMDTrainer,
OnlineDPOConfig,
OnlineDPOTrainer,
ORPOConfig,
ORPOTrainer,
RewardConfig,
RewardTrainer,
SFTConfig,
SFTTrainer,
XPOConfig,
XPOTrainer,
)
@ -219,8 +226,30 @@ class TrainerArgTester(unittest.TestCase):
self.assertEqual(trainer.args.ref_model_init_kwargs, {"trust_remote_code": True})
self.assertEqual(trainer.args.dataset_num_proc, 4)
def test_online_dpo(self):
tokenizer = AutoTokenizer.from_pretrained("gpt2")
@parameterized.expand([(False,), (True,)])
def test_nash_md(self, mixtures_coef_list):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = NashMDConfig(
tmp_dir,
mixture_coef=0.5 if not mixtures_coef_list else [0.5, 0.6],
)
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m")
ref_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m")
reward_model = AutoModelForSequenceClassification.from_pretrained("EleutherAI/pythia-14m", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m")
trainer = NashMDTrainer(
args=training_args,
tokenizer=tokenizer,
model=model,
ref_model=ref_model,
reward_model=reward_model,
train_dataset=dataset,
)
self.assertEqual(trainer.args.mixture_coef, 0.5 if not mixtures_coef_list else [0.5, 0.6])
@parameterized.expand([(False,), (True,)])
def test_online_dpo(self, beta_list):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
args = OnlineDPOConfig(
@ -228,13 +257,14 @@ class TrainerArgTester(unittest.TestCase):
max_new_tokens=42,
temperature=0.5,
missing_eos_penalty=0.33,
beta=0.6,
beta=0.6 if not beta_list else [0.6, 0.7],
loss_type="hinge",
dataset_num_proc=4,
)
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m")
ref_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m")
reward_model = AutoModelForSequenceClassification.from_pretrained("EleutherAI/pythia-14m", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m")
trainer = OnlineDPOTrainer(
args=args,
tokenizer=tokenizer,
@ -246,7 +276,7 @@ class TrainerArgTester(unittest.TestCase):
self.assertEqual(trainer.args.max_new_tokens, 42)
self.assertEqual(trainer.args.temperature, 0.5)
self.assertEqual(trainer.args.missing_eos_penalty, 0.33)
self.assertEqual(trainer.args.beta, 0.6)
self.assertEqual(trainer.args.beta, 0.6 if not beta_list else [0.6, 0.7])
self.assertEqual(trainer.args.loss_type, "hinge")
self.assertEqual(trainer.args.dataset_num_proc, 4)
@ -278,6 +308,27 @@ class TrainerArgTester(unittest.TestCase):
self.assertEqual(trainer.args.disable_dropout, False)
self.assertEqual(trainer.args.label_pad_token_id, -99)
def test_reward(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = RewardConfig(
tmp_dir,
max_length=256,
dataset_num_proc=4,
center_rewards_coefficient=0.1,
)
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m")
trainer = RewardTrainer(
model=model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
)
self.assertEqual(trainer.args.max_length, 256)
self.assertEqual(trainer.args.dataset_num_proc, 4)
self.assertEqual(trainer.args.center_rewards_coefficient, 0.1)
def test_sft(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
@ -308,3 +359,25 @@ class TrainerArgTester(unittest.TestCase):
self.assertEqual(trainer.args.eval_packing, True)
self.assertEqual(trainer.args.num_of_sequences, 32)
self.assertEqual(trainer.args.chars_per_token, 4.2)
@parameterized.expand([(False,), (True,)])
def test_xpo(self, alpha_list):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = XPOConfig(
tmp_dir,
alpha=0.5 if not alpha_list else [0.5, 0.6],
)
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m")
ref_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m")
reward_model = AutoModelForSequenceClassification.from_pretrained("EleutherAI/pythia-14m", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m")
trainer = XPOTrainer(
args=training_args,
tokenizer=tokenizer,
model=model,
ref_model=ref_model,
reward_model=reward_model,
train_dataset=dataset,
)
self.assertEqual(trainer.args.alpha, 0.5 if not alpha_list else [0.5, 0.6])

View File

@ -15,12 +15,13 @@
import unittest
import torch
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available
from trl.trainer.model_config import ModelConfig
from trl.trainer.utils import decode_and_strip_padding, get_peft_config, pad
from trl.trainer.utils import DataCollatorForChatML, decode_and_strip_padding, get_peft_config, pad
if is_peft_available():
@ -126,3 +127,77 @@ class TestDecodeAndStripPadding(unittest.TestCase):
inputs = self.tokenizer(["Hello", "Hello"], padding=False, return_tensors="pt")
decoded = decode_and_strip_padding(inputs["input_ids"], self.tokenizer)
self.assertEqual(decoded, ["Hello", "Hello"])
class TestDataCollatorForChatML(unittest.TestCase):
def setUp(self):
# Initialize the tokenizer
self.tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-Instruct-hf")
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Define token IDs
self.bos_token_id = self.tokenizer.bos_token_id if self.tokenizer.bos_token_id is not None else 1
self.eos_token_id = self.tokenizer.eos_token_id if self.tokenizer.eos_token_id is not None else 2
# Token ID for "true", the last assistant's response in the example:
self.ignore_index = -100
self.max_length = 1024
self.messages_key = "messages"
# Example input
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")
self.examples = dataset.to_list()
# Initialize the data collator
self.collator = DataCollatorForChatML(
tokenizer=self.tokenizer,
max_length=self.max_length,
ignore_index=self.ignore_index,
)
def test_data_collator_for_chatml(self):
# Process the data
data = self.collator(self.examples)
# Decode input_ids and labels for verification
input_ids = data["input_ids"][0].tolist()
labels = data["labels"][0].tolist()
prompt_only = data["prompts"][0].tolist()
# Verify that input_ids start with optional padding tokens and a single BOS token and there are no extra ones
first_non_pad = next(token for token in input_ids if token != self.tokenizer.pad_token_id)
self.assertEqual(
first_non_pad, self.bos_token_id, "The first non-padding token of input_ids should be BOS token."
)
self.assertEqual(input_ids.count(self.bos_token_id), 1, "There should be exactly one BOS token in input_ids.")
# Verify that the assistant's response token is present in input_ids and not in the prompt_only
last_assistant_response = self.examples[0][self.messages_key][-1]["content"]
last_assistant_response_tokens = self.tokenizer.encode(last_assistant_response, add_special_tokens=False)
response_in_input_ids = all(token in input_ids for token in last_assistant_response_tokens)
self.assertTrue(response_in_input_ids, "The assistant's response should be present in input_ids.")
# Check if the last assistant's response tokens are not in prompt_only
response_in_prompt = all(token in prompt_only for token in last_assistant_response_tokens)
self.assertFalse(response_in_prompt, "The assistant's response should not be present in prompt_only.")
# Verify that EOS token is at the end of input_ids
self.assertEqual(input_ids[-1], self.eos_token_id, "The last token of input_ids should be EOS token.")
# Verify that the labels preserved the target string (last_assistant_response)
last_assistant_response = self.examples[0][self.messages_key][-1]["content"]
last_assistant_response_tokens = self.tokenizer.encode(last_assistant_response, add_special_tokens=False)
# Find the start and end of the last assistant's response in the labels
response_start = next(i for i, label in enumerate(labels) if label != self.ignore_index)
response_end = next(i for i in range(len(labels) - 1, -1, -1) if labels[i] != self.ignore_index)
actual_response = labels[response_start : response_end - 1]
self.assertEqual(
actual_response,
last_assistant_response_tokens,
"The labels should preserve the last assistant's response tokens.",
)
# Verify that EOS token is at the end of labels
self.assertEqual(labels[-1], self.eos_token_id, "The last token of labels should be EOS token.")

View File

@ -14,7 +14,7 @@
# flake8: noqa
__version__ = "0.11.0.dev0"
__version__ = "0.11.3"
from typing import TYPE_CHECKING
from .import_utils import _LazyModule, is_diffusers_available, OptionalDependencyNotAvailable

View File

@ -104,9 +104,7 @@ def _tokenize(
_append_prompt_tokens_to_batch(batch, prompt_tokens)
else:
_tokenize_encoder_decoder(
batch, tokenizer, features["prompt"], features["chosen"], features["rejected"], args, model
)
_tokenize_encoder_decoder(batch, tokenizer, features["prompt"], features["chosen"], features["rejected"], args)
return dict(batch)
@ -253,7 +251,6 @@ def _tokenize_encoder_decoder(
chosen: List[str],
rejected: List[str],
args: DPOConfig,
model: Optional[PreTrainedModel],
) -> None:
chosen_tokens = tokenizer(chosen, truncation=True, max_length=args.max_completion_length, add_special_tokens=True)
rejected_tokens = tokenizer(
@ -266,23 +263,6 @@ def _tokenize_encoder_decoder(
batch["prompt_input_ids"] = prompt_tokens["input_ids"]
batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
# Ensure the sequences are of the same length
max_length = max(len(seq) for seq in batch["chosen_labels"] + batch["rejected_labels"])
batch["chosen_labels"] = [
seq + [tokenizer.pad_token_id] * (max_length - len(seq)) for seq in batch["chosen_labels"]
]
batch["rejected_labels"] = [
seq + [tokenizer.pad_token_id] * (max_length - len(seq)) for seq in batch["rejected_labels"]
]
batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
labels=torch.tensor(batch["rejected_labels"])
)
batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
labels=torch.tensor(batch["chosen_labels"])
)
def _build_tokenized_answer(
prompt: str,
@ -1114,9 +1094,6 @@ class DPOTrainer(Trainer):
concatenated_batch["concatenated_attention_mask"] = (
batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
)
concatenated_batch["concatenated_decoder_input_ids"] = torch.cat(
[batch["chosen_decoder_input_ids"], batch["rejected_decoder_input_ids"]], dim=0
).to(device=device)
if is_vision_model:
concatenated_batch["pixel_values"] = torch.cat(
@ -1365,7 +1342,6 @@ class DPOTrainer(Trainer):
if self.is_encoder_decoder:
model_kwargs["labels"] = concatenated_batch["concatenated_labels"]
model_kwargs["decoder_input_ids"] = concatenated_batch.get("concatenated_decoder_input_ids")
if self.is_vision_model:
model_kwargs["pixel_values"] = concatenated_batch["pixel_values"]

View File

@ -124,13 +124,18 @@ class GKDTrainer(SFTTrainer):
student_log_probs = F.log_softmax(student_logits, dim=-1)
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
# Compute the interpolated log probabilities
interpolated_log_probs = beta * student_log_probs + (1 - beta) * teacher_log_probs
# Compute the log of the mixture distribution
# log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture
beta = torch.tensor(beta, dtype=student_log_probs.dtype)
mixture_log_probs = torch.logsumexp(
torch.stack([student_log_probs + torch.log(beta), teacher_log_probs + torch.log(1 - beta)]),
dim=0,
)
# Compute KL divergences using F.kl_div
# PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper.
kl_teacher = F.kl_div(interpolated_log_probs, teacher_log_probs, reduction="none", log_target=True)
kl_student = F.kl_div(interpolated_log_probs, student_log_probs, reduction="none", log_target=True)
kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)
# Compute the Generalized Jensen-Shannon Divergence
jsd = beta * kl_teacher + (1 - beta) * kl_student

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Union
from dataclasses import dataclass, field
from typing import List
from trl.trainer.online_dpo_config import OnlineDPOConfig
@ -32,4 +32,9 @@ class NashMDConfig(OnlineDPOConfig):
epochs.
"""
mixture_coef: Union[float, List[float]] = 0.5
mixture_coef: List[float] = field(default_factory=lambda: [0.5])
def __post_init__(self):
super().__post_init__()
if hasattr(self.mixture_coef, "__len__") and len(self.mixture_coef) == 1:
self.mixture_coef = self.mixture_coef[0]

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Literal, Optional, Union
from dataclasses import dataclass, field
from typing import List, Literal, Optional
from transformers import TrainingArguments
@ -63,7 +63,12 @@ class OnlineDPOConfig(TrainingArguments):
max_new_tokens: int = 64
temperature: float = 0.9
missing_eos_penalty: Optional[float] = None
beta: Union[float, List[float]] = 0.1
beta: List[float] = field(default_factory=lambda: [0.1])
loss_type: Literal["sigmoid", "ipo"] = "sigmoid"
dataset_num_proc: Optional[int] = None
disable_dropout: bool = True
def __post_init__(self):
super().__post_init__()
if hasattr(self.beta, "__len__") and len(self.beta) == 1:
self.beta = self.beta[0]

View File

@ -39,7 +39,7 @@ from transformers import (
)
from transformers.integrations import get_reporting_integration_callbacks
from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK
from transformers.trainer_callback import CallbackHandler, PrinterCallback
from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback
from ..models.utils import unwrap_model_for_generation
from ..trainer.utils import (
@ -149,10 +149,6 @@ class RLOOTrainer(Trainer):
#########
### trainer specifics
#########
self.state = OnlineTrainerState(
is_local_process_zero=self.is_local_process_zero(),
is_world_process_zero=self.is_world_process_zero(),
)
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
self.callback_handler = CallbackHandler(
@ -160,6 +156,14 @@ class RLOOTrainer(Trainer):
)
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
self.control = TrainerControl()
self.state = OnlineTrainerState(
is_local_process_zero=self.is_local_process_zero(),
is_world_process_zero=self.is_world_process_zero(),
stateful_callbacks=[
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
],
)
self.current_flos = 0
self.hp_search_backend = None
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None

View File

@ -240,6 +240,7 @@ class DataCollatorForChatML:
tokenizer: PreTrainedTokenizerBase
ignore_index: int = -100
max_length: int = None
prompt_key: str = "prompt"
messages_key: str = "messages"
def __post_init__(self):
@ -250,67 +251,69 @@ class DataCollatorForChatML:
self.max_length = min(self.tokenizer.model_max_length, 1024)
def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
prompts = []
completions = []
for example in examples:
messages = example[self.messages_key]
formatted_chat = self.tokenizer.apply_chat_template(messages, tokenize=False)
# Split the formatted chat into prompt and completion
assistant_messages = [msg for msg in messages if msg["role"] == "assistant"]
last_assistant_message = assistant_messages[-1]["content"]
prompt = formatted_chat.rsplit(last_assistant_message, 1)[0]
completion = last_assistant_message
prompts.append(prompt)
completions.append(completion)
# Tokenize prompts and completions
tokenized_prompts = self.tokenizer(
prompts, truncation=True, max_length=self.max_length, padding=False, return_tensors=None
)
tokenized_completions = self.tokenizer(
completions, truncation=True, max_length=self.max_length, padding=False, return_tensors=None
)
# Combine prompts and completions
input_ids = []
attention_mask = []
prompts_input_ids = []
prompt_attention_mask = []
labels = []
for prompt, completion in zip(tokenized_prompts["input_ids"], tokenized_completions["input_ids"]):
combined_input_ids = prompt + completion
combined_attention_mask = [1] * len(combined_input_ids)
for example in examples:
formatted_prompt = example.get(self.prompt_key, None)
if formatted_prompt is None:
prompt = example[self.messages_key][:-1]
formatted_prompt = self.tokenizer.apply_chat_template(
prompt, tokenize=False, add_generation_prompt=True
)
# Create labels for one-token ahead task, masking the prompt
combined_labels = [self.ignore_index] * len(prompt) + completion[:-1]
combined_labels.append(self.tokenizer.eos_token_id) # Add EOS token as final target
if "input_ids" not in example:
message = example[self.messages_key]
formatted_message = self.tokenizer.apply_chat_template(
message, tokenize=False, add_generation_prompt=True
)
tokenized_message = self.tokenizer(
formatted_message,
truncation=True,
max_length=self.max_length,
padding=False,
return_tensors=None,
add_special_tokens=False,
)
input_ids.append(tokenized_message["input_ids"])
attention_mask.append(tokenized_message["attention_mask"])
else:
input_ids.append(example["input_ids"])
attention_mask.append(example["attention_mask"])
input_ids.append(combined_input_ids)
attention_mask.append(combined_attention_mask)
labels.append(combined_labels)
tokenized_prompt = self.tokenizer(
formatted_prompt,
truncation=True,
max_length=len(input_ids[-1]),
padding=False,
return_tensors=None,
add_special_tokens=False,
)
# first convert to list of tensors
input_ids = [torch.tensor(ids) for ids in input_ids]
attention_mask = [torch.tensor(mask) for mask in attention_mask]
labels = [torch.tensor(label) for label in labels]
prompts_input_ids.append(tokenized_prompt["input_ids"])
prompt_attention_mask.append(tokenized_prompt["attention_mask"])
# pad the input_ids, attention_mask and labels to the same length across the batch
# Create the labels that will have all but the completion tokens of the example["input_ids"] set to ignore_index
label = [self.ignore_index] * len(input_ids[-1])
completion_start_idx = len(tokenized_prompt["input_ids"])
label[completion_start_idx:] = input_ids[-1][completion_start_idx:]
labels.append(label)
# convert to list of tensors and pad
input_ids = [torch.tensor(ids, dtype=torch.long) for ids in input_ids]
attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in attention_mask]
labels = [torch.tensor(label, dtype=torch.long) for label in labels]
input_ids = pad(input_ids, padding_side="left", padding_value=self.tokenizer.pad_token_id)
attention_mask = pad(attention_mask, padding_side="left", padding_value=0)
labels = pad(labels, padding_side="left", padding_value=self.ignore_index)
# pad the tokenized_prompts on the left to the same length convert to tensor first
prompts_input_ids = [torch.tensor(ids) for ids in tokenized_prompts["input_ids"]]
prompts_input_ids = [torch.tensor(ids, dtype=torch.long) for ids in prompts_input_ids]
prompt_attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in prompt_attention_mask]
prompts_input_ids = pad(prompts_input_ids, padding_side="left", padding_value=self.tokenizer.pad_token_id)
# prompt attention mask
prompt_attention_mask = pad(
[torch.tensor([1] * len(ids)) for ids in tokenized_prompts["input_ids"]],
padding_side="left",
padding_value=0,
)
prompt_attention_mask = pad(prompt_attention_mask, padding_side="left", padding_value=0)
return {
"input_ids": input_ids,

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Union
from dataclasses import dataclass, field
from typing import List
from trl.trainer.online_dpo_config import OnlineDPOConfig
@ -30,4 +30,9 @@ class XPOConfig(OnlineDPOConfig):
Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each new epoch and the last alpha is used for the rest of the epochs.
"""
alpha: Union[float, List[float]] = 1e-5
alpha: List[float] = field(default_factory=lambda: [1e-5])
def __post_init__(self):
super().__post_init__()
if hasattr(self.alpha, "__len__") and len(self.alpha) == 1:
self.alpha = self.alpha[0]