mirror of
https://github.com/huggingface/trl.git
synced 2025-10-21 02:53:59 +08:00
Compare commits
5 Commits
refactor_g
...
v0.18.2
Author | SHA1 | Date | |
---|---|---|---|
a21a925e30 | |||
1a6717661c | |||
2c49300910 | |||
e530486c26 | |||
1bae58c292 |
@ -1,6 +1,6 @@
|
||||
[metadata]
|
||||
name = trl
|
||||
version = 0.18.0
|
||||
version = 0.18.2
|
||||
description = Train transformer language models with reinforcement learning.
|
||||
long_description = file: README.md
|
||||
long_description_content_type = text/markdown
|
||||
@ -89,7 +89,6 @@ dev =
|
||||
%(quantization)s
|
||||
%(scikit)s
|
||||
%(test)s
|
||||
%(vllm)s
|
||||
%(vlm)s
|
||||
|
||||
[options.entry_points]
|
||||
|
@ -1178,3 +1178,34 @@ class GRPOTrainerTester(unittest.TestCase):
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
||||
|
||||
def test_training_multiple_dataloader_workers(self):
|
||||
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = GRPOConfig(
|
||||
output_dir=tmp_dir,
|
||||
learning_rate=0.1, # increase the learning rate to speed up the test
|
||||
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
||||
num_generations=3, # reduce the number of generations to reduce memory usage
|
||||
max_completion_length=8, # reduce the completion length to reduce memory usage
|
||||
dataloader_num_workers=2, # use multiple dataloader workers
|
||||
report_to="none",
|
||||
)
|
||||
trainer = GRPOTrainer(
|
||||
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
||||
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
# Check that the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__version__ = "0.18.0"
|
||||
__version__ = "0.18.2"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
@ -18,6 +18,7 @@ import warnings
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Sized
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import datasets
|
||||
@ -277,6 +278,11 @@ def nanmax(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return torch.max(tensor[~torch.isnan(tensor)])
|
||||
|
||||
|
||||
def identity(x):
|
||||
"""Do we really need docs for this?"""
|
||||
return x
|
||||
|
||||
|
||||
class GRPOTrainer(Trainer):
|
||||
"""
|
||||
Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
|
||||
@ -483,10 +489,6 @@ class GRPOTrainer(Trainer):
|
||||
reward_processing_classes[i] = reward_processing_class
|
||||
self.reward_processing_classes = reward_processing_classes
|
||||
|
||||
# Data collator
|
||||
def data_collator(features): # No data collation is needed in GRPO
|
||||
return features
|
||||
|
||||
# Training arguments
|
||||
self.max_prompt_length = args.max_prompt_length
|
||||
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
|
||||
@ -541,7 +543,7 @@ class GRPOTrainer(Trainer):
|
||||
super().__init__(
|
||||
model=model,
|
||||
args=args,
|
||||
data_collator=data_collator,
|
||||
data_collator=identity, # No data collation is needed in GRPO
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=processing_class,
|
||||
@ -750,7 +752,13 @@ class GRPOTrainer(Trainer):
|
||||
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
|
||||
dataloader_params["sampler"] = self._get_train_sampler()
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
dataloader_params["worker_init_fn"] = seed_worker
|
||||
if version.parse(transformers.__version__) >= version.parse("4.52.0"):
|
||||
# from transformers 4.52.0, the `seed_worker` requires the `num_workers` and `rank` arguments
|
||||
dataloader_params["worker_init_fn"] = partial(
|
||||
seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index
|
||||
)
|
||||
else:
|
||||
dataloader_params["worker_init_fn"] = seed_worker
|
||||
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||
|
||||
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
|
||||
@ -1229,7 +1237,7 @@ class GRPOTrainer(Trainer):
|
||||
# Identify sequences that terminated with EOS and log their lengths
|
||||
agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1))
|
||||
term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos]
|
||||
clipped_completions_ratio = 1 - len(term_completion_lengths) / len(completion_lengths)
|
||||
clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths)
|
||||
self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio)
|
||||
if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found
|
||||
term_completion_lengths = torch.zeros(1, device=device)
|
||||
|
Reference in New Issue
Block a user