mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 10:03:51 +08:00
🧺 [3/N] Refactor _generate
in GRPO/RLOO: Rely on generator for prompt truncation (#4153)
This commit is contained in:
committed by
GitHub
parent
98488e0946
commit
0e57b4a9df
@ -1471,47 +1471,6 @@ class TestGRPOTrainer(TrlTestCase):
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
|
||||
|
||||
@require_vision
|
||||
def test_training_vlm_and_prompt_truncation(self):
|
||||
# If not handled properly, prompt truncation may truncate image token
|
||||
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
|
||||
|
||||
def reward_func(completions, **kwargs):
|
||||
"""Reward function that rewards longer completions."""
|
||||
return [float(len(completion[0]["content"])) for completion in completions]
|
||||
|
||||
training_args = GRPOConfig(
|
||||
output_dir=self.tmp_dir,
|
||||
learning_rate=0.1, # increase the learning rate to speed up the test
|
||||
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
||||
num_generations=3, # reduce the number of generations to reduce memory usage
|
||||
max_completion_length=8, # reduce the completion length to reduce memory usage
|
||||
max_prompt_length=18,
|
||||
report_to="none",
|
||||
)
|
||||
trainer = GRPOTrainer(
|
||||
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
|
||||
reward_funcs=reward_func,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
|
||||
assert trainer.state.log_history[-1]["train_loss"] is not None
|
||||
|
||||
# Check that the params have changed
|
||||
# Because of the way the tiny models are initialized, the gradient does not flow properly through the
|
||||
# vision parts of the model, so we skip them. Ideally, we should fix the init of these models.
|
||||
params_to_skip = ("model.visual.",)
|
||||
for n, param in previous_trainable_params.items():
|
||||
if n.startswith(params_to_skip):
|
||||
continue
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",),
|
||||
|
@ -1212,47 +1212,6 @@ class TestRLOOTrainer(TrlTestCase):
|
||||
elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer)
|
||||
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed."
|
||||
|
||||
@require_vision
|
||||
def test_training_vlm_and_prompt_truncation(self):
|
||||
# If not handled properly, prompt truncation may truncate image token
|
||||
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
|
||||
|
||||
def reward_func(completions, **kwargs):
|
||||
"""Reward function that rewards longer completions."""
|
||||
return [float(len(completion[0]["content"])) for completion in completions]
|
||||
|
||||
training_args = RLOOConfig(
|
||||
output_dir=self.tmp_dir,
|
||||
learning_rate=0.1, # increase the learning rate to speed up the test
|
||||
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
||||
num_generations=3, # reduce the number of generations to reduce memory usage
|
||||
max_completion_length=8, # reduce the completion length to reduce memory usage
|
||||
max_prompt_length=18,
|
||||
report_to="none",
|
||||
)
|
||||
trainer = RLOOTrainer(
|
||||
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
|
||||
reward_funcs=reward_func,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
|
||||
assert trainer.state.log_history[-1]["train_loss"] is not None
|
||||
|
||||
# Check that the params have changed
|
||||
# Because of the way the tiny models are initialized, the gradient does not flow properly through the
|
||||
# vision parts of the model, so we skip them. Ideally, we should fix the init of these models.
|
||||
params_to_skip = ("model.visual.",)
|
||||
for n, param in previous_trainable_params.items():
|
||||
if n.startswith(params_to_skip):
|
||||
continue
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",),
|
||||
|
@ -42,7 +42,6 @@ from trl.trainer.utils import (
|
||||
shuffle_sequence_dict,
|
||||
split_pixel_values_by_grid,
|
||||
split_tensor_dict,
|
||||
truncate_with_protected_tokens,
|
||||
unsplit_pixel_values_by_grid,
|
||||
)
|
||||
|
||||
@ -1009,84 +1008,6 @@ class TestSplitPixelValuesByGrid(TrlTestCase):
|
||||
assert torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2], [1, 2, 1]]))
|
||||
|
||||
|
||||
class TestTruncateWithProtectedTokens(TrlTestCase):
|
||||
def test_basic_example(self):
|
||||
"""Test the basic example from the problem description."""
|
||||
prompt_ids = [1, 2, 3, 4, 5]
|
||||
protected_tokens = [2, 3]
|
||||
target_length = 3
|
||||
|
||||
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
|
||||
|
||||
expected_ids = [2, 3, 5]
|
||||
assert new_ids == expected_ids
|
||||
|
||||
def test_no_truncation_needed(self):
|
||||
"""Test when target length equals current length."""
|
||||
prompt_ids = [1, 2, 3]
|
||||
protected_tokens = [2]
|
||||
target_length = 3
|
||||
|
||||
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
|
||||
|
||||
assert new_ids == prompt_ids
|
||||
|
||||
def test_no_protected_tokens(self):
|
||||
"""Test truncation with no protected tokens (normal right truncation)."""
|
||||
prompt_ids = [1, 2, 3, 4, 5]
|
||||
protected_tokens = []
|
||||
target_length = 3
|
||||
|
||||
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
|
||||
|
||||
expected_ids = [3, 4, 5] # Last 3 tokens
|
||||
assert new_ids == expected_ids
|
||||
|
||||
def test_all_tokens_protected(self):
|
||||
"""Test when all remaining tokens are protected."""
|
||||
prompt_ids = [1, 2, 3, 4, 5]
|
||||
protected_tokens = [3, 4, 5]
|
||||
target_length = 3
|
||||
|
||||
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
|
||||
|
||||
expected_ids = [3, 4, 5]
|
||||
assert new_ids == expected_ids
|
||||
|
||||
def test_too_many_protected_tokens(self):
|
||||
"""Test error when too many protected tokens for target length."""
|
||||
prompt_ids = [1, 2, 3, 4, 5]
|
||||
protected_tokens = [1, 2, 3, 4]
|
||||
target_length = 3
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
|
||||
|
||||
def test_single_batch_single_token(self):
|
||||
"""Test edge case with single batch and single token."""
|
||||
prompt_ids = [5]
|
||||
protected_tokens = [5]
|
||||
target_length = 1
|
||||
|
||||
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
|
||||
|
||||
assert new_ids == prompt_ids
|
||||
|
||||
def test_order_preservation(self):
|
||||
"""Test that relative order is preserved."""
|
||||
prompt_ids = [10, 2, 20, 3, 30, 40]
|
||||
protected_tokens = [2, 3]
|
||||
target_length = 4
|
||||
|
||||
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
|
||||
|
||||
# Should keep protected tokens 2, 3 and last 2 non-protected tokens 30, 40
|
||||
# Order should be: 2, 3, 30, 40 (maintaining original relative positions)
|
||||
expected_ids = [2, 3, 30, 40]
|
||||
|
||||
assert new_ids == expected_ids
|
||||
|
||||
|
||||
class TestUnsplitPixelValuesByGrid(TrlTestCase):
|
||||
def test_unsplit_correctly(self):
|
||||
pixel_values = [torch.randn(4, 5), torch.randn(2, 5)]
|
||||
|
@ -182,6 +182,7 @@ class VLLMClient:
|
||||
top_k: int = -1,
|
||||
min_p: float = 0.0,
|
||||
max_tokens: int = 16,
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
guided_decoding_regex: Optional[str] = None,
|
||||
generation_kwargs: Optional[dict] = None,
|
||||
) -> list[list[int]]:
|
||||
@ -207,6 +208,10 @@ class VLLMClient:
|
||||
Minimum probability for sampling.
|
||||
max_tokens (`int`, *optional*, defaults to `16`):
|
||||
Maximum number of tokens to generate for each prompt.
|
||||
truncate_prompt_tokens (`int`, *optional*):
|
||||
If set to `-1`, will use the truncation size supported by the model. If set to an integer k, will use
|
||||
only the last k tokens from the prompt (i.e., left truncation). If set to `None`, truncation is
|
||||
disabled.
|
||||
guided_decoding_regex (`str`, *optional*):
|
||||
Regular expression to guide the decoding process.
|
||||
generation_kwargs (`dict`, *optional*):
|
||||
@ -246,6 +251,7 @@ class VLLMClient:
|
||||
"top_k": top_k,
|
||||
"min_p": min_p,
|
||||
"max_tokens": max_tokens,
|
||||
"truncate_prompt_tokens": truncate_prompt_tokens,
|
||||
"guided_decoding_regex": guided_decoding_regex,
|
||||
"generation_kwargs": generation_kwargs or {},
|
||||
},
|
||||
|
@ -495,6 +495,7 @@ def main(script_args: ScriptArguments):
|
||||
top_k: int = -1
|
||||
min_p: float = 0.0
|
||||
max_tokens: int = 16
|
||||
truncate_prompt_tokens: Optional[int] = None
|
||||
guided_decoding_regex: Optional[str] = None
|
||||
generation_kwargs: dict = field(default_factory=dict)
|
||||
|
||||
@ -525,6 +526,9 @@ def main(script_args: ScriptArguments):
|
||||
- `min_p` (`float`, *optional*, defaults to `0.0`): Minimum probability threshold for sampling.
|
||||
- `max_tokens` (`int`, *optional*, defaults to `16`): Maximum number of tokens to generate for each
|
||||
completion.
|
||||
- `truncate_prompt_tokens` (`int`, *optional*): If set to `-1`, will use the truncation size supported
|
||||
by the model. If set to an integer k, will use only the last k tokens from the prompt (i.e., left
|
||||
truncation). If set to `None`, truncation is disabled.
|
||||
- `guided_decoding_regex` (`str`, *optional*): A regex pattern for guided decoding. If provided, the
|
||||
model will only generate tokens that match this regex pattern.
|
||||
- `generation_kwargs` (`dict`, *optional*): Additional generation parameters to pass to the vLLM
|
||||
@ -575,6 +579,7 @@ def main(script_args: ScriptArguments):
|
||||
"top_k": request.top_k,
|
||||
"min_p": request.min_p,
|
||||
"max_tokens": request.max_tokens,
|
||||
"truncate_prompt_tokens": request.truncate_prompt_tokens,
|
||||
"guided_decoding": guided_decoding,
|
||||
"logprobs": 0,
|
||||
}
|
||||
|
@ -14,7 +14,6 @@
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import textwrap
|
||||
from collections import defaultdict, deque
|
||||
from contextlib import nullcontext
|
||||
@ -71,7 +70,6 @@ from .utils import (
|
||||
shuffle_sequence_dict,
|
||||
split_pixel_values_by_grid,
|
||||
split_tensor_dict,
|
||||
truncate_with_protected_tokens,
|
||||
unsplit_pixel_values_by_grid,
|
||||
)
|
||||
|
||||
@ -275,7 +273,7 @@ class GRPOTrainer(BaseTrainer):
|
||||
|
||||
# Processing class
|
||||
if processing_class is None:
|
||||
processing_class = AutoProcessor.from_pretrained(model.config._name_or_path)
|
||||
processing_class = AutoProcessor.from_pretrained(model.config._name_or_path, truncation_side="left")
|
||||
|
||||
# Handle pad token for processors or tokenizers
|
||||
if isinstance(processing_class, ProcessorMixin):
|
||||
@ -291,10 +289,6 @@ class GRPOTrainer(BaseTrainer):
|
||||
self.pad_token = tokenizer.pad_token
|
||||
self.pad_token_id = tokenizer.pad_token_id
|
||||
self.eos_token_id = tokenizer.eos_token_id
|
||||
self.image_token = getattr(processing_class, "image_token", None)
|
||||
self.image_token_id = getattr(processing_class, "image_token_id", None)
|
||||
self.vision_start_token_id = getattr(model.config, "vision_start_token_id", None)
|
||||
self.vision_end_token_id = getattr(model.config, "vision_end_token_id", None)
|
||||
|
||||
# Reward functions
|
||||
if not isinstance(reward_funcs, list):
|
||||
@ -1092,58 +1086,12 @@ class GRPOTrainer(BaseTrainer):
|
||||
maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
|
||||
]
|
||||
|
||||
prompt_inputs = self.processing_class(
|
||||
text=prompts_text,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
padding_side="left",
|
||||
add_special_tokens=False,
|
||||
**kwargs,
|
||||
)
|
||||
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
||||
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
||||
|
||||
if self.max_prompt_length is not None:
|
||||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
||||
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
|
||||
|
||||
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
|
||||
# Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special
|
||||
# tokens are needed for generation.
|
||||
protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id]
|
||||
protected = [token for token in protected if token is not None]
|
||||
prompt_ids = [truncate_with_protected_tokens(ids, self.max_prompt_length, protected) for ids in prompt_ids]
|
||||
|
||||
prompts_text = self.processing_class.batch_decode(
|
||||
prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
||||
)
|
||||
|
||||
# The chat template sometimes inserts a single image token into the prompt text. However, when this text is
|
||||
# later tokenized, the single image token string is expanded into multiple image token IDs, depending on the
|
||||
# image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We
|
||||
# collapse them back into a single token string to match the original chat template in case it originally
|
||||
# applies it. Otherwise, it assumes that the chat template uses only vision_start_token_id to indicate images
|
||||
# (e.g. Gemma 3) and removes all image_token instances and vision_end_token_id as well, leaving only
|
||||
# the vision_start_token_id (e.g. <start_of_image>).
|
||||
if self.image_token is not None:
|
||||
escaped_img_token = re.escape(self.image_token)
|
||||
# Search for the image token in the chat template
|
||||
if re.search(escaped_img_token, self.processing_class.chat_template):
|
||||
prompts_text = [
|
||||
re.sub(rf"({escaped_img_token})+", self.image_token, text) for text in prompts_text
|
||||
]
|
||||
else:
|
||||
# If the chat template doesn't use the image token, we remove all instances of it + vision_end_token_id
|
||||
if self.vision_end_token_id is not None:
|
||||
escaped_eoi_token = re.escape(
|
||||
self.processing_class.tokenizer.decode([self.vision_end_token_id])
|
||||
)
|
||||
prompts_text = [
|
||||
re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text
|
||||
]
|
||||
else:
|
||||
# If vision_end_token_id is None, just remove the image tokens
|
||||
prompts_text = [re.sub(rf"({escaped_img_token})+", "", text) for text in prompts_text]
|
||||
if images is not None:
|
||||
prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs)
|
||||
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
||||
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
||||
else:
|
||||
forward_kwargs = {}
|
||||
|
||||
# Generate completions using either vLLM or regular generation
|
||||
if self.use_vllm:
|
||||
@ -1185,6 +1133,7 @@ class GRPOTrainer(BaseTrainer):
|
||||
top_k=-1 if self.top_k is None else self.top_k,
|
||||
min_p=0.0 if self.min_p is None else self.min_p,
|
||||
max_tokens=self.max_completion_length,
|
||||
truncate_prompt_tokens=self.max_prompt_length,
|
||||
guided_decoding_regex=self.guided_decoding_regex,
|
||||
generation_kwargs=self.args.generation_kwargs,
|
||||
)
|
||||
@ -1223,6 +1172,7 @@ class GRPOTrainer(BaseTrainer):
|
||||
"top_k": -1 if self.top_k is None else self.top_k,
|
||||
"min_p": 0.0 if self.min_p is None else self.min_p,
|
||||
"max_tokens": self.max_completion_length,
|
||||
"truncate_prompt_tokens": self.max_prompt_length,
|
||||
"guided_decoding": guided_decoding,
|
||||
"logprobs": 0, # only return the logprob of the generated token
|
||||
}
|
||||
@ -1319,7 +1269,17 @@ class GRPOTrainer(BaseTrainer):
|
||||
|
||||
else:
|
||||
# Regular generation path
|
||||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
||||
generate_inputs = self.processing_class(
|
||||
text=prompts_text,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
padding_side="left",
|
||||
max_length=self.max_prompt_length,
|
||||
truncation=True,
|
||||
add_special_tokens=False,
|
||||
**kwargs,
|
||||
)
|
||||
generate_inputs = super()._prepare_inputs(generate_inputs)
|
||||
|
||||
with (
|
||||
profiling_context(self, "transformers.generate"),
|
||||
@ -1330,15 +1290,11 @@ class GRPOTrainer(BaseTrainer):
|
||||
FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
|
||||
):
|
||||
prompt_completion_ids = unwrapped_model.generate(
|
||||
input_ids=prompt_ids,
|
||||
attention_mask=prompt_mask,
|
||||
**forward_kwargs,
|
||||
generation_config=self.generation_config,
|
||||
disable_compile=True,
|
||||
**generate_inputs, generation_config=self.generation_config, disable_compile=True
|
||||
)
|
||||
# Compute prompt length and extract completion ids
|
||||
prompt_ids, prompt_mask = generate_inputs["input_ids"], generate_inputs["attention_mask"]
|
||||
prompt_length = prompt_ids.size(1)
|
||||
prompt_ids = prompt_completion_ids[:, :prompt_length]
|
||||
completion_ids = prompt_completion_ids[:, prompt_length:]
|
||||
|
||||
# Mask everything after the first EOS token
|
||||
|
@ -14,7 +14,6 @@
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import textwrap
|
||||
import warnings
|
||||
from collections import defaultdict, deque
|
||||
@ -71,7 +70,6 @@ from .utils import (
|
||||
shuffle_sequence_dict,
|
||||
split_pixel_values_by_grid,
|
||||
split_tensor_dict,
|
||||
truncate_with_protected_tokens,
|
||||
unsplit_pixel_values_by_grid,
|
||||
)
|
||||
|
||||
@ -394,7 +392,7 @@ class RLOOTrainer(BaseTrainer):
|
||||
|
||||
# Processing class
|
||||
if processing_class is None:
|
||||
processing_class = AutoProcessor.from_pretrained(model.config._name_or_path)
|
||||
processing_class = AutoProcessor.from_pretrained(model.config._name_or_path, truncation_side="left")
|
||||
|
||||
# Handle pad token for processors or tokenizers
|
||||
if isinstance(processing_class, ProcessorMixin):
|
||||
@ -410,10 +408,6 @@ class RLOOTrainer(BaseTrainer):
|
||||
self.pad_token = tokenizer.pad_token
|
||||
self.pad_token_id = tokenizer.pad_token_id
|
||||
self.eos_token_id = tokenizer.eos_token_id
|
||||
self.image_token = getattr(processing_class, "image_token", None)
|
||||
self.image_token_id = getattr(processing_class, "image_token_id", None)
|
||||
self.vision_start_token_id = getattr(model.config, "vision_start_token_id", None)
|
||||
self.vision_end_token_id = getattr(model.config, "vision_end_token_id", None)
|
||||
|
||||
# Reward functions
|
||||
if not isinstance(reward_funcs, list):
|
||||
@ -1088,58 +1082,12 @@ class RLOOTrainer(BaseTrainer):
|
||||
maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
|
||||
]
|
||||
|
||||
prompt_inputs = self.processing_class(
|
||||
text=prompts_text,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
padding_side="left",
|
||||
add_special_tokens=False,
|
||||
**kwargs,
|
||||
)
|
||||
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
||||
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
||||
|
||||
if self.max_prompt_length is not None:
|
||||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
||||
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
|
||||
|
||||
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
|
||||
# Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special
|
||||
# tokens are needed for generation.
|
||||
protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id]
|
||||
protected = [token for token in protected if token is not None]
|
||||
prompt_ids = [truncate_with_protected_tokens(ids, self.max_prompt_length, protected) for ids in prompt_ids]
|
||||
|
||||
prompts_text = self.processing_class.batch_decode(
|
||||
prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
||||
)
|
||||
|
||||
# The chat template sometimes inserts a single image token into the prompt text. However, when this text is
|
||||
# later tokenized, the single image token string is expanded into multiple image token IDs, depending on the
|
||||
# image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We
|
||||
# collapse them back into a single token string to match the original chat template in case it originally
|
||||
# applies it. Otherwise, it assumes that the chat template uses only vision_start_token_id to indicate images
|
||||
# (e.g. Gemma 3) and removes all image_token instances and vision_end_token_id as well, leaving only
|
||||
# the vision_start_token_id (e.g. <start_of_image>).
|
||||
if self.image_token is not None:
|
||||
escaped_img_token = re.escape(self.image_token)
|
||||
# Search for the image token in the chat template
|
||||
if re.search(escaped_img_token, self.processing_class.chat_template):
|
||||
prompts_text = [
|
||||
re.sub(rf"({escaped_img_token})+", self.image_token, text) for text in prompts_text
|
||||
]
|
||||
else:
|
||||
# If the chat template doesn't use the image token, we remove all instances of it + vision_end_token_id
|
||||
if self.vision_end_token_id is not None:
|
||||
escaped_eoi_token = re.escape(
|
||||
self.processing_class.tokenizer.decode([self.vision_end_token_id])
|
||||
)
|
||||
prompts_text = [
|
||||
re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text
|
||||
]
|
||||
else:
|
||||
# If vision_end_token_id is None, just remove the image tokens
|
||||
prompts_text = [re.sub(rf"({escaped_img_token})+", "", text) for text in prompts_text]
|
||||
if images is not None:
|
||||
prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs)
|
||||
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
||||
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
||||
else:
|
||||
forward_kwargs = {}
|
||||
|
||||
# Generate completions using either vLLM or regular generation
|
||||
if self.use_vllm:
|
||||
@ -1181,6 +1129,7 @@ class RLOOTrainer(BaseTrainer):
|
||||
top_k=-1 if self.top_k is None else self.top_k,
|
||||
min_p=0.0 if self.min_p is None else self.min_p,
|
||||
max_tokens=self.max_completion_length,
|
||||
truncate_prompt_tokens=self.max_prompt_length,
|
||||
guided_decoding_regex=self.guided_decoding_regex,
|
||||
generation_kwargs=self.args.generation_kwargs,
|
||||
)
|
||||
@ -1218,6 +1167,7 @@ class RLOOTrainer(BaseTrainer):
|
||||
"top_k": -1 if self.top_k is None else self.top_k,
|
||||
"min_p": 0.0 if self.min_p is None else self.min_p,
|
||||
"max_tokens": self.max_completion_length,
|
||||
"truncate_prompt_tokens": self.max_prompt_length,
|
||||
"guided_decoding": guided_decoding,
|
||||
}
|
||||
if self.args.generation_kwargs is not None:
|
||||
@ -1305,7 +1255,17 @@ class RLOOTrainer(BaseTrainer):
|
||||
|
||||
else:
|
||||
# Regular generation path
|
||||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
||||
generate_inputs = self.processing_class(
|
||||
text=prompts_text,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
padding_side="left",
|
||||
max_length=self.max_prompt_length,
|
||||
truncation=True,
|
||||
add_special_tokens=False,
|
||||
**kwargs,
|
||||
)
|
||||
generate_inputs = super()._prepare_inputs(generate_inputs)
|
||||
|
||||
with (
|
||||
profiling_context(self, "transformers.generate"),
|
||||
@ -1316,15 +1276,11 @@ class RLOOTrainer(BaseTrainer):
|
||||
FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
|
||||
):
|
||||
prompt_completion_ids = unwrapped_model.generate(
|
||||
input_ids=prompt_ids,
|
||||
attention_mask=prompt_mask,
|
||||
**forward_kwargs,
|
||||
generation_config=self.generation_config,
|
||||
disable_compile=True,
|
||||
**generate_inputs, generation_config=self.generation_config, disable_compile=True
|
||||
)
|
||||
# Compute prompt length and extract completion ids
|
||||
prompt_ids, prompt_mask = generate_inputs["input_ids"], generate_inputs["attention_mask"]
|
||||
prompt_length = prompt_ids.size(1)
|
||||
prompt_ids = prompt_completion_ids[:, :prompt_length]
|
||||
completion_ids = prompt_completion_ids[:, prompt_length:]
|
||||
|
||||
# Mask everything after the first EOS token
|
||||
|
@ -1925,47 +1925,6 @@ def unsplit_pixel_values_by_grid(batch: dict[str, Union[torch.Tensor, list[torch
|
||||
return batch
|
||||
|
||||
|
||||
def truncate_with_protected_tokens(ids: list[int], target_length: int, protected_tokens: list[int]) -> list[int]:
|
||||
"""
|
||||
Truncate list to target length while preserving protected tokens.
|
||||
|
||||
Args:
|
||||
ids (`list[int]`):
|
||||
Input sequence of token IDs.
|
||||
target_length (`int`):
|
||||
Desired length of the output sequence.
|
||||
protected_tokens (`list[int]`):
|
||||
List of token IDs that should be preserved in the output.
|
||||
|
||||
Returns:
|
||||
`list[int]`: Truncated sequence.
|
||||
|
||||
Raises:
|
||||
`ValueError`: If `len(protected_tokens ∩ seq) > target_length`.
|
||||
"""
|
||||
protected_set = set(protected_tokens)
|
||||
|
||||
# Count protected tokens
|
||||
num_protected = sum(1 for t in ids if t in protected_set)
|
||||
if num_protected > target_length:
|
||||
raise ValueError(
|
||||
f"target_length ({target_length}) is too small for the protected tokens ({num_protected} tokens). "
|
||||
f"Please increase target length to at least {num_protected} or disable truncation."
|
||||
)
|
||||
num_non_protected_needed = target_length - num_protected
|
||||
result = []
|
||||
|
||||
# Iterate backward to select all protected tokens and rightmost non-protected tokens
|
||||
for t in reversed(ids):
|
||||
if t in protected_set:
|
||||
result.append(t)
|
||||
elif num_non_protected_needed > 0:
|
||||
result.append(t)
|
||||
num_non_protected_needed -= 1
|
||||
# Reverse to restore original order
|
||||
return result[::-1]
|
||||
|
||||
|
||||
TListOrMapping = TypeVar("TListOrMapping", list, Mapping)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user