🧺 [3/N] Refactor _generate in GRPO/RLOO: Rely on generator for prompt truncation (#4153)

This commit is contained in:
Quentin Gallouédec
2025-10-10 10:02:11 -05:00
committed by GitHub
parent 98488e0946
commit 0e57b4a9df
8 changed files with 55 additions and 334 deletions

View File

@ -1471,47 +1471,6 @@ class TestGRPOTrainer(TrlTestCase):
new_param = trainer.model.get_parameter(n) new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed." assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
@require_vision
def test_training_vlm_and_prompt_truncation(self):
# If not handled properly, prompt truncation may truncate image token
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
def reward_func(completions, **kwargs):
"""Reward function that rewards longer completions."""
return [float(len(completion[0]["content"])) for completion in completions]
training_args = GRPOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=8, # reduce the completion length to reduce memory usage
max_prompt_length=18,
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
# Because of the way the tiny models are initialized, the gradient does not flow properly through the
# vision parts of the model, so we skip them. Ideally, we should fix the init of these models.
params_to_skip = ("model.visual.",)
for n, param in previous_trainable_params.items():
if n.startswith(params_to_skip):
continue
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
@parameterized.expand( @parameterized.expand(
[ [
("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",), ("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",),

View File

@ -1212,47 +1212,6 @@ class TestRLOOTrainer(TrlTestCase):
elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer) elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer)
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed." assert not torch.allclose(param, new_param), f"Parameter {n} has not changed."
@require_vision
def test_training_vlm_and_prompt_truncation(self):
# If not handled properly, prompt truncation may truncate image token
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
def reward_func(completions, **kwargs):
"""Reward function that rewards longer completions."""
return [float(len(completion[0]["content"])) for completion in completions]
training_args = RLOOConfig(
output_dir=self.tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=8, # reduce the completion length to reduce memory usage
max_prompt_length=18,
report_to="none",
)
trainer = RLOOTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
reward_funcs=reward_func,
args=training_args,
train_dataset=dataset,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
assert trainer.state.log_history[-1]["train_loss"] is not None
# Check that the params have changed
# Because of the way the tiny models are initialized, the gradient does not flow properly through the
# vision parts of the model, so we skip them. Ideally, we should fix the init of these models.
params_to_skip = ("model.visual.",)
for n, param in previous_trainable_params.items():
if n.startswith(params_to_skip):
continue
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
@parameterized.expand( @parameterized.expand(
[ [
("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",), ("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",),

View File

@ -42,7 +42,6 @@ from trl.trainer.utils import (
shuffle_sequence_dict, shuffle_sequence_dict,
split_pixel_values_by_grid, split_pixel_values_by_grid,
split_tensor_dict, split_tensor_dict,
truncate_with_protected_tokens,
unsplit_pixel_values_by_grid, unsplit_pixel_values_by_grid,
) )
@ -1009,84 +1008,6 @@ class TestSplitPixelValuesByGrid(TrlTestCase):
assert torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2], [1, 2, 1]])) assert torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2], [1, 2, 1]]))
class TestTruncateWithProtectedTokens(TrlTestCase):
def test_basic_example(self):
"""Test the basic example from the problem description."""
prompt_ids = [1, 2, 3, 4, 5]
protected_tokens = [2, 3]
target_length = 3
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
expected_ids = [2, 3, 5]
assert new_ids == expected_ids
def test_no_truncation_needed(self):
"""Test when target length equals current length."""
prompt_ids = [1, 2, 3]
protected_tokens = [2]
target_length = 3
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
assert new_ids == prompt_ids
def test_no_protected_tokens(self):
"""Test truncation with no protected tokens (normal right truncation)."""
prompt_ids = [1, 2, 3, 4, 5]
protected_tokens = []
target_length = 3
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
expected_ids = [3, 4, 5] # Last 3 tokens
assert new_ids == expected_ids
def test_all_tokens_protected(self):
"""Test when all remaining tokens are protected."""
prompt_ids = [1, 2, 3, 4, 5]
protected_tokens = [3, 4, 5]
target_length = 3
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
expected_ids = [3, 4, 5]
assert new_ids == expected_ids
def test_too_many_protected_tokens(self):
"""Test error when too many protected tokens for target length."""
prompt_ids = [1, 2, 3, 4, 5]
protected_tokens = [1, 2, 3, 4]
target_length = 3
with pytest.raises(ValueError):
truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
def test_single_batch_single_token(self):
"""Test edge case with single batch and single token."""
prompt_ids = [5]
protected_tokens = [5]
target_length = 1
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
assert new_ids == prompt_ids
def test_order_preservation(self):
"""Test that relative order is preserved."""
prompt_ids = [10, 2, 20, 3, 30, 40]
protected_tokens = [2, 3]
target_length = 4
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
# Should keep protected tokens 2, 3 and last 2 non-protected tokens 30, 40
# Order should be: 2, 3, 30, 40 (maintaining original relative positions)
expected_ids = [2, 3, 30, 40]
assert new_ids == expected_ids
class TestUnsplitPixelValuesByGrid(TrlTestCase): class TestUnsplitPixelValuesByGrid(TrlTestCase):
def test_unsplit_correctly(self): def test_unsplit_correctly(self):
pixel_values = [torch.randn(4, 5), torch.randn(2, 5)] pixel_values = [torch.randn(4, 5), torch.randn(2, 5)]

View File

@ -182,6 +182,7 @@ class VLLMClient:
top_k: int = -1, top_k: int = -1,
min_p: float = 0.0, min_p: float = 0.0,
max_tokens: int = 16, max_tokens: int = 16,
truncate_prompt_tokens: Optional[int] = None,
guided_decoding_regex: Optional[str] = None, guided_decoding_regex: Optional[str] = None,
generation_kwargs: Optional[dict] = None, generation_kwargs: Optional[dict] = None,
) -> list[list[int]]: ) -> list[list[int]]:
@ -207,6 +208,10 @@ class VLLMClient:
Minimum probability for sampling. Minimum probability for sampling.
max_tokens (`int`, *optional*, defaults to `16`): max_tokens (`int`, *optional*, defaults to `16`):
Maximum number of tokens to generate for each prompt. Maximum number of tokens to generate for each prompt.
truncate_prompt_tokens (`int`, *optional*):
If set to `-1`, will use the truncation size supported by the model. If set to an integer k, will use
only the last k tokens from the prompt (i.e., left truncation). If set to `None`, truncation is
disabled.
guided_decoding_regex (`str`, *optional*): guided_decoding_regex (`str`, *optional*):
Regular expression to guide the decoding process. Regular expression to guide the decoding process.
generation_kwargs (`dict`, *optional*): generation_kwargs (`dict`, *optional*):
@ -246,6 +251,7 @@ class VLLMClient:
"top_k": top_k, "top_k": top_k,
"min_p": min_p, "min_p": min_p,
"max_tokens": max_tokens, "max_tokens": max_tokens,
"truncate_prompt_tokens": truncate_prompt_tokens,
"guided_decoding_regex": guided_decoding_regex, "guided_decoding_regex": guided_decoding_regex,
"generation_kwargs": generation_kwargs or {}, "generation_kwargs": generation_kwargs or {},
}, },

View File

@ -495,6 +495,7 @@ def main(script_args: ScriptArguments):
top_k: int = -1 top_k: int = -1
min_p: float = 0.0 min_p: float = 0.0
max_tokens: int = 16 max_tokens: int = 16
truncate_prompt_tokens: Optional[int] = None
guided_decoding_regex: Optional[str] = None guided_decoding_regex: Optional[str] = None
generation_kwargs: dict = field(default_factory=dict) generation_kwargs: dict = field(default_factory=dict)
@ -525,6 +526,9 @@ def main(script_args: ScriptArguments):
- `min_p` (`float`, *optional*, defaults to `0.0`): Minimum probability threshold for sampling. - `min_p` (`float`, *optional*, defaults to `0.0`): Minimum probability threshold for sampling.
- `max_tokens` (`int`, *optional*, defaults to `16`): Maximum number of tokens to generate for each - `max_tokens` (`int`, *optional*, defaults to `16`): Maximum number of tokens to generate for each
completion. completion.
- `truncate_prompt_tokens` (`int`, *optional*): If set to `-1`, will use the truncation size supported
by the model. If set to an integer k, will use only the last k tokens from the prompt (i.e., left
truncation). If set to `None`, truncation is disabled.
- `guided_decoding_regex` (`str`, *optional*): A regex pattern for guided decoding. If provided, the - `guided_decoding_regex` (`str`, *optional*): A regex pattern for guided decoding. If provided, the
model will only generate tokens that match this regex pattern. model will only generate tokens that match this regex pattern.
- `generation_kwargs` (`dict`, *optional*): Additional generation parameters to pass to the vLLM - `generation_kwargs` (`dict`, *optional*): Additional generation parameters to pass to the vLLM
@ -575,6 +579,7 @@ def main(script_args: ScriptArguments):
"top_k": request.top_k, "top_k": request.top_k,
"min_p": request.min_p, "min_p": request.min_p,
"max_tokens": request.max_tokens, "max_tokens": request.max_tokens,
"truncate_prompt_tokens": request.truncate_prompt_tokens,
"guided_decoding": guided_decoding, "guided_decoding": guided_decoding,
"logprobs": 0, "logprobs": 0,
} }

View File

@ -14,7 +14,6 @@
import inspect import inspect
import os import os
import re
import textwrap import textwrap
from collections import defaultdict, deque from collections import defaultdict, deque
from contextlib import nullcontext from contextlib import nullcontext
@ -71,7 +70,6 @@ from .utils import (
shuffle_sequence_dict, shuffle_sequence_dict,
split_pixel_values_by_grid, split_pixel_values_by_grid,
split_tensor_dict, split_tensor_dict,
truncate_with_protected_tokens,
unsplit_pixel_values_by_grid, unsplit_pixel_values_by_grid,
) )
@ -275,7 +273,7 @@ class GRPOTrainer(BaseTrainer):
# Processing class # Processing class
if processing_class is None: if processing_class is None:
processing_class = AutoProcessor.from_pretrained(model.config._name_or_path) processing_class = AutoProcessor.from_pretrained(model.config._name_or_path, truncation_side="left")
# Handle pad token for processors or tokenizers # Handle pad token for processors or tokenizers
if isinstance(processing_class, ProcessorMixin): if isinstance(processing_class, ProcessorMixin):
@ -291,10 +289,6 @@ class GRPOTrainer(BaseTrainer):
self.pad_token = tokenizer.pad_token self.pad_token = tokenizer.pad_token
self.pad_token_id = tokenizer.pad_token_id self.pad_token_id = tokenizer.pad_token_id
self.eos_token_id = tokenizer.eos_token_id self.eos_token_id = tokenizer.eos_token_id
self.image_token = getattr(processing_class, "image_token", None)
self.image_token_id = getattr(processing_class, "image_token_id", None)
self.vision_start_token_id = getattr(model.config, "vision_start_token_id", None)
self.vision_end_token_id = getattr(model.config, "vision_end_token_id", None)
# Reward functions # Reward functions
if not isinstance(reward_funcs, list): if not isinstance(reward_funcs, list):
@ -1092,58 +1086,12 @@ class GRPOTrainer(BaseTrainer):
maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
] ]
prompt_inputs = self.processing_class( if images is not None:
text=prompts_text, prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs)
return_tensors="pt", prompt_inputs = super()._prepare_inputs(prompt_inputs)
padding=True, forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
padding_side="left", else:
add_special_tokens=False, forward_kwargs = {}
**kwargs,
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
if self.max_prompt_length is not None:
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
# Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special
# tokens are needed for generation.
protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id]
protected = [token for token in protected if token is not None]
prompt_ids = [truncate_with_protected_tokens(ids, self.max_prompt_length, protected) for ids in prompt_ids]
prompts_text = self.processing_class.batch_decode(
prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)
# The chat template sometimes inserts a single image token into the prompt text. However, when this text is
# later tokenized, the single image token string is expanded into multiple image token IDs, depending on the
# image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We
# collapse them back into a single token string to match the original chat template in case it originally
# applies it. Otherwise, it assumes that the chat template uses only vision_start_token_id to indicate images
# (e.g. Gemma 3) and removes all image_token instances and vision_end_token_id as well, leaving only
# the vision_start_token_id (e.g. <start_of_image>).
if self.image_token is not None:
escaped_img_token = re.escape(self.image_token)
# Search for the image token in the chat template
if re.search(escaped_img_token, self.processing_class.chat_template):
prompts_text = [
re.sub(rf"({escaped_img_token})+", self.image_token, text) for text in prompts_text
]
else:
# If the chat template doesn't use the image token, we remove all instances of it + vision_end_token_id
if self.vision_end_token_id is not None:
escaped_eoi_token = re.escape(
self.processing_class.tokenizer.decode([self.vision_end_token_id])
)
prompts_text = [
re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text
]
else:
# If vision_end_token_id is None, just remove the image tokens
prompts_text = [re.sub(rf"({escaped_img_token})+", "", text) for text in prompts_text]
# Generate completions using either vLLM or regular generation # Generate completions using either vLLM or regular generation
if self.use_vllm: if self.use_vllm:
@ -1185,6 +1133,7 @@ class GRPOTrainer(BaseTrainer):
top_k=-1 if self.top_k is None else self.top_k, top_k=-1 if self.top_k is None else self.top_k,
min_p=0.0 if self.min_p is None else self.min_p, min_p=0.0 if self.min_p is None else self.min_p,
max_tokens=self.max_completion_length, max_tokens=self.max_completion_length,
truncate_prompt_tokens=self.max_prompt_length,
guided_decoding_regex=self.guided_decoding_regex, guided_decoding_regex=self.guided_decoding_regex,
generation_kwargs=self.args.generation_kwargs, generation_kwargs=self.args.generation_kwargs,
) )
@ -1223,6 +1172,7 @@ class GRPOTrainer(BaseTrainer):
"top_k": -1 if self.top_k is None else self.top_k, "top_k": -1 if self.top_k is None else self.top_k,
"min_p": 0.0 if self.min_p is None else self.min_p, "min_p": 0.0 if self.min_p is None else self.min_p,
"max_tokens": self.max_completion_length, "max_tokens": self.max_completion_length,
"truncate_prompt_tokens": self.max_prompt_length,
"guided_decoding": guided_decoding, "guided_decoding": guided_decoding,
"logprobs": 0, # only return the logprob of the generated token "logprobs": 0, # only return the logprob of the generated token
} }
@ -1319,7 +1269,17 @@ class GRPOTrainer(BaseTrainer):
else: else:
# Regular generation path # Regular generation path
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] generate_inputs = self.processing_class(
text=prompts_text,
return_tensors="pt",
padding=True,
padding_side="left",
max_length=self.max_prompt_length,
truncation=True,
add_special_tokens=False,
**kwargs,
)
generate_inputs = super()._prepare_inputs(generate_inputs)
with ( with (
profiling_context(self, "transformers.generate"), profiling_context(self, "transformers.generate"),
@ -1330,15 +1290,11 @@ class GRPOTrainer(BaseTrainer):
FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
): ):
prompt_completion_ids = unwrapped_model.generate( prompt_completion_ids = unwrapped_model.generate(
input_ids=prompt_ids, **generate_inputs, generation_config=self.generation_config, disable_compile=True
attention_mask=prompt_mask,
**forward_kwargs,
generation_config=self.generation_config,
disable_compile=True,
) )
# Compute prompt length and extract completion ids # Compute prompt length and extract completion ids
prompt_ids, prompt_mask = generate_inputs["input_ids"], generate_inputs["attention_mask"]
prompt_length = prompt_ids.size(1) prompt_length = prompt_ids.size(1)
prompt_ids = prompt_completion_ids[:, :prompt_length]
completion_ids = prompt_completion_ids[:, prompt_length:] completion_ids = prompt_completion_ids[:, prompt_length:]
# Mask everything after the first EOS token # Mask everything after the first EOS token

View File

@ -14,7 +14,6 @@
import inspect import inspect
import os import os
import re
import textwrap import textwrap
import warnings import warnings
from collections import defaultdict, deque from collections import defaultdict, deque
@ -71,7 +70,6 @@ from .utils import (
shuffle_sequence_dict, shuffle_sequence_dict,
split_pixel_values_by_grid, split_pixel_values_by_grid,
split_tensor_dict, split_tensor_dict,
truncate_with_protected_tokens,
unsplit_pixel_values_by_grid, unsplit_pixel_values_by_grid,
) )
@ -394,7 +392,7 @@ class RLOOTrainer(BaseTrainer):
# Processing class # Processing class
if processing_class is None: if processing_class is None:
processing_class = AutoProcessor.from_pretrained(model.config._name_or_path) processing_class = AutoProcessor.from_pretrained(model.config._name_or_path, truncation_side="left")
# Handle pad token for processors or tokenizers # Handle pad token for processors or tokenizers
if isinstance(processing_class, ProcessorMixin): if isinstance(processing_class, ProcessorMixin):
@ -410,10 +408,6 @@ class RLOOTrainer(BaseTrainer):
self.pad_token = tokenizer.pad_token self.pad_token = tokenizer.pad_token
self.pad_token_id = tokenizer.pad_token_id self.pad_token_id = tokenizer.pad_token_id
self.eos_token_id = tokenizer.eos_token_id self.eos_token_id = tokenizer.eos_token_id
self.image_token = getattr(processing_class, "image_token", None)
self.image_token_id = getattr(processing_class, "image_token_id", None)
self.vision_start_token_id = getattr(model.config, "vision_start_token_id", None)
self.vision_end_token_id = getattr(model.config, "vision_end_token_id", None)
# Reward functions # Reward functions
if not isinstance(reward_funcs, list): if not isinstance(reward_funcs, list):
@ -1088,58 +1082,12 @@ class RLOOTrainer(BaseTrainer):
maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
] ]
prompt_inputs = self.processing_class( if images is not None:
text=prompts_text, prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs)
return_tensors="pt", prompt_inputs = super()._prepare_inputs(prompt_inputs)
padding=True, forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
padding_side="left", else:
add_special_tokens=False, forward_kwargs = {}
**kwargs,
)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
if self.max_prompt_length is not None:
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
# Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special
# tokens are needed for generation.
protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id]
protected = [token for token in protected if token is not None]
prompt_ids = [truncate_with_protected_tokens(ids, self.max_prompt_length, protected) for ids in prompt_ids]
prompts_text = self.processing_class.batch_decode(
prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)
# The chat template sometimes inserts a single image token into the prompt text. However, when this text is
# later tokenized, the single image token string is expanded into multiple image token IDs, depending on the
# image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We
# collapse them back into a single token string to match the original chat template in case it originally
# applies it. Otherwise, it assumes that the chat template uses only vision_start_token_id to indicate images
# (e.g. Gemma 3) and removes all image_token instances and vision_end_token_id as well, leaving only
# the vision_start_token_id (e.g. <start_of_image>).
if self.image_token is not None:
escaped_img_token = re.escape(self.image_token)
# Search for the image token in the chat template
if re.search(escaped_img_token, self.processing_class.chat_template):
prompts_text = [
re.sub(rf"({escaped_img_token})+", self.image_token, text) for text in prompts_text
]
else:
# If the chat template doesn't use the image token, we remove all instances of it + vision_end_token_id
if self.vision_end_token_id is not None:
escaped_eoi_token = re.escape(
self.processing_class.tokenizer.decode([self.vision_end_token_id])
)
prompts_text = [
re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text
]
else:
# If vision_end_token_id is None, just remove the image tokens
prompts_text = [re.sub(rf"({escaped_img_token})+", "", text) for text in prompts_text]
# Generate completions using either vLLM or regular generation # Generate completions using either vLLM or regular generation
if self.use_vllm: if self.use_vllm:
@ -1181,6 +1129,7 @@ class RLOOTrainer(BaseTrainer):
top_k=-1 if self.top_k is None else self.top_k, top_k=-1 if self.top_k is None else self.top_k,
min_p=0.0 if self.min_p is None else self.min_p, min_p=0.0 if self.min_p is None else self.min_p,
max_tokens=self.max_completion_length, max_tokens=self.max_completion_length,
truncate_prompt_tokens=self.max_prompt_length,
guided_decoding_regex=self.guided_decoding_regex, guided_decoding_regex=self.guided_decoding_regex,
generation_kwargs=self.args.generation_kwargs, generation_kwargs=self.args.generation_kwargs,
) )
@ -1218,6 +1167,7 @@ class RLOOTrainer(BaseTrainer):
"top_k": -1 if self.top_k is None else self.top_k, "top_k": -1 if self.top_k is None else self.top_k,
"min_p": 0.0 if self.min_p is None else self.min_p, "min_p": 0.0 if self.min_p is None else self.min_p,
"max_tokens": self.max_completion_length, "max_tokens": self.max_completion_length,
"truncate_prompt_tokens": self.max_prompt_length,
"guided_decoding": guided_decoding, "guided_decoding": guided_decoding,
} }
if self.args.generation_kwargs is not None: if self.args.generation_kwargs is not None:
@ -1305,7 +1255,17 @@ class RLOOTrainer(BaseTrainer):
else: else:
# Regular generation path # Regular generation path
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] generate_inputs = self.processing_class(
text=prompts_text,
return_tensors="pt",
padding=True,
padding_side="left",
max_length=self.max_prompt_length,
truncation=True,
add_special_tokens=False,
**kwargs,
)
generate_inputs = super()._prepare_inputs(generate_inputs)
with ( with (
profiling_context(self, "transformers.generate"), profiling_context(self, "transformers.generate"),
@ -1316,15 +1276,11 @@ class RLOOTrainer(BaseTrainer):
FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
): ):
prompt_completion_ids = unwrapped_model.generate( prompt_completion_ids = unwrapped_model.generate(
input_ids=prompt_ids, **generate_inputs, generation_config=self.generation_config, disable_compile=True
attention_mask=prompt_mask,
**forward_kwargs,
generation_config=self.generation_config,
disable_compile=True,
) )
# Compute prompt length and extract completion ids # Compute prompt length and extract completion ids
prompt_ids, prompt_mask = generate_inputs["input_ids"], generate_inputs["attention_mask"]
prompt_length = prompt_ids.size(1) prompt_length = prompt_ids.size(1)
prompt_ids = prompt_completion_ids[:, :prompt_length]
completion_ids = prompt_completion_ids[:, prompt_length:] completion_ids = prompt_completion_ids[:, prompt_length:]
# Mask everything after the first EOS token # Mask everything after the first EOS token

View File

@ -1925,47 +1925,6 @@ def unsplit_pixel_values_by_grid(batch: dict[str, Union[torch.Tensor, list[torch
return batch return batch
def truncate_with_protected_tokens(ids: list[int], target_length: int, protected_tokens: list[int]) -> list[int]:
"""
Truncate list to target length while preserving protected tokens.
Args:
ids (`list[int]`):
Input sequence of token IDs.
target_length (`int`):
Desired length of the output sequence.
protected_tokens (`list[int]`):
List of token IDs that should be preserved in the output.
Returns:
`list[int]`: Truncated sequence.
Raises:
`ValueError`: If `len(protected_tokens ∩ seq) > target_length`.
"""
protected_set = set(protected_tokens)
# Count protected tokens
num_protected = sum(1 for t in ids if t in protected_set)
if num_protected > target_length:
raise ValueError(
f"target_length ({target_length}) is too small for the protected tokens ({num_protected} tokens). "
f"Please increase target length to at least {num_protected} or disable truncation."
)
num_non_protected_needed = target_length - num_protected
result = []
# Iterate backward to select all protected tokens and rightmost non-protected tokens
for t in reversed(ids):
if t in protected_set:
result.append(t)
elif num_non_protected_needed > 0:
result.append(t)
num_non_protected_needed -= 1
# Reverse to restore original order
return result[::-1]
TListOrMapping = TypeVar("TListOrMapping", list, Mapping) TListOrMapping = TypeVar("TListOrMapping", list, Mapping)