Compare commits

...

9 Commits

Author SHA1 Message Date
accf7383a3 Release: 0.19.1 2025-07-08 00:36:20 +00:00
ebb08d4f1f ✂️ [BUG when vllm and prompt_truncation are used]: Strip out pad tokens in truncated prompt text (#3698)
Co-authored-by: Shirin Yamani <75791599+shirinyamani@users.noreply.github.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-07-08 00:35:07 +00:00
de44966d78 Fix non-serializable torch.dtype bug in VLLM weight sync (#3690)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-07-08 00:35:04 +00:00
7a2bb48928 📣 Use explicit version for checking datasets version (#3702) 2025-07-08 00:34:59 +00:00
e302e82bf0 Support datasets 4 (#3688)
Co-authored-by: Quentin Lhoest <quentinlhoest@Quentin-Ls-MacBook-Pro.local>
2025-07-08 00:34:55 +00:00
f2b55e0bb3 [SFT] drop attention_mask if we have position ids for fa2 (#3673)
Co-authored-by: Shirin Yamani <75791599+shirinyamani@users.noreply.github.com>
2025-07-08 00:34:51 +00:00
9155c8fbb8 Add paranthesis to correct the check. (#3658) 2025-07-08 00:34:47 +00:00
7dbf4777b4 [GRPO] Make sure special tokens aren't lost when truncating prompt. (#3651) 2025-07-08 00:34:42 +00:00
84982ad793 🐛 fix grpo generation_kwargs (#3634)
Signed-off-by: ahatamizadeh <ahatamizadeh@nvidia.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-07-08 00:34:36 +00:00
9 changed files with 91 additions and 19 deletions

View File

@ -1,6 +1,6 @@
[metadata]
name = trl
version = 0.19.0
version = 0.19.1
description = Train transformer language models with reinforcement learning.
long_description = file: README.md
long_description_content_type = text/markdown

View File

@ -126,6 +126,53 @@ class TestDataCollatorForLanguageModeling(unittest.TestCase):
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1]]))
torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, 4, 5]]))
def test_packing_drops_attention_mask_for_flash_attention(self):
"""Test that when using packing with position_ids, attention_mask is dropped with fa2."""
collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True, return_position_ids=True)
# Simulate packed sequences with position_ids that restart (typical of FFD packing)
examples = [
{
"input_ids": [1, 2, 3, 4, 5, 6, 7, 8], # Packed: [1,2,3] + [4,5] + [6,7,8]
"position_ids": [0, 1, 2, 0, 1, 0, 1, 2], # Position IDs restart for each sequence
}
]
result = collator(examples)
# Verify that attention_mask is NOT present - this allows flash attention to use position_ids
self.assertNotIn("attention_mask", result, "attention_mask should be dropped for packing with position_ids")
# Verify essential keys are present
self.assertIn("input_ids", result)
self.assertIn("position_ids", result)
self.assertIn("labels", result)
# Verify the data is correctly processed
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1, 0, 1, 2]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]))
def test_padding_free_without_position_ids_keeps_attention_mask(self):
"""
Test that padding_free mode without explicit position_ids still creates attention_mask.
"""
collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True, return_position_ids=True)
# Examples without position_ids (not packed)
examples = [{"input_ids": [1, 2, 3, 4, 5]}]
result = collator(examples)
# Should still have attention_mask since no packed position_ids
self.assertIn("attention_mask", result, "attention_mask should be present when no packed position_ids")
self.assertIn("position_ids", result)
self.assertIn("input_ids", result)
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1, 1]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 3, 4]]))
def test_pad_to_multiple_of(self):
"""Test padding to multiple of specified value."""
collator = DataCollatorForLanguageModeling(pad_token_id=0, pad_to_multiple_of=4)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.19.0"
__version__ = "0.19.1"
from typing import TYPE_CHECKING

View File

@ -15,16 +15,26 @@
import logging
from typing import Callable, Literal, Optional, Union
import datasets
from datasets import Dataset, Value
from packaging import version
from transformers import AutoTokenizer
from ..trainer.utils import ConstantLengthDataset
FORMAT_MAPPING = {
"chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}],
"instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)},
}
if version.parse(datasets.__version__) >= version.parse("4.0.0"):
from datasets import List
FORMAT_MAPPING = {
"chatml": List({"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}),
"instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)},
}
else:
FORMAT_MAPPING = {
"chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}],
"instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)},
}
def conversations_formatting_function(

View File

@ -110,21 +110,22 @@ class WeightSyncWorkerExtension:
# The client process that sends updated weights has the highest rank (world_size - 1).
self.client_rank = world_size - 1
def update_named_param(self, name: str, dtype: torch.dtype, shape: Sequence[int]) -> None:
def update_named_param(self, name: str, dtype: str, shape: Sequence[int]) -> None:
"""
Receives updated weights from the client process and updates the named parameter in the model.
Args:
name (`str`):
Name of the weight tensor being updated.
dtype (`torch.dtype`):
Data type of the weight tensor (e.g., `torch.float32`).
dtype (`str`):
Data type of the weight tensor as a string (e.g., `"torch.float32"`).
shape (`Sequence[int]`):
Shape of the weight tensor.
"""
if self.pynccl_comm is None:
raise RuntimeError("Communicator not initialized. Call `init_communicator` first.")
dtype = getattr(torch, dtype.split(".")[-1])
# Allocate memory for the incoming weight tensor on the correct device.
weight = torch.empty(shape, dtype=dtype, device=self.device)
@ -560,11 +561,10 @@ def main(script_args: ScriptArguments):
- `shape` (list of `int`): Shape of the weight
"""
# The function update_named_param is called this way: update_named_param("name", torch.float32, (10, 10))
# The function update_named_param is called this way: update_named_param("name", "torch.float32", (10, 10))
# So with collective_rpc we need to call it this way:
# llm.collective_rpc("update_named_param", args=("name", torch.float32, (10, 10)))
dtype = torch.__getattribute__(request.dtype.split(".")[-1])
kwargs = {"method": "update_named_param", "args": (request.name, dtype, tuple(request.shape))}
kwargs = {"method": "update_named_param", "args": (request.name, request.dtype, tuple(request.shape))}
for connection in connections:
connection.send({"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs})

View File

@ -1327,7 +1327,6 @@ class DPOTrainer(Trainer):
with self.null_ref_context():
ref_outputs = ref_base_model(
input_ids,
attention_mask=attention_mask,
use_cache=False,
**model_kwargs,
)

View File

@ -544,7 +544,7 @@ class GRPOConfig(TrainingArguments):
if self.generation_batch_size is None:
self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation
if self.generation_batch_size % self.per_device_train_batch_size * num_processes != 0:
if self.generation_batch_size % (self.per_device_train_batch_size * num_processes) != 0:
raise ValueError(
f"generation_batch_size ({self.generation_batch_size}) must be divisible by the global batch size "
f"({self.per_device_train_batch_size * num_processes})."

View File

@ -13,6 +13,7 @@
# limitations under the License.
import os
import re
import textwrap
import warnings
from collections import defaultdict, deque
@ -1062,11 +1063,17 @@ class GRPOTrainer(Trainer):
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
if self.max_prompt_length is not None:
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
# Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text,
# because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation).
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
prompts_text = self.processing_class.batch_decode(
prompt_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
)
prompts_text = [
re.sub(rf"^({re.escape(self.processing_class.pad_token)})+", "", text) for text in prompts_text
]
# Generate completions using either vLLM or regular generation
if self.use_vllm:
@ -1124,7 +1131,8 @@ class GRPOTrainer(Trainer):
"max_tokens": self.max_completion_length,
"guided_decoding": guided_decoding,
}
generation_kwargs.update(self.args.generation_kwargs)
if self.args.generation_kwargs is not None:
generation_kwargs.update(self.args.generation_kwargs)
sampling_params = SamplingParams(**generation_kwargs)
if self.vllm_tensor_parallel_size > 1:

View File

@ -186,7 +186,15 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]:
# Convert to tensor
input_ids = [torch.tensor(example["input_ids"]) for example in examples]
attention_mask = [torch.ones_like(input_ids) for input_ids in input_ids]
# Check if we have meaningful position_ids from packing (restarting sequences)
has_packed_position_ids = self.return_position_ids and "position_ids" in examples[0] and self.padding_free
# For packing with position_ids, we should NOT create attention_mask as it causes
# flash attention to ignore position_ids and compute wrong cu_seq_lens from the all-1s mask
if not has_packed_position_ids:
attention_mask = [torch.ones_like(input_ids) for input_ids in input_ids]
if self.return_position_ids:
if "position_ids" in examples[0]:
position_ids = [torch.tensor(example["position_ids"]) for example in examples]
@ -205,7 +213,8 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
output = {}
if self.padding_free:
output["input_ids"] = torch.cat(input_ids, dim=0).unsqueeze(0)
output["attention_mask"] = torch.cat(attention_mask, dim=0).unsqueeze(0)
if not has_packed_position_ids:
output["attention_mask"] = torch.cat(attention_mask, dim=0).unsqueeze(0)
if self.return_position_ids:
output["position_ids"] = torch.cat(position_ids, dim=0).unsqueeze(0)
output["labels"] = torch.cat(labels, dim=0).unsqueeze(0)
@ -215,7 +224,6 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
if "assistant_masks" in examples[0]:
assistant_masks = torch.cat(assistant_masks, dim=0).unsqueeze(0)
output["labels"][assistant_masks == 0] = -100
else:
output["input_ids"] = pad(
input_ids,