Compare commits

...

4 Commits

8 changed files with 150 additions and 54 deletions

View File

@ -69,7 +69,7 @@ To create the package for PyPI.
from setuptools import find_packages, setup
__version__ = "0.16.0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
__version__ = "0.16.1" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
REQUIRED_PKGS = [
"accelerate>=0.34.0",

View File

@ -105,7 +105,6 @@ class SFTTrainerSlowTester(unittest.TestCase):
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
trainer = SFTTrainer(
model,
@ -140,7 +139,6 @@ class SFTTrainerSlowTester(unittest.TestCase):
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
trainer = SFTTrainer(
model,
@ -177,7 +175,6 @@ class SFTTrainerSlowTester(unittest.TestCase):
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
trainer = SFTTrainer(
model,
@ -213,7 +210,6 @@ class SFTTrainerSlowTester(unittest.TestCase):
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
trainer = SFTTrainer(
model,
@ -250,7 +246,6 @@ class SFTTrainerSlowTester(unittest.TestCase):
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
trainer = SFTTrainer(
model,
@ -294,7 +289,6 @@ class SFTTrainerSlowTester(unittest.TestCase):
model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
trainer = SFTTrainer(
model,
@ -334,7 +328,6 @@ class SFTTrainerSlowTester(unittest.TestCase):
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
trainer = SFTTrainer(
model,
@ -380,7 +373,6 @@ class SFTTrainerSlowTester(unittest.TestCase):
if tokenizer.chat_template is None:
model, tokenizer = setup_chat_format(model, tokenizer)
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
trainer = SFTTrainer(
model,

View File

@ -33,6 +33,7 @@ from transformers.utils import is_peft_available
from trl import SFTConfig, SFTTrainer
from trl.trainer import ConstantLengthDataset, DataCollatorForCompletionOnlyLM
from trl.trainer.sft_trainer import DataCollatorForLanguageModeling
def formatting_prompts_func(example):
@ -59,6 +60,34 @@ if is_vision_available():
from PIL import Image as PILImage
class TestDataCollatorForLanguageModeling(unittest.TestCase):
def test_collate_padding(self):
collator = DataCollatorForLanguageModeling(pad_token_id=0)
examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}]
output = collator(examples)
expected_input_ids = torch.tensor([[1, 2, 3], [4, 5, 0]])
expected_attention_mask = torch.tensor([[1, 1, 1], [1, 1, 0]])
expected_labels = torch.tensor([[1, 2, 3], [4, 5, -100]])
self.assertEqual(output["input_ids"].tolist(), expected_input_ids.tolist())
self.assertEqual(output["attention_mask"].tolist(), expected_attention_mask.tolist())
self.assertEqual(output["labels"].tolist(), expected_labels.tolist())
def test_collate_no_padding(self):
collator = DataCollatorForLanguageModeling(pad_token_id=0)
examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5, 6]}]
output = collator(examples)
expected_input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
expected_attention_mask = torch.tensor([[1, 1, 1], [1, 1, 1]])
expected_labels = torch.tensor([[1, 2, 3], [4, 5, 6]])
self.assertEqual(output["input_ids"].tolist(), expected_input_ids.tolist())
self.assertEqual(output["attention_mask"].tolist(), expected_attention_mask.tolist())
self.assertEqual(output["labels"].tolist(), expected_labels.tolist())
class SFTTrainerTester(unittest.TestCase):
r""" """
@ -66,7 +95,6 @@ class SFTTrainerTester(unittest.TestCase):
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.tokenizer.pad_token = self.tokenizer.eos_token
self.dummy_dataset = Dataset.from_dict(
{
"question": [

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.16.0"
__version__ = "0.16.1"
from typing import TYPE_CHECKING

View File

@ -98,8 +98,6 @@ def main(script_args, training_args, model_args):
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
################
# Dataset

View File

@ -695,7 +695,9 @@ class OnlineDPOTrainer(Trainer):
# Same as Trainer._maybe_log_save_evaluate but log our metrics
# start_time defaults to None to allow compatibility with transformers<=4.46
def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None):
def _maybe_log_save_evaluate(
self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time=None, learning_rate=None
):
if self.control.should_log and self.state.global_step > self._globalstep_last_logged:
logs: dict[str, float] = {}
@ -708,7 +710,10 @@ class OnlineDPOTrainer(Trainer):
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
if grad_norm is not None:
logs["grad_norm"] = grad_norm.detach().item() if isinstance(grad_norm, torch.Tensor) else grad_norm
logs["learning_rate"] = self._get_learning_rate()
if learning_rate is not None:
logs["learning_rate"] = learning_rate
else:
logs["learning_rate"] = self._get_learning_rate()
# Add our metrics
for key, val in self.stats.items():

View File

@ -47,6 +47,9 @@ class SFTConfig(TrainingArguments):
`skip_prepare_dataset`.
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
Number of processes to use for processing the dataset.
pad_token (`str` or `None`, *optional*, defaults to `None`):
Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`,
it falls back to `processing_class.eos_token`.
max_length (`int` or `None`, *optional*, defaults to `1024`):
Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right.
If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length.
@ -92,6 +95,13 @@ class SFTConfig(TrainingArguments):
default=None,
metadata={"help": "Number of processes to use for processing the dataset."},
)
pad_token: Optional[str] = field(
default=None,
metadata={
"help": "Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that "
"is also `None`, it falls back to `processing_class.eos_token`."
},
)
max_length: Optional[int] = field(
default=1024,
metadata={

View File

@ -16,6 +16,7 @@ import dataclasses
import os
import warnings
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Callable, Optional, Type, Union
import torch
@ -29,7 +30,6 @@ from transformers import (
AutoTokenizer,
BaseImageProcessor,
DataCollator,
DataCollatorForLanguageModeling,
DataCollatorWithFlattening,
FeatureExtractionMixin,
PreTrainedModel,
@ -39,6 +39,7 @@ from transformers import (
TrainingArguments,
is_wandb_available,
)
from transformers.data.data_collator import DataCollatorMixin
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from transformers.utils import is_peft_available
@ -51,7 +52,13 @@ from ..data_utils import (
truncate_dataset,
)
from .sft_config import SFTConfig
from .utils import ConstantLengthDataset, generate_model_card, get_comet_experiment_url, peft_module_casting_to_bf16
from .utils import (
ConstantLengthDataset,
generate_model_card,
get_comet_experiment_url,
pad,
peft_module_casting_to_bf16,
)
if is_peft_available():
@ -62,6 +69,54 @@ if is_wandb_available():
import wandb
@dataclass
class DataCollatorForLanguageModeling(DataCollatorMixin):
"""
Data collator used for language modeling data. Inputs are dynamically padded to the maximum length of a batch if
they are not all of the same length.
Args:
pad_token_id (`int`):
Token ID to use for padding.
return_tensors (`str`, *optional*, defaults to `"pt"`):
Type of Tensor to return. Only `"pt"` is currently supported.
Examples:
```python
>>> from trl import DataCollatorForLanguageModeling
>>> collator = DataCollatorForLanguageModeling(pad_token_id=0)
>>> examples = [
... {"input_ids": [1, 2, 3]},
... {"input_ids": [4, 5]}
... ]
>>> collator(examples)
{'input_ids': tensor([[ 1, 2, 3],
[ 4, 5, 0]]),
'attention_mask': tensor([[ 1, 1, 1],
[ 1, 1, 0]]),
'labels': tensor([[ 1, 2, 3],
[ 4, 5, -100]])}
```
"""
pad_token_id: int
return_tensors: str = "pt"
def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]:
# Convert to tensor
input_ids = [torch.tensor(example["input_ids"]) for example in examples]
attention_mask = [torch.ones_like(input_ids) for input_ids in input_ids]
labels = [torch.tensor(example["input_ids"]) for example in examples]
# Pad
output = {}
output["input_ids"] = pad(input_ids, padding_value=self.pad_token_id, padding_side="right")
output["attention_mask"] = pad(attention_mask, padding_value=0, padding_side="right")
output["labels"] = pad(labels, padding_value=-100, padding_side="right")
return output
class SFTTrainer(Trainer):
"""
Trainer for Supervised Fine-Tuning (SFT) method.
@ -159,9 +214,9 @@ class SFTTrainer(Trainer):
formatting_func: Optional[Union[Callable[[dict], str], Callable[[dict], list[str]]]] = None,
):
# Args
model_id = model if isinstance(model, str) else model.config._name_or_path
if args is None:
model_name = model if isinstance(model, str) else model.config._name_or_path
model_name = model_name.split("/")[-1]
model_name = model_id.split("/")[-1]
args = SFTConfig(f"{model_name}-SFT")
elif isinstance(args, TrainingArguments) and not isinstance(args, SFTConfig):
dict_args = args.to_dict()
@ -169,42 +224,9 @@ class SFTTrainer(Trainer):
dict_args.pop("push_to_hub_token")
args = SFTConfig(**dict_args)
# Model
if args.model_init_kwargs is not None and not isinstance(model, str):
warnings.warn(
"You passed model_init_kwargs to the `SFTConfig`, but your model is already instantiated. "
"The `model_init_kwargs` will be ignored."
)
if isinstance(model, str):
model = self._create_model_from_path(model, args)
# PEFT configuration and model wrapping
if peft_config is not None:
model = self._prepare_peft_model(model, peft_config, args)
# Handle the tokenizer
if processing_class is None:
processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path)
if processing_class.pad_token is None:
processing_class.pad_token = processing_class.eos_token # required for padding when collating data
# Dataset
preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False)
if preprocess_dataset:
train_dataset = self._prepare_dataset(
train_dataset, processing_class, args, args.packing, formatting_func, "train"
)
if eval_dataset is not None:
packing = args.packing if args.eval_packing is None else args.eval_packing
if isinstance(eval_dataset, dict):
eval_dataset = {
key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key)
for key, dataset in eval_dataset.items()
}
else:
eval_dataset = self._prepare_dataset(
eval_dataset, processing_class, args, packing, formatting_func, "eval"
)
processing_class = AutoTokenizer.from_pretrained(model_id)
# Data collator
if args.padding_free:
@ -233,7 +255,48 @@ class SFTTrainer(Trainer):
data_collator = DataCollatorWithFlattening()
if data_collator is None:
data_collator = DataCollatorForLanguageModeling(tokenizer=processing_class, mlm=False)
# Get the pad token: if not provided, use the one from the processing class or the eos token
# if the processing class does not have a pad token.
pad_token = args.pad_token or processing_class.pad_token or processing_class.eos_token
pad_token_id = processing_class.convert_tokens_to_ids(pad_token)
if pad_token_id is None:
raise ValueError(
f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given "
f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists "
"in the vocabulary before using it as a padding token."
)
data_collator = DataCollatorForLanguageModeling(pad_token_id)
# Model
if args.model_init_kwargs is not None and not isinstance(model, str):
warnings.warn(
"You passed model_init_kwargs to the `SFTConfig`, but your model is already instantiated. "
"The `model_init_kwargs` will be ignored."
)
if isinstance(model, str):
model = self._create_model_from_path(model, args)
# PEFT configuration and model wrapping
if peft_config is not None:
model = self._prepare_peft_model(model, peft_config, args)
# Dataset
preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False)
if preprocess_dataset:
train_dataset = self._prepare_dataset(
train_dataset, processing_class, args, args.packing, formatting_func, "train"
)
if eval_dataset is not None:
packing = args.packing if args.eval_packing is None else args.eval_packing
if isinstance(eval_dataset, dict):
eval_dataset = {
key: self._prepare_dataset(dataset, processing_class, args, packing, formatting_func, key)
for key, dataset in eval_dataset.items()
}
else:
eval_dataset = self._prepare_dataset(
eval_dataset, processing_class, args, packing, formatting_func, "eval"
)
# Initialize the metrics
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}