mirror of
https://github.com/huggingface/trl.git
synced 2025-11-06 14:24:29 +08:00
Compare commits
4 Commits
docs/exten
...
v0.16.1
| Author | SHA1 | Date | |
|---|---|---|---|
| 2bc182c4fb | |||
| a4458afb5e | |||
| 40a5e9571c | |||
| e792f872f6 |
2
setup.py
2
setup.py
@ -69,7 +69,7 @@ To create the package for PyPI.
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
__version__ = "0.16.0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
__version__ = "0.16.1" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
|
||||
REQUIRED_PKGS = [
|
||||
"accelerate>=0.34.0",
|
||||
|
||||
@ -105,7 +105,6 @@ class SFTTrainerSlowTester(unittest.TestCase):
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
@ -140,7 +139,6 @@ class SFTTrainerSlowTester(unittest.TestCase):
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
@ -177,7 +175,6 @@ class SFTTrainerSlowTester(unittest.TestCase):
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
@ -213,7 +210,6 @@ class SFTTrainerSlowTester(unittest.TestCase):
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
@ -250,7 +246,6 @@ class SFTTrainerSlowTester(unittest.TestCase):
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
@ -294,7 +289,6 @@ class SFTTrainerSlowTester(unittest.TestCase):
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
@ -334,7 +328,6 @@ class SFTTrainerSlowTester(unittest.TestCase):
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
@ -380,7 +373,6 @@ class SFTTrainerSlowTester(unittest.TestCase):
|
||||
|
||||
if tokenizer.chat_template is None:
|
||||
model, tokenizer = setup_chat_format(model, tokenizer)
|
||||
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
|
||||
@ -33,6 +33,7 @@ from transformers.utils import is_peft_available
|
||||
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
from trl.trainer import ConstantLengthDataset, DataCollatorForCompletionOnlyLM
|
||||
from trl.trainer.sft_trainer import DataCollatorForLanguageModeling
|
||||
|
||||
|
||||
def formatting_prompts_func(example):
|
||||
@ -59,6 +60,34 @@ if is_vision_available():
|
||||
from PIL import Image as PILImage
|
||||
|
||||
|
||||
class TestDataCollatorForLanguageModeling(unittest.TestCase):
|
||||
def test_collate_padding(self):
|
||||
collator = DataCollatorForLanguageModeling(pad_token_id=0)
|
||||
examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}]
|
||||
output = collator(examples)
|
||||
|
||||
expected_input_ids = torch.tensor([[1, 2, 3], [4, 5, 0]])
|
||||
expected_attention_mask = torch.tensor([[1, 1, 1], [1, 1, 0]])
|
||||
expected_labels = torch.tensor([[1, 2, 3], [4, 5, -100]])
|
||||
|
||||
self.assertEqual(output["input_ids"].tolist(), expected_input_ids.tolist())
|
||||
self.assertEqual(output["attention_mask"].tolist(), expected_attention_mask.tolist())
|
||||
self.assertEqual(output["labels"].tolist(), expected_labels.tolist())
|
||||
|
||||
def test_collate_no_padding(self):
|
||||
collator = DataCollatorForLanguageModeling(pad_token_id=0)
|
||||
examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5, 6]}]
|
||||
output = collator(examples)
|
||||
|
||||
expected_input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
|
||||
expected_attention_mask = torch.tensor([[1, 1, 1], [1, 1, 1]])
|
||||
expected_labels = torch.tensor([[1, 2, 3], [4, 5, 6]])
|
||||
|
||||
self.assertEqual(output["input_ids"].tolist(), expected_input_ids.tolist())
|
||||
self.assertEqual(output["attention_mask"].tolist(), expected_attention_mask.tolist())
|
||||
self.assertEqual(output["labels"].tolist(), expected_labels.tolist())
|
||||
|
||||
|
||||
class SFTTrainerTester(unittest.TestCase):
|
||||
r""" """
|
||||
|
||||
@ -66,7 +95,6 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
||||
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
self.dummy_dataset = Dataset.from_dict(
|
||||
{
|
||||
"question": [
|
||||
|
||||
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__version__ = "0.16.0"
|
||||
__version__ = "0.16.1"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
@ -98,8 +98,6 @@ def main(script_args, training_args, model_args):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
|
||||
)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
################
|
||||
# Dataset
|
||||
|
||||
@ -695,7 +695,9 @@ class OnlineDPOTrainer(Trainer):
|
||||
|
||||
# Same as Trainer._maybe_log_save_evaluate but log our metrics
|
||||
# start_time defaults to None to allow compatibility with transformers<=4.46
|
||||
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None):
|
||||
def _maybe_log_save_evaluate(
|
||||
self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None, learning_rate=None
|
||||
):
|
||||
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
|
||||
logs: dict[str, float] = {}
|
||||
|
||||
@ -708,7 +710,10 @@ class OnlineDPOTrainer(Trainer):
|
||||
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
|
||||
if grad_norm is not None:
|
||||
logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
|
||||
logs["learning_rate"] = self._get_learning_rate()
|
||||
if learning_rate is not None:
|
||||
logs["learning_rate"] = learning_rate
|
||||
else:
|
||||
logs["learning_rate"] = self._get_learning_rate()
|
||||
|
||||
# Add our metrics
|
||||
for key, val in self.stats.items():
|
||||
|
||||
@ -47,6 +47,9 @@ class SFTConfig(TrainingArguments):
|
||||
`skip_prepare_dataset`.
|
||||
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
||||
Number of processes to use for processing the dataset.
|
||||
pad_token (`str` or `None`, *optional*, defaults to `None`):
|
||||
Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`,
|
||||
it falls back to `processing_class.eos_token`.
|
||||
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
||||
Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right.
|
||||
If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length.
|
||||
@ -92,6 +95,13 @@ class SFTConfig(TrainingArguments):
|
||||
default=None,
|
||||
metadata={"help": "Number of processes to use for processing the dataset."},
|
||||
)
|
||||
pad_token: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that "
|
||||
"is also `None`, it falls back to `processing_class.eos_token`."
|
||||
},
|
||||
)
|
||||
max_length: Optional[int] = field(
|
||||
default=1024,
|
||||
metadata={
|
||||
|
||||
@ -16,6 +16,7 @@ import dataclasses
|
||||
import os
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
@ -29,7 +30,6 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
BaseImageProcessor,
|
||||
DataCollator,
|
||||
DataCollatorForLanguageModeling,
|
||||
DataCollatorWithFlattening,
|
||||
FeatureExtractionMixin,
|
||||
PreTrainedModel,
|
||||
@ -39,6 +39,7 @@ from transformers import (
|
||||
TrainingArguments,
|
||||
is_wandb_available,
|
||||
)
|
||||
from transformers.data.data_collator import DataCollatorMixin
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_utils import EvalPrediction
|
||||
from transformers.utils import is_peft_available
|
||||
@ -51,7 +52,13 @@ from ..data_utils import (
|
||||
truncate_dataset,
|
||||
)
|
||||
from .sft_config import SFTConfig
|
||||
from .utils import ConstantLengthDataset, generate_model_card, get_comet_experiment_url, peft_module_casting_to_bf16
|
||||
from .utils import (
|
||||
ConstantLengthDataset,
|
||||
generate_model_card,
|
||||
get_comet_experiment_url,
|
||||
pad,
|
||||
peft_module_casting_to_bf16,
|
||||
)
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
@ -62,6 +69,54 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
"""
|
||||
Data collator used for language modeling data. Inputs are dynamically padded to the maximum length of a batch if
|
||||
they are not all of the same length.
|
||||
|
||||
Args:
|
||||
pad_token_id (`int`):
|
||||
Token ID to use for padding.
|
||||
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
||||
Type of Tensor to return. Only `"pt"` is currently supported.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
>>> from trl import DataCollatorForLanguageModeling
|
||||
>>> collator = DataCollatorForLanguageModeling(pad_token_id=0)
|
||||
>>> examples = [
|
||||
... {"input_ids": [1, 2, 3]},
|
||||
... {"input_ids": [4, 5]}
|
||||
... ]
|
||||
>>> collator(examples)
|
||||
{'input_ids': tensor([[ 1, 2, 3],
|
||||
[ 4, 5, 0]]),
|
||||
'attention_mask': tensor([[ 1, 1, 1],
|
||||
[ 1, 1, 0]]),
|
||||
'labels': tensor([[ 1, 2, 3],
|
||||
[ 4, 5, -100]])}
|
||||
```
|
||||
"""
|
||||
|
||||
pad_token_id: int
|
||||
return_tensors: str = "pt"
|
||||
|
||||
def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]:
|
||||
# Convert to tensor
|
||||
input_ids = [torch.tensor(example["input_ids"]) for example in examples]
|
||||
attention_mask = [torch.ones_like(input_ids) for input_ids in input_ids]
|
||||
labels = [torch.tensor(example["input_ids"]) for example in examples]
|
||||
|
||||
# Pad
|
||||
output = {}
|
||||
output["input_ids"] = pad(input_ids, padding_value=self.pad_token_id, padding_side="right")
|
||||
output["attention_mask"] = pad(attention_mask, padding_value=0, padding_side="right")
|
||||
output["labels"] = pad(labels, padding_value=-100, padding_side="right")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class SFTTrainer(Trainer):
|
||||
"""
|
||||
Trainer for Supervised Fine-Tuning (SFT) method.
|
||||
@ -159,9 +214,9 @@ class SFTTrainer(Trainer):
|
||||
formatting_func: Optional[Union[Callable[[dict], str], Callable[[dict], list[str]]]] = None,
|
||||
):
|
||||
# Args
|
||||
model_id = model if isinstance(model, str) else model.config._name_or_path
|
||||
if args is None:
|
||||
model_name = model if isinstance(model, str) else model.config._name_or_path
|
||||
model_name = model_name.split("/")[-1]
|
||||
model_name = model_id.split("/")[-1]
|
||||
args = SFTConfig(f"{model_name}-SFT")
|
||||
elif isinstance(args, TrainingArguments) and not isinstance(args, SFTConfig):
|
||||
dict_args = args.to_dict()
|
||||
@ -169,42 +224,9 @@ class SFTTrainer(Trainer):
|
||||
dict_args.pop("push_to_hub_token")
|
||||
args = SFTConfig(**dict_args)
|
||||
|
||||
# Model
|
||||
if args.model_init_kwargs is not None and not isinstance(model, str):
|
||||
warnings.warn(
|
||||
"You passed model_init_kwargs to the `SFTConfig`, but your model is already instantiated. "
|
||||
"The `model_init_kwargs` will be ignored."
|
||||
)
|
||||
if isinstance(model, str):
|
||||
model = self._create_model_from_path(model, args)
|
||||
|
||||
# PEFT configuration and model wrapping
|
||||
if peft_config is not None:
|
||||
model = self._prepare_peft_model(model, peft_config, args)
|
||||
|
||||
# Handle the tokenizer
|
||||
if processing_class is None:
|
||||
processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path)
|
||||
if processing_class.pad_token is None:
|
||||
processing_class.pad_token = processing_class.eos_token # required for padding when collating data
|
||||
|
||||
# Dataset
|
||||
preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False)
|
||||
if preprocess_dataset:
|
||||
train_dataset = self._prepare_dataset(
|
||||
train_dataset, processing_class, args, args.packing, formatting_func, "train"
|
||||
)
|
||||
if eval_dataset is not None:
|
||||
packing = args.packing if args.eval_packing is None else args.eval_packing
|
||||
if isinstance(eval_dataset, dict):
|
||||
eval_dataset = {
|
||||
key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key)
|
||||
for key, dataset in eval_dataset.items()
|
||||
}
|
||||
else:
|
||||
eval_dataset = self._prepare_dataset(
|
||||
eval_dataset, processing_class, args, packing, formatting_func, "eval"
|
||||
)
|
||||
processing_class = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
# Data collator
|
||||
if args.padding_free:
|
||||
@ -233,7 +255,48 @@ class SFTTrainer(Trainer):
|
||||
data_collator = DataCollatorWithFlattening()
|
||||
|
||||
if data_collator is None:
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=processing_class, mlm=False)
|
||||
# Get the pad token: if not provided, use the one from the processing class or the eos token
|
||||
# if the processing class does not have a pad token.
|
||||
pad_token = args.pad_token or processing_class.pad_token or processing_class.eos_token
|
||||
pad_token_id = processing_class.convert_tokens_to_ids(pad_token)
|
||||
if pad_token_id is None:
|
||||
raise ValueError(
|
||||
f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given "
|
||||
f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists "
|
||||
"in the vocabulary before using it as a padding token."
|
||||
)
|
||||
data_collator = DataCollatorForLanguageModeling(pad_token_id)
|
||||
|
||||
# Model
|
||||
if args.model_init_kwargs is not None and not isinstance(model, str):
|
||||
warnings.warn(
|
||||
"You passed model_init_kwargs to the `SFTConfig`, but your model is already instantiated. "
|
||||
"The `model_init_kwargs` will be ignored."
|
||||
)
|
||||
if isinstance(model, str):
|
||||
model = self._create_model_from_path(model, args)
|
||||
|
||||
# PEFT configuration and model wrapping
|
||||
if peft_config is not None:
|
||||
model = self._prepare_peft_model(model, peft_config, args)
|
||||
|
||||
# Dataset
|
||||
preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False)
|
||||
if preprocess_dataset:
|
||||
train_dataset = self._prepare_dataset(
|
||||
train_dataset, processing_class, args, args.packing, formatting_func, "train"
|
||||
)
|
||||
if eval_dataset is not None:
|
||||
packing = args.packing if args.eval_packing is None else args.eval_packing
|
||||
if isinstance(eval_dataset, dict):
|
||||
eval_dataset = {
|
||||
key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key)
|
||||
for key, dataset in eval_dataset.items()
|
||||
}
|
||||
else:
|
||||
eval_dataset = self._prepare_dataset(
|
||||
eval_dataset, processing_class, args, packing, formatting_func, "eval"
|
||||
)
|
||||
|
||||
# Initialize the metrics
|
||||
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
|
||||
|
||||
Reference in New Issue
Block a user