mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
495 lines
20 KiB
Python
495 lines
20 KiB
Python
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import json
|
|
import os
|
|
from unittest.mock import call, patch
|
|
|
|
from datasets import load_dataset
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, Trainer, TrainingArguments
|
|
from transformers.testing_utils import require_wandb
|
|
from transformers.trainer_utils import get_last_checkpoint
|
|
from transformers.utils import is_peft_available
|
|
|
|
from trl import (
|
|
BasePairwiseJudge,
|
|
BEMACallback,
|
|
DPOConfig,
|
|
DPOTrainer,
|
|
LogCompletionsCallback,
|
|
MergeModelCallback,
|
|
WinRateCallback,
|
|
)
|
|
from trl.mergekit_utils import MergeConfig
|
|
|
|
from .testing_utils import TrlTestCase, require_comet, require_mergekit, require_peft
|
|
|
|
|
|
if is_peft_available():
|
|
from peft import LoraConfig
|
|
|
|
|
|
class HalfPairwiseJudge(BasePairwiseJudge):
|
|
"""Naive pairwise judge that always returns [1, 0] for two prompts"""
|
|
|
|
def judge(self, prompts, completions, shuffle_order=True, return_scores=False):
|
|
# just check that the batch size is 2
|
|
assert len(prompts) == 2
|
|
if return_scores:
|
|
return [0.3, 0.9]
|
|
return [1, 0]
|
|
|
|
|
|
class TrainerWithRefModel(Trainer):
|
|
# This is a dummy class to test the callback. Compared to the Trainer class, it only has an additional
|
|
# ref_model attribute
|
|
def __init__(self, model, ref_model, args, train_dataset, eval_dataset, processing_class):
|
|
super().__init__(
|
|
model=model,
|
|
args=args,
|
|
train_dataset=train_dataset,
|
|
eval_dataset=eval_dataset,
|
|
processing_class=processing_class,
|
|
)
|
|
self.ref_model = ref_model
|
|
|
|
|
|
class TestWinRateCallback(TrlTestCase):
|
|
def setup_method(self):
|
|
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
|
|
self.ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
|
|
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")
|
|
dataset["train"] = dataset["train"].select(range(8))
|
|
self.expected_winrates = [
|
|
{"eval_win_rate": 0.5, "epoch": 0.0, "step": 0},
|
|
{"eval_win_rate": 0.5, "epoch": 0.5, "step": 2},
|
|
{"eval_win_rate": 0.5, "epoch": 1.0, "step": 4},
|
|
{"eval_win_rate": 0.5, "epoch": 1.5, "step": 6},
|
|
{"eval_win_rate": 0.5, "epoch": 2.0, "step": 8},
|
|
{"eval_win_rate": 0.5, "epoch": 2.5, "step": 10},
|
|
{"eval_win_rate": 0.5, "epoch": 3.0, "step": 12},
|
|
]
|
|
|
|
def tokenize_function(examples):
|
|
out = self.tokenizer(examples["prompt"], padding="max_length", max_length=16, truncation=True)
|
|
out["labels"] = out["input_ids"].copy()
|
|
return out
|
|
|
|
self.dataset = dataset.map(tokenize_function, batched=True)
|
|
|
|
self.generation_config = GenerationConfig(max_length=32)
|
|
self.judge = HalfPairwiseJudge()
|
|
|
|
def test_basic(self):
|
|
training_args = TrainingArguments(
|
|
output_dir=self.tmp_dir,
|
|
eval_strategy="steps",
|
|
eval_steps=2, # evaluate every 2 steps
|
|
per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch
|
|
per_device_eval_batch_size=2,
|
|
report_to="none",
|
|
)
|
|
trainer = TrainerWithRefModel(
|
|
model=self.model,
|
|
ref_model=self.ref_model,
|
|
args=training_args,
|
|
train_dataset=self.dataset["train"],
|
|
eval_dataset=self.dataset["test"],
|
|
processing_class=self.tokenizer,
|
|
)
|
|
win_rate_callback = WinRateCallback(
|
|
judge=self.judge, trainer=trainer, generation_config=self.generation_config
|
|
)
|
|
trainer.add_callback(win_rate_callback)
|
|
trainer.train()
|
|
winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h]
|
|
for history_row, expected_row in zip(winrate_history, self.expected_winrates):
|
|
assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row)
|
|
|
|
def test_without_ref_model(self):
|
|
# Same as before, but without the ref_model attribute. It should use the model attribute instead
|
|
training_args = TrainingArguments(
|
|
output_dir=self.tmp_dir,
|
|
eval_strategy="steps",
|
|
eval_steps=2, # evaluate every 2 steps
|
|
per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch
|
|
per_device_eval_batch_size=2,
|
|
report_to="none",
|
|
)
|
|
trainer = Trainer(
|
|
model=self.model,
|
|
args=training_args,
|
|
train_dataset=self.dataset["train"],
|
|
eval_dataset=self.dataset["test"],
|
|
processing_class=self.tokenizer,
|
|
)
|
|
win_rate_callback = WinRateCallback(
|
|
judge=self.judge, trainer=trainer, generation_config=self.generation_config
|
|
)
|
|
trainer.add_callback(win_rate_callback)
|
|
trainer.train()
|
|
winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h]
|
|
for history_row, expected_row in zip(winrate_history, self.expected_winrates):
|
|
assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row)
|
|
|
|
def test_soft_judge(self):
|
|
"""Test that the soft judge functionality works correctly"""
|
|
training_args = TrainingArguments(
|
|
output_dir=self.tmp_dir,
|
|
eval_strategy="steps",
|
|
eval_steps=2, # evaluate every 2 steps
|
|
per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch
|
|
per_device_eval_batch_size=2,
|
|
report_to="none",
|
|
)
|
|
trainer = TrainerWithRefModel(
|
|
model=self.model,
|
|
ref_model=self.ref_model,
|
|
args=training_args,
|
|
train_dataset=self.dataset["train"],
|
|
eval_dataset=self.dataset["test"],
|
|
processing_class=self.tokenizer,
|
|
)
|
|
win_rate_callback = WinRateCallback(
|
|
judge=self.judge, trainer=trainer, generation_config=self.generation_config, use_soft_judge=True
|
|
)
|
|
trainer.add_callback(win_rate_callback)
|
|
trainer.train()
|
|
|
|
# Expected values based on judge returning [0.3, 0.9] for each pair
|
|
expected_soft_winrates = [
|
|
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 0.0, "step": 0},
|
|
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 0.5, "step": 2},
|
|
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 1.0, "step": 4},
|
|
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 1.5, "step": 6},
|
|
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 2.0, "step": 8},
|
|
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 2.5, "step": 10},
|
|
{"eval_avg_win_prob": 0.4, "eval_win_rate": 0.5, "epoch": 3.0, "step": 12},
|
|
]
|
|
|
|
winrate_history = [
|
|
{k: h[k] for k in ["eval_avg_win_prob", "eval_win_rate", "epoch", "step"]}
|
|
for h in trainer.state.log_history
|
|
if "eval_avg_win_prob" in h
|
|
]
|
|
for history_row, expected_row in zip(winrate_history, expected_soft_winrates):
|
|
assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row)
|
|
|
|
@require_peft
|
|
def test_lora(self):
|
|
peft_config = LoraConfig(
|
|
r=16,
|
|
lora_alpha=32,
|
|
lora_dropout=0.05,
|
|
bias="none",
|
|
task_type="CAUSAL_LM",
|
|
)
|
|
self.model.add_adapter(peft_config)
|
|
training_args = TrainingArguments(
|
|
output_dir=self.tmp_dir,
|
|
eval_strategy="steps",
|
|
eval_steps=2, # evaluate every 2 steps
|
|
per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch
|
|
per_device_eval_batch_size=2,
|
|
report_to="none",
|
|
)
|
|
trainer = Trainer(
|
|
model=self.model,
|
|
args=training_args,
|
|
train_dataset=self.dataset["train"],
|
|
eval_dataset=self.dataset["test"],
|
|
processing_class=self.tokenizer,
|
|
)
|
|
win_rate_callback = WinRateCallback(
|
|
judge=self.judge, trainer=trainer, generation_config=self.generation_config
|
|
)
|
|
trainer.add_callback(win_rate_callback)
|
|
trainer.train()
|
|
winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h]
|
|
for history_row, expected_row in zip(winrate_history, self.expected_winrates):
|
|
assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row)
|
|
|
|
|
|
class TestLogCompletionsCallback(TrlTestCase):
|
|
def setup_method(self):
|
|
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
|
|
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")
|
|
dataset["train"] = dataset["train"].select(range(8))
|
|
|
|
def tokenize_function(examples):
|
|
out = self.tokenizer(examples["prompt"], padding="max_length", max_length=16, truncation=True)
|
|
out["labels"] = out["input_ids"].copy()
|
|
return out
|
|
|
|
self.dataset = dataset.map(tokenize_function, batched=True)
|
|
|
|
self.generation_config = GenerationConfig(max_length=32)
|
|
|
|
@require_wandb
|
|
def test_basic_wandb(self):
|
|
import wandb
|
|
|
|
training_args = TrainingArguments(
|
|
output_dir=self.tmp_dir,
|
|
eval_strategy="steps",
|
|
eval_steps=2, # evaluate every 2 steps
|
|
per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch
|
|
per_device_eval_batch_size=2,
|
|
report_to="wandb",
|
|
)
|
|
trainer = Trainer(
|
|
model=self.model,
|
|
args=training_args,
|
|
train_dataset=self.dataset["train"],
|
|
eval_dataset=self.dataset["test"],
|
|
processing_class=self.tokenizer,
|
|
)
|
|
completions_callback = LogCompletionsCallback(trainer, self.generation_config, num_prompts=2)
|
|
trainer.add_callback(completions_callback)
|
|
trainer.train()
|
|
|
|
# Get the current run
|
|
completions_path = wandb.run.summary.completions["path"]
|
|
json_path = os.path.join(wandb.run.dir, completions_path)
|
|
with open(json_path) as f:
|
|
completions = json.load(f)
|
|
|
|
# Check that the columns are correct
|
|
assert "step" in completions["columns"]
|
|
assert "prompt" in completions["columns"]
|
|
assert "completion" in completions["columns"]
|
|
|
|
# Check that the prompt is in the log
|
|
assert self.dataset["test"][0]["prompt"] in completions["data"][0]
|
|
|
|
@require_comet
|
|
def test_basic_comet(self):
|
|
import comet_ml
|
|
|
|
training_args = TrainingArguments(
|
|
output_dir=self.tmp_dir,
|
|
eval_strategy="steps",
|
|
eval_steps=2, # evaluate every 2 steps
|
|
per_device_train_batch_size=2, # 8 samples in total so 4 batches of 2 per epoch
|
|
per_device_eval_batch_size=2,
|
|
report_to="comet_ml",
|
|
)
|
|
trainer = Trainer(
|
|
model=self.model,
|
|
args=training_args,
|
|
train_dataset=self.dataset["train"],
|
|
eval_dataset=self.dataset["test"],
|
|
processing_class=self.tokenizer,
|
|
)
|
|
completions_callback = LogCompletionsCallback(trainer, self.generation_config, num_prompts=2)
|
|
trainer.add_callback(completions_callback)
|
|
trainer.train()
|
|
|
|
# close experiment to make sure all pending data are flushed
|
|
experiment = comet_ml.get_running_experiment()
|
|
assert experiment is not None
|
|
experiment.end()
|
|
|
|
# get experiment assets and check that all required tables was logged
|
|
steps = len(self.dataset["train"]) + len(self.dataset["test"])
|
|
tables_logged = int(steps / 2) + 1 # +1 to include zero step
|
|
|
|
api_experiment = comet_ml.APIExperiment(previous_experiment=experiment.id)
|
|
tables = api_experiment.get_asset_list("dataframe")
|
|
assert tables is not None
|
|
assert len(tables) == tables_logged
|
|
assert all(table["fileName"] == "completions.csv" for table in tables)
|
|
|
|
|
|
@require_mergekit
|
|
class TestMergeModelCallback(TrlTestCase):
|
|
def setup_method(self):
|
|
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
|
|
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
|
|
self.dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
|
|
|
|
def test_callback(self):
|
|
training_args = DPOConfig(
|
|
output_dir=self.tmp_dir,
|
|
num_train_epochs=1,
|
|
report_to="none",
|
|
save_strategy="steps",
|
|
save_steps=1,
|
|
)
|
|
config = MergeConfig()
|
|
merge_callback = MergeModelCallback(config)
|
|
trainer = DPOTrainer(
|
|
model=self.model,
|
|
args=training_args,
|
|
train_dataset=self.dataset,
|
|
processing_class=self.tokenizer,
|
|
callbacks=[merge_callback],
|
|
)
|
|
trainer.train()
|
|
last_checkpoint = get_last_checkpoint(self.tmp_dir)
|
|
merged_path = os.path.join(last_checkpoint, "merged")
|
|
assert os.path.isdir(merged_path), "Merged folder does not exist in the last checkpoint."
|
|
|
|
def test_every_checkpoint(self):
|
|
training_args = DPOConfig(
|
|
output_dir=self.tmp_dir,
|
|
num_train_epochs=1,
|
|
report_to="none",
|
|
save_strategy="steps",
|
|
save_steps=1,
|
|
)
|
|
config = MergeConfig()
|
|
merge_callback = MergeModelCallback(config, merge_at_every_checkpoint=True)
|
|
trainer = DPOTrainer(
|
|
model=self.model,
|
|
args=training_args,
|
|
train_dataset=self.dataset,
|
|
processing_class=self.tokenizer,
|
|
callbacks=[merge_callback],
|
|
)
|
|
trainer.train()
|
|
|
|
checkpoints = sorted(
|
|
[os.path.join(self.tmp_dir, cp) for cp in os.listdir(self.tmp_dir) if cp.startswith("checkpoint-")]
|
|
)
|
|
|
|
for checkpoint in checkpoints:
|
|
merged_path = os.path.join(checkpoint, "merged")
|
|
assert os.path.isdir(merged_path), f"Merged folder does not exist in checkpoint {checkpoint}."
|
|
|
|
|
|
class TestBEMACallback(TrlTestCase):
|
|
def setup_method(self):
|
|
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
|
|
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling")
|
|
|
|
def tokenize_function(examples, tokenizer):
|
|
out = tokenizer(examples["text"], padding="max_length", max_length=17)
|
|
out["labels"] = out["input_ids"].copy()
|
|
return out
|
|
|
|
self.dataset = dataset.map(
|
|
tokenize_function, fn_kwargs={"tokenizer": self.tokenizer}, remove_columns=["text"], batched=True
|
|
)
|
|
|
|
def test_model_saved(self):
|
|
"""Test that BEMACallback saves the BEMA model."""
|
|
training_args = TrainingArguments(output_dir=self.tmp_dir, report_to="none")
|
|
bema_callback = BEMACallback(update_freq=2)
|
|
trainer = Trainer(
|
|
model=self.model,
|
|
args=training_args,
|
|
train_dataset=self.dataset["train"],
|
|
processing_class=self.tokenizer,
|
|
callbacks=[bema_callback],
|
|
)
|
|
trainer.train()
|
|
|
|
# Check that the BEMA model was saved and can be loaded
|
|
bema_path = os.path.join(self.tmp_dir, "bema")
|
|
assert os.path.isdir(bema_path), "BEMA directory was not created"
|
|
AutoModelForCausalLM.from_pretrained(bema_path)
|
|
|
|
def test_update_frequency_0(self):
|
|
"""Test that BEMA callback respects the update frequency."""
|
|
training_args = TrainingArguments(output_dir=self.tmp_dir, report_to="none")
|
|
bema_callback = BEMACallback(update_freq=2)
|
|
|
|
with patch.object(bema_callback, "_update_bema_weights") as mock_update:
|
|
trainer = Trainer(
|
|
model=self.model,
|
|
args=training_args,
|
|
train_dataset=self.dataset["train"],
|
|
processing_class=self.tokenizer,
|
|
callbacks=[bema_callback],
|
|
)
|
|
|
|
trainer.train()
|
|
|
|
# Total 9 steps (17 samples, batch size 8, 3 epochs).
|
|
# BEMA starts after step 0 and updates every 2 steps → updates at 2, 4, 5, 8
|
|
assert mock_update.call_args_list == [call(2), call(4), call(6), call(8)]
|
|
|
|
def test_update_frequency_1(self):
|
|
"""Test that BEMA callback respects the update frequency."""
|
|
training_args = TrainingArguments(output_dir=self.tmp_dir, report_to="none")
|
|
bema_callback = BEMACallback(update_freq=3)
|
|
|
|
with patch.object(bema_callback, "_update_bema_weights") as mock_update:
|
|
trainer = Trainer(
|
|
model=self.model,
|
|
args=training_args,
|
|
train_dataset=self.dataset["train"],
|
|
processing_class=self.tokenizer,
|
|
callbacks=[bema_callback],
|
|
)
|
|
|
|
trainer.train()
|
|
|
|
# Total 9 steps (17 samples, batch size 8, 3 epochs).
|
|
# BEMA starts after step 0 and updates every 3 steps → updates at 3, 6, 9
|
|
assert mock_update.call_args_list == [call(3), call(6), call(9)]
|
|
|
|
def test_update_frequency_2(self):
|
|
"""Test that BEMA callback respects the update frequency."""
|
|
training_args = TrainingArguments(output_dir=self.tmp_dir, report_to="none")
|
|
bema_callback = BEMACallback(update_freq=2, update_after=3)
|
|
|
|
with patch.object(bema_callback, "_update_bema_weights") as mock_update:
|
|
trainer = Trainer(
|
|
model=self.model,
|
|
args=training_args,
|
|
train_dataset=self.dataset["train"],
|
|
processing_class=self.tokenizer,
|
|
callbacks=[bema_callback],
|
|
)
|
|
|
|
trainer.train()
|
|
|
|
# Total 9 steps (17 samples, batch size 8, 3 epochs).
|
|
# BEMA starts after step 3 and updates every 2 steps → updates at 5, 7, 9
|
|
assert mock_update.call_args_list == [call(5), call(7), call(9)]
|
|
|
|
def test_no_bema(self):
|
|
"""Test that BEMACallback works without BEMA updates."""
|
|
training_args = TrainingArguments(output_dir=self.tmp_dir, report_to="none")
|
|
bema_callback = BEMACallback(update_freq=2, bias_power=0.0)
|
|
trainer = Trainer(
|
|
model=self.model,
|
|
args=training_args,
|
|
train_dataset=self.dataset["train"],
|
|
processing_class=self.tokenizer,
|
|
callbacks=[bema_callback],
|
|
)
|
|
trainer.train()
|
|
|
|
def test_no_ema(self):
|
|
"""Test that BEMACallback works without EMA updates."""
|
|
training_args = TrainingArguments(output_dir=self.tmp_dir, report_to="none")
|
|
bema_callback = BEMACallback(update_freq=2, ema_power=0.0)
|
|
trainer = Trainer(
|
|
model=self.model,
|
|
args=training_args,
|
|
train_dataset=self.dataset["train"],
|
|
processing_class=self.tokenizer,
|
|
callbacks=[bema_callback],
|
|
)
|
|
trainer.train()
|