mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Compare commits
9 Commits
aa25c2697c
...
v0.19-rele
Author | SHA1 | Date | |
---|---|---|---|
accf7383a3 | |||
ebb08d4f1f | |||
de44966d78 | |||
7a2bb48928 | |||
e302e82bf0 | |||
f2b55e0bb3 | |||
9155c8fbb8 | |||
7dbf4777b4 | |||
84982ad793 |
@ -1,6 +1,6 @@
|
||||
[metadata]
|
||||
name = trl
|
||||
version = 0.19.0
|
||||
version = 0.19.1
|
||||
description = Train transformer language models with reinforcement learning.
|
||||
long_description = file: README.md
|
||||
long_description_content_type = text/markdown
|
||||
|
@ -126,6 +126,53 @@ class TestDataCollatorForLanguageModeling(unittest.TestCase):
|
||||
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1]]))
|
||||
torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, 4, 5]]))
|
||||
|
||||
def test_packing_drops_attention_mask_for_flash_attention(self):
|
||||
"""Test that when using packing with position_ids, attention_mask is dropped with fa2."""
|
||||
collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True, return_position_ids=True)
|
||||
|
||||
# Simulate packed sequences with position_ids that restart (typical of FFD packing)
|
||||
examples = [
|
||||
{
|
||||
"input_ids": [1, 2, 3, 4, 5, 6, 7, 8], # Packed: [1,2,3] + [4,5] + [6,7,8]
|
||||
"position_ids": [0, 1, 2, 0, 1, 0, 1, 2], # Position IDs restart for each sequence
|
||||
}
|
||||
]
|
||||
|
||||
result = collator(examples)
|
||||
|
||||
# Verify that attention_mask is NOT present - this allows flash attention to use position_ids
|
||||
self.assertNotIn("attention_mask", result, "attention_mask should be dropped for packing with position_ids")
|
||||
|
||||
# Verify essential keys are present
|
||||
self.assertIn("input_ids", result)
|
||||
self.assertIn("position_ids", result)
|
||||
self.assertIn("labels", result)
|
||||
|
||||
# Verify the data is correctly processed
|
||||
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]))
|
||||
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1, 0, 1, 2]]))
|
||||
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]))
|
||||
|
||||
def test_padding_free_without_position_ids_keeps_attention_mask(self):
|
||||
"""
|
||||
Test that padding_free mode without explicit position_ids still creates attention_mask.
|
||||
"""
|
||||
collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True, return_position_ids=True)
|
||||
|
||||
# Examples without position_ids (not packed)
|
||||
examples = [{"input_ids": [1, 2, 3, 4, 5]}]
|
||||
|
||||
result = collator(examples)
|
||||
|
||||
# Should still have attention_mask since no packed position_ids
|
||||
self.assertIn("attention_mask", result, "attention_mask should be present when no packed position_ids")
|
||||
self.assertIn("position_ids", result)
|
||||
self.assertIn("input_ids", result)
|
||||
|
||||
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]]))
|
||||
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1, 1]]))
|
||||
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 3, 4]]))
|
||||
|
||||
def test_pad_to_multiple_of(self):
|
||||
"""Test padding to multiple of specified value."""
|
||||
collator = DataCollatorForLanguageModeling(pad_token_id=0, pad_to_multiple_of=4)
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__version__ = "0.19.0"
|
||||
__version__ = "0.19.1"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
@ -15,16 +15,26 @@
|
||||
import logging
|
||||
from typing import Callable, Literal, Optional, Union
|
||||
|
||||
import datasets
|
||||
from datasets import Dataset, Value
|
||||
from packaging import version
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from ..trainer.utils import ConstantLengthDataset
|
||||
|
||||
|
||||
FORMAT_MAPPING = {
|
||||
"chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}],
|
||||
"instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)},
|
||||
}
|
||||
if version.parse(datasets.__version__) >= version.parse("4.0.0"):
|
||||
from datasets import List
|
||||
|
||||
FORMAT_MAPPING = {
|
||||
"chatml": List({"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}),
|
||||
"instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)},
|
||||
}
|
||||
else:
|
||||
FORMAT_MAPPING = {
|
||||
"chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}],
|
||||
"instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)},
|
||||
}
|
||||
|
||||
|
||||
def conversations_formatting_function(
|
||||
|
@ -110,21 +110,22 @@ class WeightSyncWorkerExtension:
|
||||
# The client process that sends updated weights has the highest rank (world_size - 1).
|
||||
self.client_rank = world_size - 1
|
||||
|
||||
def update_named_param(self, name: str, dtype: torch.dtype, shape: Sequence[int]) -> None:
|
||||
def update_named_param(self, name: str, dtype: str, shape: Sequence[int]) -> None:
|
||||
"""
|
||||
Receives updated weights from the client process and updates the named parameter in the model.
|
||||
|
||||
Args:
|
||||
name (`str`):
|
||||
Name of the weight tensor being updated.
|
||||
dtype (`torch.dtype`):
|
||||
Data type of the weight tensor (e.g., `torch.float32`).
|
||||
dtype (`str`):
|
||||
Data type of the weight tensor as a string (e.g., `"torch.float32"`).
|
||||
shape (`Sequence[int]`):
|
||||
Shape of the weight tensor.
|
||||
"""
|
||||
if self.pynccl_comm is None:
|
||||
raise RuntimeError("Communicator not initialized. Call `init_communicator` first.")
|
||||
|
||||
dtype = getattr(torch, dtype.split(".")[-1])
|
||||
# Allocate memory for the incoming weight tensor on the correct device.
|
||||
weight = torch.empty(shape, dtype=dtype, device=self.device)
|
||||
|
||||
@ -560,11 +561,10 @@ def main(script_args: ScriptArguments):
|
||||
- `shape` (list of `int`): Shape of the weight
|
||||
|
||||
"""
|
||||
# The function update_named_param is called this way: update_named_param("name", torch.float32, (10, 10))
|
||||
# The function update_named_param is called this way: update_named_param("name", "torch.float32", (10, 10))
|
||||
# So with collective_rpc we need to call it this way:
|
||||
# llm.collective_rpc("update_named_param", args=("name", torch.float32, (10, 10)))
|
||||
dtype = torch.__getattribute__(request.dtype.split(".")[-1])
|
||||
kwargs = {"method": "update_named_param", "args": (request.name, dtype, tuple(request.shape))}
|
||||
kwargs = {"method": "update_named_param", "args": (request.name, request.dtype, tuple(request.shape))}
|
||||
for connection in connections:
|
||||
connection.send({"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs})
|
||||
|
||||
|
@ -1327,7 +1327,6 @@ class DPOTrainer(Trainer):
|
||||
with self.null_ref_context():
|
||||
ref_outputs = ref_base_model(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=False,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
@ -544,7 +544,7 @@ class GRPOConfig(TrainingArguments):
|
||||
if self.generation_batch_size is None:
|
||||
self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation
|
||||
|
||||
if self.generation_batch_size % self.per_device_train_batch_size * num_processes != 0:
|
||||
if self.generation_batch_size % (self.per_device_train_batch_size * num_processes) != 0:
|
||||
raise ValueError(
|
||||
f"generation_batch_size ({self.generation_batch_size}) must be divisible by the global batch size "
|
||||
f"({self.per_device_train_batch_size * num_processes})."
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import re
|
||||
import textwrap
|
||||
import warnings
|
||||
from collections import defaultdict, deque
|
||||
@ -1062,11 +1063,17 @@ class GRPOTrainer(Trainer):
|
||||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
||||
|
||||
if self.max_prompt_length is not None:
|
||||
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
|
||||
# Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text,
|
||||
# because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation).
|
||||
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
|
||||
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
|
||||
prompts_text = self.processing_class.batch_decode(
|
||||
prompt_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
||||
)
|
||||
prompts_text = [
|
||||
re.sub(rf"^({re.escape(self.processing_class.pad_token)})+", "", text) for text in prompts_text
|
||||
]
|
||||
|
||||
# Generate completions using either vLLM or regular generation
|
||||
if self.use_vllm:
|
||||
@ -1124,7 +1131,8 @@ class GRPOTrainer(Trainer):
|
||||
"max_tokens": self.max_completion_length,
|
||||
"guided_decoding": guided_decoding,
|
||||
}
|
||||
generation_kwargs.update(self.args.generation_kwargs)
|
||||
if self.args.generation_kwargs is not None:
|
||||
generation_kwargs.update(self.args.generation_kwargs)
|
||||
sampling_params = SamplingParams(**generation_kwargs)
|
||||
|
||||
if self.vllm_tensor_parallel_size > 1:
|
||||
|
@ -186,7 +186,15 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]:
|
||||
# Convert to tensor
|
||||
input_ids = [torch.tensor(example["input_ids"]) for example in examples]
|
||||
attention_mask = [torch.ones_like(input_ids) for input_ids in input_ids]
|
||||
|
||||
# Check if we have meaningful position_ids from packing (restarting sequences)
|
||||
has_packed_position_ids = self.return_position_ids and "position_ids" in examples[0] and self.padding_free
|
||||
|
||||
# For packing with position_ids, we should NOT create attention_mask as it causes
|
||||
# flash attention to ignore position_ids and compute wrong cu_seq_lens from the all-1s mask
|
||||
if not has_packed_position_ids:
|
||||
attention_mask = [torch.ones_like(input_ids) for input_ids in input_ids]
|
||||
|
||||
if self.return_position_ids:
|
||||
if "position_ids" in examples[0]:
|
||||
position_ids = [torch.tensor(example["position_ids"]) for example in examples]
|
||||
@ -205,7 +213,8 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
output = {}
|
||||
if self.padding_free:
|
||||
output["input_ids"] = torch.cat(input_ids, dim=0).unsqueeze(0)
|
||||
output["attention_mask"] = torch.cat(attention_mask, dim=0).unsqueeze(0)
|
||||
if not has_packed_position_ids:
|
||||
output["attention_mask"] = torch.cat(attention_mask, dim=0).unsqueeze(0)
|
||||
if self.return_position_ids:
|
||||
output["position_ids"] = torch.cat(position_ids, dim=0).unsqueeze(0)
|
||||
output["labels"] = torch.cat(labels, dim=0).unsqueeze(0)
|
||||
@ -215,7 +224,6 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
if "assistant_masks" in examples[0]:
|
||||
assistant_masks = torch.cat(assistant_masks, dim=0).unsqueeze(0)
|
||||
output["labels"][assistant_masks == 0] = -100
|
||||
|
||||
else:
|
||||
output["input_ids"] = pad(
|
||||
input_ids,
|
||||
|
Reference in New Issue
Block a user