mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
1009 lines
50 KiB
Python
1009 lines
50 KiB
Python
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import tempfile
|
|
import unittest
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
from datasets import load_dataset
|
|
from parameterized import parameterized
|
|
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
|
|
from transformers.testing_utils import require_peft
|
|
from transformers.utils import is_peft_available
|
|
|
|
from trl import GRPOConfig, GRPOTrainer
|
|
from trl.trainer.grpo_trainer import RepeatRandomSampler
|
|
|
|
from .testing_utils import require_vllm
|
|
|
|
|
|
if is_peft_available():
|
|
from peft import LoraConfig, PeftModel
|
|
|
|
|
|
class RepeatRandomSamplerTester(unittest.TestCase):
|
|
def test_sampler(self):
|
|
dataset = ["a", "b", "c", "d", "e", "f", "g"]
|
|
sampler = RepeatRandomSampler(dataset, mini_repeat_count=2)
|
|
# Should output something like [4, 4, 3, 3, 0, 0, 1, 1, 2, 2, 6, 6, 5, 5]
|
|
sampled = list(sampler)
|
|
# Check that the length is doubled
|
|
assert len(sampled) == 2 * len(dataset)
|
|
# Check that all indexes are present
|
|
assert set(sampled) == set(range(len(dataset)))
|
|
# Check that each element is repeated twice
|
|
assert all(sampled[i] == sampled[i + 1] for i in range(0, len(sampled), 2))
|
|
|
|
def test_sampler_no_repeat(self):
|
|
dataset = ["a", "b", "c", "d", "e", "f", "g"]
|
|
sampler = RepeatRandomSampler(dataset, mini_repeat_count=1)
|
|
# Should output something like [4, 3, 0, 1, 2, 6, 5]
|
|
sampled = list(sampler)
|
|
# Check that the length is the same
|
|
assert len(sampled) == len(dataset)
|
|
# Check that all indexes are present
|
|
assert set(sampled) == set(range(len(dataset)))
|
|
|
|
def test_sampler_with_batch_size(self):
|
|
dataset = ["a", "b", "c", "d", "e", "f", "g", "h"]
|
|
sampler = RepeatRandomSampler(dataset, mini_repeat_count=1, batch_size=2, repeat_count=2)
|
|
# Should output something like [4, 3, 4, 3, 0, 1, 0, 1, 2, 6, 2, 6, 5, 7, 5, 7]
|
|
sampled = list(sampler)
|
|
# Check that the length is doubled
|
|
assert len(sampled) == 2 * len(dataset)
|
|
# Check that all indexes are present
|
|
assert set(sampled) == set(range(len(dataset)))
|
|
# Check that each element is repeated as expected
|
|
assert all(sampled[i : i + 1] == sampled[i + 2 : i + 3] for i in range(0, len(sampled), 4))
|
|
|
|
def test_sampler_with_batch_size_and_drop(self):
|
|
dataset = ["a", "b", "c", "d", "e", "f", "g"]
|
|
sampler = RepeatRandomSampler(dataset, mini_repeat_count=1, batch_size=2, repeat_count=2)
|
|
# Should output something like [4, 3, 4, 3, 0, 1, 0, 1, 2, 6, 2, 6]
|
|
sampled = list(sampler)
|
|
# Check that the length is doubled
|
|
assert len(sampled) == 2 * (
|
|
len(dataset) - 1
|
|
) # one element is dropped, because it's not enough to form a batch
|
|
# Check that the sampled indexes are a subset of the dataset indexes
|
|
assert set(sampled).issubset(set(range(len(dataset))))
|
|
# Check that each element is repeated as expected
|
|
assert all(sampled[i : i + 1] == sampled[i + 2 : i + 3] for i in range(0, len(sampled), 4))
|
|
|
|
def test_sampler_with_mini_repeat_count_and_batch_size_1(self):
|
|
dataset = ["a", "b", "c", "d", "e", "f", "g"]
|
|
sampler = RepeatRandomSampler(dataset, mini_repeat_count=2, batch_size=3, repeat_count=2)
|
|
# Should output something like [4, 4, 3, 3, 0, 0, 4, 4, 3, 3, 0, 0,
|
|
# 1, 1, 2, 2, 6, 6, 1, 1, 2, 2, 6, 6]
|
|
sampled = list(sampler)
|
|
# Check that the length is quadrupled
|
|
assert len(sampled) == 4 * (len(dataset) - 1) # 1 element is dropped, because it's not enough to form a batch
|
|
# Check that the sampled indexes are a subset of the dataset indexes
|
|
assert set(sampled).issubset(set(range(len(dataset))))
|
|
# Check that each element is repeated as expected
|
|
assert all(sampled[i] == sampled[i + 1] for i in range(0, len(sampled), 2))
|
|
# Check that the batch is repeated as expected
|
|
assert sampled[0:6] == sampled[6:12]
|
|
assert sampled[12:18] == sampled[18:24]
|
|
|
|
def test_sampler_with_mini_repeat_count_and_batch_size_2(self):
|
|
dataset = ["a", "b", "c", "d", "e", "f", "g"]
|
|
sampler = RepeatRandomSampler(dataset, mini_repeat_count=3, batch_size=2, repeat_count=2)
|
|
# Should output something like [4, 4, 4, 3, 3, 3, 4, 4, 4, 3, 3, 3,
|
|
# 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1,
|
|
# 2, 2, 2, 6, 6, 6, 2, 2, 2, 6, 6, 6]
|
|
sampled = list(sampler)
|
|
# Check that the length is sextupled
|
|
assert len(sampled) == 6 * (len(dataset) - 1) # 1 element is dropped, because it's not enough to form a batch
|
|
# Check that the sampled indexes are a subset of the dataset indexes
|
|
assert set(sampled).issubset(set(range(len(dataset))))
|
|
# Check that each element is repeated as expected
|
|
assert all(sampled[i] == sampled[i + 1] == sampled[i + 2] for i in range(0, len(sampled), 3))
|
|
# Check that the batch is repeated as expected
|
|
assert sampled[0:6] == sampled[6:12]
|
|
assert sampled[12:18] == sampled[18:24]
|
|
assert sampled[24:30] == sampled[30:36]
|
|
|
|
def test_sampler_with_mini_repeat_count_and_batch_size_3(self):
|
|
dataset = ["a", "b", "c", "d", "e", "f", "g"]
|
|
sampler = RepeatRandomSampler(dataset, mini_repeat_count=2, batch_size=2, repeat_count=3)
|
|
# Should output something like [4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3,
|
|
# 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1,
|
|
# 2, 2, 6, 6, 2, 2, 6, 6, 2, 2, 6, 6]
|
|
sampled = list(sampler)
|
|
# Check that the length is sextupled
|
|
assert len(sampled) == 6 * (len(dataset) - 1) # 1 element is dropped, because it's not enough to form a batch
|
|
# Check that the sampled indexes are a subset of the dataset indexes
|
|
assert set(sampled).issubset(set(range(len(dataset))))
|
|
# Check that each element is repeated as expected
|
|
assert all(sampled[i] == sampled[i + 1] for i in range(0, len(sampled), 2))
|
|
# Check that the batch is repeated as expected
|
|
assert sampled[0:4] == sampled[4:8] == sampled[8:12]
|
|
assert sampled[12:16] == sampled[16:20] == sampled[20:24]
|
|
assert sampled[24:28] == sampled[28:32] == sampled[32:36]
|
|
|
|
|
|
class GRPOTrainerTester(unittest.TestCase):
|
|
def test_init_minimal(self):
|
|
# Test that GRPOTrainer can be instantiated with only model, reward_model and train_dataset
|
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
|
GRPOTrainer(
|
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
|
train_dataset=dataset,
|
|
)
|
|
|
|
@parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)])
|
|
def test_training(self, config_name):
|
|
dataset = load_dataset("trl-internal-testing/zen", config_name, split="train")
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
training_args = GRPOConfig(
|
|
output_dir=tmp_dir,
|
|
learning_rate=0.1, # increase the learning rate to speed up the test
|
|
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
|
num_generations=3, # reduce the number of generations to reduce memory usage
|
|
max_completion_length=8, # reduce the completion length to reduce memory usage
|
|
report_to="none",
|
|
)
|
|
trainer = GRPOTrainer(
|
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
|
args=training_args,
|
|
train_dataset=dataset,
|
|
)
|
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
trainer.train()
|
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
|
|
|
# Check that the params have changed
|
|
for n, param in previous_trainable_params.items():
|
|
new_param = trainer.model.get_parameter(n)
|
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
|
|
|
def test_training_with_eval(self):
|
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
training_args = GRPOConfig(
|
|
output_dir=tmp_dir,
|
|
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
|
per_device_eval_batch_size=3, # reduce the batch size to reduce memory usage
|
|
num_generations=3, # reduce the number of generations to reduce memory usage
|
|
max_completion_length=8, # reduce the completion length to reduce memory usage
|
|
eval_strategy="steps",
|
|
eval_steps=2,
|
|
report_to="none",
|
|
)
|
|
trainer = GRPOTrainer(
|
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
|
args=training_args,
|
|
train_dataset=dataset["train"],
|
|
eval_dataset=dataset["test"],
|
|
)
|
|
|
|
trainer.train()
|
|
|
|
def test_training_multiple_iterations(self):
|
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
training_args = GRPOConfig(
|
|
output_dir=tmp_dir,
|
|
learning_rate=0.1, # increase the learning rate to speed up the test
|
|
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
|
num_generations=3, # reduce the number of generations to reduce memory usage
|
|
max_completion_length=8, # reduce the completion length to reduce memory usage
|
|
num_iterations=2,
|
|
report_to="none",
|
|
)
|
|
trainer = GRPOTrainer(
|
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
|
args=training_args,
|
|
train_dataset=dataset,
|
|
)
|
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
trainer.train()
|
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
|
|
|
# Check that the params have changed
|
|
for n, param in previous_trainable_params.items():
|
|
new_param = trainer.model.get_parameter(n)
|
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
|
|
|
@require_peft
|
|
def test_training_peft(self):
|
|
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
|
|
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
|
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
training_args = GRPOConfig(
|
|
output_dir=tmp_dir,
|
|
learning_rate=0.1, # increase the learning rate to speed up the test
|
|
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
|
num_generations=3, # reduce the number of generations to reduce memory usage
|
|
max_completion_length=8, # reduce the completion length to reduce memory usage
|
|
report_to="none",
|
|
)
|
|
trainer = GRPOTrainer(
|
|
model=model,
|
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
|
args=training_args,
|
|
train_dataset=dataset,
|
|
peft_config=LoraConfig(),
|
|
)
|
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
trainer.train()
|
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
|
|
|
# Check that the peft params have changed and the base model params have not changed
|
|
for n, param in previous_trainable_params.items():
|
|
new_param = trainer.model.get_parameter(n)
|
|
if n in base_param_names: # We expect the base model params to be the same
|
|
self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.")
|
|
elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer)
|
|
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.")
|
|
|
|
@require_peft
|
|
def test_training_peft_with_gradient_checkpointing(self):
|
|
"""Test that training works with PEFT and gradient checkpointing enabled."""
|
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
"trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
|
torch_dtype=torch.float32, # Use float32 for testing to avoid precision issues
|
|
use_cache=False, # Required for gradient checkpointing
|
|
)
|
|
|
|
lora_config = LoraConfig(
|
|
r=8, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none"
|
|
)
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
training_args = GRPOConfig(
|
|
output_dir=tmp_dir,
|
|
learning_rate=0.1,
|
|
per_device_train_batch_size=3,
|
|
num_generations=3,
|
|
max_completion_length=8,
|
|
gradient_checkpointing=True, # Enable gradient checkpointing
|
|
report_to="none",
|
|
)
|
|
trainer = GRPOTrainer(
|
|
model=model,
|
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
|
args=training_args,
|
|
train_dataset=dataset,
|
|
peft_config=lora_config,
|
|
)
|
|
|
|
# Verify gradient checkpointing is enabled
|
|
self.assertIsInstance(trainer.model, PeftModel)
|
|
|
|
# Store initial parameters to check which ones change
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
trainer.train()
|
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
|
|
|
# Check that only LoRA parameters have changed, base model parameters remain unchanged
|
|
for n, param in previous_trainable_params.items():
|
|
new_param = trainer.model.get_parameter(n)
|
|
if "lora" in n.lower(): # LoRA parameters should change
|
|
self.assertFalse(torch.equal(param, new_param), f"LoRA parameter {n} has not changed.")
|
|
else: # Base model parameters should not change
|
|
self.assertTrue(torch.equal(param, new_param), f"Base parameter {n} has changed.")
|
|
|
|
def test_training_different_reward_model(self):
|
|
# Use a reward model different from the model: different chat template, tokenization, etc.
|
|
dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train")
|
|
reward_model_id = "trl-internal-testing/tiny-LlamaForSequenceClassification-3.2"
|
|
reward_model = AutoModelForSequenceClassification.from_pretrained(reward_model_id)
|
|
reward_tokenizer = AutoTokenizer.from_pretrained(reward_model_id)
|
|
# By default, the trainer uses the eos token as the padding token. However, for Llama models, the eos token
|
|
# appears in the chat template. Using it as a pad token disrupts the reward calculation, as the calculation
|
|
# considers the score of the last token before the first pad token. To ensure correct reward calculations,
|
|
# we use a separate pad token instead.
|
|
reward_tokenizer.pad_token = "<|finetune_right_pad_id|>"
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
training_args = GRPOConfig(
|
|
output_dir=tmp_dir,
|
|
learning_rate=0.1, # increase the learning rate to speed up the test
|
|
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
|
num_generations=3, # reduce the number of generations to reduce memory usage
|
|
max_completion_length=8, # reduce the completion length to reduce memory usage
|
|
report_to="none",
|
|
)
|
|
trainer = GRPOTrainer(
|
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
|
reward_funcs=reward_model,
|
|
args=training_args,
|
|
train_dataset=dataset,
|
|
reward_processing_classes=reward_tokenizer,
|
|
)
|
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
trainer.train()
|
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
|
|
|
# Check that the params have changed
|
|
for n, param in previous_trainable_params.items():
|
|
new_param = trainer.model.get_parameter(n)
|
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
|
|
|
def test_training_reward_func_standard(self):
|
|
# Test if trainer can handle reward function with standard format
|
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
|
|
|
def reward_func(completions, **kwargs):
|
|
"""Reward function that rewards longer completions."""
|
|
return [float(len(completion)) for completion in completions]
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
training_args = GRPOConfig(
|
|
output_dir=tmp_dir,
|
|
learning_rate=0.1, # increase the learning rate to speed up the test
|
|
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
|
num_generations=3, # reduce the number of generations to reduce memory usage
|
|
max_completion_length=8, # reduce the completion length to reduce memory usage
|
|
report_to="none",
|
|
)
|
|
trainer = GRPOTrainer(
|
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
|
reward_funcs=reward_func,
|
|
args=training_args,
|
|
train_dataset=dataset,
|
|
)
|
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
trainer.train()
|
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
|
|
|
# Check that the params have changed
|
|
for n, param in previous_trainable_params.items():
|
|
new_param = trainer.model.get_parameter(n)
|
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
|
|
|
def test_training_reward_func_conversational(self):
|
|
# Test if trainer can handle reward function with conversational format
|
|
dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train")
|
|
|
|
def reward_func(completions, **kwargs):
|
|
"""Reward function that gives higher scores to longer completion content."""
|
|
completion_contents = [completion[0]["content"] for completion in completions]
|
|
return [float(len(content)) for content in completion_contents]
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
training_args = GRPOConfig(
|
|
output_dir=tmp_dir,
|
|
learning_rate=0.1, # increase the learning rate to speed up the test
|
|
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
|
num_generations=3, # reduce the number of generations to reduce memory usage
|
|
max_completion_length=8, # reduce the completion length to reduce memory usage
|
|
report_to="none",
|
|
)
|
|
trainer = GRPOTrainer(
|
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
|
reward_funcs=reward_func,
|
|
args=training_args,
|
|
train_dataset=dataset,
|
|
)
|
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
trainer.train()
|
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
|
|
|
# Check that the params have changed
|
|
for n, param in previous_trainable_params.items():
|
|
new_param = trainer.model.get_parameter(n)
|
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
|
|
|
def test_training_multiple_reward_funcs(self):
|
|
# Test that GRPOTrainer can be instantiated with multiple reward functions
|
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
|
|
|
def reward_func1(completions, **kwargs):
|
|
"""Reward function that rewards longer completions."""
|
|
return [float(len(completion)) for completion in completions]
|
|
|
|
def reward_func2(completions, **kwargs):
|
|
"""Reward function that rewards completions with more unique letters."""
|
|
return [float(len(set(completion))) for completion in completions]
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
training_args = GRPOConfig(
|
|
output_dir=tmp_dir,
|
|
learning_rate=0.1, # increase the learning rate to speed up the test
|
|
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
|
num_generations=3, # reduce the number of generations to reduce memory usage
|
|
max_completion_length=8, # reduce the completion length to reduce memory usage
|
|
report_to="none",
|
|
)
|
|
trainer = GRPOTrainer(
|
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
|
reward_funcs=[reward_func1, reward_func2],
|
|
args=training_args,
|
|
train_dataset=dataset,
|
|
)
|
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
trainer.train()
|
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
|
|
|
# Check that the params have changed
|
|
for n, param in previous_trainable_params.items():
|
|
new_param = trainer.model.get_parameter(n)
|
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
|
|
|
def test_training_multiple_reward_funcs_with_None_output(self):
|
|
"""Test that a valid math reward function is processed correctly while the code reward function returns None."""
|
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
|
|
|
def applicable_reward_func(completions, **kwargs):
|
|
"""A reward function that rewards longer completions."""
|
|
return [float(len(completion)) for completion in completions]
|
|
|
|
def non_applicable_reward_func(completions, **kwargs):
|
|
"""A reward function that returns None for all inputs, as it is not applicable to this sample."""
|
|
return [None] * len(completions)
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
training_args = GRPOConfig(
|
|
output_dir=tmp_dir,
|
|
learning_rate=0.1,
|
|
per_device_train_batch_size=3,
|
|
num_generations=3,
|
|
max_completion_length=8,
|
|
report_to="none",
|
|
)
|
|
|
|
trainer = GRPOTrainer(
|
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
|
reward_funcs=[
|
|
applicable_reward_func,
|
|
non_applicable_reward_func,
|
|
], # One applicable, one non applicable
|
|
args=training_args,
|
|
train_dataset=dataset,
|
|
)
|
|
|
|
previous_trainable_params = {
|
|
n: param.clone() for n, param in trainer.model.named_parameters() if param.requires_grad
|
|
}
|
|
|
|
trainer.train()
|
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
|
|
|
# Check that the params have changed
|
|
for n, param in previous_trainable_params.items():
|
|
new_param = trainer.model.get_parameter(n)
|
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
|
|
|
def test_training_multiple_reward_funcs_with_weights(self):
|
|
"""Test that GRPOTrainer can handle multiple reward functions with weights."""
|
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
|
|
|
def reward_func1(completions, **kwargs):
|
|
"""Reward function that rewards longer completions."""
|
|
return [float(len(completion)) for completion in completions]
|
|
|
|
def reward_func2(completions, **kwargs):
|
|
"""Reward function that rewards completions with more unique letters."""
|
|
return [float(len(set(completion))) for completion in completions]
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
training_args = GRPOConfig(
|
|
output_dir=tmp_dir,
|
|
learning_rate=0.1, # increase the learning rate to speed up the test
|
|
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
|
num_generations=3, # reduce the number of generations to reduce memory usage
|
|
max_completion_length=8, # reduce the completion length to reduce memory usage
|
|
report_to="none",
|
|
reward_weights=[0.7, 0.3], # weight of reward_func1 and reward_func2 respectively
|
|
)
|
|
trainer = GRPOTrainer(
|
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
|
reward_funcs=[reward_func1, reward_func2],
|
|
args=training_args,
|
|
train_dataset=dataset,
|
|
)
|
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
trainer.train()
|
|
|
|
# Check that training logs contain both reward metrics
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
|
self.assertIn("rewards/reward_func1/mean", trainer.state.log_history[-1])
|
|
self.assertIn("rewards/reward_func1/std", trainer.state.log_history[-1])
|
|
self.assertIn("rewards/reward_func2/mean", trainer.state.log_history[-1])
|
|
self.assertIn("rewards/reward_func2/std", trainer.state.log_history[-1])
|
|
|
|
# Check that the params have changed
|
|
for n, param in previous_trainable_params.items():
|
|
new_param = trainer.model.get_parameter(n)
|
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
|
|
|
def test_training_multiple_mixed_reward_funcs(self):
|
|
# Test if the trainer can handle a mix of reward functions and reward models
|
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
|
|
|
def reward_func(completions, **kwargs):
|
|
"""Reward function that rewards longer completions."""
|
|
return [float(len(completion)) for completion in completions]
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
training_args = GRPOConfig(
|
|
output_dir=tmp_dir,
|
|
learning_rate=0.1, # increase the learning rate to speed up the test
|
|
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
|
num_generations=3, # reduce the number of generations to reduce memory usage
|
|
max_completion_length=8, # reduce the completion length to reduce memory usage
|
|
report_to="none",
|
|
)
|
|
trainer = GRPOTrainer(
|
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
|
reward_funcs=[reward_func, "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5"],
|
|
args=training_args,
|
|
train_dataset=dataset,
|
|
)
|
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
trainer.train()
|
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
|
|
|
# Check that the params have changed
|
|
for n, param in previous_trainable_params.items():
|
|
new_param = trainer.model.get_parameter(n)
|
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
|
|
|
def test_training_reward_func_additional_column(self):
|
|
# Test if trainer can handle reward function that rely on additional columns in the dataset
|
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
|
|
|
# Add a column to the dataset (dummy example, the column could be anything)
|
|
some_values = list(range(len(dataset)))
|
|
dataset = dataset.add_column("some_values", some_values)
|
|
|
|
def reward_func(completions, some_values, **kwargs):
|
|
"""Reward function that rewards completions with lengths closer to the values in some_values."""
|
|
return [float(abs(len(completion) - value)) for completion, value in zip(completions, some_values)]
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
training_args = GRPOConfig(
|
|
output_dir=tmp_dir,
|
|
learning_rate=0.1, # increase the learning rate to speed up the test
|
|
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
|
num_generations=3, # reduce the number of generations to reduce memory usage
|
|
max_completion_length=8, # reduce the completion length to reduce memory usage
|
|
report_to="none",
|
|
)
|
|
trainer = GRPOTrainer(
|
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
|
reward_funcs=reward_func,
|
|
args=training_args,
|
|
train_dataset=dataset,
|
|
)
|
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
trainer.train()
|
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
|
|
|
# Check that the params have changed
|
|
for n, param in previous_trainable_params.items():
|
|
new_param = trainer.model.get_parameter(n)
|
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
|
|
|
@require_vllm
|
|
@unittest.skip("We should add a mock for the vLLM server.")
|
|
def test_training_vllm(self):
|
|
"""Test that training works with vLLM for generation."""
|
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
training_args = GRPOConfig(
|
|
output_dir=tmp_dir,
|
|
learning_rate=0.1, # increase the learning rate to speed up the test
|
|
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
|
num_generations=3, # reduce the number of generations to reduce memory usage
|
|
max_completion_length=8, # reduce the completion length to reduce memory usage
|
|
report_to="none",
|
|
use_vllm=True,
|
|
)
|
|
trainer = GRPOTrainer(
|
|
model="Qwen/Qwen2.5-0.5B-Instruct", # tiny is too small for vLLM
|
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
|
args=training_args,
|
|
train_dataset=dataset,
|
|
)
|
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
trainer.train()
|
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
|
|
|
# Check that the params have changed
|
|
for n, param in previous_trainable_params.items():
|
|
new_param = trainer.model.get_parameter(n)
|
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
|
|
|
def test_training_with_sync_ref_model(self):
|
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
training_args = GRPOConfig(
|
|
output_dir=tmp_dir,
|
|
learning_rate=0.1, # increase the learning rate to speed up the test
|
|
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
|
num_generations=3, # reduce the number of generations to reduce memory usage
|
|
max_completion_length=8, # reduce the completion length to reduce memory usage
|
|
sync_ref_model=True,
|
|
ref_model_sync_steps=2, # reduce sync steps to ensure a sync happens
|
|
report_to="none",
|
|
)
|
|
trainer = GRPOTrainer(
|
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
|
args=training_args,
|
|
train_dataset=dataset,
|
|
)
|
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
trainer.train()
|
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
|
|
|
# Check that the params have changed
|
|
for n, param in previous_trainable_params.items():
|
|
new_param = trainer.model.get_parameter(n)
|
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
|
|
|
def test_beta_zero_no_ref_model_and_no_kl(self):
|
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
training_args = GRPOConfig(
|
|
output_dir=tmp_dir,
|
|
beta=0.0, # set beta to 0 to test the case where the reference model is not used
|
|
learning_rate=0.1, # increase the learning rate to speed up the test
|
|
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
|
num_generations=3, # reduce the number of generations to reduce memory usage
|
|
max_completion_length=8, # reduce the completion length to reduce memory usage
|
|
report_to="none",
|
|
)
|
|
trainer = GRPOTrainer(
|
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
|
args=training_args,
|
|
train_dataset=dataset,
|
|
)
|
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
trainer.train()
|
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
|
|
|
# Check that the params have changed
|
|
for n, param in previous_trainable_params.items():
|
|
new_param = trainer.model.get_parameter(n)
|
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
|
|
|
@unittest.skip("We should add a mock for the vLLM server.")
|
|
@require_peft
|
|
@require_vllm
|
|
def test_training_vllm_and_peft(self):
|
|
"""Test that training works with vLLM for generation."""
|
|
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") # tiny model is too small for vLLM
|
|
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
|
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
training_args = GRPOConfig(
|
|
output_dir=tmp_dir,
|
|
learning_rate=0.1, # increase the learning rate to speed up the test
|
|
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
|
num_generations=3, # reduce the number of generations to reduce memory usage
|
|
max_completion_length=8, # reduce the completion length to reduce memory usage
|
|
report_to="none",
|
|
use_vllm=True,
|
|
)
|
|
lora_config = LoraConfig(
|
|
target_modules="all-linear",
|
|
# test with non-default modules as it add extra keys in state_dict tht we need to handle
|
|
modules_to_save=["embed_tokens", "lm_head"],
|
|
)
|
|
trainer = GRPOTrainer(
|
|
model=model,
|
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
|
args=training_args,
|
|
train_dataset=dataset,
|
|
peft_config=lora_config,
|
|
)
|
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
trainer.train()
|
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
|
|
|
# Check that the peft params have changed and the base model params have not changed
|
|
for n, param in previous_trainable_params.items():
|
|
new_param = trainer.model.get_parameter(n)
|
|
if n in base_param_names: # We expect the base model params to be the same
|
|
self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.")
|
|
elif "base_layer" not in n and "original_module" not in n:
|
|
# We expect the peft params to be different (except for the base layer)
|
|
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.")
|
|
|
|
@require_vllm
|
|
@unittest.skip("We should add a mock for the vLLM server.")
|
|
def test_training_vllm_guided_decoding(self):
|
|
"""Test that training works with vLLM for generation with guided decoding."""
|
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
training_args = GRPOConfig(
|
|
output_dir=tmp_dir,
|
|
learning_rate=0.1, # increase the learning rate to speed up the test
|
|
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
|
num_generations=3, # reduce the number of generations to reduce memory usage
|
|
max_completion_length=8, # reduce the completion length to reduce memory usage
|
|
report_to="none",
|
|
use_vllm=True,
|
|
vllm_guided_decoding_regex=r"<reasoning>\n.*\n</reasoning>\n<answer>\n.*\n</answer>",
|
|
)
|
|
trainer = GRPOTrainer(
|
|
model="Qwen/Qwen2.5-0.5B-Instruct", # tiny model is too small for vLLM
|
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
|
args=training_args,
|
|
train_dataset=dataset,
|
|
)
|
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
trainer.train()
|
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
|
|
|
# Check that the params have changed
|
|
for n, param in previous_trainable_params.items():
|
|
new_param = trainer.model.get_parameter(n)
|
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
|
|
|
def test_training_with_additional_generation_kwargs(self):
|
|
"""Test that training works with additional generation kwargs."""
|
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
training_args = GRPOConfig(
|
|
output_dir=tmp_dir,
|
|
learning_rate=0.1, # increase the learning rate to speed up the test
|
|
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
|
num_generations=3, # reduce the number of generations to reduce memory usage
|
|
max_completion_length=8, # reduce the completion length to reduce memory usage
|
|
report_to="none",
|
|
top_p=0.9,
|
|
top_k=10,
|
|
min_p=0.01,
|
|
repetition_penalty=1.1,
|
|
)
|
|
|
|
trainer = GRPOTrainer(
|
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
|
args=training_args,
|
|
train_dataset=dataset,
|
|
)
|
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
trainer.train()
|
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
|
|
|
# Check that the params have changed
|
|
for n, param in previous_trainable_params.items():
|
|
new_param = trainer.model.get_parameter(n)
|
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
|
|
|
@require_vllm
|
|
@unittest.skip("We should add a mock for the vLLM server.")
|
|
def test_training_vllm_with_additional_generation_kwargs(self):
|
|
"""Test that training works with vLLM and additional generation kwargs."""
|
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
training_args = GRPOConfig(
|
|
output_dir=tmp_dir,
|
|
learning_rate=0.1, # increase the learning rate to speed up the test
|
|
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
|
num_generations=3, # reduce the number of generations to reduce memory usage
|
|
max_completion_length=8, # reduce the completion length to reduce memory usage
|
|
report_to="none",
|
|
use_vllm=True,
|
|
top_p=0.9,
|
|
top_k=10,
|
|
min_p=0.01,
|
|
repetition_penalty=1.1,
|
|
)
|
|
|
|
trainer = GRPOTrainer(
|
|
model="Qwen/Qwen2.5-0.5B-Instruct", # tiny model is too small for vLLM
|
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
|
args=training_args,
|
|
train_dataset=dataset,
|
|
)
|
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
trainer.train()
|
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
|
|
|
# Check that the params have changed
|
|
for n, param in previous_trainable_params.items():
|
|
new_param = trainer.model.get_parameter(n)
|
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
|
|
|
def test_training_no_scale_rewards(self):
|
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
training_args = GRPOConfig(
|
|
output_dir=tmp_dir,
|
|
learning_rate=0.1, # increase the learning rate to speed up the test
|
|
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
|
num_generations=3, # reduce the number of generations to reduce memory usage
|
|
max_completion_length=8, # reduce the completion length to reduce memory usage
|
|
scale_rewards=False,
|
|
report_to="none",
|
|
)
|
|
trainer = GRPOTrainer(
|
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
|
args=training_args,
|
|
train_dataset=dataset,
|
|
)
|
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
trainer.train()
|
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
|
|
|
# Check that the params have changed
|
|
for n, param in previous_trainable_params.items():
|
|
new_param = trainer.model.get_parameter(n)
|
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
|
|
|
@patch("transformers.generation.utils.GenerationMixin.generate")
|
|
def test_training_with_mask_truncated_completions(self, mock_generate):
|
|
"""Test that training works with mask_truncated_completions=True parameter."""
|
|
|
|
# We mock the generate method because the model's random weights make it extremely unlikely to produce a
|
|
# sequence containing the EOS token within the allowed max_completion_length. As a result, all tokens are
|
|
# masked in the loss, the model doesn't update, and the final check (which verifies the update) fails.
|
|
def fake_generate(prompt_ids, **kwargs):
|
|
# pad_token_id = 151643; eos_token_id = 151645
|
|
completions_ids = torch.tensor(
|
|
[
|
|
[1, 2, 3, 4, 5, 6, 7, 8], # this one is truncated
|
|
[9, 10, 11, 151645, 151643, 151643, 151643, 151643], # this one contains eos
|
|
[12, 13, 14, 15, 16, 17, 18, 151645], # particular case, eos is generated just within the limit
|
|
],
|
|
device=prompt_ids.device,
|
|
)
|
|
return torch.cat([prompt_ids, completions_ids], dim=1)
|
|
|
|
mock_generate.side_effect = fake_generate
|
|
|
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
training_args = GRPOConfig(
|
|
output_dir=tmp_dir,
|
|
learning_rate=0.1, # increase the learning rate to speed up the test
|
|
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
|
num_generations=3, # reduce the number of generations to reduce memory usage
|
|
max_completion_length=8, # reduce the completion length to reduce memory usage
|
|
mask_truncated_completions=True, # Enable masking of truncated completions
|
|
report_to="none",
|
|
)
|
|
trainer = GRPOTrainer(
|
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
|
args=training_args,
|
|
train_dataset=dataset,
|
|
)
|
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
trainer.train()
|
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
|
|
|
# Check that the params have changed
|
|
for n, param in previous_trainable_params.items():
|
|
new_param = trainer.model.get_parameter(n)
|
|
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
|
|
|
def test_training_with_mask_truncated_completions_all_masked(self):
|
|
"""
|
|
Test that when all generated completions are truncated (i.e., none contain an EOS token), and
|
|
mask_truncated_completions=True, the model receives no effective learning signal and therefore does not update
|
|
its parameters.
|
|
|
|
Here, we don't mock the generate method, be we rely on the fact that the model the probability of generating
|
|
the EOS token is extremely low, so all generated completions are truncated.
|
|
"""
|
|
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
training_args = GRPOConfig(
|
|
output_dir=tmp_dir,
|
|
learning_rate=0.1, # increase the learning rate to speed up the test
|
|
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
|
num_generations=3, # reduce the number of generations to reduce memory usage
|
|
max_completion_length=8, # reduce the completion length to reduce memory usage
|
|
mask_truncated_completions=True, # Enable masking of truncated completions
|
|
report_to="none",
|
|
)
|
|
trainer = GRPOTrainer(
|
|
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
|
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
|
args=training_args,
|
|
train_dataset=dataset,
|
|
)
|
|
|
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
|
|
trainer.train()
|
|
|
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
|
|
|
# Check that the params have changed
|
|
for n, param in previous_trainable_params.items():
|
|
new_param = trainer.model.get_parameter(n)
|
|
self.assertTrue(torch.equal(param, new_param), f"Parameter {n} has changed.")
|