Compare commits

...

5 Commits

Author SHA1 Message Date
1a678abea7 style 2025-03-14 21:11:58 +00:00
14b51cb919 default to static 2025-03-14 21:10:36 +00:00
c38589e4d5 Merge branch 'main' into static-cache-grpo 2025-03-14 06:06:47 -07:00
a4d433771a disabling gradient chekpt for gen 2025-03-07 11:12:45 +00:00
7ab86cae02 static cache grpo 2025-03-06 16:11:59 +00:00
2 changed files with 22 additions and 5 deletions

View File

@ -72,7 +72,7 @@ class GRPOConfig(TrainingArguments):
Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far.
Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat
tokens.
cache_implementation (`str` or `None`, *optional*, defaults to `None`):
cache_implementation (`str` or `None`, *optional*, defaults to `"static"`):
Implementation of the cache method for faster generation when use_vllm is set to False.
> Parameters that control generation acceleration powered by vLLM

View File

@ -18,7 +18,7 @@ import os
import textwrap
import warnings
from collections import defaultdict
from typing import Any, Callable, Optional, Sized, Union
from typing import Any, Callable, Generator, Optional, Sized, Union
from unittest.mock import patch
import torch
@ -163,6 +163,22 @@ class RepeatRandomSampler(Sampler):
return self.num_samples * self.mini_repeat_count * self.repeat_count
@contextlib.contextmanager
def disable_gradient_checkpointing(model: PreTrainedModel) -> Generator[None, None, None]:
"""
Temporarily disables gradient checkpointing in the model, if it is enabled.
It is usefull when using the model to generate completions, while training it with gradient checkpointing.
Args:
model (`PreTrainedModel`): Model to disable gradient checkpointing for.
"""
value = model.base_model.gradient_checkpointing
model.base_model.gradient_checkpointing = False
yield
model.base_model.gradient_checkpointing = value
class GRPOTrainer(Trainer):
"""
Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
@ -770,9 +786,10 @@ class GRPOTrainer(Trainer):
else:
# Regular generation path
with unwrap_model_for_generation(self.model_wrapped, self.accelerator) as unwrapped_model:
prompt_completion_ids = unwrapped_model.generate(
prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
)
with disable_gradient_checkpointing(unwrapped_model):
prompt_completion_ids = unwrapped_model.generate(
prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
)
# Compute prompt length and extract completion ids
prompt_length = prompt_ids.size(1)