Compare commits

...

5 Commits

4 changed files with 48 additions and 10 deletions

View File

@ -1,6 +1,6 @@
[metadata]
name = trl
version = 0.18.0
version = 0.18.2
description = Train transformer language models with reinforcement learning.
long_description = file: README.md
long_description_content_type = text/markdown
@ -89,7 +89,6 @@ dev =
%(quantization)s
%(scikit)s
%(test)s
%(vllm)s
%(vlm)s
[options.entry_points]

View File

@ -1178,3 +1178,34 @@ class GRPOTrainerTester(unittest.TestCase):
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
def test_training_multiple_dataloader_workers(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=8, # reduce the completion length to reduce memory usage
dataloader_num_workers=2, # use multiple dataloader workers
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.18.0"
__version__ = "0.18.2"
from typing import TYPE_CHECKING

View File

@ -18,6 +18,7 @@ import warnings
from collections import defaultdict, deque
from collections.abc import Sized
from contextlib import nullcontext
from functools import partial
from typing import Any, Callable, Optional, Union
import datasets
@ -277,6 +278,11 @@ def nanmax(tensor: torch.Tensor) -> torch.Tensor:
return torch.max(tensor[~torch.isnan(tensor)])
def identity(x):
"""Do we really need docs for this?"""
return x
class GRPOTrainer(Trainer):
"""
Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
@ -483,10 +489,6 @@ class GRPOTrainer(Trainer):
reward_processing_classes[i] = reward_processing_class
self.reward_processing_classes = reward_processing_classes
# Data collator
def data_collator(features): # No data collation is needed in GRPO
return features
# Training arguments
self.max_prompt_length = args.max_prompt_length
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
@ -541,7 +543,7 @@ class GRPOTrainer(Trainer):
super().__init__(
model=model,
args=args,
data_collator=data_collator,
data_collator=identity, # No data collation is needed in GRPO
train_dataset=train_dataset,
eval_dataset=eval_dataset,
processing_class=processing_class,
@ -750,7 +752,13 @@ class GRPOTrainer(Trainer):
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
if version.parse(transformers.__version__) >= version.parse("4.52.0"):
# from transformers 4.52.0, the `seed_worker` requires the `num_workers` and `rank` arguments
dataloader_params["worker_init_fn"] = partial(
seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index
)
else:
dataloader_params["worker_init_fn"] = seed_worker
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
@ -1229,7 +1237,7 @@ class GRPOTrainer(Trainer):
# Identify sequences that terminated with EOS and log their lengths
agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1))
term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos]
clipped_completions_ratio = 1 - len(term_completion_lengths) / len(completion_lengths)
clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths)
self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio)
if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found
term_completion_lengths = torch.zeros(1, device=device)