Compare commits

...

6 Commits

9 changed files with 92 additions and 51 deletions

View File

@ -1 +1 @@
0.22.0
0.22.2

View File

@ -172,12 +172,6 @@ class DPOTrainerTester(TrlTestCase):
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.tokenizer.pad_token = self.tokenizer.eos_token
# get t5 as seq2seq example:
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration"
self.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
self.t5_ref_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
self.t5_tokenizer = AutoTokenizer.from_pretrained(model_id)
def test_train(self):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
@ -255,6 +249,39 @@ class DPOTrainerTester(TrlTestCase):
if param.sum() != 0: # ignore 0 biases
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
@require_liger_kernel
def test_train_encoder_decoder_liger(self):
model_id = "trl-internal-testing/tiny-BartModel"
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
tokenizer = AutoTokenizer.from_pretrained(model_id)
training_args = DPOConfig(
output_dir="selftmp_dir",
per_device_train_batch_size=2,
learning_rate=9e-1,
report_to="none",
use_liger_loss=True,
)
trainer = DPOTrainer(
model=model,
args=training_args,
processing_class=tokenizer,
train_dataset=dataset,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check that the parameters have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
def test_dpo_trainer_with_weighting(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")

View File

@ -162,6 +162,18 @@ class TestDataCollatorForLanguageModeling(TrlTestCase):
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0], [0, 1, 0, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, -100], [4, 5, -100, -100]]))
def test_pad_to_multiple_of_and_padding_free(self):
"""Test padding to multiple of specified value."""
collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True, pad_to_multiple_of=4)
examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}]
result = collator(examples)
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1, 1, 0, 0, 0]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1, 0, 0, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4, 5, -100, -100, -100]]))
def test_custom_position_ids(self):
"""Test handling of custom position IDs in examples."""
self.collator = DataCollatorForLanguageModeling(pad_token_id=0)

View File

@ -12,18 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from importlib.metadata import PackageNotFoundError, version
from pathlib import Path
from typing import TYPE_CHECKING
from .import_utils import OptionalDependencyNotAvailable, _LazyModule, is_diffusers_available
# Read version from VERSION file
_version_file = Path(__file__).parent.parent / "VERSION"
try:
with open(_version_file, encoding="utf-8") as f:
__version__ = f.read().strip()
except FileNotFoundError:
__version__ = version("trl")
except PackageNotFoundError:
__version__ = "unknown"
_import_structure = {

View File

@ -663,8 +663,9 @@ def pack_dataset(
>>> dataset = Dataset.from_dict(examples)
>>> packed_dataset = pack_dataset(dataset, seq_length=4, strategy="bfd")
>>> packed_dataset[:]
{'input_ids': [[1, 2, 3, 9], [6, 7, 8, 4, 5]],
'attention_mask': [[1, 1, 0, 1], [1, 0, 0, 1, 0]]}
{'input_ids': [[1, 2, 3, 9], [6, 7, 8], [4, 5]],
'attention_mask': [[1, 1, 0, 1], [1, 0, 0], [1, 0]],
'seq_lengths': [[3, 1], [3], [2]]}
```
"""
if map_kwargs is None:

View File

@ -102,6 +102,7 @@ def shift_tokens_right(input_ids: torch.Tensor, decoder_start_token_id: int) ->
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[:, 0] = decoder_start_token_id
return shifted_input_ids
@dataclass

View File

@ -605,6 +605,8 @@ class GRPOConfig(TrainingArguments):
super().__post_init__()
self.scale_rewards = {True: "group", False: "none"}.get(self.scale_rewards, self.scale_rewards)
num_processes = self.world_size
# The current default effective batch size
if self.generation_batch_size is None and self.steps_per_generation is None:

View File

@ -352,7 +352,7 @@ class GRPOTrainer(Trainer):
self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode
self.use_liger_loss = args.use_liger_loss
self.loss_type = args.loss_type
self.scale_rewards = {True: "group", False: "none"}.get(args.scale_rewards, args.scale_rewards)
self.scale_rewards = args.scale_rewards
self.importance_sampling_level = args.importance_sampling_level
self.mask_truncated_completions = args.mask_truncated_completions
self.top_entropy_quantile = args.top_entropy_quantile
@ -1398,11 +1398,11 @@ class GRPOTrainer(Trainer):
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = rewards - mean_grouped_rewards
if self.scale_rewards in ["batch", "none"]:
if self.scale_rewards in ["group", "none"]:
# If self.scale_rewards = "none", we'll still log group level std
std_rewards = rewards.view(-1, self.num_generations).std(dim=1)
std_rewards = std_rewards.repeat_interleave(self.num_generations, dim=0)
elif self.scale_rewards == "group":
elif self.scale_rewards == "batch":
# Compute global std
std_rewards = rewards.std().expand_as(rewards)
else:
@ -1411,7 +1411,7 @@ class GRPOTrainer(Trainer):
)
is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards))
if self.scale_rewards in ["batch", "none"]:
if self.scale_rewards != "none":
advantages = advantages / (std_rewards + 1e-4)
# Slice to keep only the local part of the data

View File

@ -124,7 +124,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
that are no in the completion.
padding_free (`bool`, *optional*, defaults to `False`):
If set to `True`, the sequences will be flattened into a single sequence, and the position IDs will be
generated accordingly. The attention mask will be set to 1 for all tokens.
generated accordingly.
pad_to_multiple_of (`int` or `None`, *optional*, defaults to `None`):
If set, the sequences will be padded to a multiple of this value.
return_tensors (`str`, *optional*, defaults to `"pt"`):
@ -206,48 +206,48 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
if "assistant_masks" in examples[0]:
assistant_masks = [torch.tensor(example["assistant_masks"]) for example in examples]
# Pad
# If padding_free, flatten everything into a single sequence
output = {}
if self.padding_free:
output["input_ids"] = torch.cat(input_ids, dim=0).unsqueeze(0)
input_ids = [torch.cat(input_ids, dim=0)]
if not has_packed_position_ids:
output["attention_mask"] = torch.cat(attention_mask, dim=0).unsqueeze(0)
attention_mask = [torch.cat(attention_mask, dim=0)]
if self.return_position_ids:
output["position_ids"] = torch.cat(position_ids, dim=0).unsqueeze(0)
output["labels"] = torch.cat(labels, dim=0).unsqueeze(0)
position_ids = [torch.cat(position_ids, dim=0)]
labels = [torch.cat(labels, dim=0)]
if self.completion_only_loss and "completion_mask" in examples[0]:
completion_mask = torch.cat(completion_mask, dim=0).unsqueeze(0)
output["labels"][completion_mask == 0] = -100
completion_mask = [torch.cat(completion_mask, dim=0)]
if "assistant_masks" in examples[0]:
assistant_masks = torch.cat(assistant_masks, dim=0).unsqueeze(0)
output["labels"][assistant_masks == 0] = -100
else:
output["input_ids"] = pad(
input_ids,
padding_value=self.pad_token_id,
padding_side="right",
pad_to_multiple_of=self.pad_to_multiple_of,
)
assistant_masks = [torch.cat(assistant_masks, dim=0)]
# Pad
output["input_ids"] = pad(
input_ids,
padding_value=self.pad_token_id,
padding_side="right",
pad_to_multiple_of=self.pad_to_multiple_of,
)
if not has_packed_position_ids:
output["attention_mask"] = pad(
attention_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
if self.return_position_ids:
output["position_ids"] = pad(
position_ids, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
output["labels"] = pad(
labels, padding_value=-100, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
if self.return_position_ids:
output["position_ids"] = pad(
position_ids, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
if self.completion_only_loss and "completion_mask" in examples[0]:
completion_mask = pad(
completion_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
output["labels"][completion_mask == 0] = -100 # mask everything that is not in the completion
if "assistant_masks" in examples[0]:
assistant_masks = pad(
assistant_masks, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
output["labels"][assistant_masks == 0] = -100
output["labels"] = pad(
labels, padding_value=-100, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
if self.completion_only_loss and "completion_mask" in examples[0]:
completion_mask = pad(
completion_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
output["labels"][completion_mask == 0] = -100 # mask everything that is not in the completion
if "assistant_masks" in examples[0]:
assistant_masks = pad(
assistant_masks, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
output["labels"][assistant_masks == 0] = -100
return output
@staticmethod