mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
🧺 [3/N] Refactor _generate
in GRPO/RLOO: Rely on generator for prompt truncation (#4153)
This commit is contained in:
committed by
GitHub
parent
98488e0946
commit
0e57b4a9df
@ -1471,47 +1471,6 @@ class TestGRPOTrainer(TrlTestCase):
|
|||||||
new_param = trainer.model.get_parameter(n)
|
new_param = trainer.model.get_parameter(n)
|
||||||
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
|
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
|
||||||
|
|
||||||
@require_vision
|
|
||||||
def test_training_vlm_and_prompt_truncation(self):
|
|
||||||
# If not handled properly, prompt truncation may truncate image token
|
|
||||||
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
|
|
||||||
|
|
||||||
def reward_func(completions, **kwargs):
|
|
||||||
"""Reward function that rewards longer completions."""
|
|
||||||
return [float(len(completion[0]["content"])) for completion in completions]
|
|
||||||
|
|
||||||
training_args = GRPOConfig(
|
|
||||||
output_dir=self.tmp_dir,
|
|
||||||
learning_rate=0.1, # increase the learning rate to speed up the test
|
|
||||||
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
|
||||||
num_generations=3, # reduce the number of generations to reduce memory usage
|
|
||||||
max_completion_length=8, # reduce the completion length to reduce memory usage
|
|
||||||
max_prompt_length=18,
|
|
||||||
report_to="none",
|
|
||||||
)
|
|
||||||
trainer = GRPOTrainer(
|
|
||||||
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
|
|
||||||
reward_funcs=reward_func,
|
|
||||||
args=training_args,
|
|
||||||
train_dataset=dataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
||||||
|
|
||||||
trainer.train()
|
|
||||||
|
|
||||||
assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
||||||
|
|
||||||
# Check that the params have changed
|
|
||||||
# Because of the way the tiny models are initialized, the gradient does not flow properly through the
|
|
||||||
# vision parts of the model, so we skip them. Ideally, we should fix the init of these models.
|
|
||||||
params_to_skip = ("model.visual.",)
|
|
||||||
for n, param in previous_trainable_params.items():
|
|
||||||
if n.startswith(params_to_skip):
|
|
||||||
continue
|
|
||||||
new_param = trainer.model.get_parameter(n)
|
|
||||||
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
|
|
||||||
|
|
||||||
@parameterized.expand(
|
@parameterized.expand(
|
||||||
[
|
[
|
||||||
("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",),
|
("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",),
|
||||||
|
@ -1212,47 +1212,6 @@ class TestRLOOTrainer(TrlTestCase):
|
|||||||
elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer)
|
elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer)
|
||||||
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed."
|
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed."
|
||||||
|
|
||||||
@require_vision
|
|
||||||
def test_training_vlm_and_prompt_truncation(self):
|
|
||||||
# If not handled properly, prompt truncation may truncate image token
|
|
||||||
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
|
|
||||||
|
|
||||||
def reward_func(completions, **kwargs):
|
|
||||||
"""Reward function that rewards longer completions."""
|
|
||||||
return [float(len(completion[0]["content"])) for completion in completions]
|
|
||||||
|
|
||||||
training_args = RLOOConfig(
|
|
||||||
output_dir=self.tmp_dir,
|
|
||||||
learning_rate=0.1, # increase the learning rate to speed up the test
|
|
||||||
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
|
||||||
num_generations=3, # reduce the number of generations to reduce memory usage
|
|
||||||
max_completion_length=8, # reduce the completion length to reduce memory usage
|
|
||||||
max_prompt_length=18,
|
|
||||||
report_to="none",
|
|
||||||
)
|
|
||||||
trainer = RLOOTrainer(
|
|
||||||
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
|
|
||||||
reward_funcs=reward_func,
|
|
||||||
args=training_args,
|
|
||||||
train_dataset=dataset,
|
|
||||||
)
|
|
||||||
|
|
||||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
|
||||||
|
|
||||||
trainer.train()
|
|
||||||
|
|
||||||
assert trainer.state.log_history[-1]["train_loss"] is not None
|
|
||||||
|
|
||||||
# Check that the params have changed
|
|
||||||
# Because of the way the tiny models are initialized, the gradient does not flow properly through the
|
|
||||||
# vision parts of the model, so we skip them. Ideally, we should fix the init of these models.
|
|
||||||
params_to_skip = ("model.visual.",)
|
|
||||||
for n, param in previous_trainable_params.items():
|
|
||||||
if n.startswith(params_to_skip):
|
|
||||||
continue
|
|
||||||
new_param = trainer.model.get_parameter(n)
|
|
||||||
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
|
|
||||||
|
|
||||||
@parameterized.expand(
|
@parameterized.expand(
|
||||||
[
|
[
|
||||||
("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",),
|
("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",),
|
||||||
|
@ -42,7 +42,6 @@ from trl.trainer.utils import (
|
|||||||
shuffle_sequence_dict,
|
shuffle_sequence_dict,
|
||||||
split_pixel_values_by_grid,
|
split_pixel_values_by_grid,
|
||||||
split_tensor_dict,
|
split_tensor_dict,
|
||||||
truncate_with_protected_tokens,
|
|
||||||
unsplit_pixel_values_by_grid,
|
unsplit_pixel_values_by_grid,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1009,84 +1008,6 @@ class TestSplitPixelValuesByGrid(TrlTestCase):
|
|||||||
assert torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2], [1, 2, 1]]))
|
assert torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2], [1, 2, 1]]))
|
||||||
|
|
||||||
|
|
||||||
class TestTruncateWithProtectedTokens(TrlTestCase):
|
|
||||||
def test_basic_example(self):
|
|
||||||
"""Test the basic example from the problem description."""
|
|
||||||
prompt_ids = [1, 2, 3, 4, 5]
|
|
||||||
protected_tokens = [2, 3]
|
|
||||||
target_length = 3
|
|
||||||
|
|
||||||
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
|
|
||||||
|
|
||||||
expected_ids = [2, 3, 5]
|
|
||||||
assert new_ids == expected_ids
|
|
||||||
|
|
||||||
def test_no_truncation_needed(self):
|
|
||||||
"""Test when target length equals current length."""
|
|
||||||
prompt_ids = [1, 2, 3]
|
|
||||||
protected_tokens = [2]
|
|
||||||
target_length = 3
|
|
||||||
|
|
||||||
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
|
|
||||||
|
|
||||||
assert new_ids == prompt_ids
|
|
||||||
|
|
||||||
def test_no_protected_tokens(self):
|
|
||||||
"""Test truncation with no protected tokens (normal right truncation)."""
|
|
||||||
prompt_ids = [1, 2, 3, 4, 5]
|
|
||||||
protected_tokens = []
|
|
||||||
target_length = 3
|
|
||||||
|
|
||||||
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
|
|
||||||
|
|
||||||
expected_ids = [3, 4, 5] # Last 3 tokens
|
|
||||||
assert new_ids == expected_ids
|
|
||||||
|
|
||||||
def test_all_tokens_protected(self):
|
|
||||||
"""Test when all remaining tokens are protected."""
|
|
||||||
prompt_ids = [1, 2, 3, 4, 5]
|
|
||||||
protected_tokens = [3, 4, 5]
|
|
||||||
target_length = 3
|
|
||||||
|
|
||||||
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
|
|
||||||
|
|
||||||
expected_ids = [3, 4, 5]
|
|
||||||
assert new_ids == expected_ids
|
|
||||||
|
|
||||||
def test_too_many_protected_tokens(self):
|
|
||||||
"""Test error when too many protected tokens for target length."""
|
|
||||||
prompt_ids = [1, 2, 3, 4, 5]
|
|
||||||
protected_tokens = [1, 2, 3, 4]
|
|
||||||
target_length = 3
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
|
|
||||||
|
|
||||||
def test_single_batch_single_token(self):
|
|
||||||
"""Test edge case with single batch and single token."""
|
|
||||||
prompt_ids = [5]
|
|
||||||
protected_tokens = [5]
|
|
||||||
target_length = 1
|
|
||||||
|
|
||||||
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
|
|
||||||
|
|
||||||
assert new_ids == prompt_ids
|
|
||||||
|
|
||||||
def test_order_preservation(self):
|
|
||||||
"""Test that relative order is preserved."""
|
|
||||||
prompt_ids = [10, 2, 20, 3, 30, 40]
|
|
||||||
protected_tokens = [2, 3]
|
|
||||||
target_length = 4
|
|
||||||
|
|
||||||
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
|
|
||||||
|
|
||||||
# Should keep protected tokens 2, 3 and last 2 non-protected tokens 30, 40
|
|
||||||
# Order should be: 2, 3, 30, 40 (maintaining original relative positions)
|
|
||||||
expected_ids = [2, 3, 30, 40]
|
|
||||||
|
|
||||||
assert new_ids == expected_ids
|
|
||||||
|
|
||||||
|
|
||||||
class TestUnsplitPixelValuesByGrid(TrlTestCase):
|
class TestUnsplitPixelValuesByGrid(TrlTestCase):
|
||||||
def test_unsplit_correctly(self):
|
def test_unsplit_correctly(self):
|
||||||
pixel_values = [torch.randn(4, 5), torch.randn(2, 5)]
|
pixel_values = [torch.randn(4, 5), torch.randn(2, 5)]
|
||||||
|
@ -182,6 +182,7 @@ class VLLMClient:
|
|||||||
top_k: int = -1,
|
top_k: int = -1,
|
||||||
min_p: float = 0.0,
|
min_p: float = 0.0,
|
||||||
max_tokens: int = 16,
|
max_tokens: int = 16,
|
||||||
|
truncate_prompt_tokens: Optional[int] = None,
|
||||||
guided_decoding_regex: Optional[str] = None,
|
guided_decoding_regex: Optional[str] = None,
|
||||||
generation_kwargs: Optional[dict] = None,
|
generation_kwargs: Optional[dict] = None,
|
||||||
) -> list[list[int]]:
|
) -> list[list[int]]:
|
||||||
@ -207,6 +208,10 @@ class VLLMClient:
|
|||||||
Minimum probability for sampling.
|
Minimum probability for sampling.
|
||||||
max_tokens (`int`, *optional*, defaults to `16`):
|
max_tokens (`int`, *optional*, defaults to `16`):
|
||||||
Maximum number of tokens to generate for each prompt.
|
Maximum number of tokens to generate for each prompt.
|
||||||
|
truncate_prompt_tokens (`int`, *optional*):
|
||||||
|
If set to `-1`, will use the truncation size supported by the model. If set to an integer k, will use
|
||||||
|
only the last k tokens from the prompt (i.e., left truncation). If set to `None`, truncation is
|
||||||
|
disabled.
|
||||||
guided_decoding_regex (`str`, *optional*):
|
guided_decoding_regex (`str`, *optional*):
|
||||||
Regular expression to guide the decoding process.
|
Regular expression to guide the decoding process.
|
||||||
generation_kwargs (`dict`, *optional*):
|
generation_kwargs (`dict`, *optional*):
|
||||||
@ -246,6 +251,7 @@ class VLLMClient:
|
|||||||
"top_k": top_k,
|
"top_k": top_k,
|
||||||
"min_p": min_p,
|
"min_p": min_p,
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
|
"truncate_prompt_tokens": truncate_prompt_tokens,
|
||||||
"guided_decoding_regex": guided_decoding_regex,
|
"guided_decoding_regex": guided_decoding_regex,
|
||||||
"generation_kwargs": generation_kwargs or {},
|
"generation_kwargs": generation_kwargs or {},
|
||||||
},
|
},
|
||||||
|
@ -495,6 +495,7 @@ def main(script_args: ScriptArguments):
|
|||||||
top_k: int = -1
|
top_k: int = -1
|
||||||
min_p: float = 0.0
|
min_p: float = 0.0
|
||||||
max_tokens: int = 16
|
max_tokens: int = 16
|
||||||
|
truncate_prompt_tokens: Optional[int] = None
|
||||||
guided_decoding_regex: Optional[str] = None
|
guided_decoding_regex: Optional[str] = None
|
||||||
generation_kwargs: dict = field(default_factory=dict)
|
generation_kwargs: dict = field(default_factory=dict)
|
||||||
|
|
||||||
@ -525,6 +526,9 @@ def main(script_args: ScriptArguments):
|
|||||||
- `min_p` (`float`, *optional*, defaults to `0.0`): Minimum probability threshold for sampling.
|
- `min_p` (`float`, *optional*, defaults to `0.0`): Minimum probability threshold for sampling.
|
||||||
- `max_tokens` (`int`, *optional*, defaults to `16`): Maximum number of tokens to generate for each
|
- `max_tokens` (`int`, *optional*, defaults to `16`): Maximum number of tokens to generate for each
|
||||||
completion.
|
completion.
|
||||||
|
- `truncate_prompt_tokens` (`int`, *optional*): If set to `-1`, will use the truncation size supported
|
||||||
|
by the model. If set to an integer k, will use only the last k tokens from the prompt (i.e., left
|
||||||
|
truncation). If set to `None`, truncation is disabled.
|
||||||
- `guided_decoding_regex` (`str`, *optional*): A regex pattern for guided decoding. If provided, the
|
- `guided_decoding_regex` (`str`, *optional*): A regex pattern for guided decoding. If provided, the
|
||||||
model will only generate tokens that match this regex pattern.
|
model will only generate tokens that match this regex pattern.
|
||||||
- `generation_kwargs` (`dict`, *optional*): Additional generation parameters to pass to the vLLM
|
- `generation_kwargs` (`dict`, *optional*): Additional generation parameters to pass to the vLLM
|
||||||
@ -575,6 +579,7 @@ def main(script_args: ScriptArguments):
|
|||||||
"top_k": request.top_k,
|
"top_k": request.top_k,
|
||||||
"min_p": request.min_p,
|
"min_p": request.min_p,
|
||||||
"max_tokens": request.max_tokens,
|
"max_tokens": request.max_tokens,
|
||||||
|
"truncate_prompt_tokens": request.truncate_prompt_tokens,
|
||||||
"guided_decoding": guided_decoding,
|
"guided_decoding": guided_decoding,
|
||||||
"logprobs": 0,
|
"logprobs": 0,
|
||||||
}
|
}
|
||||||
|
@ -14,7 +14,6 @@
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import textwrap
|
import textwrap
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
@ -71,7 +70,6 @@ from .utils import (
|
|||||||
shuffle_sequence_dict,
|
shuffle_sequence_dict,
|
||||||
split_pixel_values_by_grid,
|
split_pixel_values_by_grid,
|
||||||
split_tensor_dict,
|
split_tensor_dict,
|
||||||
truncate_with_protected_tokens,
|
|
||||||
unsplit_pixel_values_by_grid,
|
unsplit_pixel_values_by_grid,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -275,7 +273,7 @@ class GRPOTrainer(BaseTrainer):
|
|||||||
|
|
||||||
# Processing class
|
# Processing class
|
||||||
if processing_class is None:
|
if processing_class is None:
|
||||||
processing_class = AutoProcessor.from_pretrained(model.config._name_or_path)
|
processing_class = AutoProcessor.from_pretrained(model.config._name_or_path, truncation_side="left")
|
||||||
|
|
||||||
# Handle pad token for processors or tokenizers
|
# Handle pad token for processors or tokenizers
|
||||||
if isinstance(processing_class, ProcessorMixin):
|
if isinstance(processing_class, ProcessorMixin):
|
||||||
@ -291,10 +289,6 @@ class GRPOTrainer(BaseTrainer):
|
|||||||
self.pad_token = tokenizer.pad_token
|
self.pad_token = tokenizer.pad_token
|
||||||
self.pad_token_id = tokenizer.pad_token_id
|
self.pad_token_id = tokenizer.pad_token_id
|
||||||
self.eos_token_id = tokenizer.eos_token_id
|
self.eos_token_id = tokenizer.eos_token_id
|
||||||
self.image_token = getattr(processing_class, "image_token", None)
|
|
||||||
self.image_token_id = getattr(processing_class, "image_token_id", None)
|
|
||||||
self.vision_start_token_id = getattr(model.config, "vision_start_token_id", None)
|
|
||||||
self.vision_end_token_id = getattr(model.config, "vision_end_token_id", None)
|
|
||||||
|
|
||||||
# Reward functions
|
# Reward functions
|
||||||
if not isinstance(reward_funcs, list):
|
if not isinstance(reward_funcs, list):
|
||||||
@ -1092,58 +1086,12 @@ class GRPOTrainer(BaseTrainer):
|
|||||||
maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
|
maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
|
||||||
]
|
]
|
||||||
|
|
||||||
prompt_inputs = self.processing_class(
|
if images is not None:
|
||||||
text=prompts_text,
|
prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs)
|
||||||
return_tensors="pt",
|
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
||||||
padding=True,
|
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
||||||
padding_side="left",
|
else:
|
||||||
add_special_tokens=False,
|
forward_kwargs = {}
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
|
||||||
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
|
||||||
|
|
||||||
if self.max_prompt_length is not None:
|
|
||||||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
|
||||||
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
|
|
||||||
|
|
||||||
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
|
|
||||||
# Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special
|
|
||||||
# tokens are needed for generation.
|
|
||||||
protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id]
|
|
||||||
protected = [token for token in protected if token is not None]
|
|
||||||
prompt_ids = [truncate_with_protected_tokens(ids, self.max_prompt_length, protected) for ids in prompt_ids]
|
|
||||||
|
|
||||||
prompts_text = self.processing_class.batch_decode(
|
|
||||||
prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# The chat template sometimes inserts a single image token into the prompt text. However, when this text is
|
|
||||||
# later tokenized, the single image token string is expanded into multiple image token IDs, depending on the
|
|
||||||
# image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We
|
|
||||||
# collapse them back into a single token string to match the original chat template in case it originally
|
|
||||||
# applies it. Otherwise, it assumes that the chat template uses only vision_start_token_id to indicate images
|
|
||||||
# (e.g. Gemma 3) and removes all image_token instances and vision_end_token_id as well, leaving only
|
|
||||||
# the vision_start_token_id (e.g. <start_of_image>).
|
|
||||||
if self.image_token is not None:
|
|
||||||
escaped_img_token = re.escape(self.image_token)
|
|
||||||
# Search for the image token in the chat template
|
|
||||||
if re.search(escaped_img_token, self.processing_class.chat_template):
|
|
||||||
prompts_text = [
|
|
||||||
re.sub(rf"({escaped_img_token})+", self.image_token, text) for text in prompts_text
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
# If the chat template doesn't use the image token, we remove all instances of it + vision_end_token_id
|
|
||||||
if self.vision_end_token_id is not None:
|
|
||||||
escaped_eoi_token = re.escape(
|
|
||||||
self.processing_class.tokenizer.decode([self.vision_end_token_id])
|
|
||||||
)
|
|
||||||
prompts_text = [
|
|
||||||
re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
# If vision_end_token_id is None, just remove the image tokens
|
|
||||||
prompts_text = [re.sub(rf"({escaped_img_token})+", "", text) for text in prompts_text]
|
|
||||||
|
|
||||||
# Generate completions using either vLLM or regular generation
|
# Generate completions using either vLLM or regular generation
|
||||||
if self.use_vllm:
|
if self.use_vllm:
|
||||||
@ -1185,6 +1133,7 @@ class GRPOTrainer(BaseTrainer):
|
|||||||
top_k=-1 if self.top_k is None else self.top_k,
|
top_k=-1 if self.top_k is None else self.top_k,
|
||||||
min_p=0.0 if self.min_p is None else self.min_p,
|
min_p=0.0 if self.min_p is None else self.min_p,
|
||||||
max_tokens=self.max_completion_length,
|
max_tokens=self.max_completion_length,
|
||||||
|
truncate_prompt_tokens=self.max_prompt_length,
|
||||||
guided_decoding_regex=self.guided_decoding_regex,
|
guided_decoding_regex=self.guided_decoding_regex,
|
||||||
generation_kwargs=self.args.generation_kwargs,
|
generation_kwargs=self.args.generation_kwargs,
|
||||||
)
|
)
|
||||||
@ -1223,6 +1172,7 @@ class GRPOTrainer(BaseTrainer):
|
|||||||
"top_k": -1 if self.top_k is None else self.top_k,
|
"top_k": -1 if self.top_k is None else self.top_k,
|
||||||
"min_p": 0.0 if self.min_p is None else self.min_p,
|
"min_p": 0.0 if self.min_p is None else self.min_p,
|
||||||
"max_tokens": self.max_completion_length,
|
"max_tokens": self.max_completion_length,
|
||||||
|
"truncate_prompt_tokens": self.max_prompt_length,
|
||||||
"guided_decoding": guided_decoding,
|
"guided_decoding": guided_decoding,
|
||||||
"logprobs": 0, # only return the logprob of the generated token
|
"logprobs": 0, # only return the logprob of the generated token
|
||||||
}
|
}
|
||||||
@ -1319,7 +1269,17 @@ class GRPOTrainer(BaseTrainer):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
# Regular generation path
|
# Regular generation path
|
||||||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
generate_inputs = self.processing_class(
|
||||||
|
text=prompts_text,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
padding_side="left",
|
||||||
|
max_length=self.max_prompt_length,
|
||||||
|
truncation=True,
|
||||||
|
add_special_tokens=False,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
generate_inputs = super()._prepare_inputs(generate_inputs)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
profiling_context(self, "transformers.generate"),
|
profiling_context(self, "transformers.generate"),
|
||||||
@ -1330,15 +1290,11 @@ class GRPOTrainer(BaseTrainer):
|
|||||||
FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
|
FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
|
||||||
):
|
):
|
||||||
prompt_completion_ids = unwrapped_model.generate(
|
prompt_completion_ids = unwrapped_model.generate(
|
||||||
input_ids=prompt_ids,
|
**generate_inputs, generation_config=self.generation_config, disable_compile=True
|
||||||
attention_mask=prompt_mask,
|
|
||||||
**forward_kwargs,
|
|
||||||
generation_config=self.generation_config,
|
|
||||||
disable_compile=True,
|
|
||||||
)
|
)
|
||||||
# Compute prompt length and extract completion ids
|
# Compute prompt length and extract completion ids
|
||||||
|
prompt_ids, prompt_mask = generate_inputs["input_ids"], generate_inputs["attention_mask"]
|
||||||
prompt_length = prompt_ids.size(1)
|
prompt_length = prompt_ids.size(1)
|
||||||
prompt_ids = prompt_completion_ids[:, :prompt_length]
|
|
||||||
completion_ids = prompt_completion_ids[:, prompt_length:]
|
completion_ids = prompt_completion_ids[:, prompt_length:]
|
||||||
|
|
||||||
# Mask everything after the first EOS token
|
# Mask everything after the first EOS token
|
||||||
|
@ -14,7 +14,6 @@
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import textwrap
|
import textwrap
|
||||||
import warnings
|
import warnings
|
||||||
from collections import defaultdict, deque
|
from collections import defaultdict, deque
|
||||||
@ -71,7 +70,6 @@ from .utils import (
|
|||||||
shuffle_sequence_dict,
|
shuffle_sequence_dict,
|
||||||
split_pixel_values_by_grid,
|
split_pixel_values_by_grid,
|
||||||
split_tensor_dict,
|
split_tensor_dict,
|
||||||
truncate_with_protected_tokens,
|
|
||||||
unsplit_pixel_values_by_grid,
|
unsplit_pixel_values_by_grid,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -394,7 +392,7 @@ class RLOOTrainer(BaseTrainer):
|
|||||||
|
|
||||||
# Processing class
|
# Processing class
|
||||||
if processing_class is None:
|
if processing_class is None:
|
||||||
processing_class = AutoProcessor.from_pretrained(model.config._name_or_path)
|
processing_class = AutoProcessor.from_pretrained(model.config._name_or_path, truncation_side="left")
|
||||||
|
|
||||||
# Handle pad token for processors or tokenizers
|
# Handle pad token for processors or tokenizers
|
||||||
if isinstance(processing_class, ProcessorMixin):
|
if isinstance(processing_class, ProcessorMixin):
|
||||||
@ -410,10 +408,6 @@ class RLOOTrainer(BaseTrainer):
|
|||||||
self.pad_token = tokenizer.pad_token
|
self.pad_token = tokenizer.pad_token
|
||||||
self.pad_token_id = tokenizer.pad_token_id
|
self.pad_token_id = tokenizer.pad_token_id
|
||||||
self.eos_token_id = tokenizer.eos_token_id
|
self.eos_token_id = tokenizer.eos_token_id
|
||||||
self.image_token = getattr(processing_class, "image_token", None)
|
|
||||||
self.image_token_id = getattr(processing_class, "image_token_id", None)
|
|
||||||
self.vision_start_token_id = getattr(model.config, "vision_start_token_id", None)
|
|
||||||
self.vision_end_token_id = getattr(model.config, "vision_end_token_id", None)
|
|
||||||
|
|
||||||
# Reward functions
|
# Reward functions
|
||||||
if not isinstance(reward_funcs, list):
|
if not isinstance(reward_funcs, list):
|
||||||
@ -1088,58 +1082,12 @@ class RLOOTrainer(BaseTrainer):
|
|||||||
maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
|
maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
|
||||||
]
|
]
|
||||||
|
|
||||||
prompt_inputs = self.processing_class(
|
if images is not None:
|
||||||
text=prompts_text,
|
prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs)
|
||||||
return_tensors="pt",
|
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
||||||
padding=True,
|
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
||||||
padding_side="left",
|
else:
|
||||||
add_special_tokens=False,
|
forward_kwargs = {}
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
|
||||||
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
|
||||||
|
|
||||||
if self.max_prompt_length is not None:
|
|
||||||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
|
||||||
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
|
|
||||||
|
|
||||||
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
|
|
||||||
# Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special
|
|
||||||
# tokens are needed for generation.
|
|
||||||
protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id]
|
|
||||||
protected = [token for token in protected if token is not None]
|
|
||||||
prompt_ids = [truncate_with_protected_tokens(ids, self.max_prompt_length, protected) for ids in prompt_ids]
|
|
||||||
|
|
||||||
prompts_text = self.processing_class.batch_decode(
|
|
||||||
prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# The chat template sometimes inserts a single image token into the prompt text. However, when this text is
|
|
||||||
# later tokenized, the single image token string is expanded into multiple image token IDs, depending on the
|
|
||||||
# image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We
|
|
||||||
# collapse them back into a single token string to match the original chat template in case it originally
|
|
||||||
# applies it. Otherwise, it assumes that the chat template uses only vision_start_token_id to indicate images
|
|
||||||
# (e.g. Gemma 3) and removes all image_token instances and vision_end_token_id as well, leaving only
|
|
||||||
# the vision_start_token_id (e.g. <start_of_image>).
|
|
||||||
if self.image_token is not None:
|
|
||||||
escaped_img_token = re.escape(self.image_token)
|
|
||||||
# Search for the image token in the chat template
|
|
||||||
if re.search(escaped_img_token, self.processing_class.chat_template):
|
|
||||||
prompts_text = [
|
|
||||||
re.sub(rf"({escaped_img_token})+", self.image_token, text) for text in prompts_text
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
# If the chat template doesn't use the image token, we remove all instances of it + vision_end_token_id
|
|
||||||
if self.vision_end_token_id is not None:
|
|
||||||
escaped_eoi_token = re.escape(
|
|
||||||
self.processing_class.tokenizer.decode([self.vision_end_token_id])
|
|
||||||
)
|
|
||||||
prompts_text = [
|
|
||||||
re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
# If vision_end_token_id is None, just remove the image tokens
|
|
||||||
prompts_text = [re.sub(rf"({escaped_img_token})+", "", text) for text in prompts_text]
|
|
||||||
|
|
||||||
# Generate completions using either vLLM or regular generation
|
# Generate completions using either vLLM or regular generation
|
||||||
if self.use_vllm:
|
if self.use_vllm:
|
||||||
@ -1181,6 +1129,7 @@ class RLOOTrainer(BaseTrainer):
|
|||||||
top_k=-1 if self.top_k is None else self.top_k,
|
top_k=-1 if self.top_k is None else self.top_k,
|
||||||
min_p=0.0 if self.min_p is None else self.min_p,
|
min_p=0.0 if self.min_p is None else self.min_p,
|
||||||
max_tokens=self.max_completion_length,
|
max_tokens=self.max_completion_length,
|
||||||
|
truncate_prompt_tokens=self.max_prompt_length,
|
||||||
guided_decoding_regex=self.guided_decoding_regex,
|
guided_decoding_regex=self.guided_decoding_regex,
|
||||||
generation_kwargs=self.args.generation_kwargs,
|
generation_kwargs=self.args.generation_kwargs,
|
||||||
)
|
)
|
||||||
@ -1218,6 +1167,7 @@ class RLOOTrainer(BaseTrainer):
|
|||||||
"top_k": -1 if self.top_k is None else self.top_k,
|
"top_k": -1 if self.top_k is None else self.top_k,
|
||||||
"min_p": 0.0 if self.min_p is None else self.min_p,
|
"min_p": 0.0 if self.min_p is None else self.min_p,
|
||||||
"max_tokens": self.max_completion_length,
|
"max_tokens": self.max_completion_length,
|
||||||
|
"truncate_prompt_tokens": self.max_prompt_length,
|
||||||
"guided_decoding": guided_decoding,
|
"guided_decoding": guided_decoding,
|
||||||
}
|
}
|
||||||
if self.args.generation_kwargs is not None:
|
if self.args.generation_kwargs is not None:
|
||||||
@ -1305,7 +1255,17 @@ class RLOOTrainer(BaseTrainer):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
# Regular generation path
|
# Regular generation path
|
||||||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
generate_inputs = self.processing_class(
|
||||||
|
text=prompts_text,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
padding_side="left",
|
||||||
|
max_length=self.max_prompt_length,
|
||||||
|
truncation=True,
|
||||||
|
add_special_tokens=False,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
generate_inputs = super()._prepare_inputs(generate_inputs)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
profiling_context(self, "transformers.generate"),
|
profiling_context(self, "transformers.generate"),
|
||||||
@ -1316,15 +1276,11 @@ class RLOOTrainer(BaseTrainer):
|
|||||||
FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
|
FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
|
||||||
):
|
):
|
||||||
prompt_completion_ids = unwrapped_model.generate(
|
prompt_completion_ids = unwrapped_model.generate(
|
||||||
input_ids=prompt_ids,
|
**generate_inputs, generation_config=self.generation_config, disable_compile=True
|
||||||
attention_mask=prompt_mask,
|
|
||||||
**forward_kwargs,
|
|
||||||
generation_config=self.generation_config,
|
|
||||||
disable_compile=True,
|
|
||||||
)
|
)
|
||||||
# Compute prompt length and extract completion ids
|
# Compute prompt length and extract completion ids
|
||||||
|
prompt_ids, prompt_mask = generate_inputs["input_ids"], generate_inputs["attention_mask"]
|
||||||
prompt_length = prompt_ids.size(1)
|
prompt_length = prompt_ids.size(1)
|
||||||
prompt_ids = prompt_completion_ids[:, :prompt_length]
|
|
||||||
completion_ids = prompt_completion_ids[:, prompt_length:]
|
completion_ids = prompt_completion_ids[:, prompt_length:]
|
||||||
|
|
||||||
# Mask everything after the first EOS token
|
# Mask everything after the first EOS token
|
||||||
|
@ -1925,47 +1925,6 @@ def unsplit_pixel_values_by_grid(batch: dict[str, Union[torch.Tensor, list[torch
|
|||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
def truncate_with_protected_tokens(ids: list[int], target_length: int, protected_tokens: list[int]) -> list[int]:
|
|
||||||
"""
|
|
||||||
Truncate list to target length while preserving protected tokens.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ids (`list[int]`):
|
|
||||||
Input sequence of token IDs.
|
|
||||||
target_length (`int`):
|
|
||||||
Desired length of the output sequence.
|
|
||||||
protected_tokens (`list[int]`):
|
|
||||||
List of token IDs that should be preserved in the output.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`list[int]`: Truncated sequence.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
`ValueError`: If `len(protected_tokens ∩ seq) > target_length`.
|
|
||||||
"""
|
|
||||||
protected_set = set(protected_tokens)
|
|
||||||
|
|
||||||
# Count protected tokens
|
|
||||||
num_protected = sum(1 for t in ids if t in protected_set)
|
|
||||||
if num_protected > target_length:
|
|
||||||
raise ValueError(
|
|
||||||
f"target_length ({target_length}) is too small for the protected tokens ({num_protected} tokens). "
|
|
||||||
f"Please increase target length to at least {num_protected} or disable truncation."
|
|
||||||
)
|
|
||||||
num_non_protected_needed = target_length - num_protected
|
|
||||||
result = []
|
|
||||||
|
|
||||||
# Iterate backward to select all protected tokens and rightmost non-protected tokens
|
|
||||||
for t in reversed(ids):
|
|
||||||
if t in protected_set:
|
|
||||||
result.append(t)
|
|
||||||
elif num_non_protected_needed > 0:
|
|
||||||
result.append(t)
|
|
||||||
num_non_protected_needed -= 1
|
|
||||||
# Reverse to restore original order
|
|
||||||
return result[::-1]
|
|
||||||
|
|
||||||
|
|
||||||
TListOrMapping = TypeVar("TListOrMapping", list, Mapping)
|
TListOrMapping = TypeVar("TListOrMapping", list, Mapping)
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user