mirror of
https://github.com/huggingface/trl.git
synced 2025-10-21 11:33:51 +08:00
Compare commits
6 Commits
max-length
...
v0.22.2
Author | SHA1 | Date | |
---|---|---|---|
2d597e4a18 | |||
e8b1b83ad4 | |||
a436c0a2d5 | |||
32cbb072f3 | |||
1366bac011 | |||
e33c88cc49 |
@ -172,12 +172,6 @@ class DPOTrainerTester(TrlTestCase):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
# get t5 as seq2seq example:
|
||||
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration"
|
||||
self.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
||||
self.t5_ref_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
||||
self.t5_tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
def test_train(self):
|
||||
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
||||
dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
|
||||
@ -255,6 +249,39 @@ class DPOTrainerTester(TrlTestCase):
|
||||
if param.sum() != 0: # ignore 0 biases
|
||||
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
|
||||
|
||||
@require_liger_kernel
|
||||
def test_train_encoder_decoder_liger(self):
|
||||
model_id = "trl-internal-testing/tiny-BartModel"
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
||||
dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
training_args = DPOConfig(
|
||||
output_dir="selftmp_dir",
|
||||
per_device_train_batch_size=2,
|
||||
learning_rate=9e-1,
|
||||
report_to="none",
|
||||
use_liger_loss=True,
|
||||
)
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
processing_class=tokenizer,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
# Check that the parameters have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
if param.sum() != 0: # ignore 0 biases
|
||||
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
|
||||
|
||||
def test_dpo_trainer_with_weighting(self):
|
||||
dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
|
||||
|
||||
|
@ -162,6 +162,18 @@ class TestDataCollatorForLanguageModeling(TrlTestCase):
|
||||
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0], [0, 1, 0, 0]]))
|
||||
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, -100], [4, 5, -100, -100]]))
|
||||
|
||||
def test_pad_to_multiple_of_and_padding_free(self):
|
||||
"""Test padding to multiple of specified value."""
|
||||
collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True, pad_to_multiple_of=4)
|
||||
examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}]
|
||||
|
||||
result = collator(examples)
|
||||
|
||||
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0]]))
|
||||
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1, 1, 0, 0, 0]]))
|
||||
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1, 0, 0, 0]]))
|
||||
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4, 5, -100, -100, -100]]))
|
||||
|
||||
def test_custom_position_ids(self):
|
||||
"""Test handling of custom position IDs in examples."""
|
||||
self.collator = DataCollatorForLanguageModeling(pad_token_id=0)
|
||||
|
@ -12,18 +12,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .import_utils import OptionalDependencyNotAvailable, _LazyModule, is_diffusers_available
|
||||
|
||||
|
||||
# Read version from VERSION file
|
||||
_version_file = Path(__file__).parent.parent / "VERSION"
|
||||
try:
|
||||
with open(_version_file, encoding="utf-8") as f:
|
||||
__version__ = f.read().strip()
|
||||
except FileNotFoundError:
|
||||
__version__ = version("trl")
|
||||
except PackageNotFoundError:
|
||||
__version__ = "unknown"
|
||||
|
||||
_import_structure = {
|
||||
|
@ -663,8 +663,9 @@ def pack_dataset(
|
||||
>>> dataset = Dataset.from_dict(examples)
|
||||
>>> packed_dataset = pack_dataset(dataset, seq_length=4, strategy="bfd")
|
||||
>>> packed_dataset[:]
|
||||
{'input_ids': [[1, 2, 3, 9], [6, 7, 8, 4, 5]],
|
||||
'attention_mask': [[1, 1, 0, 1], [1, 0, 0, 1, 0]]}
|
||||
{'input_ids': [[1, 2, 3, 9], [6, 7, 8], [4, 5]],
|
||||
'attention_mask': [[1, 1, 0, 1], [1, 0, 0], [1, 0]],
|
||||
'seq_lengths': [[3, 1], [3], [2]]}
|
||||
```
|
||||
"""
|
||||
if map_kwargs is None:
|
||||
|
@ -102,6 +102,7 @@ def shift_tokens_right(input_ids: torch.Tensor, decoder_start_token_id: int) ->
|
||||
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
||||
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
|
||||
shifted_input_ids[:, 0] = decoder_start_token_id
|
||||
return shifted_input_ids
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -605,6 +605,8 @@ class GRPOConfig(TrainingArguments):
|
||||
|
||||
super().__post_init__()
|
||||
|
||||
self.scale_rewards = {True: "group", False: "none"}.get(self.scale_rewards, self.scale_rewards)
|
||||
|
||||
num_processes = self.world_size
|
||||
# The current default effective batch size
|
||||
if self.generation_batch_size is None and self.steps_per_generation is None:
|
||||
|
@ -352,7 +352,7 @@ class GRPOTrainer(Trainer):
|
||||
self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode
|
||||
self.use_liger_loss = args.use_liger_loss
|
||||
self.loss_type = args.loss_type
|
||||
self.scale_rewards = {True: "group", False: "none"}.get(args.scale_rewards, args.scale_rewards)
|
||||
self.scale_rewards = args.scale_rewards
|
||||
self.importance_sampling_level = args.importance_sampling_level
|
||||
self.mask_truncated_completions = args.mask_truncated_completions
|
||||
self.top_entropy_quantile = args.top_entropy_quantile
|
||||
@ -1398,11 +1398,11 @@ class GRPOTrainer(Trainer):
|
||||
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
||||
advantages = rewards - mean_grouped_rewards
|
||||
|
||||
if self.scale_rewards in ["batch", "none"]:
|
||||
if self.scale_rewards in ["group", "none"]:
|
||||
# If self.scale_rewards = "none", we'll still log group level std
|
||||
std_rewards = rewards.view(-1, self.num_generations).std(dim=1)
|
||||
std_rewards = std_rewards.repeat_interleave(self.num_generations, dim=0)
|
||||
elif self.scale_rewards == "group":
|
||||
elif self.scale_rewards == "batch":
|
||||
# Compute global std
|
||||
std_rewards = rewards.std().expand_as(rewards)
|
||||
else:
|
||||
@ -1411,7 +1411,7 @@ class GRPOTrainer(Trainer):
|
||||
)
|
||||
|
||||
is_std_zero = torch.isclose(std_rewards, torch.zeros_like(std_rewards))
|
||||
if self.scale_rewards in ["batch", "none"]:
|
||||
if self.scale_rewards != "none":
|
||||
advantages = advantages / (std_rewards + 1e-4)
|
||||
|
||||
# Slice to keep only the local part of the data
|
||||
|
@ -124,7 +124,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
that are no in the completion.
|
||||
padding_free (`bool`, *optional*, defaults to `False`):
|
||||
If set to `True`, the sequences will be flattened into a single sequence, and the position IDs will be
|
||||
generated accordingly. The attention mask will be set to 1 for all tokens.
|
||||
generated accordingly.
|
||||
pad_to_multiple_of (`int` or `None`, *optional*, defaults to `None`):
|
||||
If set, the sequences will be padded to a multiple of this value.
|
||||
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
||||
@ -206,48 +206,48 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
if "assistant_masks" in examples[0]:
|
||||
assistant_masks = [torch.tensor(example["assistant_masks"]) for example in examples]
|
||||
|
||||
# Pad
|
||||
# If padding_free, flatten everything into a single sequence
|
||||
output = {}
|
||||
if self.padding_free:
|
||||
output["input_ids"] = torch.cat(input_ids, dim=0).unsqueeze(0)
|
||||
input_ids = [torch.cat(input_ids, dim=0)]
|
||||
if not has_packed_position_ids:
|
||||
output["attention_mask"] = torch.cat(attention_mask, dim=0).unsqueeze(0)
|
||||
attention_mask = [torch.cat(attention_mask, dim=0)]
|
||||
if self.return_position_ids:
|
||||
output["position_ids"] = torch.cat(position_ids, dim=0).unsqueeze(0)
|
||||
output["labels"] = torch.cat(labels, dim=0).unsqueeze(0)
|
||||
position_ids = [torch.cat(position_ids, dim=0)]
|
||||
labels = [torch.cat(labels, dim=0)]
|
||||
if self.completion_only_loss and "completion_mask" in examples[0]:
|
||||
completion_mask = torch.cat(completion_mask, dim=0).unsqueeze(0)
|
||||
output["labels"][completion_mask == 0] = -100
|
||||
completion_mask = [torch.cat(completion_mask, dim=0)]
|
||||
if "assistant_masks" in examples[0]:
|
||||
assistant_masks = torch.cat(assistant_masks, dim=0).unsqueeze(0)
|
||||
output["labels"][assistant_masks == 0] = -100
|
||||
else:
|
||||
output["input_ids"] = pad(
|
||||
input_ids,
|
||||
padding_value=self.pad_token_id,
|
||||
padding_side="right",
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
)
|
||||
assistant_masks = [torch.cat(assistant_masks, dim=0)]
|
||||
|
||||
# Pad
|
||||
output["input_ids"] = pad(
|
||||
input_ids,
|
||||
padding_value=self.pad_token_id,
|
||||
padding_side="right",
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
)
|
||||
if not has_packed_position_ids:
|
||||
output["attention_mask"] = pad(
|
||||
attention_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
||||
)
|
||||
if self.return_position_ids:
|
||||
output["position_ids"] = pad(
|
||||
position_ids, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
||||
)
|
||||
output["labels"] = pad(
|
||||
labels, padding_value=-100, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
||||
if self.return_position_ids:
|
||||
output["position_ids"] = pad(
|
||||
position_ids, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
||||
)
|
||||
if self.completion_only_loss and "completion_mask" in examples[0]:
|
||||
completion_mask = pad(
|
||||
completion_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
||||
)
|
||||
output["labels"][completion_mask == 0] = -100 # mask everything that is not in the completion
|
||||
if "assistant_masks" in examples[0]:
|
||||
assistant_masks = pad(
|
||||
assistant_masks, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
||||
)
|
||||
output["labels"][assistant_masks == 0] = -100
|
||||
output["labels"] = pad(
|
||||
labels, padding_value=-100, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
||||
)
|
||||
if self.completion_only_loss and "completion_mask" in examples[0]:
|
||||
completion_mask = pad(
|
||||
completion_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
||||
)
|
||||
output["labels"][completion_mask == 0] = -100 # mask everything that is not in the completion
|
||||
if "assistant_masks" in examples[0]:
|
||||
assistant_masks = pad(
|
||||
assistant_masks, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
||||
)
|
||||
output["labels"][assistant_masks == 0] = -100
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
|
Reference in New Issue
Block a user