mirror of
https://github.com/huggingface/trl.git
synced 2025-10-21 02:53:59 +08:00
Compare commits
9 Commits
openenv-in
...
v0.11.3
Author | SHA1 | Date | |
---|---|---|---|
1661bc295e | |||
fdbcaaea3d | |||
22567cd1d4 | |||
00b537eefd | |||
01142bb7c4 | |||
d3fb4860e2 | |||
86ad7a7e85 | |||
3aa1996986 | |||
e4935e13fc |
2
setup.py
2
setup.py
@ -73,7 +73,7 @@ import os
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
__version__ = "0.11.0.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
__version__ = "0.11.3" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
|
||||
REQUIRED_PKGS = [
|
||||
"torch>=1.4.0",
|
||||
|
@ -13,8 +13,14 @@
|
||||
# limitations under the License.
|
||||
import platform
|
||||
import subprocess
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
from trl import RLOOConfig, RLOOTrainer
|
||||
|
||||
|
||||
def test():
|
||||
@ -26,6 +32,8 @@ python examples/scripts/rloo/rloo.py \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--total_episodes 10 \
|
||||
--model_name_or_path EleutherAI/pythia-14m \
|
||||
--sft_model_path EleutherAI/pythia-14m \
|
||||
--reward_model_path EleutherAI/pythia-14m \
|
||||
--missing_eos_penalty 1.0 \
|
||||
--save_strategy no \
|
||||
--stop_token eos
|
||||
@ -71,3 +79,42 @@ def test_rloo_reward():
|
||||
baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1)
|
||||
vec_advantages = rlhf_reward - baseline
|
||||
torch.testing.assert_close(vec_advantages.flatten(), advantages)
|
||||
|
||||
|
||||
class RLOOTrainerTester(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.sft_model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
|
||||
self.reward_model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
|
||||
|
||||
self.policy_model = AutoModelForCausalLM.from_pretrained(self.sft_model_id)
|
||||
self.reward_model = AutoModelForSequenceClassification.from_pretrained(self.reward_model_id)
|
||||
self.policy_ref_model = AutoModelForCausalLM.from_pretrained(self.sft_model_id)
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.sft_model_id, padding_side="left")
|
||||
self.tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}"
|
||||
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
|
||||
def test_rloo_checkpoint(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = RLOOConfig(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
total_episodes=1,
|
||||
report_to="none",
|
||||
)
|
||||
|
||||
dummy_text = {"content": "Hello World!", "role": "user"}
|
||||
dummy_data = self.tokenizer.apply_chat_template(dummy_text)
|
||||
dummy_dataset = Dataset.from_dict({"input_ids": dummy_data})
|
||||
|
||||
trainer = RLOOTrainer(
|
||||
config=training_args,
|
||||
policy=self.policy_model,
|
||||
reward_model=self.reward_model,
|
||||
ref_policy=self.policy_ref_model,
|
||||
processing_class=self.tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
eval_dataset=dummy_dataset,
|
||||
)
|
||||
|
||||
trainer._save_checkpoint(trainer.model, trial=None)
|
||||
|
@ -16,6 +16,7 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
from datasets import load_dataset
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
from trl import (
|
||||
@ -27,12 +28,18 @@ from trl import (
|
||||
DPOTrainer,
|
||||
KTOConfig,
|
||||
KTOTrainer,
|
||||
NashMDConfig,
|
||||
NashMDTrainer,
|
||||
OnlineDPOConfig,
|
||||
OnlineDPOTrainer,
|
||||
ORPOConfig,
|
||||
ORPOTrainer,
|
||||
RewardConfig,
|
||||
RewardTrainer,
|
||||
SFTConfig,
|
||||
SFTTrainer,
|
||||
XPOConfig,
|
||||
XPOTrainer,
|
||||
)
|
||||
|
||||
|
||||
@ -219,8 +226,30 @@ class TrainerArgTester(unittest.TestCase):
|
||||
self.assertEqual(trainer.args.ref_model_init_kwargs, {"trust_remote_code": True})
|
||||
self.assertEqual(trainer.args.dataset_num_proc, 4)
|
||||
|
||||
def test_online_dpo(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_nash_md(self, mixtures_coef_list):
|
||||
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = NashMDConfig(
|
||||
tmp_dir,
|
||||
mixture_coef=0.5 if not mixtures_coef_list else [0.5, 0.6],
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m")
|
||||
ref_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m")
|
||||
reward_model = AutoModelForSequenceClassification.from_pretrained("EleutherAI/pythia-14m", num_labels=1)
|
||||
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m")
|
||||
trainer = NashMDTrainer(
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
ref_model=ref_model,
|
||||
reward_model=reward_model,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
self.assertEqual(trainer.args.mixture_coef, 0.5 if not mixtures_coef_list else [0.5, 0.6])
|
||||
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_online_dpo(self, beta_list):
|
||||
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
args = OnlineDPOConfig(
|
||||
@ -228,13 +257,14 @@ class TrainerArgTester(unittest.TestCase):
|
||||
max_new_tokens=42,
|
||||
temperature=0.5,
|
||||
missing_eos_penalty=0.33,
|
||||
beta=0.6,
|
||||
beta=0.6 if not beta_list else [0.6, 0.7],
|
||||
loss_type="hinge",
|
||||
dataset_num_proc=4,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m")
|
||||
ref_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m")
|
||||
reward_model = AutoModelForSequenceClassification.from_pretrained("EleutherAI/pythia-14m", num_labels=1)
|
||||
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m")
|
||||
trainer = OnlineDPOTrainer(
|
||||
args=args,
|
||||
tokenizer=tokenizer,
|
||||
@ -246,7 +276,7 @@ class TrainerArgTester(unittest.TestCase):
|
||||
self.assertEqual(trainer.args.max_new_tokens, 42)
|
||||
self.assertEqual(trainer.args.temperature, 0.5)
|
||||
self.assertEqual(trainer.args.missing_eos_penalty, 0.33)
|
||||
self.assertEqual(trainer.args.beta, 0.6)
|
||||
self.assertEqual(trainer.args.beta, 0.6 if not beta_list else [0.6, 0.7])
|
||||
self.assertEqual(trainer.args.loss_type, "hinge")
|
||||
self.assertEqual(trainer.args.dataset_num_proc, 4)
|
||||
|
||||
@ -278,6 +308,27 @@ class TrainerArgTester(unittest.TestCase):
|
||||
self.assertEqual(trainer.args.disable_dropout, False)
|
||||
self.assertEqual(trainer.args.label_pad_token_id, -99)
|
||||
|
||||
def test_reward(self):
|
||||
dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = RewardConfig(
|
||||
tmp_dir,
|
||||
max_length=256,
|
||||
dataset_num_proc=4,
|
||||
center_rewards_coefficient=0.1,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m")
|
||||
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m")
|
||||
trainer = RewardTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
self.assertEqual(trainer.args.max_length, 256)
|
||||
self.assertEqual(trainer.args.dataset_num_proc, 4)
|
||||
self.assertEqual(trainer.args.center_rewards_coefficient, 0.1)
|
||||
|
||||
def test_sft(self):
|
||||
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@ -308,3 +359,25 @@ class TrainerArgTester(unittest.TestCase):
|
||||
self.assertEqual(trainer.args.eval_packing, True)
|
||||
self.assertEqual(trainer.args.num_of_sequences, 32)
|
||||
self.assertEqual(trainer.args.chars_per_token, 4.2)
|
||||
|
||||
@parameterized.expand([(False,), (True,)])
|
||||
def test_xpo(self, alpha_list):
|
||||
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = XPOConfig(
|
||||
tmp_dir,
|
||||
alpha=0.5 if not alpha_list else [0.5, 0.6],
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m")
|
||||
ref_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m")
|
||||
reward_model = AutoModelForSequenceClassification.from_pretrained("EleutherAI/pythia-14m", num_labels=1)
|
||||
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m")
|
||||
trainer = XPOTrainer(
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
ref_model=ref_model,
|
||||
reward_model=reward_model,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
self.assertEqual(trainer.args.alpha, 0.5 if not alpha_list else [0.5, 0.6])
|
||||
|
@ -15,12 +15,13 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.testing_utils import require_peft
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from trl.trainer.model_config import ModelConfig
|
||||
from trl.trainer.utils import decode_and_strip_padding, get_peft_config, pad
|
||||
from trl.trainer.utils import DataCollatorForChatML, decode_and_strip_padding, get_peft_config, pad
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
@ -126,3 +127,77 @@ class TestDecodeAndStripPadding(unittest.TestCase):
|
||||
inputs = self.tokenizer(["Hello", "Hello"], padding=False, return_tensors="pt")
|
||||
decoded = decode_and_strip_padding(inputs["input_ids"], self.tokenizer)
|
||||
self.assertEqual(decoded, ["Hello", "Hello"])
|
||||
|
||||
|
||||
class TestDataCollatorForChatML(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Initialize the tokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-Instruct-hf")
|
||||
if self.tokenizer.pad_token is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
# Define token IDs
|
||||
self.bos_token_id = self.tokenizer.bos_token_id if self.tokenizer.bos_token_id is not None else 1
|
||||
self.eos_token_id = self.tokenizer.eos_token_id if self.tokenizer.eos_token_id is not None else 2
|
||||
# Token ID for "true", the last assistant's response in the example:
|
||||
self.ignore_index = -100
|
||||
self.max_length = 1024
|
||||
self.messages_key = "messages"
|
||||
|
||||
# Example input
|
||||
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")
|
||||
self.examples = dataset.to_list()
|
||||
|
||||
# Initialize the data collator
|
||||
self.collator = DataCollatorForChatML(
|
||||
tokenizer=self.tokenizer,
|
||||
max_length=self.max_length,
|
||||
ignore_index=self.ignore_index,
|
||||
)
|
||||
|
||||
def test_data_collator_for_chatml(self):
|
||||
# Process the data
|
||||
data = self.collator(self.examples)
|
||||
|
||||
# Decode input_ids and labels for verification
|
||||
input_ids = data["input_ids"][0].tolist()
|
||||
labels = data["labels"][0].tolist()
|
||||
prompt_only = data["prompts"][0].tolist()
|
||||
|
||||
# Verify that input_ids start with optional padding tokens and a single BOS token and there are no extra ones
|
||||
first_non_pad = next(token for token in input_ids if token != self.tokenizer.pad_token_id)
|
||||
self.assertEqual(
|
||||
first_non_pad, self.bos_token_id, "The first non-padding token of input_ids should be BOS token."
|
||||
)
|
||||
self.assertEqual(input_ids.count(self.bos_token_id), 1, "There should be exactly one BOS token in input_ids.")
|
||||
|
||||
# Verify that the assistant's response token is present in input_ids and not in the prompt_only
|
||||
last_assistant_response = self.examples[0][self.messages_key][-1]["content"]
|
||||
last_assistant_response_tokens = self.tokenizer.encode(last_assistant_response, add_special_tokens=False)
|
||||
response_in_input_ids = all(token in input_ids for token in last_assistant_response_tokens)
|
||||
self.assertTrue(response_in_input_ids, "The assistant's response should be present in input_ids.")
|
||||
|
||||
# Check if the last assistant's response tokens are not in prompt_only
|
||||
response_in_prompt = all(token in prompt_only for token in last_assistant_response_tokens)
|
||||
self.assertFalse(response_in_prompt, "The assistant's response should not be present in prompt_only.")
|
||||
|
||||
# Verify that EOS token is at the end of input_ids
|
||||
self.assertEqual(input_ids[-1], self.eos_token_id, "The last token of input_ids should be EOS token.")
|
||||
|
||||
# Verify that the labels preserved the target string (last_assistant_response)
|
||||
last_assistant_response = self.examples[0][self.messages_key][-1]["content"]
|
||||
last_assistant_response_tokens = self.tokenizer.encode(last_assistant_response, add_special_tokens=False)
|
||||
|
||||
# Find the start and end of the last assistant's response in the labels
|
||||
response_start = next(i for i, label in enumerate(labels) if label != self.ignore_index)
|
||||
response_end = next(i for i in range(len(labels) - 1, -1, -1) if labels[i] != self.ignore_index)
|
||||
|
||||
actual_response = labels[response_start : response_end - 1]
|
||||
self.assertEqual(
|
||||
actual_response,
|
||||
last_assistant_response_tokens,
|
||||
"The labels should preserve the last assistant's response tokens.",
|
||||
)
|
||||
|
||||
# Verify that EOS token is at the end of labels
|
||||
self.assertEqual(labels[-1], self.eos_token_id, "The last token of labels should be EOS token.")
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
# flake8: noqa
|
||||
|
||||
__version__ = "0.11.0.dev0"
|
||||
__version__ = "0.11.3"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from .import_utils import _LazyModule, is_diffusers_available, OptionalDependencyNotAvailable
|
||||
|
@ -104,9 +104,7 @@ def _tokenize(
|
||||
_append_prompt_tokens_to_batch(batch, prompt_tokens)
|
||||
|
||||
else:
|
||||
_tokenize_encoder_decoder(
|
||||
batch, tokenizer, features["prompt"], features["chosen"], features["rejected"], args, model
|
||||
)
|
||||
_tokenize_encoder_decoder(batch, tokenizer, features["prompt"], features["chosen"], features["rejected"], args)
|
||||
|
||||
return dict(batch)
|
||||
|
||||
@ -253,7 +251,6 @@ def _tokenize_encoder_decoder(
|
||||
chosen: List[str],
|
||||
rejected: List[str],
|
||||
args: DPOConfig,
|
||||
model: Optional[PreTrainedModel],
|
||||
) -> None:
|
||||
chosen_tokens = tokenizer(chosen, truncation=True, max_length=args.max_completion_length, add_special_tokens=True)
|
||||
rejected_tokens = tokenizer(
|
||||
@ -266,23 +263,6 @@ def _tokenize_encoder_decoder(
|
||||
batch["prompt_input_ids"] = prompt_tokens["input_ids"]
|
||||
batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
|
||||
|
||||
if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
||||
# Ensure the sequences are of the same length
|
||||
max_length = max(len(seq) for seq in batch["chosen_labels"] + batch["rejected_labels"])
|
||||
batch["chosen_labels"] = [
|
||||
seq + [tokenizer.pad_token_id] * (max_length - len(seq)) for seq in batch["chosen_labels"]
|
||||
]
|
||||
batch["rejected_labels"] = [
|
||||
seq + [tokenizer.pad_token_id] * (max_length - len(seq)) for seq in batch["rejected_labels"]
|
||||
]
|
||||
|
||||
batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
||||
labels=torch.tensor(batch["rejected_labels"])
|
||||
)
|
||||
batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
||||
labels=torch.tensor(batch["chosen_labels"])
|
||||
)
|
||||
|
||||
|
||||
def _build_tokenized_answer(
|
||||
prompt: str,
|
||||
@ -1114,9 +1094,6 @@ class DPOTrainer(Trainer):
|
||||
concatenated_batch["concatenated_attention_mask"] = (
|
||||
batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
|
||||
)
|
||||
concatenated_batch["concatenated_decoder_input_ids"] = torch.cat(
|
||||
[batch["chosen_decoder_input_ids"], batch["rejected_decoder_input_ids"]], dim=0
|
||||
).to(device=device)
|
||||
|
||||
if is_vision_model:
|
||||
concatenated_batch["pixel_values"] = torch.cat(
|
||||
@ -1365,7 +1342,6 @@ class DPOTrainer(Trainer):
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
model_kwargs["labels"] = concatenated_batch["concatenated_labels"]
|
||||
model_kwargs["decoder_input_ids"] = concatenated_batch.get("concatenated_decoder_input_ids")
|
||||
|
||||
if self.is_vision_model:
|
||||
model_kwargs["pixel_values"] = concatenated_batch["pixel_values"]
|
||||
|
@ -124,13 +124,18 @@ class GKDTrainer(SFTTrainer):
|
||||
student_log_probs = F.log_softmax(student_logits, dim=-1)
|
||||
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
|
||||
|
||||
# Compute the interpolated log probabilities
|
||||
interpolated_log_probs = beta * student_log_probs + (1 - beta) * teacher_log_probs
|
||||
# Compute the log of the mixture distribution
|
||||
# log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture
|
||||
beta = torch.tensor(beta, dtype=student_log_probs.dtype)
|
||||
mixture_log_probs = torch.logsumexp(
|
||||
torch.stack([student_log_probs + torch.log(beta), teacher_log_probs + torch.log(1 - beta)]),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Compute KL divergences using F.kl_div
|
||||
# PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper.
|
||||
kl_teacher = F.kl_div(interpolated_log_probs, teacher_log_probs, reduction="none", log_target=True)
|
||||
kl_student = F.kl_div(interpolated_log_probs, student_log_probs, reduction="none", log_target=True)
|
||||
kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
|
||||
kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)
|
||||
|
||||
# Compute the Generalized Jensen-Shannon Divergence
|
||||
jsd = beta * kl_teacher + (1 - beta) * kl_student
|
||||
|
@ -12,8 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from trl.trainer.online_dpo_config import OnlineDPOConfig
|
||||
|
||||
@ -32,4 +32,9 @@ class NashMDConfig(OnlineDPOConfig):
|
||||
epochs.
|
||||
"""
|
||||
|
||||
mixture_coef: Union[float, List[float]] = 0.5
|
||||
mixture_coef: List[float] = field(default_factory=lambda: [0.5])
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if hasattr(self.mixture_coef, "__len__") and len(self.mixture_coef) == 1:
|
||||
self.mixture_coef = self.mixture_coef[0]
|
||||
|
@ -12,8 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Literal, Optional, Union
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from transformers import TrainingArguments
|
||||
|
||||
@ -63,7 +63,12 @@ class OnlineDPOConfig(TrainingArguments):
|
||||
max_new_tokens: int = 64
|
||||
temperature: float = 0.9
|
||||
missing_eos_penalty: Optional[float] = None
|
||||
beta: Union[float, List[float]] = 0.1
|
||||
beta: List[float] = field(default_factory=lambda: [0.1])
|
||||
loss_type: Literal["sigmoid", "ipo"] = "sigmoid"
|
||||
dataset_num_proc: Optional[int] = None
|
||||
disable_dropout: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if hasattr(self.beta, "__len__") and len(self.beta) == 1:
|
||||
self.beta = self.beta[0]
|
||||
|
@ -39,7 +39,7 @@ from transformers import (
|
||||
)
|
||||
from transformers.integrations import get_reporting_integration_callbacks
|
||||
from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK
|
||||
from transformers.trainer_callback import CallbackHandler, PrinterCallback
|
||||
from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback
|
||||
|
||||
from ..models.utils import unwrap_model_for_generation
|
||||
from ..trainer.utils import (
|
||||
@ -149,10 +149,6 @@ class RLOOTrainer(Trainer):
|
||||
#########
|
||||
### trainer specifics
|
||||
#########
|
||||
self.state = OnlineTrainerState(
|
||||
is_local_process_zero=self.is_local_process_zero(),
|
||||
is_world_process_zero=self.is_world_process_zero(),
|
||||
)
|
||||
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
|
||||
self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
|
||||
self.callback_handler = CallbackHandler(
|
||||
@ -160,6 +156,14 @@ class RLOOTrainer(Trainer):
|
||||
)
|
||||
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
|
||||
self.control = TrainerControl()
|
||||
self.state = OnlineTrainerState(
|
||||
is_local_process_zero=self.is_local_process_zero(),
|
||||
is_world_process_zero=self.is_world_process_zero(),
|
||||
stateful_callbacks=[
|
||||
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
|
||||
],
|
||||
)
|
||||
|
||||
self.current_flos = 0
|
||||
self.hp_search_backend = None
|
||||
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
||||
|
@ -240,6 +240,7 @@ class DataCollatorForChatML:
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
ignore_index: int = -100
|
||||
max_length: int = None
|
||||
prompt_key: str = "prompt"
|
||||
messages_key: str = "messages"
|
||||
|
||||
def __post_init__(self):
|
||||
@ -250,67 +251,69 @@ class DataCollatorForChatML:
|
||||
self.max_length = min(self.tokenizer.model_max_length, 1024)
|
||||
|
||||
def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||
prompts = []
|
||||
completions = []
|
||||
|
||||
for example in examples:
|
||||
messages = example[self.messages_key]
|
||||
formatted_chat = self.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
# Split the formatted chat into prompt and completion
|
||||
assistant_messages = [msg for msg in messages if msg["role"] == "assistant"]
|
||||
last_assistant_message = assistant_messages[-1]["content"]
|
||||
prompt = formatted_chat.rsplit(last_assistant_message, 1)[0]
|
||||
completion = last_assistant_message
|
||||
|
||||
prompts.append(prompt)
|
||||
completions.append(completion)
|
||||
|
||||
# Tokenize prompts and completions
|
||||
tokenized_prompts = self.tokenizer(
|
||||
prompts, truncation=True, max_length=self.max_length, padding=False, return_tensors=None
|
||||
)
|
||||
tokenized_completions = self.tokenizer(
|
||||
completions, truncation=True, max_length=self.max_length, padding=False, return_tensors=None
|
||||
)
|
||||
|
||||
# Combine prompts and completions
|
||||
input_ids = []
|
||||
attention_mask = []
|
||||
prompts_input_ids = []
|
||||
prompt_attention_mask = []
|
||||
labels = []
|
||||
|
||||
for prompt, completion in zip(tokenized_prompts["input_ids"], tokenized_completions["input_ids"]):
|
||||
combined_input_ids = prompt + completion
|
||||
combined_attention_mask = [1] * len(combined_input_ids)
|
||||
for example in examples:
|
||||
formatted_prompt = example.get(self.prompt_key, None)
|
||||
if formatted_prompt is None:
|
||||
prompt = example[self.messages_key][:-1]
|
||||
formatted_prompt = self.tokenizer.apply_chat_template(
|
||||
prompt, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Create labels for one-token ahead task, masking the prompt
|
||||
combined_labels = [self.ignore_index] * len(prompt) + completion[:-1]
|
||||
combined_labels.append(self.tokenizer.eos_token_id) # Add EOS token as final target
|
||||
if "input_ids" not in example:
|
||||
message = example[self.messages_key]
|
||||
formatted_message = self.tokenizer.apply_chat_template(
|
||||
message, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
tokenized_message = self.tokenizer(
|
||||
formatted_message,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
padding=False,
|
||||
return_tensors=None,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
input_ids.append(tokenized_message["input_ids"])
|
||||
attention_mask.append(tokenized_message["attention_mask"])
|
||||
else:
|
||||
input_ids.append(example["input_ids"])
|
||||
attention_mask.append(example["attention_mask"])
|
||||
|
||||
input_ids.append(combined_input_ids)
|
||||
attention_mask.append(combined_attention_mask)
|
||||
labels.append(combined_labels)
|
||||
tokenized_prompt = self.tokenizer(
|
||||
formatted_prompt,
|
||||
truncation=True,
|
||||
max_length=len(input_ids[-1]),
|
||||
padding=False,
|
||||
return_tensors=None,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
|
||||
# first convert to list of tensors
|
||||
input_ids = [torch.tensor(ids) for ids in input_ids]
|
||||
attention_mask = [torch.tensor(mask) for mask in attention_mask]
|
||||
labels = [torch.tensor(label) for label in labels]
|
||||
prompts_input_ids.append(tokenized_prompt["input_ids"])
|
||||
prompt_attention_mask.append(tokenized_prompt["attention_mask"])
|
||||
|
||||
# pad the input_ids, attention_mask and labels to the same length across the batch
|
||||
# Create the labels that will have all but the completion tokens of the example["input_ids"] set to ignore_index
|
||||
label = [self.ignore_index] * len(input_ids[-1])
|
||||
completion_start_idx = len(tokenized_prompt["input_ids"])
|
||||
label[completion_start_idx:] = input_ids[-1][completion_start_idx:]
|
||||
labels.append(label)
|
||||
|
||||
# convert to list of tensors and pad
|
||||
input_ids = [torch.tensor(ids, dtype=torch.long) for ids in input_ids]
|
||||
attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in attention_mask]
|
||||
labels = [torch.tensor(label, dtype=torch.long) for label in labels]
|
||||
input_ids = pad(input_ids, padding_side="left", padding_value=self.tokenizer.pad_token_id)
|
||||
attention_mask = pad(attention_mask, padding_side="left", padding_value=0)
|
||||
labels = pad(labels, padding_side="left", padding_value=self.ignore_index)
|
||||
|
||||
# pad the tokenized_prompts on the left to the same length convert to tensor first
|
||||
prompts_input_ids = [torch.tensor(ids) for ids in tokenized_prompts["input_ids"]]
|
||||
prompts_input_ids = [torch.tensor(ids, dtype=torch.long) for ids in prompts_input_ids]
|
||||
prompt_attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in prompt_attention_mask]
|
||||
prompts_input_ids = pad(prompts_input_ids, padding_side="left", padding_value=self.tokenizer.pad_token_id)
|
||||
|
||||
# prompt attention mask
|
||||
prompt_attention_mask = pad(
|
||||
[torch.tensor([1] * len(ids)) for ids in tokenized_prompts["input_ids"]],
|
||||
padding_side="left",
|
||||
padding_value=0,
|
||||
)
|
||||
prompt_attention_mask = pad(prompt_attention_mask, padding_side="left", padding_value=0)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
|
@ -12,8 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from trl.trainer.online_dpo_config import OnlineDPOConfig
|
||||
|
||||
@ -30,4 +30,9 @@ class XPOConfig(OnlineDPOConfig):
|
||||
Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each new epoch and the last alpha is used for the rest of the epochs.
|
||||
"""
|
||||
|
||||
alpha: Union[float, List[float]] = 1e-5
|
||||
alpha: List[float] = field(default_factory=lambda: [1e-5])
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if hasattr(self.alpha, "__len__") and len(self.alpha) == 1:
|
||||
self.alpha = self.alpha[0]
|
||||
|
Reference in New Issue
Block a user