mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Compare commits
2 Commits
7e9c6e45d5
...
grpo-log-e
Author | SHA1 | Date | |
---|---|---|---|
9c46200672 | |||
ac6dc65fdd |
51
train_grpo.py
Normal file
51
train_grpo.py
Normal file
@ -0,0 +1,51 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
|
||||
|
||||
dataset = load_dataset("open-r1/DAPO-Math-17k-Processed-R1-Distill-Qwen-Math-7B-v03.00-step-000008190-filter", split="train")
|
||||
def make_conversation(example, prompt_column: str = "prompt"):
|
||||
prompt = []
|
||||
|
||||
if prompt_column not in example:
|
||||
raise ValueError(f"Dataset Question Field Error: {prompt_column} is not supported.")
|
||||
|
||||
prompt.append({"role": "user", "content": example[prompt_column]})
|
||||
return {"prompt": prompt}
|
||||
|
||||
dataset = dataset.map(make_conversation)
|
||||
|
||||
if "messages" in dataset.column_names:
|
||||
dataset = dataset.remove_columns("messages")
|
||||
|
||||
# Define the reward function, which rewards completions that are close to 20 characters
|
||||
def reward_len(completions, **kwargs):
|
||||
return [-abs(20 - len(completion)) for completion in completions]
|
||||
|
||||
|
||||
training_args = GRPOConfig(output_dir="data/Qwen/Qwen2.5-0.5B-Instruct", logging_steps=1, gradient_accumulation_steps=2, num_generations=4,
|
||||
max_completion_length=4000, max_steps=20, gradient_checkpointing=False, beta=0.0, per_device_train_batch_size=2,
|
||||
replay_buffer_class="SSRReplayBuffer",
|
||||
ssr_capacity_scalar=4
|
||||
)
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2.5-0.5B-Instruct",
|
||||
reward_funcs=reward_len,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
trainer.train()
|
@ -60,6 +60,7 @@ from .utils import (
|
||||
pad,
|
||||
print_prompt_completions_sample,
|
||||
selective_log_softmax,
|
||||
calculate_entropy,
|
||||
)
|
||||
|
||||
|
||||
@ -834,6 +835,7 @@ class GRPOTrainer(Trainer):
|
||||
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep, batch_size=None) -> torch.Tensor:
|
||||
batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak
|
||||
all_logps = []
|
||||
entropies = []
|
||||
for i in range(0, input_ids.size(0), batch_size):
|
||||
input_ids_batch = input_ids[i : i + batch_size]
|
||||
attention_mask_batch = attention_mask[i : i + batch_size]
|
||||
@ -852,7 +854,9 @@ class GRPOTrainer(Trainer):
|
||||
logits = logits / self.temperature
|
||||
logps = selective_log_softmax(logits, input_ids_batch) # compute logprobs for the input tokens
|
||||
all_logps.append(logps)
|
||||
return torch.cat(all_logps, dim=0)
|
||||
entropy = calculate_entropy(logits, attention_mask_batch[:, -logits_to_keep:])
|
||||
entropies.append(entropy)
|
||||
return torch.cat(all_logps, dim=0), torch.cat(entropies, dim=0)
|
||||
|
||||
def _sync_fsdp_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None):
|
||||
"""Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with vLLM."""
|
||||
@ -1125,7 +1129,7 @@ class GRPOTrainer(Trainer):
|
||||
# old_per_token_logps == per_token_logps, so we can skip it's computation here, and use
|
||||
# per_token_logps.detach() instead.
|
||||
if self.num_iterations > 1 or self.args.steps_per_generation > self.args.gradient_accumulation_steps:
|
||||
old_per_token_logps = self._get_per_token_logps(
|
||||
old_per_token_logps, _ = self._get_per_token_logps(
|
||||
self.model, prompt_completion_ids, attention_mask, logits_to_keep, batch_size
|
||||
)
|
||||
else:
|
||||
@ -1268,12 +1272,12 @@ class GRPOTrainer(Trainer):
|
||||
if self.beta != 0.0:
|
||||
with torch.no_grad():
|
||||
if self.ref_model is not None:
|
||||
ref_per_token_logps = self._get_per_token_logps(
|
||||
ref_per_token_logps, ref_entropy = self._get_per_token_logps(
|
||||
self.ref_model, input_ids, attention_mask, logits_to_keep
|
||||
)
|
||||
else:
|
||||
with self.accelerator.unwrap_model(self.model).disable_adapter():
|
||||
ref_per_token_logps = self._get_per_token_logps(
|
||||
ref_per_token_logps, ref_entropy = self._get_per_token_logps(
|
||||
self.model, input_ids, attention_mask, logits_to_keep
|
||||
)
|
||||
|
||||
@ -1321,18 +1325,18 @@ class GRPOTrainer(Trainer):
|
||||
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
||||
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
||||
|
||||
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
|
||||
per_token_logps, entropy = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
|
||||
|
||||
# Compute the KL divergence between the model and the reference model
|
||||
if self.beta != 0.0:
|
||||
with torch.no_grad():
|
||||
if self.ref_model is not None:
|
||||
ref_per_token_logps = self._get_per_token_logps(
|
||||
ref_per_token_logps, ref_entropy = self._get_per_token_logps(
|
||||
self.ref_model, input_ids, attention_mask, logits_to_keep
|
||||
)
|
||||
else:
|
||||
with self.accelerator.unwrap_model(self.model).disable_adapter():
|
||||
ref_per_token_logps = self._get_per_token_logps(
|
||||
ref_per_token_logps, ref_entropy = self._get_per_token_logps(
|
||||
self.model, input_ids, attention_mask, logits_to_keep
|
||||
)
|
||||
per_token_kl = (
|
||||
@ -1395,6 +1399,8 @@ class GRPOTrainer(Trainer):
|
||||
self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item())
|
||||
gathered_clip_ratio = self.accelerator.gather_for_metrics(clip_ratio)
|
||||
self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())
|
||||
gathered_entropy = self.accelerator.gather_for_metrics(entropy.detach())
|
||||
self._metrics[mode]["entropy"].append(gathered_entropy.mean().item())
|
||||
return loss
|
||||
|
||||
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
|
||||
|
@ -1683,6 +1683,15 @@ def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> tuple[torch.Tensor
|
||||
else:
|
||||
return mask, *tensors
|
||||
|
||||
def calculate_entropy(logits, attention_mask):
|
||||
entropy_values = []
|
||||
for row_logits, mask_row in zip(logits, attention_mask):
|
||||
probs = F.softmax(row_logits, dim=-1)
|
||||
entropy_row = -torch.sum(probs * torch.log(probs + 1e-20), dim=-1) * mask_row
|
||||
entropy_values.append(entropy_row)
|
||||
entropy_values = torch.stack(entropy_values)
|
||||
return entropy_values
|
||||
|
||||
|
||||
def selective_log_softmax(logits, index):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user