Compare commits

...

21 Commits

Author SHA1 Message Date
a6263a5041 Merge branch 'main' into py3.14 2025-10-07 09:02:52 -06:00
a5ca7d4ba7 style 2025-10-06 20:18:41 +00:00
cfcec4af86 some missing 2025-10-06 20:10:00 +00:00
d66ea247dc revert video change 2025-10-06 19:55:52 +00:00
c97bb24098 rm whitespace 2025-10-06 19:54:25 +00:00
88eee87e11 rm prerelease 2025-10-06 19:53:37 +00:00
a33e642f16 style 2025-10-06 19:52:23 +00:00
cbb41f7366 Merge branch 'main' into py3.14 2025-10-06 13:51:39 -06:00
68959ad9ea Squashed commit of the following:
commit 65eb45c32bc0d8b98555d8fb713c9b39041518dc
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Mon Oct 6 13:07:18 2025 -0600

    Apply style and revert change in `sft_video_llm` example (#4214)

commit ae6837f8d4a84fac8da24924db3aa607828db0c0
Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Date:   Mon Oct 6 18:40:18 2025 +0200

    Removed tokenizer/processor creation from example scripts (#4211)

commit 56a8f1128bce5ff7cc2ecb76e782199fe889ea82
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Mon Oct 6 17:45:44 2025 +0200

    Replace setup with pyproject and fix packaging unintended modules (#4194)

commit 529101537feafa84d7b99acde715badc356b7172
Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Date:   Mon Oct 6 16:04:06 2025 +0200

    Remove `Optional` from `processing_class` in `PPOTrainer` (#4212)

commit 0588b1f01db16a5e4712e9abbcedbe47b0b3d27a
Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Date:   Mon Oct 6 15:57:17 2025 +0200

    Updated vLLM integration guide (#4162)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

commit 45ee98b05e979c817ee06c81122a82c42a352f87
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Mon Oct 6 11:14:54 2025 +0200

    Replace unittest with pytest (#4188)

commit 3800a6ecc740b10cab3c1cf337719f27f1e5422c
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Mon Oct 6 11:13:21 2025 +0200

    Hotfix: Exclude transformers 4.57.0 for Python 3.9 (#4209)

    Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>

commit 7ad9ce8accbab096a3f0910672649ed1e706dfa0
Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Date:   Mon Oct 6 11:04:20 2025 +0200

    Remove tokenizer creation from `sft` example script (#4197)

commit 0c2dc14014d1036a82ecd223cac2c82883846768
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Mon Oct 6 08:31:58 2025 +0200

    Remove custome_container for building the docs (#4198)

commit ced8b337ba51a3fe3dadace208eb7a1a7a48ed29
Author: burtenshaw <ben.burtenshaw@gmail.com>
Date:   Mon Oct 6 08:23:11 2025 +0200

    [DOCS/FIX] lora without regrets - fix lr (#4207)
2025-10-06 19:50:12 +00:00
b07df79a92 Squashed commit of the following:
commit ae6837f8d4a84fac8da24924db3aa607828db0c0
Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Date:   Mon Oct 6 18:40:18 2025 +0200

    Removed tokenizer/processor creation from example scripts (#4211)

commit 56a8f1128bce5ff7cc2ecb76e782199fe889ea82
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Mon Oct 6 17:45:44 2025 +0200

    Replace setup with pyproject and fix packaging unintended modules (#4194)

commit 529101537feafa84d7b99acde715badc356b7172
Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Date:   Mon Oct 6 16:04:06 2025 +0200

    Remove `Optional` from `processing_class` in `PPOTrainer` (#4212)

commit 0588b1f01db16a5e4712e9abbcedbe47b0b3d27a
Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Date:   Mon Oct 6 15:57:17 2025 +0200

    Updated vLLM integration guide (#4162)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

commit 45ee98b05e979c817ee06c81122a82c42a352f87
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Mon Oct 6 11:14:54 2025 +0200

    Replace unittest with pytest (#4188)

commit 3800a6ecc740b10cab3c1cf337719f27f1e5422c
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Mon Oct 6 11:13:21 2025 +0200

    Hotfix: Exclude transformers 4.57.0 for Python 3.9 (#4209)

    Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>

commit 7ad9ce8accbab096a3f0910672649ed1e706dfa0
Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Date:   Mon Oct 6 11:04:20 2025 +0200

    Remove tokenizer creation from `sft` example script (#4197)

commit 0c2dc14014d1036a82ecd223cac2c82883846768
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Mon Oct 6 08:31:58 2025 +0200

    Remove custome_container for building the docs (#4198)

commit ced8b337ba51a3fe3dadace208eb7a1a7a48ed29
Author: burtenshaw <ben.burtenshaw@gmail.com>
Date:   Mon Oct 6 08:23:11 2025 +0200

    [DOCS/FIX] lora without regrets - fix lr (#4207)
2025-10-06 19:45:43 +00:00
b691f39bef fix tools type hint 2025-10-05 19:26:45 +00:00
b41bcbdeb7 no py314 2025-10-05 19:23:58 +00:00
dd56aaad40 revert unwanted change 2025-10-05 18:26:48 +00:00
0ccfe5df9b strict=True 2025-10-05 18:14:57 +00:00
5004c95c12 apply precommit 2025-10-05 18:11:40 +00:00
d7fe889a3f apply precommit 2025-10-05 18:11:14 +00:00
4e239a6122 Merge branch 'main' into py3.14 2025-10-05 12:09:22 -06:00
b991fd4a87 target python version 3.10 for ruff 2025-10-05 18:07:03 +00:00
73107966ed allow prerelease 2025-10-05 17:53:26 +00:00
f69c919b98 style 2025-09-30 15:17:58 +00:00
0d54019980 Drop Python 3.9, add Python 3.14 2025-09-30 15:13:31 +00:00
107 changed files with 977 additions and 997 deletions

View File

@ -36,7 +36,7 @@ jobs:
name: Tests
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
python-version: ['3.10', '3.11', '3.12', '3.13']
fail-fast: false
runs-on:
group: aws-g4dn-2xlarge

View File

@ -1,6 +1,6 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.10
rev: v0.13.3
hooks:
- id: ruff-check
types_or: [ python, pyi ]

View File

@ -315,24 +315,6 @@ def replicate_str(string: str, n: int, sep: str = " ") -> str:
* **Definite Articles:** Removed definite articles where possible to streamline language. (Eg: Changed "The string to replicate" to "String to replicate")
* **Type Annotations:**
* Always include type definitions, indicating if a parameter is optional and specifying the default value.
* Note that `Optional` means that the value can be `None`, and `*optional*` means that it is not required for the user to pass a value.
E.g., for arguments that can't be `None` and aren't required:
```python
foo (`int`, *optional*, defaults to `4`):
```
For arguments that can be `None` and are required:
```python
foo (`Optional[int]`):
```
for arguments that can be `None` and aren't required:
```python
foo (`Optional[int]`, *optional*):
```
* **String Defaults:**
* Ensured that default string values are wrapped in double quotes:

View File

@ -143,7 +143,7 @@ For reinforcement learning, the blog uses a math reasoning task that we can repr
```python
def strip_reasoning_accuracy_reward(
completions: list[list[dict[str, str]]], solution: list[str], **kwargs
) -> list[Optional[float]]:
) -> list[float | None]:
"""Reward function that strips reasoning tags and checks mathematical accuracy.
This function:

View File

@ -14,7 +14,6 @@
import re
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_dataset
from huggingface_hub import ModelCard
@ -42,7 +41,7 @@ class ScriptArguments:
repo_id: str = field(
default="trl-lib/hh-rlhf-helpful-base", metadata={"help": "Hugging Face repository ID to push the dataset to."}
)
dataset_num_proc: Optional[int] = field(
dataset_num_proc: int | None = field(
default=None, metadata={"help": "Number of workers to use for dataset processing."}
)
@ -50,7 +49,7 @@ class ScriptArguments:
def common_start(str1: str, str2: str) -> str:
# Zip the two strings and iterate over them together
common_chars = []
for c1, c2 in zip(str1, str2):
for c1, c2 in zip(str1, str2, strict=True):
if c1 == c2:
common_chars.append(c1)
else:

View File

@ -14,7 +14,6 @@
import ast
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_dataset
from huggingface_hub import ModelCard
@ -43,7 +42,7 @@ class ScriptArguments:
default="trl-lib/llava-instruct-mix",
metadata={"help": "Hugging Face repository ID to push the dataset to."},
)
dataset_num_proc: Optional[int] = field(
dataset_num_proc: int | None = field(
default=None,
metadata={"help": "Number of workers to use for dataset processing."},
)

View File

@ -13,7 +13,6 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_dataset
from huggingface_hub import ModelCard
@ -42,7 +41,7 @@ class ScriptArguments:
default="trl-lib/lm-human-preferences-descriptiveness",
metadata={"help": "Hugging Face repository ID to push the dataset to."},
)
dataset_num_proc: Optional[int] = field(
dataset_num_proc: int | None = field(
default=None,
metadata={"help": "Number of workers to use for dataset processing."},
)

View File

@ -13,7 +13,6 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_dataset
from huggingface_hub import ModelCard
@ -42,7 +41,7 @@ class ScriptArguments:
default="trl-lib/lm-human-preferences-sentiment",
metadata={"help": "Hugging Face repository ID to push the dataset to."},
)
dataset_num_proc: Optional[int] = field(
dataset_num_proc: int | None = field(
default=None,
metadata={"help": "Number of workers to use for dataset processing."},
)

View File

@ -15,7 +15,6 @@
import re
from dataclasses import dataclass, field
from itertools import chain
from typing import Optional
from datasets import load_dataset
from huggingface_hub import ModelCard
@ -44,7 +43,7 @@ class ScriptArguments:
default="trl-lib/math_shepherd",
metadata={"help": "Hugging Face repository ID to push the dataset to."},
)
dataset_num_proc: Optional[int] = field(
dataset_num_proc: int | None = field(
default=None,
metadata={"help": "Number of workers to use for dataset processing."},
)
@ -64,7 +63,7 @@ def process_example(example):
labels = [example["label"][idx] == "+" for idx in indexes]
# Split the inputs into steps (caution, the first step is missing here, it is the prompt)
steps = [inputs[i:j] for i, j in zip(chain([0], indexes), chain(indexes, [None]))]
steps = [inputs[i:j] for i, j in zip(chain([0], indexes), chain(indexes, [None]), strict=True)]
# Remove the last step (single ⶻ)
steps = steps[:-1]

View File

@ -13,7 +13,6 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_dataset
from huggingface_hub import ModelCard
@ -42,7 +41,7 @@ class ScriptArguments:
default="trl-lib/prm800k",
metadata={"help": "Hugging Face repository ID to push the dataset to."},
)
dataset_num_proc: Optional[int] = field(
dataset_num_proc: int | None = field(
default=None,
metadata={"help": "Number of workers to use for dataset processing."},
)

View File

@ -13,7 +13,6 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
from datasets import features, load_dataset
from huggingface_hub import ModelCard
@ -42,7 +41,7 @@ class ScriptArguments:
default="trl-lib/rlaif-v",
metadata={"help": "Hugging Face repository ID to push the dataset to."},
)
dataset_num_proc: Optional[int] = field(
dataset_num_proc: int | None = field(
default=None,
metadata={"help": "Number of workers to use for dataset processing."},
)

View File

@ -13,7 +13,6 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_dataset
from huggingface_hub import ModelCard
@ -42,7 +41,7 @@ class ScriptArguments:
default="trl-lib/tldr",
metadata={"help": "Hugging Face repository ID to push the dataset to."},
)
dataset_num_proc: Optional[int] = field(
dataset_num_proc: int | None = field(
default=None,
metadata={"help": "Number of workers to use for dataset processing."},
)

View File

@ -13,7 +13,6 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_dataset
from huggingface_hub import ModelCard
@ -42,7 +41,7 @@ class ScriptArguments:
default="trl-lib/tldr-preference",
metadata={"help": "Hugging Face repository ID to push the dataset to."},
)
dataset_num_proc: Optional[int] = field(
dataset_num_proc: int | None = field(
default=None,
metadata={"help": "Number of workers to use for dataset processing."},
)

View File

@ -13,7 +13,6 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_dataset
from huggingface_hub import ModelCard
@ -42,7 +41,7 @@ class ScriptArguments:
default="trl-lib/ultrafeedback-prompt",
metadata={"help": "Hugging Face repository ID to push the dataset to."},
)
dataset_num_proc: Optional[int] = field(
dataset_num_proc: int | None = field(
default=None,
metadata={"help": "Number of workers to use for dataset processing."},
)

View File

@ -13,7 +13,6 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_dataset
from huggingface_hub import ModelCard
@ -79,7 +78,7 @@ class ScriptArguments:
default="trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness",
metadata={"help": "Hugging Face repository ID to push the dataset to."},
)
dataset_num_proc: Optional[int] = field(
dataset_num_proc: int | None = field(
default=None,
metadata={"help": "Number of workers to use for dataset processing."},
)

View File

@ -13,7 +13,6 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
import torch
from peft import PeftConfig, PeftModel
@ -27,9 +26,9 @@ class ScriptArguments:
merged model.
"""
adapter_model_name: Optional[str] = field(default=None, metadata={"help": "the adapter name"})
base_model_name: Optional[str] = field(default=None, metadata={"help": "the base model name"})
output_name: Optional[str] = field(default=None, metadata={"help": "the merged model name"})
adapter_model_name: str | None = field(default=None, metadata={"help": "the adapter name"})
base_model_name: str | None = field(default=None, metadata={"help": "the base model name"})
output_name: str | None = field(default=None, metadata={"help": "the merged model name"})
parser = HfArgumentParser(ScriptArguments)

View File

@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import Any, Optional, Union
from typing import Any
import evaluate
import numpy as np
@ -41,70 +41,70 @@ class ScriptArguments:
These arguments vary depending on how many GPUs you have, what their capacity and features are, and what size model you want to train.
"""
local_rank: Optional[int] = field(default=-1, metadata={"help": "Used for multi-gpu"})
resume_from_checkpoint: Optional[bool] = field(
local_rank: int | None = field(default=-1, metadata={"help": "Used for multi-gpu"})
resume_from_checkpoint: bool | None = field(
default=False,
metadata={"help": "If you want to resume training where it left off."},
)
deepspeed: Optional[str] = field(
deepspeed: str | None = field(
default=None,
metadata={
"help": "Path to deepspeed config if using deepspeed. You may need this if the model that you want to train doesn't fit on a single GPU."
},
)
per_device_train_batch_size: Optional[int] = field(default=4)
per_device_eval_batch_size: Optional[int] = field(default=1)
gradient_accumulation_steps: Optional[int] = field(default=1)
learning_rate: Optional[float] = field(default=2e-5)
weight_decay: Optional[float] = field(default=0.001)
model_name: Optional[str] = field(
per_device_train_batch_size: int | None = field(default=4)
per_device_eval_batch_size: int | None = field(default=1)
gradient_accumulation_steps: int | None = field(default=1)
learning_rate: float | None = field(default=2e-5)
weight_decay: float | None = field(default=0.001)
model_name: str | None = field(
default="gpt2",
metadata={
"help": "The model that you want to train from the Hugging Face hub. E.g. gpt2, gpt2-xl, bert, etc."
},
)
tokenizer_name: Optional[str] = field(
tokenizer_name: str | None = field(
default=None,
metadata={
"help": "The tokenizer for your model, if left empty will use the default for your model",
},
)
bf16: Optional[bool] = field(
bf16: bool | None = field(
default=True,
metadata={
"help": "This essentially cuts the training time in half if you want to sacrifice a little precision and have a supported GPU."
},
)
num_train_epochs: Optional[int] = field(
num_train_epochs: int | None = field(
default=1,
metadata={"help": "The number of training epochs for the reward model."},
)
train_subset: Optional[int] = field(
train_subset: int | None = field(
default=100000,
metadata={"help": "The size of the subset of the training data to use"},
)
eval_subset: Optional[int] = field(
eval_subset: int | None = field(
default=50000,
metadata={"help": "The size of the subset of the eval data to use"},
)
gradient_checkpointing: Optional[bool] = field(
gradient_checkpointing: bool | None = field(
default=False,
metadata={"help": "Enables gradient checkpointing."},
)
optim: Optional[str] = field(
optim: str | None = field(
default="adamw_hf",
metadata={"help": "The optimizer to use."},
)
lr_scheduler_type: Optional[str] = field(
lr_scheduler_type: str | None = field(
default="linear",
metadata={"help": "The lr scheduler"},
)
max_length: Optional[int] = field(default=512)
eval_first_step: Optional[bool] = field(
max_length: int | None = field(default=512)
eval_first_step: bool | None = field(
default=False,
metadata={"help": "Whether to run eval after the first step"},
)
seed: Optional[int] = field(
seed: int | None = field(
default=0, metadata={"help": "Random seed that will be set at the beginning of training."}
)
@ -189,7 +189,9 @@ def preprocess_function(examples):
"input_ids_k": [],
"attention_mask_k": [],
}
for question, response_j, response_k in zip(examples["question"], examples["response_j"], examples["response_k"]):
for question, response_j, response_k in zip(
examples["question"], examples["response_j"], examples["response_k"], strict=True
):
tokenized_j = tokenizer("Question: " + question + "\n\nAnswer: " + response_j, truncation=True)
tokenized_k = tokenizer("Question: " + question + "\n\nAnswer: " + response_k, truncation=True)
@ -229,8 +231,8 @@ eval_dataset = eval_dataset.filter(
@dataclass
class RewardDataCollatorWithPadding:
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
pad_to_multiple_of: Optional[int] = None
padding: bool | str | PaddingStrategy = True
pad_to_multiple_of: int | None = None
return_tensors: str = "pt"
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:

View File

@ -13,7 +13,6 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
import torch
from accelerate import Accelerator
@ -37,37 +36,37 @@ class ScriptArguments:
# NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode
# models like gpt-neo* models are more suitable.
model_name: Optional[str] = field(default="", metadata={"help": "the model name"})
tokenizer_name: Optional[str] = field(default="", metadata={"help": "the tokenizer name"})
reward_model_name: Optional[str] = field(default="", metadata={"help": "the reward model name"})
log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"})
output_max_length: Optional[int] = field(default=128, metadata={"help": "maximum length for generation"})
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"})
batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"})
ppo_epochs: Optional[int] = field(default=4, metadata={"help": "the number of ppo epochs"})
gradient_accumulation_steps: Optional[int] = field(
model_name: str | None = field(default="", metadata={"help": "the model name"})
tokenizer_name: str | None = field(default="", metadata={"help": "the tokenizer name"})
reward_model_name: str | None = field(default="", metadata={"help": "the reward model name"})
log_with: str | None = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
learning_rate: float | None = field(default=1.41e-5, metadata={"help": "the learning rate"})
output_max_length: int | None = field(default=128, metadata={"help": "maximum length for generation"})
mini_batch_size: int | None = field(default=1, metadata={"help": "the PPO minibatch size"})
batch_size: int | None = field(default=32, metadata={"help": "the batch size"})
ppo_epochs: int | None = field(default=4, metadata={"help": "the number of ppo epochs"})
gradient_accumulation_steps: int | None = field(
default=4, metadata={"help": "the number of gradient accumulation steps"}
)
adafactor: Optional[bool] = field(default=False, metadata={"help": "whether to use the adafactor optimizer"})
early_stopping: Optional[bool] = field(default=False, metadata={"help": "whether to early stop"})
target_kl: Optional[float] = field(default=0.1, metadata={"help": "kl target for early stopping"})
reward_baseline: Optional[float] = field(
adafactor: bool | None = field(default=False, metadata={"help": "whether to use the adafactor optimizer"})
early_stopping: bool | None = field(default=False, metadata={"help": "whether to early stop"})
target_kl: float | None = field(default=0.1, metadata={"help": "kl target for early stopping"})
reward_baseline: float | None = field(
default=0.0,
metadata={"help": "a baseline value that is subtracted from the reward"},
)
batched_gen: Optional[bool] = field(default=False, metadata={"help": "whether to use the batched text gen"})
save_freq: Optional[int] = field(default=None, metadata={"help": "n steps to save the model"})
output_dir: Optional[str] = field(default="runs/", metadata={"help": "n steps to save the model"})
seed: Optional[int] = field(default=0, metadata={"help": "the seed"})
steps: Optional[int] = field(default=20000, metadata={"help": "number of epochs"})
init_kl_coef: Optional[float] = field(
batched_gen: bool | None = field(default=False, metadata={"help": "whether to use the batched text gen"})
save_freq: int | None = field(default=None, metadata={"help": "n steps to save the model"})
output_dir: str | None = field(default="runs/", metadata={"help": "n steps to save the model"})
seed: int | None = field(default=0, metadata={"help": "the seed"})
steps: int | None = field(default=20000, metadata={"help": "number of epochs"})
init_kl_coef: float | None = field(
default=0.2,
metadata={"help": "Initial KL penalty coefficient (used for adaptive and linear control)"},
)
adap_kl_ctrl: Optional[bool] = field(default=True, metadata={"help": "Use adaptive KL control, otherwise linear"})
load_in_8bit: Optional[bool] = field(default=True, metadata={"help": "whether to load the model in 8bit"})
adap_kl_ctrl: bool | None = field(default=True, metadata={"help": "Use adaptive KL control, otherwise linear"})
load_in_8bit: bool | None = field(default=True, metadata={"help": "whether to load the model in 8bit"})
parser = HfArgumentParser(ScriptArguments)
@ -258,7 +257,7 @@ for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)
# Compute reward score (using the sentiment analysis pipeline)
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
texts = [q + r for q, r in zip(batch["query"], batch["response"], strict=True)]
pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
rewards = [torch.tensor(output[0]["score"] - script_args.reward_baseline) for output in pipe_outputs]

View File

@ -70,7 +70,7 @@ def chars_token_ratio(dataset, tokenizer, nb_examples=400):
Estimate the average number of characters per token in the dataset.
"""
total_characters, total_tokens = 0, 0
for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
for _, example in tqdm(zip(range(nb_examples), iter(dataset), strict=True), total=nb_examples):
text = prepare_sample_text(example)
total_characters += len(text)
if tokenizer.is_fast:

View File

@ -15,7 +15,6 @@
# 0. imports
import os
from dataclasses import dataclass, field
from typing import Optional
import torch
from accelerate import Accelerator
@ -34,52 +33,52 @@ class ScriptArguments:
"""
# data parameters
beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
beta: float | None = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
# training parameters
model_name_or_path: Optional[str] = field(
model_name_or_path: str | None = field(
default="../sft/results/final_checkpoint",
metadata={"help": "the location of the SFT model name or path"},
)
learning_rate: Optional[float] = field(default=5e-4, metadata={"help": "optimizer learning rate"})
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"})
warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"})
weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"})
optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"})
learning_rate: float | None = field(default=5e-4, metadata={"help": "optimizer learning rate"})
lr_scheduler_type: str | None = field(default="cosine", metadata={"help": "the lr scheduler type"})
warmup_steps: int | None = field(default=100, metadata={"help": "the number of warmup steps"})
weight_decay: float | None = field(default=0.05, metadata={"help": "the weight decay"})
optimizer_type: str | None = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"})
per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "train batch size per device"})
per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"})
gradient_accumulation_steps: Optional[int] = field(
per_device_train_batch_size: int | None = field(default=4, metadata={"help": "train batch size per device"})
per_device_eval_batch_size: int | None = field(default=1, metadata={"help": "eval batch size per device"})
gradient_accumulation_steps: int | None = field(
default=4, metadata={"help": "the number of gradient accumulation steps"}
)
gradient_checkpointing: Optional[bool] = field(
gradient_checkpointing: bool | None = field(
default=True, metadata={"help": "whether to use gradient checkpointing"}
)
gradient_checkpointing_use_reentrant: Optional[bool] = field(
gradient_checkpointing_use_reentrant: bool | None = field(
default=False, metadata={"help": "whether to use reentrant for gradient checkpointing"}
)
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
lora_alpha: float | None = field(default=16, metadata={"help": "the lora alpha parameter"})
lora_dropout: float | None = field(default=0.05, metadata={"help": "the lora dropout parameter"})
lora_r: int | None = field(default=8, metadata={"help": "the lora r parameter"})
max_prompt_length: Optional[int] = field(default=512, metadata={"help": "the maximum prompt length"})
max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"})
max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"})
logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"})
save_steps: Optional[int] = field(default=100, metadata={"help": "the saving frequency"})
eval_steps: Optional[int] = field(default=100, metadata={"help": "the evaluation frequency"})
max_prompt_length: int | None = field(default=512, metadata={"help": "the maximum prompt length"})
max_length: int | None = field(default=1024, metadata={"help": "the maximum sequence length"})
max_steps: int | None = field(default=1000, metadata={"help": "max number of training steps"})
logging_steps: int | None = field(default=10, metadata={"help": "the logging frequency"})
save_steps: int | None = field(default=100, metadata={"help": "the saving frequency"})
eval_steps: int | None = field(default=100, metadata={"help": "the evaluation frequency"})
output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"})
log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"})
load_in_4bit: Optional[bool] = field(default=True, metadata={"help": "whether to load the model in 4bit"})
model_dtype: Optional[str] = field(
output_dir: str | None = field(default="./results", metadata={"help": "the output directory"})
log_freq: int | None = field(default=1, metadata={"help": "the logging frequency"})
load_in_4bit: bool | None = field(default=True, metadata={"help": "whether to load the model in 4bit"})
model_dtype: str | None = field(
default="float16", metadata={"help": "model_dtype[float16, bfloat16, float] for loading."}
)
# instrumentation
report_to: Optional[str] = field(
report_to: str | None = field(
default="wandb",
metadata={
"help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
@ -88,21 +87,21 @@ class ScriptArguments:
},
)
# debug argument for distributed training
ignore_bias_buffers: Optional[bool] = field(
ignore_bias_buffers: bool | None = field(
default=False,
metadata={
"help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
},
)
seed: Optional[int] = field(
seed: int | None = field(
default=0, metadata={"help": "Random seed that will be set at the beginning of training."}
)
def get_stack_exchange_paired(
data_dir: str = "data/rl",
cache_dir: Optional[str] = None,
cache_dir: str | None = None,
num_proc=24,
) -> Dataset:
"""Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format.

View File

@ -15,7 +15,6 @@
# Fine-Tune Llama2-7b on SE paired dataset
import os
from dataclasses import dataclass, field
from typing import Optional
import torch
from accelerate import Accelerator
@ -38,21 +37,21 @@ from trl.trainer import ConstantLengthDataset
@dataclass
class ScriptArguments:
model_name: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the model name"})
dataset_name: Optional[str] = field(default="lvwerra/stack-exchange-paired", metadata={"help": "the dataset name"})
subset: Optional[str] = field(default="data/finetune", metadata={"help": "the subset to use"})
split: Optional[str] = field(default="train", metadata={"help": "the split to use"})
size_valid_set: Optional[int] = field(default=4000, metadata={"help": "the size of the validation set"})
streaming: Optional[bool] = field(default=True, metadata={"help": "whether to stream the dataset"})
shuffle_buffer: Optional[int] = field(default=5000, metadata={"help": "the shuffle buffer size"})
seq_length: Optional[int] = field(default=1024, metadata={"help": "the sequence length"})
num_workers: Optional[int] = field(default=4, metadata={"help": "the number of workers"})
use_bnb: Optional[bool] = field(default=True, metadata={"help": "whether to use BitsAndBytes"})
model_name: str | None = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the model name"})
dataset_name: str | None = field(default="lvwerra/stack-exchange-paired", metadata={"help": "the dataset name"})
subset: str | None = field(default="data/finetune", metadata={"help": "the subset to use"})
split: str | None = field(default="train", metadata={"help": "the split to use"})
size_valid_set: int | None = field(default=4000, metadata={"help": "the size of the validation set"})
streaming: bool | None = field(default=True, metadata={"help": "whether to stream the dataset"})
shuffle_buffer: int | None = field(default=5000, metadata={"help": "the shuffle buffer size"})
seq_length: int | None = field(default=1024, metadata={"help": "the sequence length"})
num_workers: int | None = field(default=4, metadata={"help": "the number of workers"})
use_bnb: bool | None = field(default=True, metadata={"help": "whether to use BitsAndBytes"})
# LoraConfig
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
lora_alpha: float | None = field(default=16, metadata={"help": "the lora alpha parameter"})
lora_dropout: float | None = field(default=0.05, metadata={"help": "the lora dropout parameter"})
lora_r: int | None = field(default=8, metadata={"help": "the lora r parameter"})
parser = HfArgumentParser((ScriptArguments, SFTConfig))
@ -82,7 +81,7 @@ def chars_token_ratio(dataset, tokenizer, nb_examples=400):
Estimate the average number of characters per token in the dataset.
"""
total_characters, total_tokens = 0, 0
for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
for _, example in tqdm(zip(range(nb_examples), iter(dataset), strict=True), total=nb_examples):
text = prepare_sample_text(example)
total_characters += len(text)
if tokenizer.is_fast:

View File

@ -13,7 +13,6 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
import torch
from datasets import load_dataset
@ -65,15 +64,15 @@ class ScriptArguments:
# NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode
# models like gpt-neo* models are more suitable.
model_name: Optional[str] = field(default="ybelkada/gpt-j-6b-sharded-bf16", metadata={"help": "the model name"})
log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
learning_rate: Optional[float] = field(default=(1.47e-5) * 2, metadata={"help": "the learning rate"})
mini_batch_size: Optional[int] = field(default=4, metadata={"help": "the PPO minibatch size"})
batch_size: Optional[int] = field(default=16, metadata={"help": "the batch size"})
gradient_accumulation_steps: Optional[int] = field(
model_name: str | None = field(default="ybelkada/gpt-j-6b-sharded-bf16", metadata={"help": "the model name"})
log_with: str | None = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
learning_rate: float | None = field(default=(1.47e-5) * 2, metadata={"help": "the learning rate"})
mini_batch_size: int | None = field(default=4, metadata={"help": "the PPO minibatch size"})
batch_size: int | None = field(default=16, metadata={"help": "the batch size"})
gradient_accumulation_steps: int | None = field(
default=1, metadata={"help": "the number of gradient accumulation steps"}
)
model_save_path: Optional[str] = field(
model_save_path: str | None = field(
default="./gpt-j-6B-detoxified-long-context-26-shl-1e4-final",
metadata={"help": "the path to save the model"},
)

View File

@ -19,7 +19,6 @@
# ///
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_dataset
from transformers import HfArgumentParser
@ -74,7 +73,7 @@ class ScriptArguments:
"'meta-llama/Meta-Llama-3-70B-Instruct'."
},
)
num_examples: Optional[int] = field(default=None, metadata={"help": "Number of examples to evaluate."})
num_examples: int | None = field(default=None, metadata={"help": "Number of examples to evaluate."})
if __name__ == "__main__":
@ -103,7 +102,7 @@ if __name__ == "__main__":
else:
judge = HfPairwiseJudge(script_args.judge_model)
completions = [[c0, c1] for c0, c1 in zip(reference_completions, model_completions)]
completions = [[c0, c1] for c0, c1 in zip(reference_completions, model_completions, strict=True)]
best_idxs = judge.judge(prompts, completions)
model_win_rate = best_idxs.count(1) / len(best_idxs)
print(f"Model win rate: {model_win_rate * 100:.2f}%")

View File

@ -159,7 +159,7 @@ if __name__ == "__main__":
"""
rewards = []
contents = [completion[0]["content"] for completion in completions]
for content, sol in zip(contents, solution):
for content, sol in zip(contents, solution, strict=True):
try:
gold_parsed = parse(sol, extraction_mode="first_match")
except Exception:

View File

@ -130,7 +130,7 @@ if __name__ == "__main__":
"""
rewards = []
contents = [completion[0]["content"] for completion in completions]
for content, sol in zip(contents, solution):
for content, sol in zip(contents, solution, strict=True):
try:
gold_parsed = parse(sol, extraction_mode="first_match")
except Exception:

View File

@ -146,7 +146,7 @@ if __name__ == "__main__":
"""
rewards = []
contents = [completion[0]["content"] for completion in completions]
for content, sol in zip(contents, solution):
for content, sol in zip(contents, solution, strict=True):
try:
gold_parsed = parse(sol, extraction_mode="first_match")
except Exception:

View File

@ -202,7 +202,7 @@ if __name__ == "__main__":
"""
rewards = []
contents = [completion[0]["content"] for completion in completions]
for content, sol in zip(contents, solution):
for content, sol in zip(contents, solution, strict=True):
try:
gold_parsed = parse(sol, extraction_mode="first_match")
except Exception:

View File

@ -75,7 +75,7 @@ def main():
"""
rewards = []
contents = [completion[0]["content"] for completion in completions]
for content, sol in zip(contents, solution):
for content, sol in zip(contents, solution, strict=True):
try:
gold_parsed = parse(sol, extraction_mode="first_match")
except Exception:

View File

@ -159,7 +159,7 @@ if __name__ == "__main__":
"""
rewards = []
contents = [completion[0]["content"] for completion in completions]
for content, sol in zip(contents, solution):
for content, sol in zip(contents, solution, strict=True):
try:
gold_parsed = parse(sol, extraction_mode="first_match")
except Exception:

View File

@ -21,13 +21,12 @@ classifiers = [
"Natural Language :: English",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13"
]
requires-python = ">=3.9"
requires-python = ">=3.10"
dependencies = [
"accelerate>=1.4.0",
"datasets>=3.0.0",
@ -125,7 +124,7 @@ version = { file = "VERSION" }
branch = true
[tool.ruff]
target-version = "py39"
target-version = "py310"
line-length = 119
src = ["trl"]

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from transformers import AutoModelForCausalLM
@ -103,7 +102,7 @@ class TestActivationOffloading(TrlTestCase):
grads2 = [p.grad.clone() for p in model.parameters()]
# Gradients should match as NoOpManager should have prevented offloading
for g1, g2 in zip(grads1, grads2):
for g1, g2 in zip(grads1, grads2, strict=True):
assert torch.allclose(g1, g2, rtol=1e-4, atol=1e-5)
@require_torch_accelerator
@ -152,5 +151,5 @@ class TestActivationOffloading(TrlTestCase):
# Check outputs and gradients match
assert torch.allclose(out1, out2, rtol=1e-5)
for g1, g2 in zip(grads1, grads2):
for g1, g2 in zip(grads1, grads2, strict=True):
assert torch.allclose(g1, g2, rtol=1e-5)

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from transformers import AutoTokenizer, GenerationConfig

View File

@ -116,7 +116,7 @@ class TestWinRateCallback(TrlTestCase):
trainer.add_callback(win_rate_callback)
trainer.train()
winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h]
for history_row, expected_row in zip(winrate_history, self.expected_winrates):
for history_row, expected_row in zip(winrate_history, self.expected_winrates, strict=True):
assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row)
def test_without_ref_model(self):
@ -142,7 +142,7 @@ class TestWinRateCallback(TrlTestCase):
trainer.add_callback(win_rate_callback)
trainer.train()
winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h]
for history_row, expected_row in zip(winrate_history, self.expected_winrates):
for history_row, expected_row in zip(winrate_history, self.expected_winrates, strict=True):
assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row)
def test_soft_judge(self):
@ -185,7 +185,7 @@ class TestWinRateCallback(TrlTestCase):
for h in trainer.state.log_history
if "eval_avg_win_prob" in h
]
for history_row, expected_row in zip(winrate_history, expected_soft_winrates):
for history_row, expected_row in zip(winrate_history, expected_soft_winrates, strict=True):
assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row)
@require_peft
@ -219,7 +219,7 @@ class TestWinRateCallback(TrlTestCase):
trainer.add_callback(win_rate_callback)
trainer.train()
winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h]
for history_row, expected_row in zip(winrate_history, self.expected_winrates):
for history_row, expected_row in zip(winrate_history, self.expected_winrates, strict=True):
assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row)

View File

@ -12,23 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from io import StringIO
from unittest.mock import patch
import pytest
import yaml
from .testing_utils import TrlTestCase
@pytest.mark.skipif(
sys.version_info < (3, 10),
reason="Transformers' generation codebase uses a Python >3.10 syntax (`str | None`), which seems to cause the CLI tests "
"to fail on Python <3.10.", # let's say it's a known issue, but not expected to be fixed, because too niche
)
class TestCLI(TrlTestCase):
def test_dpo(self):
from trl.cli import main

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from trl.trainer.dpo_trainer import DataCollatorForPreference

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from trl.core import masked_mean, masked_var, masked_whiten

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from datasets import load_dataset
from parameterized import parameterized

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable
from collections.abc import Callable
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer

View File

@ -13,7 +13,6 @@
# limitations under the License.
import re
import sys
from unittest.mock import MagicMock
import numpy as np
@ -1300,7 +1299,6 @@ class TestDPOTrainer(TrlTestCase):
]
)
@require_liger_kernel
@pytest.mark.skipif(not (sys.version_info >= (3, 10)), reason="Liger kernel is not supported on Python 3.9")
def test_dpo_trainer_with_liger(self, beta, loss_type):
"""Test DPO trainer with Liger loss enabled across supported loss types.

View File

@ -69,7 +69,7 @@ class TestGKDTrainerGenerateOnPolicy(TrlTestCase):
generated_texts = self.tokenizer.batch_decode(new_input_ids, skip_special_tokens=True)
# Check if the generated texts start with the original prompts
for prompt, generated_text in zip(prompts, generated_texts):
for prompt, generated_text in zip(prompts, generated_texts, strict=True):
assert generated_text.startswith(prompt), (
f"Generated text '{generated_text}' does not start with prompt '{prompt}'"
)

View File

@ -602,7 +602,9 @@ class TestGRPOTrainer(TrlTestCase):
def reward_func(completions, some_values, **kwargs):
"""Reward function that rewards completions with lengths closer to the values in some_values."""
return [float(abs(len(completion) - value)) for completion, value in zip(completions, some_values)]
return [
float(abs(len(completion) - value)) for completion, value in zip(completions, some_values, strict=True)
]
training_args = GRPOConfig(
output_dir=self.tmp_dir,
@ -1915,7 +1917,7 @@ class TestUpdateWithReplayBuffer:
(item[1]["prompt_ids"].tolist(), item[1]["completion_ids"].tolist())
for item in self.trainer.replay_buffer.heap
]
buffered_prompt_ids, buffered_completion_ids = zip(*buffered_prompt_completion_ids)
buffered_prompt_ids, buffered_completion_ids = zip(*buffered_prompt_completion_ids, strict=True)
# Check for new entry with seq len 3 in buffer
assert [[3, 4, 5], [3, 4, 5]] in buffered_prompt_ids # excluded no-variance group

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from datasets import load_dataset

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from transformers import AutoModelForCausalLM, GenerationConfig

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from datasets import load_dataset
from parameterized import parameterized

View File

@ -151,7 +151,7 @@ class TestPeftModel(TrlTestCase):
model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(self.tmp_dir)
# check all the weights are the same
for p1, p2 in zip(model.named_parameters(), model_from_pretrained.named_parameters()):
for p1, p2 in zip(model.named_parameters(), model_from_pretrained.named_parameters(), strict=True):
assert torch.allclose(p1[1], p2[1]), f"{p1[0]} != {p2[0]}"
def test_load_pretrained_peft(self):
@ -175,7 +175,7 @@ class TestPeftModel(TrlTestCase):
)
# check all the weights are the same
for p1, p2 in zip(model.named_parameters(), model_from_pretrained.named_parameters()):
for p1, p2 in zip(model.named_parameters(), model_from_pretrained.named_parameters(), strict=True):
if p1[0] not in ["v_head.summary.weight", "v_head.summary.bias"]:
assert torch.allclose(p1[1], p2[1]), f"{p1[0]} != {p2[0]}"

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from trl.rewards import get_soft_overlong_punishment, think_format_reward
from .testing_utils import TrlTestCase

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from datasets import Dataset

View File

@ -491,7 +491,9 @@ class TestRLOOTrainer(TrlTestCase):
def reward_func(completions, some_values, **kwargs):
"""Reward function that rewards completions with lengths closer to the values in some_values."""
return [float(abs(len(completion) - value)) for completion, value in zip(completions, some_values)]
return [
float(abs(len(completion) - value)) for completion, value in zip(completions, some_values, strict=True)
]
training_args = RLOOConfig(
output_dir=self.tmp_dir,

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer

View File

@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import numpy as np
import torch
from accelerate import logging
@ -22,7 +20,7 @@ from accelerate import logging
logger = logging.get_logger(__name__)
def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor:
def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: bool | None = None) -> torch.Tensor:
"""Compute mean of tensor with a masked values."""
if axis is not None:
return (values * mask).sum(axis=axis) / mask.sum(axis=axis)

View File

@ -13,9 +13,9 @@
# limitations under the License.
from collections import defaultdict, deque
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from itertools import takewhile
from typing import Any, Callable, Optional, TypeVar, Union
from typing import Any, TypeVar
import numpy as np
import pyarrow as pa
@ -119,8 +119,8 @@ def is_conversational(example: dict[str, Any]) -> bool:
def apply_chat_template(
example: dict[str, list[dict[str, str]]],
tokenizer: Union[PreTrainedTokenizerBase, ProcessorMixin],
tools: Optional[list[Union[dict, Callable]]] = None,
tokenizer: PreTrainedTokenizerBase | ProcessorMixin,
tools: list[dict | Callable] | None = None,
**template_kwargs,
) -> dict[str, str]:
r"""
@ -174,7 +174,7 @@ def apply_chat_template(
# DeepSeek-R1 inserts a <tool_call> token when using `add_generation_prompt`, which can cause discrepancies
# between the prompt alone and the combined prompt+completion. To ensure consistency, we extract the
# common prefix between the two. In most cases, this is a no-op.
prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_chosen)))
prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_chosen, strict=False)))
chosen = prompt_chosen[len(prompt) :]
if "rejected" in example and "prompt" in example: # explicit prompt
@ -182,14 +182,18 @@ def apply_chat_template(
example["prompt"] + example["rejected"], tools=tools, tokenize=False, **template_kwargs
)
# Handle DeepSeek-R1 <tool_call> token, see the above comment for details
prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_rejected)))
prompt = "".join(
x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_rejected, strict=False))
)
rejected = prompt_rejected[len(prompt) :]
if "completion" in example:
prompt_completion = tokenizer.apply_chat_template(
example["prompt"] + example["completion"], tools=tools, tokenize=False, **template_kwargs
)
# Handle DeepSeek-R1 <tool_call> token, see the above comment for details
prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_completion)))
prompt = "".join(
x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_completion, strict=False))
)
completion = prompt_completion[len(prompt) :]
else: # implicit prompt case
if "chosen" in example:
@ -220,7 +224,7 @@ def apply_chat_template(
def maybe_apply_chat_template(
example: dict[str, list[dict[str, str]]],
tokenizer: PreTrainedTokenizerBase,
tools: Optional[list[Union[dict, Callable]]] = None,
tools: list[dict | Callable] | None = None,
**template_kwargs: Any,
) -> dict[str, str]:
r"""
@ -242,7 +246,7 @@ def maybe_apply_chat_template(
messages, where each message is a dictionary with keys `"role"` and `"content"`.
tokenizer (`PreTrainedTokenizerBase`):
Tokenizer to apply the chat template with.
tools (`list[Union[dict, Callable]]`, *optional*):
tools (`list[dict | Callable]`, *optional*):
A list of tools (callable functions) that will be accessible to the model. If the template does not support
function calling, this argument will have no effect.
**template_kwargs (`Any`, *optional*):
@ -291,7 +295,7 @@ def _unpair_row(examples: list[dict[str, list[dict[str, str]]]]) -> list[dict[st
def unpair_preference_dataset(
dataset: DatasetType, num_proc: Optional[int] = None, desc: Optional[str] = None
dataset: DatasetType, num_proc: int | None = None, desc: str | None = None
) -> DatasetType:
r"""
Unpair a preference dataset.
@ -334,7 +338,7 @@ def unpair_preference_dataset(
def maybe_unpair_preference_dataset(
dataset: DatasetType, num_proc: Optional[int] = None, desc: Optional[str] = None
dataset: DatasetType, num_proc: int | None = None, desc: str | None = None
) -> DatasetType:
r"""
Unpair a preference dataset if it is paired.
@ -565,7 +569,7 @@ def _pack_bfd(examples: pa.Table, seq_length: int) -> pa.Table:
# Bin is represented as a dict (of example ids and sum of their lengths) to allow in-place updates
bins: list[dict] = []
for length, idx in zip(lengths.field(0).to_numpy(), lengths.field(1).to_numpy()):
for length, idx in zip(lengths.field(0).to_numpy(), lengths.field(1).to_numpy(), strict=True):
space = segment_tree.search(length)
if space < seq_length:
@ -627,7 +631,7 @@ def _pack_wrapped(examples: pa.Table, seq_length: int) -> pa.Table:
def pack_dataset(
dataset: DatasetType, seq_length: int, strategy: str = "bfd", map_kwargs: Optional[dict[str, Any]] = None
dataset: DatasetType, seq_length: int, strategy: str = "bfd", map_kwargs: dict[str, Any] | None = None
) -> DatasetType:
r"""
Pack sequences in a dataset into chunks of size `seq_length`.
@ -682,9 +686,7 @@ def pack_dataset(
return dataset
def truncate_dataset(
dataset: DatasetType, max_length: int, map_kwargs: Optional[dict[str, Any]] = None
) -> DatasetType:
def truncate_dataset(dataset: DatasetType, max_length: int, map_kwargs: dict[str, Any] | None = None) -> DatasetType:
r"""
Truncate sequences in a dataset to a specified `max_length`.

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import torch

View File

@ -13,14 +13,13 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
from ...trainer.grpo_config import GRPOConfig as _GRPOConfig
@dataclass
class GFPOConfig(_GRPOConfig):
num_remains_in_group: Optional[int] = field(
num_remains_in_group: int | None = field(
default=None,
metadata={
"help": "number inputs remains after group filter function, `'num_remains_in_group'` must be >=2 if given."

View File

@ -13,7 +13,8 @@
# limitations under the License.
import logging
from typing import Any, Callable
from collections.abc import Callable
from typing import Any
import torch
from accelerate.utils import gather_object
@ -183,7 +184,7 @@ class GFPOTrainer(_GRPOTrainer):
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
if is_conversational(inputs[0]):
completions = []
for prompt, completion in zip(prompts, completions_text):
for prompt, completion in zip(prompts, completions_text, strict=True):
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
completions.append([{"role": "assistant", "content": bootstrap + completion}])
else:

View File

@ -13,7 +13,7 @@
# limitations under the License.
import heapq
from typing import Any, Optional, Union
from typing import Any
import torch
from accelerate.utils import gather_object
@ -35,7 +35,7 @@ class ReplayBuffer:
self.heap = [] # Min-heap of (score, data) tuples
def add(self, scores: list[float], data: list[dict]):
for score, datum in zip(scores, data):
for score, datum in zip(scores, data, strict=True):
if len(self.heap) < self.max_size:
heapq.heappush(self.heap, (score, datum))
else:
@ -58,13 +58,13 @@ class ReplayBuffer:
class GRPOWithReplayBufferTrainer(GRPOTrainer):
def __init__(self, args: Optional[GRPOWithReplayBufferConfig] = None, **kwargs):
def __init__(self, args: GRPOWithReplayBufferConfig | None = None, **kwargs):
super().__init__(args=args, **kwargs)
self.replay_buffer = ReplayBuffer(args.replay_buffer_size) if args.replay_buffer_size > 0 else None
def _generate_and_score_completions(
self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
) -> dict[str, Union[torch.Tensor, Any]]:
self, inputs: list[dict[str, torch.Tensor | Any]]
) -> dict[str, torch.Tensor | Any]:
device = self.accelerator.device
mode = "train" if self.model.training else "eval"
@ -89,7 +89,9 @@ class GRPOWithReplayBufferTrainer(GRPOTrainer):
# Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need
# to re-tokenize completions if the reward is computed from tokens.
completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())]
completion_ids_list = [
row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool(), strict=True)
]
# Concatenate prompt_mask with completion_mask for logit computation
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
@ -162,7 +164,7 @@ class GRPOWithReplayBufferTrainer(GRPOTrainer):
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
if is_conversational(inputs[0]):
completions = []
for prompt, completion in zip(prompts, completions_text):
for prompt, completion in zip(prompts, completions_text, strict=True):
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
completions.append([{"role": "assistant", "content": bootstrap + completion}])
else:
@ -339,9 +341,9 @@ class GRPOWithReplayBufferTrainer(GRPOTrainer):
completion_mask: torch.Tensor,
forward_kwargs: dict,
optional_vision_fields: list[str] = None,
old_per_token_logps: Optional[torch.Tensor] = None,
ref_per_token_logps: Optional[torch.Tensor] = None,
importance_sampling_ratio: Optional[float] = None,
old_per_token_logps: torch.Tensor | None = None,
ref_per_token_logps: torch.Tensor | None = None,
importance_sampling_ratio: float | None = None,
) -> None:
"""
Update the replay buffer with groups that have reward variance (std > 0).
@ -463,9 +465,9 @@ class GRPOWithReplayBufferTrainer(GRPOTrainer):
completion_mask: torch.Tensor,
forward_kwargs: dict,
num_items_in_batch: int,
old_per_token_logps: Optional[torch.Tensor] = None,
ref_per_token_logps: Optional[torch.Tensor] = None,
importance_sampling_ratio: Optional[float] = None,
old_per_token_logps: torch.Tensor | None = None,
ref_per_token_logps: torch.Tensor | None = None,
importance_sampling_ratio: float | None = None,
) -> None:
"""
Update current batch data with samples from replay buffer.

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from trl import GRPOTrainer as _GRPOTrainer

View File

@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, Union
from collections.abc import Callable
from typing import Any
import torch
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast, set_seed
@ -47,13 +48,13 @@ class BestOfNSampler:
def __init__(
self,
model: PreTrainedModelWrapper,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
queries_to_scores: Callable[[list[str]], list[float]],
length_sampler: Any,
sample_size: int = 4,
seed: Optional[int] = None,
seed: int | None = None,
n_candidates: int = 1,
generation_config: Optional[GenerationConfig] = None,
generation_config: GenerationConfig | None = None,
) -> None:
if seed is not None:
set_seed(seed)
@ -78,9 +79,9 @@ class BestOfNSampler:
def generate(
self,
tokenized_query: Union[list[int], torch.Tensor, list[torch.Tensor], list[list[int]]],
tokenized_query: list[int] | torch.Tensor | list[torch.Tensor] | list[list[int]],
skip_special_tokens: bool = True,
device: Optional[Union[str, torch.device]] = None,
device: str | torch.device | None = None,
**generation_kwargs,
) -> list[list[str]]:
"""

View File

@ -13,7 +13,8 @@
# limitations under the License.
import logging
from typing import Callable, Literal, Optional
from collections.abc import Callable
from typing import Literal
import datasets
from datasets import Dataset, Value
@ -36,7 +37,7 @@ else:
def conversations_formatting_function(
tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"], tools: Optional[list] = None
tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"], tools: list | None = None
):
r"""
return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the
@ -84,8 +85,8 @@ def instructions_formatting_function(tokenizer: AutoTokenizer):
def get_formatting_func_from_dataset(
dataset: Dataset, tokenizer: AutoTokenizer, tools: Optional[list] = None
) -> Optional[Callable]:
dataset: Dataset, tokenizer: AutoTokenizer, tools: list | None = None
) -> Callable | None:
r"""
Finds the correct formatting function based on the dataset structure. Currently supported datasets are:
- `ChatML` with [{"role": str, "content": str}]

View File

@ -18,7 +18,6 @@ import logging
import socket
import time
from io import BytesIO
from typing import Optional, Union
from urllib.parse import urlparse
import torch
@ -105,7 +104,7 @@ class VLLMClient:
def __init__(
self,
base_url: Optional[str] = None,
base_url: str | None = None,
host: str = "0.0.0.0",
server_port: int = 8000,
group_port: int = 51216,
@ -170,7 +169,7 @@ class VLLMClient:
def generate(
self,
prompts: list[str],
images: Optional[list] = None,
images: list | None = None,
n: int = 1,
repetition_penalty: float = 1.0,
temperature: float = 1.0,
@ -178,8 +177,8 @@ class VLLMClient:
top_k: int = -1,
min_p: float = 0.0,
max_tokens: int = 16,
guided_decoding_regex: Optional[str] = None,
generation_kwargs: Optional[dict] = None,
guided_decoding_regex: str | None = None,
generation_kwargs: dict | None = None,
) -> list[list[int]]:
"""
Generates model completions for the provided prompts.
@ -250,7 +249,7 @@ class VLLMClient:
else:
raise Exception(f"Request failed: {response.status_code}, {response.text}")
def init_communicator(self, device: Union[torch.device, str, int] = 0):
def init_communicator(self, device: torch.device | str | int = 0):
"""
Initializes the weight update group in a distributed setup for model synchronization.

View File

@ -16,7 +16,6 @@ import json
import logging
import os
from copy import deepcopy
from typing import Optional
import torch
import torch.nn as nn
@ -392,7 +391,7 @@ class PreTrainedModelWrapper(nn.Module):
object to handle corner cases when running scripts in distributed environments.
Returns:
current_device (`Union[int, str]`):
current_device (`int | str`):
The current device.
"""
state = PartialState()
@ -590,7 +589,7 @@ class PreTrainedModelWrapper(nn.Module):
def create_reference_model(
model: PreTrainedModelWrapper, num_shared_layers: Optional[int] = None, pattern: Optional[str] = None
model: PreTrainedModelWrapper, num_shared_layers: int | None = None, pattern: str | None = None
) -> PreTrainedModelWrapper:
"""
Creates a static reference copy of a model. Note that model will be in `.eval()` mode.

View File

@ -18,7 +18,7 @@ from collections.abc import Callable
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, Literal
import torch
import torch.nn as nn
@ -87,8 +87,8 @@ FORMAT_MAPPING = {"chatml": ChatMlSpecialTokens}
def setup_chat_format(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
format: Optional[Literal["chatml"]] = "chatml",
resize_to_multiple_of: Optional[int] = None,
format: Literal["chatml"] | None = "chatml",
resize_to_multiple_of: int | None = None,
) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
# docstyle-ignore
"""
@ -104,7 +104,7 @@ def setup_chat_format(
Args:
model (`~transformers.PreTrainedModel`): The model to be modified.
tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified.
format (`Optional[Literal["chatml"]]`): The format to be set. Defaults to "chatml".
format (`Literal["chatml"] | None`): The format to be set. Defaults to "chatml".
resize_to_multiple_of (`int` or `None`): Number to resize the embedding layer to. Defaults to None.
Returns:
@ -164,7 +164,7 @@ def clone_chat_template(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
source_tokenizer_path: str,
resize_to_multiple_of: Optional[int] = 64,
resize_to_multiple_of: int | None = 64,
) -> tuple[PreTrainedModel, PreTrainedTokenizer, list[int]]:
"""
Clones a chat template from a source tokenizer to the target tokenizer and updates the model accordingly.
@ -306,7 +306,7 @@ def add_hooks(model: "DeepSpeedEngine") -> None:
@contextmanager
def unwrap_model_for_generation(
model: Union["DistributedDataParallel", "DeepSpeedEngine"],
model: "DistributedDataParallel | DeepSpeedEngine",
accelerator: "Accelerator",
gather_deepspeed3_params: bool = True,
):
@ -314,7 +314,7 @@ def unwrap_model_for_generation(
Context manager to unwrap distributed or accelerated models for generation tasks.
Args:
model (`Union[DistributedDataParallel, DeepSpeedEngine]`):
model (`DistributedDataParallel | DeepSpeedEngine`):
Model to be unwrapped.
accelerator (`~accelerate.Accelerator`):
Accelerator instance managing the model.
@ -472,7 +472,7 @@ class _ForwardRedirection:
def enable_gradient_checkpointing(
model: PreTrainedModel, gradient_checkpointing_kwargs: Optional[dict]
model: PreTrainedModel, gradient_checkpointing_kwargs: dict | None
) -> PreTrainedModel:
"""Enables gradient checkpointing for the model."""
# Enable gradient checkpointing on the base model for PEFT
@ -511,7 +511,7 @@ def peft_module_casting_to_bf16(model):
def prepare_peft_model(
model: PreTrainedModel, peft_config: Optional["PeftConfig"], args: TrainingArguments
model: PreTrainedModel, peft_config: "PeftConfig | None", args: TrainingArguments
) -> PreTrainedModel:
"""Prepares a model for PEFT training."""
if not is_peft_available():

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from typing import TYPE_CHECKING

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Callable

View File

@ -62,7 +62,6 @@ python trl/scripts/dpo.py \
import argparse
import os
from typing import Optional
import torch
from accelerate import logging
@ -167,7 +166,7 @@ def main(script_args, training_args, model_args, dataset_args):
trainer.accelerator.print(f"🤗 Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.")
def make_parser(subparsers: Optional[argparse._SubParsersAction] = None):
def make_parser(subparsers: argparse._SubParsersAction | None = None):
dataclass_types = (ScriptArguments, DPOConfig, ModelConfig, DatasetMixtureConfig)
if subparsers is not None:
parser = subparsers.add_parser("dpo", help="Run the DPO training script", dataclass_types=dataclass_types)

View File

@ -26,7 +26,6 @@ import importlib
import os
import sys
from dataclasses import dataclass, field
from typing import Optional
from accelerate import logging
from datasets import load_dataset
@ -73,14 +72,14 @@ class GRPOScriptArguments(ScriptArguments):
- any dotted import path " (e.g., `'my_lib.rewards.custom_reward'`).
"""
reward_model_name_or_path: Optional[str] = field(
reward_model_name_or_path: str | None = field(
default=None,
metadata={
"help": "Reward model id of a pretrained model hosted inside a model repo on huggingface.co or "
"local path to a directory containing model weights saved using `PreTrainedModel.save_pretrained`."
},
)
reward_funcs: Optional[list[str]] = field(
reward_funcs: list[str] | None = field(
default=None,
metadata={
"help": "Reward functions to use. Supported values are: `think_format_reward`, "
@ -153,7 +152,7 @@ def main(script_args, training_args, model_args, dataset_args):
trainer.accelerator.print(f"🤗 Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.")
def make_parser(subparsers: Optional[argparse._SubParsersAction] = None):
def make_parser(subparsers: argparse._SubParsersAction | None = None):
dataclass_types = (GRPOScriptArguments, GRPOConfig, ModelConfig, DatasetMixtureConfig)
if subparsers is not None:
parser = subparsers.add_parser("grpo", help="Run the GRPO training script", dataclass_types=dataclass_types)

View File

@ -66,7 +66,6 @@ python trl/scripts/kto.py \
import argparse
import os
from typing import Optional
from accelerate import logging
from datasets import load_dataset
@ -147,7 +146,7 @@ def main(script_args, training_args, model_args, dataset_args):
trainer.accelerator.print(f"🤗 Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.")
def make_parser(subparsers: Optional[argparse._SubParsersAction] = None):
def make_parser(subparsers: argparse._SubParsersAction | None = None):
dataclass_types = (ScriptArguments, KTOConfig, ModelConfig, DatasetMixtureConfig)
if subparsers is not None:
parser = subparsers.add_parser("kto", help="Run the KTO training script", dataclass_types=dataclass_types)

View File

@ -23,7 +23,6 @@
import argparse
import os
from typing import Optional
from accelerate import logging
from datasets import load_dataset
@ -87,7 +86,7 @@ def main(script_args, training_args, model_args, dataset_args):
trainer.accelerator.print(f"🤗 Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.")
def make_parser(subparsers: Optional[argparse._SubParsersAction] = None):
def make_parser(subparsers: argparse._SubParsersAction | None = None):
dataclass_types = (ScriptArguments, RewardConfig, ModelConfig, DatasetMixtureConfig)
if subparsers is not None:
parser = subparsers.add_parser(

View File

@ -26,7 +26,6 @@ import importlib
import os
import sys
from dataclasses import dataclass, field
from typing import Optional
from accelerate import logging
from datasets import load_dataset
@ -73,14 +72,14 @@ class RLOOScriptArguments(ScriptArguments):
- any dotted import path " (e.g., `'my_lib.rewards.custom_reward'`).
"""
reward_model_name_or_path: Optional[str] = field(
reward_model_name_or_path: str | None = field(
default=None,
metadata={
"help": "Reward model id of a pretrained model hosted inside a model repo on huggingface.co or "
"local path to a directory containing model weights saved using `PreTrainedModel.save_pretrained`."
},
)
reward_funcs: Optional[list[str]] = field(
reward_funcs: list[str] | None = field(
default=None,
metadata={
"help": "Reward functions to use. Supported values are: `think_format_reward`, "
@ -153,7 +152,7 @@ def main(script_args, training_args, model_args, dataset_args):
trainer.accelerator.print(f"🤗 Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.")
def make_parser(subparsers: Optional[argparse._SubParsersAction] = None):
def make_parser(subparsers: argparse._SubParsersAction | None = None):
dataclass_types = (RLOOScriptArguments, RLOOConfig, ModelConfig, DatasetMixtureConfig)
if subparsers is not None:
parser = subparsers.add_parser("rloo", help="Run the RLOO training script", dataclass_types=dataclass_types)

View File

@ -64,7 +64,6 @@ python trl/scripts/sft.py \
import argparse
import os
from typing import Optional
from accelerate import logging
from datasets import load_dataset
@ -158,7 +157,7 @@ def main(script_args, training_args, model_args, dataset_args):
trainer.accelerator.print(f"🤗 Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.")
def make_parser(subparsers: Optional[argparse._SubParsersAction] = None):
def make_parser(subparsers: argparse._SubParsersAction | None = None):
dataclass_types = (ScriptArguments, SFTConfig, ModelConfig, DatasetMixtureConfig)
if subparsers is not None:
parser = subparsers.add_parser("sft", help="Run the SFT training script", dataclass_types=dataclass_types)

View File

@ -21,7 +21,6 @@ import subprocess
import sys
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import Optional, Union
import datasets
import yaml
@ -80,11 +79,11 @@ class DatasetConfig:
"""
path: str
name: Optional[str] = None
data_dir: Optional[str] = None
data_files: Optional[Union[str, list[str], dict[str, str]]] = None
name: str | None = None
data_dir: str | None = None
data_files: str | list[str] | dict[str, str] | None = None
split: str = "train"
columns: Optional[list[str]] = None
columns: list[str] | None = None
@dataclass
@ -135,7 +134,7 @@ class DatasetMixtureConfig:
default=False,
metadata={"help": "Whether to stream the datasets. If True, the datasets will be loaded in streaming mode."},
)
test_split_size: Optional[float] = field(
test_split_size: float | None = field(
default=None,
metadata={
"help": "Size of the test split. Refer to the `test_size` parameter in the `datasets.train_test_split` "
@ -177,11 +176,11 @@ class ScriptArguments:
https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992.
"""
dataset_name: Optional[str] = field(
dataset_name: str | None = field(
default=None,
metadata={"help": "Path or name of the dataset to load. If `datasets` is provided, this will be ignored."},
)
dataset_config: Optional[str] = field(
dataset_config: str | None = field(
default=None,
metadata={
"help": "Dataset configuration name. Corresponds to the `name` argument of the `datasets.load_dataset` "
@ -250,7 +249,7 @@ class TrlParser(HfArgumentParser):
configurations, while also supporting configuration file loading and environment variable management.
Args:
dataclass_types (`Union[DataClassType, Iterable[DataClassType]]`, *optional*):
dataclass_types (`DataClassType | Iterable[DataClassType]`, *optional*):
Dataclass types to use for argument parsing.
**kwargs:
Additional keyword arguments passed to the [`transformers.HfArgumentParser`] constructor.
@ -294,7 +293,7 @@ class TrlParser(HfArgumentParser):
def __init__(
self,
dataclass_types: Optional[Union[DataClassType, Iterable[DataClassType]]] = None,
dataclass_types: DataClassType | Iterable[DataClassType] | None = None,
**kwargs,
):
# Make sure dataclass_types is an iterable
@ -315,7 +314,7 @@ class TrlParser(HfArgumentParser):
def parse_args_and_config(
self,
args: Optional[Iterable[str]] = None,
args: Iterable[str] | None = None,
return_remaining_strings: bool = False,
fail_with_unknown_args: bool = True,
) -> tuple[DataClass, ...]:

View File

@ -23,7 +23,6 @@ from io import BytesIO
from itertools import chain
from multiprocessing import Pipe, Process
from multiprocessing.connection import Connection
from typing import Optional
import torch
import torch.distributed.distributed_c10d as c10d
@ -239,7 +238,7 @@ class ScriptArguments:
model: str = field(
metadata={"help": "Model name or path to load the model from."},
)
revision: Optional[str] = field(
revision: str | None = field(
default=None,
metadata={"help": "Revision to use for the model. If not specified, the default branch will be used."},
)
@ -275,7 +274,7 @@ class ScriptArguments:
"determined based on the model configuration. Find the supported values in the vLLM documentation."
},
)
max_model_len: Optional[int] = field(
max_model_len: int | None = field(
default=None,
metadata={
"help": "If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced "
@ -283,14 +282,14 @@ class ScriptArguments:
"context size, which might be much larger than the KV cache, leading to inefficiencies."
},
)
enable_prefix_caching: Optional[bool] = field(
enable_prefix_caching: bool | None = field(
default=None,
metadata={
"help": "Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the "
"hardware support this feature."
},
)
enforce_eager: Optional[bool] = field(
enforce_eager: bool | None = field(
default=False,
metadata={
"help": "Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always "
@ -487,7 +486,7 @@ def main(script_args: ScriptArguments):
class GenerateRequest(BaseModel):
prompts: list[str]
images: Optional[list[str]] = None
images: list[str] | None = None
n: int = 1
repetition_penalty: float = 1.0
temperature: float = 1.0
@ -495,7 +494,7 @@ def main(script_args: ScriptArguments):
top_k: int = -1
min_p: float = 0.0
max_tokens: int = 16
guided_decoding_regex: Optional[str] = None
guided_decoding_regex: str | None = None
generation_kwargs: dict = field(default_factory=dict)
class GenerateResponse(BaseModel):
@ -549,7 +548,7 @@ def main(script_args: ScriptArguments):
request.images = request.images or [None] * len(request.prompts)
prompts = []
for prompt, image in zip(request.prompts, request.images):
for prompt, image in zip(request.prompts, request.images, strict=True):
row = {"prompt": prompt}
if image is not None:
row["multi_modal_data"] = {"image": Image.open(BytesIO(base64.b64decode(image)))}
@ -579,7 +578,7 @@ def main(script_args: ScriptArguments):
chunked_prompts = chunk_list(prompts, script_args.data_parallel_size)
# Send the prompts to each worker
for connection, prompts in zip(connections, chunked_prompts):
for connection, prompts in zip(connections, chunked_prompts, strict=True):
# When the number of prompts is less than data_parallel_size, some workers will receive empty prompts.
# However, vLLM requires that we always send at least one prompt. So we send a placeholder prompt to comply
# with vLLM's requirement, and we later ignore the result.
@ -592,7 +591,7 @@ def main(script_args: ScriptArguments):
all_outputs = [connection.recv() for connection in connections]
# Handle empty prompts (see above)
all_outputs = [output for output, prompts in zip(all_outputs, chunked_prompts) if prompts]
all_outputs = [output for output, prompts in zip(all_outputs, chunked_prompts, strict=True) if prompts]
# Flatten and combine all results
all_outputs = list(chain.from_iterable(all_outputs)) # from list of list to single list
@ -691,7 +690,7 @@ def main(script_args: ScriptArguments):
uvicorn.run(app, host=script_args.host, port=script_args.port, log_level=script_args.log_level)
def make_parser(subparsers: Optional[argparse._SubParsersAction] = None):
def make_parser(subparsers: argparse._SubParsersAction | None = None):
if subparsers is not None:
parser = subparsers.add_parser("vllm-serve", help="Run the vLLM serve script", dataclass_types=ScriptArguments)
else:

View File

@ -13,7 +13,6 @@
# limitations under the License.
import os
from typing import Optional, Union
from transformers import Trainer, is_wandb_available
@ -32,9 +31,9 @@ class BaseTrainer(Trainer):
def create_model_card(
self,
model_name: Optional[str] = None,
dataset_name: Optional[str] = None,
tags: Optional[Union[str, list[str]]] = None,
model_name: str | None = None,
dataset_name: str | None = None,
tags: str | list[str] | None = None,
):
"""
Creates a draft of a model card using the information available to the `Trainer`.

View File

@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import Any, Optional
from typing import Any
from transformers import TrainingArguments
@ -93,7 +93,7 @@ class BCOConfig(TrainingArguments):
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)
bf16: Optional[bool] = field(
bf16: bool | None = field(
default=None,
metadata={
"help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA "
@ -102,21 +102,21 @@ class BCOConfig(TrainingArguments):
},
)
max_length: Optional[int] = field(
max_length: int | None = field(
default=1024,
metadata={
"help": "Maximum length of the sequences (prompt + completion) in the batch. "
"This argument is required if you want to use the default data collator."
},
)
max_prompt_length: Optional[int] = field(
max_prompt_length: int | None = field(
default=512,
metadata={
"help": "Maximum length of the prompt. "
"This argument is required if you want to use the default data collator."
},
)
max_completion_length: Optional[int] = field(
max_completion_length: int | None = field(
default=None,
metadata={
"help": "Maximum length of the completion. This argument is required if you want to use the "
@ -136,7 +136,7 @@ class BCOConfig(TrainingArguments):
"help": "Label pad token id. This argument is required if you want to use the default data collator."
},
)
padding_value: Optional[int] = field(
padding_value: int | None = field(
default=None,
metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."},
)
@ -159,7 +159,7 @@ class BCOConfig(TrainingArguments):
"to W&B during evaluation."
},
)
is_encoder_decoder: Optional[bool] = field(
is_encoder_decoder: bool | None = field(
default=None,
metadata={
"help": "When using the `model_init` argument (callable) to instantiate the model instead of the "
@ -175,21 +175,21 @@ class BCOConfig(TrainingArguments):
"needed."
},
)
model_init_kwargs: Optional[dict[str, Any]] = field(
model_init_kwargs: dict[str, Any] | None = field(
default=None,
metadata={
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the "
"model from a string."
},
)
ref_model_init_kwargs: Optional[dict[str, Any]] = field(
ref_model_init_kwargs: dict[str, Any] | None = field(
default=None,
metadata={
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the "
"reference model from a string."
},
)
dataset_num_proc: Optional[int] = field(
dataset_num_proc: int | None = field(
default=None,
metadata={"help": "Number of processes to use for processing the dataset."},
)

View File

@ -17,10 +17,11 @@ import os
import random
import textwrap
from collections import defaultdict
from collections.abc import Callable
from contextlib import contextmanager, nullcontext
from operator import itemgetter
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, Literal
import numpy as np
import pandas as pd
@ -89,25 +90,27 @@ CLF_NAME = "clf.pkl"
def _tokenize(
batch: dict[str, list[Any]],
tokenizer: "PreTrainedTokenizer",
embedding_tokenizer: Optional["PreTrainedTokenizer"] = None,
embedding_tokenizer: "PreTrainedTokenizer | None" = None,
) -> dict[str, list[Any]]:
"""Tokenize a batch from a BCO specific dataset."""
prompt_tokenized = tokenizer(batch["prompt"], add_special_tokens=False)
prompt_input_ids = prompt_tokenized["input_ids"]
prompt_attention_mask = prompt_tokenized["attention_mask"]
prompt_and_completion = [prompt + completion for prompt, completion in zip(batch["prompt"], batch["completion"])]
prompt_and_completion = [
prompt + completion for prompt, completion in zip(batch["prompt"], batch["completion"], strict=True)
]
full_tokenized = tokenizer(prompt_and_completion, add_special_tokens=False)
full_input_ids = full_tokenized["input_ids"]
full_attention_mask = full_tokenized["attention_mask"]
answer_input_ids = [f[len(p) :] for f, p in zip(full_input_ids, prompt_input_ids)]
answer_attention_mask = [f[len(p) :] for f, p in zip(full_attention_mask, prompt_attention_mask)]
answer_input_ids = [f[len(p) :] for f, p in zip(full_input_ids, prompt_input_ids, strict=True)]
answer_attention_mask = [f[len(p) :] for f, p in zip(full_attention_mask, prompt_attention_mask, strict=True)]
# Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
full_concat_input_ids = [np.concatenate([p, a]) for p, a in zip(prompt_input_ids, answer_input_ids)]
full_concat_input_ids = [np.concatenate([p, a]) for p, a in zip(prompt_input_ids, answer_input_ids, strict=True)]
# Prepare input tokens for token by token comparison
full_input_ids = [np.array(f) for f in full_input_ids]
for full, concat in zip(full_input_ids, full_concat_input_ids):
for full, concat in zip(full_input_ids, full_concat_input_ids, strict=True):
if len(full) != len(concat):
raise ValueError(
"The elements in 'full_input_ids' and 'full_concat_input_ids' must have the same pairwise length."
@ -121,19 +124,19 @@ def _tokenize(
# If tokenized prompt is different than both prompt+answer, then it means the
# last token has changed due to merging.
for idx, (p, f, r) in enumerate(zip(prompt_input_ids, full_input_ids, response_token_ids_start_idx)):
for idx, (p, f, r) in enumerate(zip(prompt_input_ids, full_input_ids, response_token_ids_start_idx, strict=True)):
if not np.array_equal(p, f[:r]):
response_token_ids_start_idx[idx] -= 1
prompt_input_ids = [f[:r] for f, r in zip(full_input_ids, response_token_ids_start_idx)]
prompt_attention_mask = [f[:r] for f, r in zip(full_attention_mask, response_token_ids_start_idx)]
prompt_input_ids = [f[:r] for f, r in zip(full_input_ids, response_token_ids_start_idx, strict=True)]
prompt_attention_mask = [f[:r] for f, r in zip(full_attention_mask, response_token_ids_start_idx, strict=True)]
for p, m in zip(prompt_input_ids, prompt_attention_mask):
for p, m in zip(prompt_input_ids, prompt_attention_mask, strict=True):
if len(p) != len(m):
raise ValueError("Prompt input ids and attention mask should have the same length.")
answer_input_ids = [f[r:] for f, r in zip(full_input_ids, response_token_ids_start_idx)]
answer_attention_mask = [f[r:] for f, r in zip(full_attention_mask, response_token_ids_start_idx)]
answer_input_ids = [f[r:] for f, r in zip(full_input_ids, response_token_ids_start_idx, strict=True)]
answer_attention_mask = [f[r:] for f, r in zip(full_attention_mask, response_token_ids_start_idx, strict=True)]
output = dict(
prompt_input_ids=prompt_input_ids,
@ -340,25 +343,27 @@ class BCOTrainer(BaseTrainer):
def __init__(
self,
model: Union[PreTrainedModel, nn.Module, str] = None,
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
model: PreTrainedModel | nn.Module | str = None,
ref_model: PreTrainedModel | nn.Module | str | None = None,
args: BCOConfig = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
processing_class: Optional[
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
] = None,
data_collator: Optional[DataCollator] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
callbacks: Optional[list[TrainerCallback]] = None,
train_dataset: Dataset | None = None,
eval_dataset: Dataset | dict[str, Dataset] | None = None,
processing_class: PreTrainedTokenizerBase
| BaseImageProcessor
| FeatureExtractionMixin
| ProcessorMixin
| None = None,
data_collator: DataCollator | None = None,
model_init: Callable[[], PreTrainedModel] | None = None,
callbacks: list[TrainerCallback] | None = None,
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
peft_config: Optional[dict] = None,
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
model_adapter_name: Optional[str] = None,
ref_adapter_name: Optional[str] = None,
embedding_func: Optional[Callable] = None,
embedding_tokenizer: Optional[PreTrainedTokenizerBase] = None,
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
peft_config: dict | None = None,
compute_metrics: Callable[[EvalLoopOutput], dict] | None = None,
model_adapter_name: str | None = None,
ref_adapter_name: str | None = None,
embedding_func: Callable | None = None,
embedding_tokenizer: PreTrainedTokenizerBase | None = None,
):
if embedding_func is not None and not (is_sklearn_available() and is_joblib_available()):
raise ImportError(
@ -810,7 +815,7 @@ class BCOTrainer(BaseTrainer):
return embeddings
def _get_prompt_embeddings(
self, batch: dict[str, Union[list, torch.LongTensor]]
self, batch: dict[str, list | torch.LongTensor]
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
"""Extract embeddings from frozen embedding model"""
@ -939,7 +944,7 @@ class BCOTrainer(BaseTrainer):
return super().get_train_dataloader()
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
"""
Returns the evaluation [`~torch.utils.data.DataLoader`].
@ -1080,7 +1085,7 @@ class BCOTrainer(BaseTrainer):
return (per_token_logps * loss_mask).sum(-1)
def forward(
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
self, model: nn.Module, batch: dict[str, list | torch.LongTensor]
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
model_kwargs = (
{
@ -1143,8 +1148,8 @@ class BCOTrainer(BaseTrainer):
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
chosen_embeddings: Optional[torch.FloatTensor],
rejected_embeddings: Optional[torch.FloatTensor],
chosen_embeddings: torch.FloatTensor | None,
rejected_embeddings: torch.FloatTensor | None,
do_train: bool = True,
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Compute the BCO loss for a batch of policy and reference model log probabilities.
@ -1196,7 +1201,7 @@ class BCOTrainer(BaseTrainer):
def get_batch_loss_metrics(
self,
model,
batch: dict[str, Union[list, torch.LongTensor]],
batch: dict[str, list | torch.LongTensor],
do_train: bool = True,
):
"""Compute the BCO loss and other metrics for the given batch of inputs for train or test."""
@ -1289,11 +1294,11 @@ class BCOTrainer(BaseTrainer):
def compute_loss(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: dict[str, Union[torch.Tensor, Any]],
model: PreTrainedModel | nn.Module,
inputs: dict[str, torch.Tensor | Any],
return_outputs=False,
num_items_in_batch=None,
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
compute_loss_context_manager = (
autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
)
@ -1315,7 +1320,7 @@ class BCOTrainer(BaseTrainer):
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
def _get_train_sampler(self, dataset: Optional[Dataset] = None) -> Optional[torch.utils.data.Sampler]:
def _get_train_sampler(self, dataset: Dataset | None = None) -> torch.utils.data.Sampler | None:
if dataset is None:
dataset = self.train_dataset
if dataset is None or not has_length(dataset):
@ -1371,10 +1376,10 @@ class BCOTrainer(BaseTrainer):
def prediction_step(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: dict[str, Union[torch.Tensor, Any]],
model: PreTrainedModel | nn.Module,
inputs: dict[str, torch.Tensor | Any],
prediction_loss_only: bool,
ignore_keys: Optional[list[str]] = None,
ignore_keys: list[str] | None = None,
):
if ignore_keys is None:
if hasattr(model, "config"):
@ -1411,8 +1416,8 @@ class BCOTrainer(BaseTrainer):
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[list[str]] = None,
prediction_loss_only: bool | None = None,
ignore_keys: list[str] | None = None,
metric_key_prefix: str = "eval",
) -> EvalLoopOutput:
"""
@ -1446,7 +1451,9 @@ class BCOTrainer(BaseTrainer):
columns=["Prompt", "Policy", "Ref Model"],
data=[
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
for prompt, pol, ref in zip(
target_batch["prompt"], policy_output_decoded, ref_output_decoded, strict=True
)
],
)
if "wandb" in self.args.report_to:
@ -1465,7 +1472,7 @@ class BCOTrainer(BaseTrainer):
return initial_output
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.

View File

@ -14,7 +14,6 @@
import logging
import os
from typing import Optional, Union
import pandas as pd
import torch
@ -66,7 +65,7 @@ def _generate_completions(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
accelerator: Accelerator,
generation_config: Optional[GenerationConfig],
generation_config: GenerationConfig | None,
batch_size: int = 1,
) -> list[str]:
"""
@ -92,7 +91,7 @@ def _generate_completions(
**tokenized_batch,
generation_config=generation_config,
)
for prompt, generation in zip(tokenized_batch.input_ids, generations):
for prompt, generation in zip(tokenized_batch.input_ids, generations, strict=True):
# Remove prompt from generation
generation = generation[len(prompt) :]
completion = tokenizer.decode(generation, skip_special_tokens=True)
@ -107,15 +106,15 @@ class SyncRefModelCallback(TrainerCallback):
def __init__(
self,
ref_model: Union[PreTrainedModel, torch.nn.Module],
accelerator: Optional[Accelerator],
ref_model: PreTrainedModel | torch.nn.Module,
accelerator: Accelerator | None,
):
self.accelerator = accelerator
self.ref_model = ref_model
@staticmethod
def _sync_target_model(model, target_model, alpha):
for target_param, copy_param in zip(target_model.parameters(), model.parameters()):
for target_param, copy_param in zip(target_model.parameters(), model.parameters(), strict=True):
target_param.data.mul_(1.0 - alpha).add_(copy_param.data, alpha=alpha)
@staticmethod
@ -225,7 +224,7 @@ def _win_rate_completions_df(
state: TrainerState, prompts: list[str], completions: list[str], winner_indices: list[str]
) -> pd.DataFrame:
global_step = [str(state.global_step)] * len(prompts)
data = list(zip(global_step, prompts, completions, winner_indices))
data = list(zip(global_step, prompts, completions, winner_indices, strict=True))
# Split completions from reference model and policy
split_data = [(item[0], item[1], item[2][0], item[2][1], item[3]) for item in data]
return pd.DataFrame(split_data, columns=["step", "prompt", "reference_model", "policy", "winner_index"])
@ -273,8 +272,8 @@ class WinRateCallback(TrainerCallback):
self,
judge: BasePairwiseJudge,
trainer: Trainer,
generation_config: Optional[GenerationConfig] = None,
num_prompts: Optional[int] = None,
generation_config: GenerationConfig | None = None,
num_prompts: int | None = None,
shuffle_order: bool = True,
use_soft_judge: bool = False,
):
@ -319,7 +318,7 @@ class WinRateCallback(TrainerCallback):
batch_size=args.per_device_eval_batch_size,
)
# Compute initial win rate as a reference point
completions = list(zip(self.ref_completions, self.ref_completions))
completions = list(zip(self.ref_completions, self.ref_completions, strict=True))
if self.use_soft_judge:
ref_win_probs = self.judge.judge(prompts, completions, self.shuffle_order, return_scores=True)
winner_indices = [0 if score > 0.5 else 1 for score in ref_win_probs]
@ -379,7 +378,7 @@ class WinRateCallback(TrainerCallback):
batch_size=args.per_device_eval_batch_size,
)
completions = list(zip(self.ref_completions, completions))
completions = list(zip(self.ref_completions, completions, strict=True))
if self.use_soft_judge:
ref_win_probs = self.judge.judge(prompts, completions, self.shuffle_order, return_scores=True)
@ -450,9 +449,9 @@ class LogCompletionsCallback(TrainerCallback):
def __init__(
self,
trainer: Trainer,
generation_config: Optional[GenerationConfig] = None,
num_prompts: Optional[int] = None,
freq: Optional[int] = None,
generation_config: GenerationConfig | None = None,
num_prompts: int | None = None,
freq: int | None = None,
):
self.trainer = trainer
self.generation_config = generation_config
@ -498,7 +497,7 @@ class LogCompletionsCallback(TrainerCallback):
# Build the data to log
if self.trainer.accelerator.is_main_process:
global_step = [str(state.global_step)] * len(prompts)
data = list(zip(global_step, prompts, completions))
data = list(zip(global_step, prompts, completions, strict=True))
self.table.extend(data)
table = pd.DataFrame(columns=["step", "prompt", "completion"], data=self.table)
@ -567,7 +566,7 @@ class WeaveCallback(TrainerCallback):
scorers (`dict[str, Callable]`, *optional*):
Dictionary mapping scorer names to scorer functions. If `None`, operates in tracing mode (predictions
only). If provided, operates in evaluation mode (predictions + scores + summary). Scorer functions should
have signature: `scorer(prompt: str, completion: str) -> Union[float, int]`
have signature: `scorer(prompt: str, completion: str) -> float | int`
generation_config (`GenerationConfig`, *optional*):
Generation config to use for generating completions.
num_prompts (`int` or `None`, *optional*):
@ -582,12 +581,12 @@ class WeaveCallback(TrainerCallback):
def __init__(
self,
trainer: Trainer,
project_name: Optional[str] = None,
scorers: Optional[dict[str, callable]] = None,
generation_config: Optional[GenerationConfig] = None,
num_prompts: Optional[int] = None,
project_name: str | None = None,
scorers: dict[str, callable] | None = None,
generation_config: GenerationConfig | None = None,
num_prompts: int | None = None,
dataset_name: str = "eval_dataset",
model_name: Optional[str] = None,
model_name: str | None = None,
):
self.trainer = trainer
self.project_name = project_name
@ -696,7 +695,7 @@ class WeaveCallback(TrainerCallback):
successful_predictions = 0
total_score_values = {} # For summary statistics
for prompt, completion in zip(all_prompts, all_completions):
for prompt, completion in zip(all_prompts, all_completions, strict=True):
try:
pred_logger = eval_logger.log_prediction(inputs={"prompt": prompt}, output=completion)
@ -771,7 +770,7 @@ class MergeModelCallback(TrainerCallback):
def __init__(
self,
merge_config: Optional["MergeConfig"] = None,
merge_config: "MergeConfig | None" = None,
merge_at_every_checkpoint: bool = False,
push_to_hub: bool = False,
):
@ -954,7 +953,7 @@ class BEMACallback(TrainerCallback):
# Compute EMA + BEMA in-place and write directly to running_model
for thetat, theta0, ema, run_param in zip(
self.thetat_params, self.theta0_params, self.ema_params, self.running_model.parameters()
self.thetat_params, self.theta0_params, self.ema_params, self.running_model.parameters(), strict=True
):
thetat = thetat.detach().to(self.device)
ema.mul_(1 - beta).add_(thetat, alpha=beta) # EMA update: ema = (1 - beta) * ema + beta * θₜ
@ -972,7 +971,9 @@ class BEMACallback(TrainerCallback):
# Snapshot θ₀ and EMA at first update
if step == self.update_after:
for thetat_param, theta0_param, ema_param in zip(self.thetat_params, self.theta0_params, self.ema_params):
for thetat_param, theta0_param, ema_param in zip(
self.thetat_params, self.theta0_params, self.ema_params, strict=True
):
theta0_param.copy_(thetat_param)
ema_param.copy_(thetat_param)

View File

@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import Any, Optional
from typing import Any
from transformers import TrainingArguments
@ -107,7 +107,7 @@ class CPOConfig(TrainingArguments):
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)
bf16: Optional[bool] = field(
bf16: bool | None = field(
default=None,
metadata={
"help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA "
@ -116,18 +116,18 @@ class CPOConfig(TrainingArguments):
},
)
max_length: Optional[int] = field(
max_length: int | None = field(
default=1024,
metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."},
)
max_prompt_length: Optional[int] = field(
max_prompt_length: int | None = field(
default=512,
metadata={
"help": "Maximum length of the prompt. This argument is required if you want to use the default data "
"collator and your model is an encoder-decoder."
},
)
max_completion_length: Optional[int] = field(
max_completion_length: int | None = field(
default=None,
metadata={
"help": "Maximum length of the completion. This argument is required if you want to use the default data "
@ -176,7 +176,7 @@ class CPOConfig(TrainingArguments):
default=-100,
metadata={"help": "Label pad token id."},
)
padding_value: Optional[int] = field(
padding_value: int | None = field(
default=None,
metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."},
)
@ -191,18 +191,18 @@ class CPOConfig(TrainingArguments):
default=False,
metadata={"help": "If `True`, generates and logs completions from the model to W&B during evaluation."},
)
is_encoder_decoder: Optional[bool] = field(
is_encoder_decoder: bool | None = field(
default=None,
metadata={"help": "Whether the model is an encoder-decoder model."},
)
model_init_kwargs: Optional[dict[str, Any]] = field(
model_init_kwargs: dict[str, Any] | None = field(
default=None,
metadata={
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model "
"from a string."
},
)
dataset_num_proc: Optional[int] = field(
dataset_num_proc: int | None = field(
default=None,
metadata={"help": "Number of processes to use for processing the dataset."},
)

View File

@ -16,9 +16,10 @@ import inspect
import random
import textwrap
from collections import defaultdict
from collections.abc import Callable
from contextlib import nullcontext
from pathlib import Path
from typing import Any, Callable, Literal, Optional, Union
from typing import Any, Literal
import numpy as np
import pandas as pd
@ -127,20 +128,22 @@ class CPOTrainer(BaseTrainer):
def __init__(
self,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
args: Optional[CPOConfig] = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
processing_class: Optional[
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
callbacks: Optional[list[TrainerCallback]] = None,
model: PreTrainedModel | nn.Module | str | None = None,
args: CPOConfig | None = None,
data_collator: DataCollator | None = None,
train_dataset: Dataset | None = None,
eval_dataset: Dataset | dict[str, Dataset] | None = None,
processing_class: PreTrainedTokenizerBase
| BaseImageProcessor
| FeatureExtractionMixin
| ProcessorMixin
| None = None,
model_init: Callable[[], PreTrainedModel] | None = None,
callbacks: list[TrainerCallback] | None = None,
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
peft_config: Optional[dict] = None,
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
peft_config: dict | None = None,
compute_metrics: Callable[[EvalLoopOutput], dict] | None = None,
):
if args.model_init_kwargs is None:
model_init_kwargs = {}
@ -439,7 +442,7 @@ class CPOTrainer(BaseTrainer):
attention_mask=answer_attention_mask,
)
def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
def tokenize_row(self, feature, model: PreTrainedModel | nn.Module | None = None) -> dict:
"""Tokenize a single row from a CPO specific dataset.
At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt +
@ -487,7 +490,8 @@ class CPOTrainer(BaseTrainer):
# Make sure prompts only have one different token at most an
# and length only differs by 1 at most
num_diff_tokens = sum(
a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])
a != b
for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"], strict=True)
)
num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
if num_diff_tokens > 1 or num_diff_len > 1:
@ -586,11 +590,11 @@ class CPOTrainer(BaseTrainer):
@staticmethod
def concatenated_inputs(
batch: dict[str, Union[list, torch.LongTensor]],
batch: dict[str, list | torch.LongTensor],
is_encoder_decoder: bool = False,
label_pad_token_id: int = -100,
padding_value: int = 0,
device: Optional[torch.device] = None,
device: torch.device | None = None,
) -> dict[str, torch.LongTensor]:
"""Concatenate the chosen and rejected inputs into a single tensor.
@ -769,7 +773,7 @@ class CPOTrainer(BaseTrainer):
return (per_token_logps * loss_mask).sum(-1)
def concatenated_forward(
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
self, model: nn.Module, batch: dict[str, list | torch.LongTensor]
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
@ -846,7 +850,7 @@ class CPOTrainer(BaseTrainer):
def get_batch_loss_metrics(
self,
model,
batch: dict[str, Union[list, torch.LongTensor]],
batch: dict[str, list | torch.LongTensor],
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
@ -899,11 +903,11 @@ class CPOTrainer(BaseTrainer):
def compute_loss(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: dict[str, Union[torch.Tensor, Any]],
model: PreTrainedModel | nn.Module,
inputs: dict[str, torch.Tensor | Any],
return_outputs=False,
num_items_in_batch=None,
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
compute_loss_context_manager = (
autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
)
@ -943,10 +947,10 @@ class CPOTrainer(BaseTrainer):
def prediction_step(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: dict[str, Union[torch.Tensor, Any]],
model: PreTrainedModel | nn.Module,
inputs: dict[str, torch.Tensor | Any],
prediction_loss_only: bool,
ignore_keys: Optional[list[str]] = None,
ignore_keys: list[str] | None = None,
):
if ignore_keys is None:
if hasattr(model, "config"):
@ -986,8 +990,8 @@ class CPOTrainer(BaseTrainer):
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[list[str]] = None,
prediction_loss_only: bool | None = None,
ignore_keys: list[str] | None = None,
metric_key_prefix: str = "eval",
) -> EvalLoopOutput:
"""
@ -1013,7 +1017,8 @@ class CPOTrainer(BaseTrainer):
table = pd.DataFrame(
columns=["Prompt", "Policy"],
data=[
[prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
[prompt, pol[len(prompt) :]]
for prompt, pol in zip(random_batch["prompt"], policy_output_decoded, strict=True)
],
)
if "wandb" in self.args.report_to:
@ -1032,7 +1037,7 @@ class CPOTrainer(BaseTrainer):
return initial_output
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.

View File

@ -14,7 +14,7 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Optional, Union
from typing import Any
from transformers import TrainingArguments
@ -123,7 +123,7 @@ class DPOConfig(TrainingArguments):
Batch size to use when precomputing reference model log probabilities. This can be set higher than the
training batch size to speed up preprocessing. If `None`, defaults to `per_device_train_batch_size` for
training and `per_device_eval_batch_size` for evaluation.
tools (`Optional[list[Union[dict, Callable]]]`, *optional*):
tools (`list[dict] | None`, *optional*):
List of tools (callable functions) that will be accessible to the model. If the template does not support
function calling, this argument will have no effect.
@ -244,7 +244,7 @@ class DPOConfig(TrainingArguments):
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)
bf16: Optional[bool] = field(
bf16: bool | None = field(
default=None,
metadata={
"help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA "
@ -254,25 +254,25 @@ class DPOConfig(TrainingArguments):
)
# Parameters that control the model and reference model
model_init_kwargs: Optional[dict[str, Any]] = field(
model_init_kwargs: dict[str, Any] | None = field(
default=None,
metadata={
"help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of "
"the `DPOTrainer` is provided as a string."
},
)
ref_model_init_kwargs: Optional[dict[str, Any]] = field(
ref_model_init_kwargs: dict[str, Any] | None = field(
default=None,
metadata={
"help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `ref_model` argument "
"of the `DPOTrainer` is provided as a string."
},
)
model_adapter_name: Optional[str] = field(
model_adapter_name: str | None = field(
default=None,
metadata={"help": "Name of the train target PEFT adapter, when using LoRA with multiple adapters."},
)
ref_adapter_name: Optional[str] = field(
ref_adapter_name: str | None = field(
default=None,
metadata={"help": "Name of the reference PEFT adapter, when using LoRA with multiple adapters."},
)
@ -297,11 +297,11 @@ class DPOConfig(TrainingArguments):
)
# Parameters that control the data preprocessing
dataset_num_proc: Optional[int] = field(
dataset_num_proc: int | None = field(
default=None,
metadata={"help": "Number of processes to use for processing the dataset."},
)
pad_token: Optional[str] = field(
pad_token: str | None = field(
default=None,
metadata={
"help": "Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that "
@ -312,15 +312,15 @@ class DPOConfig(TrainingArguments):
default=-100,
metadata={"help": "Padding value to use for labels."},
)
max_prompt_length: Optional[int] = field(
max_prompt_length: int | None = field(
default=512,
metadata={"help": "Maximum length of the prompt."},
)
max_completion_length: Optional[int] = field(
max_completion_length: int | None = field(
default=None,
metadata={"help": "Maximum length of the completion."},
)
max_length: Optional[int] = field(
max_length: int | None = field(
default=1024,
metadata={"help": "Maximum length of the full sequence (prompt + completion)."},
)
@ -350,7 +350,7 @@ class DPOConfig(TrainingArguments):
"probabilities on-the-fly."
},
)
precompute_ref_batch_size: Optional[int] = field(
precompute_ref_batch_size: int | None = field(
default=None,
metadata={
"help": "Batch size to use when precomputing reference model log probabilities. This can be set higher "
@ -358,7 +358,7 @@ class DPOConfig(TrainingArguments):
"`per_device_train_batch_size` for training and `per_device_eval_batch_size` for evaluation."
},
)
tools: Optional[list[Union[dict, Callable]]] = field(
tools: list[dict] | None = field(
default=None,
metadata={
"help": "List of tools (callable functions) that will be accessible to the model. If the template does "
@ -396,7 +396,7 @@ class DPOConfig(TrainingArguments):
"Higher β means less deviation from the reference model."
},
)
f_divergence_type: Union[FDivergenceType, str] = field(
f_divergence_type: FDivergenceType | str = field(
default=FDivergenceType.REVERSE_KL,
metadata={
"help": "Type of f-divergence regularization function to compute divergence between policy and reference "
@ -425,7 +425,7 @@ class DPOConfig(TrainingArguments):
default=False,
metadata={"help": "Whether to weight the loss as done in the WPO paper."},
)
rpo_alpha: Optional[float] = field(
rpo_alpha: float | None = field(
default=None,
metadata={
"help": "α parameter from the RPO paper (v3), which controls the weighting of the NLL term in the loss. "
@ -433,7 +433,7 @@ class DPOConfig(TrainingArguments):
"`rpo_alpha=1.0`."
},
)
ld_alpha: Optional[float] = field(
ld_alpha: float | None = field(
default=None,
metadata={
"help": "α parameter from the LD-DPO paper, which controls the weighting of the verbose token "
@ -448,7 +448,7 @@ class DPOConfig(TrainingArguments):
"loss. The paper recommends the default value `discopop_tau=0.05`."
},
)
loss_weights: Optional[list[float]] = field(
loss_weights: list[float] | None = field(
default=None,
metadata={
"help": "List of loss weights for multi-loss combinations. Used when combining multiple loss types. "
@ -489,7 +489,7 @@ class DPOConfig(TrainingArguments):
)
# Deprecated arguments
padding_value: Optional[int] = field(
padding_value: int | None = field(
default=None,
metadata={"help": "Deprecated, use `pad_token` (str) instead."},
)

View File

@ -17,10 +17,11 @@ import random
import textwrap
import warnings
from collections import defaultdict
from collections.abc import Callable
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Literal, Optional, Union
from typing import Any, Literal
import pandas as pd
import torch
@ -144,7 +145,7 @@ class DataCollatorForPreference(DataCollatorMixin):
pad_token_id: int
return_tensors: str = "pt"
def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]:
def torch_call(self, examples: list[list[int] | Any | dict[str, Any]]) -> dict[str, Any]:
# Convert to tensor
prompt_input_ids = [torch.tensor(example["prompt_input_ids"]) for example in examples]
prompt_attention_mask = [torch.ones_like(input_ids) for input_ids in prompt_input_ids]
@ -188,7 +189,7 @@ class DPOTrainer(BaseTrainer):
This class is a wrapper around the [`transformers.Trainer`] class and inherits all of its attributes and methods.
Args:
model (`Union[str, PreTrainedModel]`):
model (`str | PreTrainedModel`):
Model to be trained. Can be either:
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
@ -213,7 +214,7 @@ class DPOTrainer(BaseTrainer):
- [Standard](dataset_formats#standard): Each sample contains plain text.
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
and content).
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Dataset | IterableDataset]`):
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
Processing class used to process the data. If `None`, the processing class is loaded from the model's name
@ -265,21 +266,23 @@ class DPOTrainer(BaseTrainer):
def __init__(
self,
model: Union[str, nn.Module, PreTrainedModel],
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
args: Optional[DPOConfig] = None,
data_collator: Optional[DataCollator] = None, # type: ignore
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
processing_class: Optional[
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
] = None,
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
callbacks: Optional[list[TrainerCallback]] = None,
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None,
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
peft_config: Optional["PeftConfig"] = None,
model: str | nn.Module | PreTrainedModel,
ref_model: PreTrainedModel | nn.Module | str | None = None,
args: DPOConfig | None = None,
data_collator: DataCollator | None = None, # type: ignore
train_dataset: Dataset | IterableDataset | None = None,
eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None,
processing_class: PreTrainedTokenizerBase
| BaseImageProcessor
| FeatureExtractionMixin
| ProcessorMixin
| None = None,
compute_metrics: Callable[[EvalLoopOutput], dict] | None = None,
callbacks: list[TrainerCallback] | None = None,
optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None),
optimizer_cls_and_kwargs: tuple[type[torch.optim.Optimizer], dict[str, Any]] | None = None,
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
peft_config: "PeftConfig | None" = None,
):
# Args
if args is None:
@ -632,11 +635,11 @@ class DPOTrainer(BaseTrainer):
def _prepare_dataset(
self,
dataset: Union[Dataset, IterableDataset],
processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin],
dataset: Dataset | IterableDataset,
processing_class: PreTrainedTokenizerBase | BaseImageProcessor | FeatureExtractionMixin | ProcessorMixin,
args: DPOConfig,
dataset_name: str,
) -> Union[Dataset, IterableDataset]:
) -> Dataset | IterableDataset:
# Build the kwargs for the `map` function
map_kwargs = {}
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc nor writer_batch_size
@ -679,8 +682,8 @@ class DPOTrainer(BaseTrainer):
def tokenize_row(
features: dict[str, str],
processing_class: PreTrainedTokenizerBase,
max_prompt_length: Optional[int] = None,
max_completion_length: Optional[int] = None,
max_prompt_length: int | None = None,
max_completion_length: int | None = None,
add_special_tokens: bool = True,
) -> dict[str, list[int]]:
"""
@ -748,8 +751,8 @@ class DPOTrainer(BaseTrainer):
def process_row(
features: dict[str, str],
processing_class: PreTrainedTokenizerBase,
max_prompt_length: Optional[int] = None,
max_completion_length: Optional[int] = None,
max_prompt_length: int | None = None,
max_completion_length: int | None = None,
add_special_tokens: bool = True,
) -> dict[str, list[int]]:
"""
@ -854,7 +857,7 @@ class DPOTrainer(BaseTrainer):
return super().get_train_dataloader()
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
"""
Returns the evaluation [`~torch.utils.data.DataLoader`].
@ -934,14 +937,14 @@ class DPOTrainer(BaseTrainer):
@staticmethod
def concatenated_inputs(
batch: dict[str, Union[list, torch.LongTensor]], padding_value: int
batch: dict[str, list | torch.LongTensor], padding_value: int
) -> dict[str, torch.LongTensor]:
"""
Concatenate the `chosen` and `rejected` inputs from the batch into a single tensor for both the prompt and
completion sequences.
Args:
batch (`dict[str, Union[list, torch.LongTensor]]`):
batch (`dict[str, list | torch.LongTensor]`):
A batch of input data. The batch must contain the following keys:
- `"prompt_input_ids"`: Tensor of shape `(batch_size, prompt_length)` representing the prompt input
@ -1233,7 +1236,7 @@ class DPOTrainer(BaseTrainer):
return losses, chosen_rewards, rejected_rewards
def _compute_loss_liger(
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
self, model: nn.Module, batch: dict[str, list | torch.LongTensor]
) -> dict[str, torch.Tensor]:
unwrapped_model = self.accelerator.unwrap_model(model)
concatenated_batch = self.concatenated_inputs(batch, padding_value=self.pad_token_id)
@ -1465,7 +1468,7 @@ class DPOTrainer(BaseTrainer):
return output
def concatenated_forward(
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]], is_ref_model: bool = False
self, model: nn.Module, batch: dict[str, list | torch.LongTensor], is_ref_model: bool = False
) -> dict[str, torch.Tensor]:
"""
Runs the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
@ -1687,8 +1690,8 @@ class DPOTrainer(BaseTrainer):
def get_batch_loss_metrics(
self,
model: Union[PreTrainedModel, nn.Module],
batch: dict[str, Union[list, torch.LongTensor]],
model: PreTrainedModel | nn.Module,
batch: dict[str, list | torch.LongTensor],
train_eval: Literal["train", "eval"] = "train",
) -> tuple[torch.Tensor, dict[str, float]]:
"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
@ -1775,11 +1778,11 @@ class DPOTrainer(BaseTrainer):
def compute_loss(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: dict[str, Union[torch.Tensor, Any]],
model: PreTrainedModel | nn.Module,
inputs: dict[str, torch.Tensor | Any],
return_outputs=False,
num_items_in_batch=None,
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, float]]]:
) -> torch.Tensor | tuple[torch.Tensor, dict[str, float]]:
compute_loss_context_manager = (
autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
)
@ -1846,11 +1849,11 @@ class DPOTrainer(BaseTrainer):
def prediction_step(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: dict[str, Union[torch.Tensor, Any]],
model: PreTrainedModel | nn.Module,
inputs: dict[str, torch.Tensor | Any],
prediction_loss_only: bool,
ignore_keys: Optional[list[str]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
ignore_keys: list[str] | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
if ignore_keys is None:
if hasattr(model, "config"):
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
@ -1889,8 +1892,8 @@ class DPOTrainer(BaseTrainer):
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[list[str]] = None,
prediction_loss_only: bool | None = None,
ignore_keys: list[str] | None = None,
metric_key_prefix: str = "eval",
) -> EvalLoopOutput:
"""
@ -1918,7 +1921,7 @@ class DPOTrainer(BaseTrainer):
data=[
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
for prompt, pol, ref in zip(
random_batch_dataset["prompt"], policy_output_decoded, ref_output_decoded
random_batch_dataset["prompt"], policy_output_decoded, ref_output_decoded, strict=True
)
],
)
@ -1941,7 +1944,7 @@ class DPOTrainer(BaseTrainer):
return initial_output
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.

View File

@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import Any, Optional
from typing import Any
from transformers import TrainingArguments
@ -77,14 +77,14 @@ class GKDConfig(SFTConfig):
default=128,
metadata={"help": "Maximum number of tokens to generate per completion."},
)
teacher_model_name_or_path: Optional[str] = field(
teacher_model_name_or_path: str | None = field(
default=None,
metadata={
"help": "Model name or path of the teacher model. If `None`, the teacher model will be the same as the "
"model being trained."
},
)
teacher_model_init_kwargs: Optional[dict[str, Any]] = field(
teacher_model_init_kwargs: dict[str, Any] | None = field(
default=None,
metadata={
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the "

View File

@ -14,7 +14,8 @@
import random
import textwrap
from typing import Any, Callable, Optional, Union
from collections.abc import Callable
from typing import Any
import torch
import torch.nn as nn
@ -111,21 +112,23 @@ class GKDTrainer(SFTTrainer):
def __init__(
self,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
teacher_model: Union[PreTrainedModel, nn.Module, str] = None,
args: Optional[GKDConfig] = None,
data_collator: Optional[DataCollator] = None, # type: ignore
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
processing_class: Optional[
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
] = None,
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
callbacks: Optional[list[TrainerCallback]] = None,
model: PreTrainedModel | nn.Module | str | None = None,
teacher_model: PreTrainedModel | nn.Module | str = None,
args: GKDConfig | None = None,
data_collator: DataCollator | None = None, # type: ignore
train_dataset: Dataset | None = None,
eval_dataset: Dataset | dict[str, Dataset] | None = None,
processing_class: PreTrainedTokenizerBase
| BaseImageProcessor
| FeatureExtractionMixin
| ProcessorMixin
| None = None,
compute_metrics: Callable[[EvalPrediction], dict] | None = None,
callbacks: list[TrainerCallback] | None = None,
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
peft_config: Optional["PeftConfig"] = None,
formatting_func: Optional[Callable] = None,
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
peft_config: "PeftConfig | None" = None,
formatting_func: Callable | None = None,
):
# Ensure Trainer does not drop non-signature columns used by the collator (e.g., "prompts")
args.remove_unused_columns = False
@ -404,7 +407,7 @@ class GKDTrainer(SFTTrainer):
return generated_tokens, new_attention_mask, new_labels
def training_step(
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch: int | None = None
) -> torch.Tensor:
"""
Perform a training step for the Generalized Knowledge Distillation (GKD) model.

View File

@ -13,7 +13,6 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional, Union
from transformers import TrainingArguments
@ -262,7 +261,7 @@ class GRPOConfig(TrainingArguments):
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)
bf16: Optional[bool] = field(
bf16: bool | None = field(
default=None,
metadata={
"help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA "
@ -272,7 +271,7 @@ class GRPOConfig(TrainingArguments):
)
# Parameters that control the model and reference model
model_init_kwargs: Optional[Union[dict, str]] = field(
model_init_kwargs: dict | str | None = field(
default=None,
metadata={
"help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` "
@ -290,27 +289,27 @@ class GRPOConfig(TrainingArguments):
# Parameters that control the data preprocessing
# The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on
# additional columns to compute the reward
remove_unused_columns: Optional[bool] = field(
remove_unused_columns: bool | None = field(
default=False,
metadata={
"help": "Whether to only keep the column 'prompt' in the dataset. If you use a custom reward function "
"that requires any column other than 'prompts' and 'completions', you should keep this to `False`."
},
)
max_prompt_length: Optional[int] = field(
max_prompt_length: int | None = field(
default=512,
metadata={
"help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left."
},
)
num_generations: Optional[int] = field(
num_generations: int | None = field(
default=8,
metadata={
"help": "Number of generations to sample. The effective batch size (num_processes * per_device_batch_size "
"* gradient_accumulation_steps) must be evenly divisible by this value."
},
)
max_completion_length: Optional[int] = field(
max_completion_length: int | None = field(
default=256,
metadata={"help": "Maximum length of the generated completion."},
)
@ -323,20 +322,20 @@ class GRPOConfig(TrainingArguments):
"is not compatible with vLLM generation."
},
)
shuffle_dataset: Optional[bool] = field(
shuffle_dataset: bool | None = field(
default=True,
metadata={"help": "Whether to shuffle the training dataset."},
)
# Parameters that control generation
generation_batch_size: Optional[int] = field(
generation_batch_size: int | None = field(
default=None,
metadata={
"help": "Batch size to use for generation. If `None`, it defaults to the effective training batch size: "
"`per_device_train_batch_size * num_processes * steps_per_generation`."
},
)
steps_per_generation: Optional[int] = field(
steps_per_generation: int | None = field(
default=None,
metadata={"help": "Number of steps per generation. If `None`, it defaults to `gradient_accumulation_steps`."},
)
@ -351,21 +350,21 @@ class GRPOConfig(TrainingArguments):
"Set to 1.0 to consider all tokens."
},
)
top_k: Optional[int] = field(
top_k: int | None = field(
default=None,
metadata={
"help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, "
"top-k-filtering is disabled and all tokens are considered."
},
)
min_p: Optional[float] = field(
min_p: float | None = field(
default=None,
metadata={
"help": "Minimum token probability, which will be scaled by the probability of the most likely token. It "
"must be a value between 0.0 and 1.0. Typical values are in the 0.01-0.2 range."
},
)
generation_kwargs: Optional[dict] = field(
generation_kwargs: dict | None = field(
default=None,
metadata={
"help": "Additional keyword arguments to pass to `GenerationConfig` (if using transformers) or "
@ -390,7 +389,7 @@ class GRPOConfig(TrainingArguments):
"implementation. This parameter is only effective when `use_vllm` is set to `False`."
},
)
cache_implementation: Optional[str] = field(
cache_implementation: str | None = field(
default=None,
metadata={"help": "Implementation of the cache method for faster generation when use_vllm is set to False."},
)
@ -428,13 +427,13 @@ class GRPOConfig(TrainingArguments):
"and woken for weight sync and generation."
},
)
vllm_guided_decoding_regex: Optional[str] = field(
vllm_guided_decoding_regex: str | None = field(
default=None,
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
)
# Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`)
vllm_server_base_url: Optional[str] = field(
vllm_server_base_url: str | None = field(
default=None,
metadata={
"help": "Base URL for the vLLM server (e.g., 'http://localhost:8000'). If provided, `vllm_server_host` "
@ -491,7 +490,7 @@ class GRPOConfig(TrainingArguments):
default=0.2,
metadata={"help": "Epsilon value for clipping."},
)
delta: Optional[float] = field(
delta: float | None = field(
default=None,
metadata={
"help": "Enables the upper clipping bound in two-sided GRPO loss when set to a float. If `None` "
@ -499,7 +498,7 @@ class GRPOConfig(TrainingArguments):
"method is introduced in the [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291)."
},
)
epsilon_high: Optional[float] = field(
epsilon_high: float | None = field(
default=None,
metadata={
"help": "Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the "
@ -516,7 +515,7 @@ class GRPOConfig(TrainingArguments):
"sequence-level rewards."
},
)
reward_weights: Optional[list[float]] = field(
reward_weights: list[float] | None = field(
default=None,
metadata={
"help": "Weights for each reward function. Must match the number of reward functions. If `None`, all "
@ -624,11 +623,11 @@ class GRPOConfig(TrainingArguments):
"installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`."
},
)
num_completions_to_print: Optional[int] = field(
num_completions_to_print: int | None = field(
default=None,
metadata={"help": "Number of completions to print with `rich`. If `None`, all completions are logged."},
)
wandb_log_unique_prompts: Optional[bool] = field(
wandb_log_unique_prompts: bool | None = field(
default=False,
metadata={
"help": "Whether to log unique prompts in wandb. If `True`, only unique prompts are logged. If `False`, "

View File

@ -17,10 +17,11 @@ import os
import re
import textwrap
from collections import defaultdict, deque
from collections.abc import Callable
from contextlib import nullcontext
from functools import partial
from pathlib import Path
from typing import Any, Callable, Optional, Union
from typing import Any
import datasets
import torch
@ -94,7 +95,7 @@ logger = logging.get_logger(__name__)
# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
RewardFunc = str | PreTrainedModel | Callable[[list, list], list[float]]
class GRPOTrainer(BaseTrainer):
@ -127,7 +128,7 @@ class GRPOTrainer(BaseTrainer):
```
Args:
model (`Union[str, PreTrainedModel]`):
model (`str | PreTrainedModel`):
Model to be trained. Can be either:
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
@ -136,7 +137,7 @@ class GRPOTrainer(BaseTrainer):
using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keyword arguments in
`args.model_init_kwargs`.
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
reward_funcs (`RewardFunc | list[RewardFunc]`):
Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
functions with the prompts and completions and sum the rewards. Can be either:
@ -169,14 +170,14 @@ class GRPOTrainer(BaseTrainer):
- [Standard](dataset_formats#standard): Each sample contains plain text.
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
and content).
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Dataset | IterableDataset]`):
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`], *optional*):
Processing class used to process the data. The padding side must be set to "left". If `None`, the
processing class is loaded from the model's name with [`~transformers.AutoProcessor.from_pretrained`]. A
padding token, `tokenizer.pad_token`, must be set. If the processing class has not set a padding token,
`tokenizer.eos_token` will be used as the default.
reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*):
reward_processing_classes (`PreTrainedTokenizerBase | list[PreTrainedTokenizerBase]`, *optional*):
Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
- A single processing class: Used when `reward_funcs` contains only one reward function.
@ -217,16 +218,16 @@ class GRPOTrainer(BaseTrainer):
def __init__(
self,
model: Union[str, PreTrainedModel],
reward_funcs: Union[RewardFunc, list[RewardFunc]],
args: Optional[GRPOConfig] = None,
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
processing_class: Optional[Union[PreTrainedTokenizerBase, ProcessorMixin]] = None,
reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
callbacks: Optional[list[TrainerCallback]] = None,
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
peft_config: Optional["PeftConfig"] = None,
model: str | PreTrainedModel,
reward_funcs: RewardFunc | list[RewardFunc],
args: GRPOConfig | None = None,
train_dataset: Dataset | IterableDataset | None = None,
eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None,
processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None,
reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None,
callbacks: list[TrainerCallback] | None = None,
optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None),
peft_config: "PeftConfig | None" = None,
):
# Args
if args is None:
@ -333,7 +334,9 @@ class GRPOTrainer(BaseTrainer):
f"reward functions ({len(reward_funcs)})."
)
for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
for i, (reward_processing_class, reward_func) in enumerate(
zip(reward_processing_classes, reward_funcs, strict=True)
):
if isinstance(reward_func, PreTrainedModel):
if reward_processing_class is None:
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
@ -661,7 +664,7 @@ class GRPOTrainer(BaseTrainer):
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
def _get_train_sampler(self, dataset: Optional[Dataset] = None) -> Sampler:
def _get_train_sampler(self, dataset: Dataset | None = None) -> Sampler:
# Returns a sampler that
# 1. ensures each prompt is repeated across multiple processes. This guarantees that identical prompts are
# distributed to different GPUs, allowing rewards to be computed and normalized correctly within each prompt
@ -803,7 +806,7 @@ class GRPOTrainer(BaseTrainer):
pixel_attention_mask=None,
image_sizes=None,
token_type_ids=None,
) -> dict[str, Optional[torch.Tensor]]:
) -> dict[str, torch.Tensor | None]:
"""Compute log-probs and (optionally) entropies for each token."""
batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak
all_logps = []
@ -862,7 +865,7 @@ class GRPOTrainer(BaseTrainer):
entropies = torch.cat(all_entropies, dim=0) if compute_entropy else None
return logps, entropies
def _fix_param_name_to_vllm(self, name, extra_prefixes: Optional[list[str]] = None):
def _fix_param_name_to_vllm(self, name, extra_prefixes: list[str] | None = None):
extra_prefixes = extra_prefixes or []
prefixes = ["_checkpoint_wrapped_module."] + extra_prefixes
for prefix in prefixes:
@ -986,9 +989,7 @@ class GRPOTrainer(BaseTrainer):
self.llm.reset_prefix_cache()
@profiling_decorator
def _prepare_inputs(
self, generation_batch: dict[str, Union[torch.Tensor, Any]]
) -> dict[str, Union[torch.Tensor, Any]]:
def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]:
# Prepares inputs for model training/evaluation by managing completion generation and batch handling.
# During training:
# - Receives the local generation batch (Per-GPU batch size × steps per generation)
@ -1033,15 +1034,15 @@ class GRPOTrainer(BaseTrainer):
reward_kwargs["trainer_state"] = self.state
for i, (reward_func, reward_processing_class, reward_func_name) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names)
zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names, strict=True)
):
with profiling_context(self, reward_func_name):
if isinstance(reward_func, nn.Module): # Module (no PretrainedModel) for compat with compiled models
if is_conversational(inputs[0]):
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
messages = [{"messages": p + c} for p, c in zip(prompts, completions, strict=True)]
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
else:
texts = [p + c for p, c in zip(prompts, completions)]
texts = [p + c for p, c in zip(prompts, completions, strict=True)]
reward_inputs = reward_processing_class(
text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
)
@ -1075,7 +1076,7 @@ class GRPOTrainer(BaseTrainer):
rewards_per_func = gather(rewards_per_func)
return rewards_per_func
def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
def _generate_single_turn(self, prompts: list[str], images: list | None):
device = self.accelerator.device
# If the prompts are conversational and the inputs contain images, we need to convert the prompts from
@ -1084,7 +1085,7 @@ class GRPOTrainer(BaseTrainer):
kwargs = {}
if images is not None:
kwargs = {"images": images}
for prompt, image_list in zip(prompts, images):
for prompt, image_list in zip(prompts, images, strict=True):
if isinstance(prompt, list): # i.e., when using conversational data
prepare_multimodal_messages(prompt, num_images=len(image_list))
@ -1103,7 +1104,7 @@ class GRPOTrainer(BaseTrainer):
prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool(), strict=True)]
if self.max_prompt_length is not None:
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
@ -1245,7 +1246,7 @@ class GRPOTrainer(BaseTrainer):
if images is not None and all_images:
vllm_inputs = []
for prompt, image_list in zip(all_prompts_text, all_images):
for prompt, image_list in zip(all_prompts_text, all_images, strict=True):
vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}})
else:
@ -1342,13 +1343,13 @@ class GRPOTrainer(BaseTrainer):
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())]
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool(), strict=True)]
completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool(), strict=True)]
logprobs = None # not used in this case
return prompt_ids, completion_ids, logprobs, forward_kwargs
def _generate(self, prompts: list[str], images: Optional[list]):
def _generate(self, prompts: list[str], images: list | None):
device = self.accelerator.device
mode = "train" if self.model.training else "eval"
@ -1387,8 +1388,8 @@ class GRPOTrainer(BaseTrainer):
return prompt_ids, completion_ids, total_completion_tokens, logprobs, forward_kwargs
def _generate_and_score_completions(
self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
) -> dict[str, Union[torch.Tensor, Any]]:
self, inputs: list[dict[str, torch.Tensor | Any]]
) -> dict[str, torch.Tensor | Any]:
device = self.accelerator.device
mode = "train" if self.model.training else "eval"
@ -1507,7 +1508,7 @@ class GRPOTrainer(BaseTrainer):
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
if is_conversational(inputs[0]):
completions = []
for prompt, completion in zip(prompts, completions_text):
for prompt, completion in zip(prompts, completions_text, strict=True):
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
completions.append([{"role": "assistant", "content": bootstrap + completion}])
else:
@ -1815,7 +1816,7 @@ class GRPOTrainer(BaseTrainer):
self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())
return loss
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: list[str] | None = None):
inputs = self._prepare_inputs(inputs)
with torch.no_grad():
with self.compute_loss_context_manager():
@ -1823,7 +1824,7 @@ class GRPOTrainer(BaseTrainer):
loss = loss.mean().detach()
return loss, None, None
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
mode = "train" if self.model.training else "eval"
metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics

View File

@ -15,7 +15,6 @@
import concurrent.futures
import logging
from abc import ABC, abstractmethod
from typing import Optional, Union
import numpy as np
from accelerate import Accelerator
@ -154,7 +153,7 @@ class BaseBinaryJudge(BaseJudge):
self,
prompts: list[str],
completions: list[str],
gold_completions: Optional[list[str]] = None,
gold_completions: list[str] | None = None,
shuffle_order: bool = True,
) -> list[int]:
"""
@ -224,7 +223,7 @@ class PairRMJudge(BasePairwiseJudge):
shuffle_order: bool = True,
return_scores: bool = False,
temperature: float = 1.0,
) -> list[Union[int, float]]:
) -> list[int | float]:
"""
Judge the completion pairs for the given prompts using the PairRM model.
@ -241,7 +240,7 @@ class PairRMJudge(BasePairwiseJudge):
Temperature for scaling logits if `return_scores` is True.
Returns:
`Union[list[int, float]]`:
`list[int | float]`:
If `return_scores` is `False`, returns a list of ranks (`0` or `1`) for each prompt, indicating which
completion is preferred. If `return_scores` is `True`, returns softmax probabilities for the first
completion.
@ -260,7 +259,7 @@ class PairRMJudge(BasePairwiseJudge):
# Shuffle the order of the completions to avoid positional bias
if shuffle_order:
flip_mask = np.random.choice([True, False], size=len(prompts))
completions = [pair[::-1] if flip else pair for flip, pair in zip(flip_mask, completions)]
completions = [pair[::-1] if flip else pair for flip, pair in zip(flip_mask, completions, strict=True)]
# Rank the completions
ranks = self.blender.rank(prompts, completions, return_scores=return_scores, disable_tqdm=True)
@ -305,8 +304,8 @@ class HfPairwiseJudge(BasePairwiseJudge):
def __init__(
self,
model="meta-llama/Meta-Llama-3-70B-Instruct",
token: Optional[str] = None,
system_prompt: Optional[str] = None,
token: str | None = None,
system_prompt: str | None = None,
):
self.client = InferenceClient(model=model, token=token)
self.system_prompt = system_prompt or DEFAULT_PAIRWISE_SYSTEM_PROMPT
@ -315,7 +314,7 @@ class HfPairwiseJudge(BasePairwiseJudge):
# Shuffle the order of the completions to avoid positional bias
if shuffle_order:
flip_mask = np.random.choice([True, False], size=len(prompts))
completions = [pair[::-1] if flip else pair for flip, pair in zip(flip_mask, completions)]
completions = [pair[::-1] if flip else pair for flip, pair in zip(flip_mask, completions, strict=True)]
# Define a function to get the rank for a single prompt, will be called concurrently
def get_rank(prompt, candidates):
@ -359,7 +358,7 @@ class OpenAIPairwiseJudge(BasePairwiseJudge):
"""
def __init__(
self, model="gpt-4-turbo-preview", system_prompt: Optional[str] = None, max_requests: Union[int, None] = 1_000
self, model="gpt-4-turbo-preview", system_prompt: str | None = None, max_requests: int | None = 1_000
):
if not is_openai_available():
raise ValueError("OpenAI client is not installed. Please install it with 'pip install openai'.")
@ -384,7 +383,7 @@ class OpenAIPairwiseJudge(BasePairwiseJudge):
# Shuffle the order of the completions to avoid positional bias
if shuffle_order:
flip_mask = np.random.choice([True, False], size=len(prompts))
completions = [pair[::-1] if flip else pair for flip, pair in zip(flip_mask, completions)]
completions = [pair[::-1] if flip else pair for flip, pair in zip(flip_mask, completions, strict=True)]
# Define a function to get the rank for a single prompt, will be called concurrently
def get_rank(prompt, candidates):
@ -433,14 +432,14 @@ class AllTrueJudge(BaseBinaryJudge):
self,
prompts: list[str],
completions: list[str],
gold_completions: Optional[list[str]] = None,
gold_completions: list[str] | None = None,
shuffle_order: bool = True,
) -> list[int]:
all_binary_judgments = [
judge.judge(prompts, completions, gold_completions, shuffle_order) for judge in self.judges
]
output = []
for binary_judgments in zip(*all_binary_judgments):
for binary_judgments in zip(*all_binary_judgments, strict=True):
# Check that all values are in {0, 1, -1}
if any(binary_judgment not in {0, 1, -1} for binary_judgment in binary_judgments):
raise ValueError(

View File

@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import Any, Optional
from typing import Any
from transformers import TrainingArguments
@ -107,7 +107,7 @@ class KTOConfig(TrainingArguments):
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)
bf16: Optional[bool] = field(
bf16: bool | None = field(
default=None,
metadata={
"help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA "
@ -116,18 +116,18 @@ class KTOConfig(TrainingArguments):
},
)
max_length: Optional[int] = field(
max_length: int | None = field(
default=1024,
metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."},
)
max_prompt_length: Optional[int] = field(
max_prompt_length: int | None = field(
default=512,
metadata={
"help": "Maximum length of the prompt. This argument is required if you want to use the default data "
"collator and your model is an encoder-decoder."
},
)
max_completion_length: Optional[int] = field(
max_completion_length: int | None = field(
default=None,
metadata={
"help": "Maximum length of the completion. This argument is required if you want to use the default data "
@ -168,7 +168,7 @@ class KTOConfig(TrainingArguments):
"help": "Label pad token id. This argument is required if you want to use the default data collator."
},
)
padding_value: Optional[int] = field(
padding_value: int | None = field(
default=None,
metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."},
)
@ -186,7 +186,7 @@ class KTOConfig(TrainingArguments):
"during evaluation."
},
)
is_encoder_decoder: Optional[bool] = field(
is_encoder_decoder: bool | None = field(
default=None,
metadata={
"help": "When using the `model_init` argument (callable) to instantiate the model instead of the `model` "
@ -204,21 +204,21 @@ class KTOConfig(TrainingArguments):
"This is useful when training without the reference model to reduce the total GPU memory needed."
},
)
model_init_kwargs: Optional[dict[str, Any]] = field(
model_init_kwargs: dict[str, Any] | None = field(
default=None,
metadata={
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model "
"from a string."
},
)
ref_model_init_kwargs: Optional[dict[str, Any]] = field(
ref_model_init_kwargs: dict[str, Any] | None = field(
default=None,
metadata={
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the "
"reference model from a string."
},
)
dataset_num_proc: Optional[int] = field(
dataset_num_proc: int | None = field(
default=None,
metadata={"help": "Number of processes to use for processing the dataset."},
)

View File

@ -16,10 +16,11 @@ import inspect
import random
import textwrap
from collections import defaultdict
from collections.abc import Callable
from contextlib import contextmanager, nullcontext
from operator import itemgetter
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, Literal
import numpy as np
import pandas as pd
@ -100,19 +101,21 @@ def _tokenize(
prompt_tokenized = tokenizer(batch["prompt"], add_special_tokens=False)
prompt_input_ids = prompt_tokenized["input_ids"]
prompt_attention_mask = prompt_tokenized["attention_mask"]
prompt_and_completion = [prompt + completion for prompt, completion in zip(batch["prompt"], batch["completion"])]
prompt_and_completion = [
prompt + completion for prompt, completion in zip(batch["prompt"], batch["completion"], strict=True)
]
full_tokenized = tokenizer(prompt_and_completion, add_special_tokens=False)
full_input_ids = full_tokenized["input_ids"]
full_attention_mask = full_tokenized["attention_mask"]
answer_input_ids = [f[len(p) :] for f, p in zip(full_input_ids, prompt_input_ids)]
answer_attention_mask = [f[len(p) :] for f, p in zip(full_attention_mask, prompt_attention_mask)]
answer_input_ids = [f[len(p) :] for f, p in zip(full_input_ids, prompt_input_ids, strict=True)]
answer_attention_mask = [f[len(p) :] for f, p in zip(full_attention_mask, prompt_attention_mask, strict=True)]
# Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
full_concat_input_ids = [np.concatenate([p, a]) for p, a in zip(prompt_input_ids, answer_input_ids)]
full_concat_input_ids = [np.concatenate([p, a]) for p, a in zip(prompt_input_ids, answer_input_ids, strict=True)]
# Prepare input tokens for token by token comparison
full_input_ids = [np.array(f) for f in full_input_ids]
for full, concat in zip(full_input_ids, full_concat_input_ids):
for full, concat in zip(full_input_ids, full_concat_input_ids, strict=True):
if len(full) != len(concat):
raise ValueError(
"The elements in 'full_input_ids' and 'full_concat_input_ids' must have the same pairwise length."
@ -126,19 +129,19 @@ def _tokenize(
# If tokenized prompt is different than both prompt+answer, then it means the
# last token has changed due to merging.
for idx, (p, f, r) in enumerate(zip(prompt_input_ids, full_input_ids, response_token_ids_start_idx)):
for idx, (p, f, r) in enumerate(zip(prompt_input_ids, full_input_ids, response_token_ids_start_idx, strict=True)):
if not np.array_equal(p, f[:r]):
response_token_ids_start_idx[idx] -= 1
prompt_input_ids = [f[:r] for f, r in zip(full_input_ids, response_token_ids_start_idx)]
prompt_attention_mask = [f[:r] for f, r in zip(full_attention_mask, response_token_ids_start_idx)]
prompt_input_ids = [f[:r] for f, r in zip(full_input_ids, response_token_ids_start_idx, strict=True)]
prompt_attention_mask = [f[:r] for f, r in zip(full_attention_mask, response_token_ids_start_idx, strict=True)]
for p, m in zip(prompt_input_ids, prompt_attention_mask):
for p, m in zip(prompt_input_ids, prompt_attention_mask, strict=True):
if len(p) != len(m):
raise ValueError("Prompt input ids and attention mask should have the same length.")
answer_input_ids = [f[r:] for f, r in zip(full_input_ids, response_token_ids_start_idx)]
answer_attention_mask = [f[r:] for f, r in zip(full_attention_mask, response_token_ids_start_idx)]
answer_input_ids = [f[r:] for f, r in zip(full_input_ids, response_token_ids_start_idx, strict=True)]
answer_attention_mask = [f[r:] for f, r in zip(full_attention_mask, response_token_ids_start_idx, strict=True)]
output = dict(
prompt_input_ids=prompt_input_ids,
@ -335,23 +338,25 @@ class KTOTrainer(BaseTrainer):
def __init__(
self,
model: Union[PreTrainedModel, nn.Module, str] = None,
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
model: PreTrainedModel | nn.Module | str = None,
ref_model: PreTrainedModel | nn.Module | str | None = None,
args: KTOConfig = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
processing_class: Optional[
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
] = None,
data_collator: Optional[DataCollator] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
callbacks: Optional[list[TrainerCallback]] = None,
train_dataset: Dataset | None = None,
eval_dataset: Dataset | dict[str, Dataset] | None = None,
processing_class: PreTrainedTokenizerBase
| BaseImageProcessor
| FeatureExtractionMixin
| ProcessorMixin
| None = None,
data_collator: DataCollator | None = None,
model_init: Callable[[], PreTrainedModel] | None = None,
callbacks: list[TrainerCallback] | None = None,
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
peft_config: Optional[dict] = None,
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
model_adapter_name: Optional[str] = None,
ref_adapter_name: Optional[str] = None,
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
peft_config: dict | None = None,
compute_metrics: Callable[[EvalLoopOutput], dict] | None = None,
model_adapter_name: str | None = None,
ref_adapter_name: str | None = None,
):
if type(args) is TrainingArguments:
raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
@ -870,7 +875,7 @@ class KTOTrainer(BaseTrainer):
return super().get_train_dataloader()
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
"""
Returns the evaluation [`~torch.utils.data.DataLoader`].
@ -1057,7 +1062,7 @@ class KTOTrainer(BaseTrainer):
return (per_token_logps * loss_mask).sum(-1)
def forward(
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
self, model: nn.Module, batch: dict[str, list | torch.LongTensor]
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
KL_logps = self._compute_kl_logps(model, batch)
@ -1356,7 +1361,7 @@ class KTOTrainer(BaseTrainer):
def get_batch_loss_metrics(
self,
model,
batch: dict[str, Union[list, torch.LongTensor]],
batch: dict[str, list | torch.LongTensor],
):
"""Compute the KTO loss and other metrics for the given batch of inputs for train or test."""
metrics = {}
@ -1467,11 +1472,11 @@ class KTOTrainer(BaseTrainer):
def compute_loss(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: dict[str, Union[torch.Tensor, Any]],
model: PreTrainedModel | nn.Module,
inputs: dict[str, torch.Tensor | Any],
return_outputs=False,
num_items_in_batch=None,
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
compute_loss_context_manager = (
autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
)
@ -1493,7 +1498,7 @@ class KTOTrainer(BaseTrainer):
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
def _get_train_sampler(self, dataset: Optional[Dataset] = None) -> Optional[torch.utils.data.Sampler]:
def _get_train_sampler(self, dataset: Dataset | None = None) -> torch.utils.data.Sampler | None:
if dataset is None:
dataset = self.train_dataset
if dataset is None or not has_length(dataset):
@ -1550,10 +1555,10 @@ class KTOTrainer(BaseTrainer):
def prediction_step(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: dict[str, Union[torch.Tensor, Any]],
model: PreTrainedModel | nn.Module,
inputs: dict[str, torch.Tensor | Any],
prediction_loss_only: bool,
ignore_keys: Optional[list[str]] = None,
ignore_keys: list[str] | None = None,
):
if ignore_keys is None:
if hasattr(model, "config"):
@ -1590,8 +1595,8 @@ class KTOTrainer(BaseTrainer):
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[list[str]] = None,
prediction_loss_only: bool | None = None,
ignore_keys: list[str] | None = None,
metric_key_prefix: str = "eval",
) -> EvalLoopOutput:
"""
@ -1625,7 +1630,9 @@ class KTOTrainer(BaseTrainer):
columns=["Prompt", "Policy", "Ref Model"],
data=[
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
for prompt, pol, ref in zip(
target_batch["prompt"], policy_output_decoded, ref_output_decoded, strict=True
)
],
)
if "wandb" in self.args.report_to:
@ -1644,7 +1651,7 @@ class KTOTrainer(BaseTrainer):
return initial_output
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.

View File

@ -14,7 +14,6 @@
import warnings
from dataclasses import dataclass, field
from typing import Optional
@dataclass
@ -54,9 +53,9 @@ class ModelConfig:
LoRA alpha.
lora_dropout (`float`, *optional*, defaults to `0.05`):
LoRA dropout.
lora_target_modules (`Union[str, list[str]]`, *optional*):
lora_target_modules (`str | list[str]`, *optional*):
LoRA target modules.
lora_target_parameters (`Union[str, list[str]]`, *optional*):
lora_target_parameters (`str | list[str]`, *optional*):
List of target parameters for LoRA.
lora_modules_to_save (`list[str]`, *optional*):
Model layers to unfreeze & train.
@ -82,7 +81,7 @@ class ModelConfig:
Whether to use nested quantization.
"""
model_name_or_path: Optional[str] = field(
model_name_or_path: str | None = field(
default=None,
metadata={"help": "Model checkpoint for weights initialization."},
)
@ -90,7 +89,7 @@ class ModelConfig:
default="main",
metadata={"help": "Specific model version to use. It can be a branch name, a tag name, or a commit id."},
)
dtype: Optional[str] = field(
dtype: str | None = field(
default=None,
metadata={
"help": "Override the default `torch.dtype` and load the model under this dtype.",
@ -105,7 +104,7 @@ class ModelConfig:
"execute code present on the Hub on your local machine."
},
)
attn_implementation: Optional[str] = field(
attn_implementation: str | None = field(
default=None,
metadata={
"help": "Which attention implementation to use. You can run `--attn_implementation=flash_attention_2`, in "
@ -128,15 +127,15 @@ class ModelConfig:
default=0.05,
metadata={"help": "LoRA dropout."},
)
lora_target_modules: Optional[list[str]] = field(
lora_target_modules: list[str] | None = field(
default=None,
metadata={"help": "LoRA target modules."},
)
lora_target_parameters: Optional[list[str]] = field(
lora_target_parameters: list[str] | None = field(
default=None,
metadata={"help": "List of target parameters for LoRA."},
)
lora_modules_to_save: Optional[list[str]] = field(
lora_modules_to_save: list[str] | None = field(
default=None,
metadata={"help": "Model layers to unfreeze & train."},
)
@ -178,7 +177,7 @@ class ModelConfig:
metadata={"help": "Whether to use nested quantization."},
)
# Deprecated params
torch_dtype: Optional[str] = field(
torch_dtype: str | None = field(
default=None,
metadata={
"help": "Override the default `torch.dtype` and load the model under this dtype.",

View File

@ -13,7 +13,8 @@
# limitations under the License.
import textwrap
from typing import Any, Callable, Optional, Union
from collections.abc import Callable
from typing import Any
import jinja2
import torch
@ -122,24 +123,26 @@ class NashMDTrainer(OnlineDPOTrainer):
def __init__(
self,
model: Union[PreTrainedModel, nn.Module] = None,
ref_model: Union[PreTrainedModel, nn.Module] = None,
reward_funcs: Union[PreTrainedModel, nn.Module, None] = None,
judge: Optional[BasePairwiseJudge] = None,
args: Optional[NashMDConfig] = None,
data_collator: Optional[Callable] = None,
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
processing_class: Optional[
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
] = None,
peft_config: Optional[dict] = None,
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
callbacks: Optional[list[TrainerCallback]] = None,
model: PreTrainedModel | nn.Module = None,
ref_model: PreTrainedModel | nn.Module = None,
reward_funcs: PreTrainedModel | nn.Module | None = None,
judge: BasePairwiseJudge | None = None,
args: NashMDConfig | None = None,
data_collator: Callable | None = None,
train_dataset: Dataset | IterableDataset | None = None,
eval_dataset: Dataset | dict[str, Dataset] | None = None,
processing_class: PreTrainedTokenizerBase
| BaseImageProcessor
| FeatureExtractionMixin
| ProcessorMixin
| None = None,
peft_config: dict | None = None,
compute_metrics: Callable[[EvalPrediction], dict] | None = None,
callbacks: list[TrainerCallback] | None = None,
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
# Deprecated parameters
reward_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
reward_model: PreTrainedModel | nn.Module | None = None,
) -> None:
super().__init__(
model=model,
@ -316,7 +319,7 @@ class NashMDTrainer(OnlineDPOTrainer):
probability = self.judge.judge(
prompts,
list(zip(model_data_completions, mixture_data_completions)),
list(zip(model_data_completions, mixture_data_completions, strict=True)),
return_scores=True,
)
return torch.tensor(probability, device=model_data["input_ids"].device)
@ -426,7 +429,7 @@ class NashMDTrainer(OnlineDPOTrainer):
self.stats["mixture_coef"].append(self.mixture_coef)
def training_step(
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch: int | None = None
) -> torch.Tensor:
model.train()

View File

@ -14,7 +14,7 @@
import warnings
from dataclasses import dataclass, field
from typing import Any, Optional
from typing import Any
from transformers import TrainingArguments
@ -174,7 +174,7 @@ class OnlineDPOConfig(TrainingArguments):
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)
bf16: Optional[bool] = field(
bf16: bool | None = field(
default=None,
metadata={
"help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA "
@ -183,13 +183,13 @@ class OnlineDPOConfig(TrainingArguments):
},
)
reward_model_path: Optional[str] = field(
reward_model_path: str | None = field(
default=None,
metadata={
"help": "Path to the reward model. Either `judge` or `reward_model_path` must be set, but not both."
},
)
judge: Optional[str] = field(
judge: str | None = field(
default=None,
metadata={
"help": "Name of the judge to use. Either `judge` or `reward_model_path` must be set, but not both."
@ -218,14 +218,14 @@ class OnlineDPOConfig(TrainingArguments):
"Set to 1.0 to consider all tokens."
},
)
top_k: Optional[int] = field(
top_k: int | None = field(
default=None,
metadata={
"help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, "
"top-k-filtering is disabled and all tokens are considered."
},
)
min_p: Optional[float] = field(
min_p: float | None = field(
default=None,
metadata={
"help": "Minimum token probability, which will be scaled by the probability of the most likely token. It "
@ -240,7 +240,7 @@ class OnlineDPOConfig(TrainingArguments):
"to repeat tokens."
},
)
generation_kwargs: Optional[dict] = field(
generation_kwargs: dict | None = field(
default=None,
metadata={
"help": "Additional keyword arguments to pass to `GenerationConfig` (if using transformers) or "
@ -257,11 +257,11 @@ class OnlineDPOConfig(TrainingArguments):
"implementation. This parameter is only effective when `use_vllm` is set to `False`."
},
)
cache_implementation: Optional[str] = field(
cache_implementation: str | None = field(
default=None,
metadata={"help": "Implementation of the cache method for faster generation when use_vllm is set to False."},
)
missing_eos_penalty: Optional[float] = field(
missing_eos_penalty: float | None = field(
default=None,
metadata={
"help": "Penalty applied to the score when the model fails to generate an EOS token. This is useful to "
@ -304,11 +304,11 @@ class OnlineDPOConfig(TrainingArguments):
"model implementation."
},
)
vllm_guided_decoding_regex: Optional[str] = field(
vllm_guided_decoding_regex: str | None = field(
default=None,
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
)
vllm_gpu_memory_utilization: Optional[float] = field(
vllm_gpu_memory_utilization: float | None = field(
default=0.55,
metadata={
"help": "Control the GPU memory utilization for vLLM. This setting only applies when `vllm_mode` is set "
@ -326,7 +326,7 @@ class OnlineDPOConfig(TrainingArguments):
"contention with training.",
},
)
vllm_server_base_url: Optional[str] = field(
vllm_server_base_url: str | None = field(
default=None,
metadata={
"help": "Base URL for the vLLM server (e.g., 'http://localhost:8000'). If provided, `vllm_server_host` "
@ -365,14 +365,14 @@ class OnlineDPOConfig(TrainingArguments):
"is not compatible with vLLM generation."
},
)
model_init_kwargs: Optional[dict[str, Any]] = field(
model_init_kwargs: dict[str, Any] | None = field(
default=None,
metadata={
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model "
"from a string."
},
)
reward_weights: Optional[list[float]] = field(
reward_weights: list[float] | None = field(
default=None,
metadata={
"help": "Weights for combining multiple reward functions. Must match the number of reward functions. "
@ -381,11 +381,11 @@ class OnlineDPOConfig(TrainingArguments):
)
# Deprecated parameters
dataset_num_proc: Optional[int] = field(
dataset_num_proc: int | None = field(
default=None,
metadata={"help": "Number of processes to use for processing the dataset."},
)
gpu_memory_utilization: Optional[float] = field(
gpu_memory_utilization: float | None = field(
default=None,
metadata={
"help": "This parameter is deprecated and will be removed in version 0.25.0. Please use "

View File

@ -16,10 +16,11 @@ import os
import re
import textwrap
import warnings
from collections.abc import Callable
from contextlib import nullcontext
from functools import wraps
from pathlib import Path
from typing import Any, Callable, Optional, Union
from typing import Any
import jinja2
import torch
@ -96,7 +97,7 @@ logger = logging.get_logger(__name__)
# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
RewardFunc = str | PreTrainedModel | Callable[[list, list], list[float]]
class OnlineDPOTrainer(BaseTrainer):
@ -104,7 +105,7 @@ class OnlineDPOTrainer(BaseTrainer):
Initialize OnlineDPOTrainer.
Args:
model (`Union[str, nn.Module, PreTrainedModel]`):
model (`str | nn.Module | PreTrainedModel`):
Model to be trained. Can be either:
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
@ -118,7 +119,7 @@ class OnlineDPOTrainer(BaseTrainer):
model.
judge (`BasePairwiseJudge`):
The judge to use for pairwise comparison of model completions.
reward_funcs (`Union[RewardFunc, list[RewardFunc]]`, *optional*):
reward_funcs (`RewardFunc | list[RewardFunc]`, *optional*):
Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
functions with the prompts and completions and sum the rewards. Can be either:
@ -135,13 +136,13 @@ class OnlineDPOTrainer(BaseTrainer):
sequences in the batch, given a dataset of paired sequences.
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
The dataset to use for training.
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Dataset | IterableDataset]`):
The dataset to use for evaluation.
processing_class ([`~transformers.PreTrainedTokenizerBase`] or [`~transformers.ProcessorMixin`], *optional*):
Processing class used to process the data. If provided, will be used to automatically process the inputs
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
reuse the fine-tuned model.
reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*):
reward_processing_classes (`PreTrainedTokenizerBase | list[PreTrainedTokenizerBase]`, *optional*):
Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
- A single processing class: Used when `reward_funcs` contains only one reward function.
@ -187,24 +188,24 @@ class OnlineDPOTrainer(BaseTrainer):
def __init__(
self,
model: Union[PreTrainedModel, nn.Module, str],
ref_model: Union[PreTrainedModel, nn.Module, None] = None,
reward_funcs: Optional[Union[RewardFunc, list[RewardFunc]]] = None,
judge: Optional[BasePairwiseJudge] = None,
args: Optional[OnlineDPOConfig] = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
processing_class: Optional[Union[PreTrainedTokenizerBase, ProcessorMixin]] = None,
reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
peft_config: Optional["PeftConfig"] = None,
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
callbacks: Optional[list[TrainerCallback]] = None,
model: PreTrainedModel | nn.Module | str,
ref_model: PreTrainedModel | nn.Module | None = None,
reward_funcs: RewardFunc | list[RewardFunc] | None = None,
judge: BasePairwiseJudge | None = None,
args: OnlineDPOConfig | None = None,
data_collator: DataCollator | None = None,
train_dataset: Dataset | IterableDataset | None = None,
eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None,
processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None,
reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None,
peft_config: "PeftConfig | None" = None,
compute_metrics: Callable[[EvalPrediction], dict] | None = None,
callbacks: list[TrainerCallback] | None = None,
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
# Deprecated parameters
reward_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
reward_processing_class: Optional[PreTrainedTokenizerBase] = None,
reward_model: PreTrainedModel | nn.Module | None = None,
reward_processing_class: PreTrainedTokenizerBase | None = None,
) -> None:
if ref_model is model:
raise ValueError(
@ -289,7 +290,7 @@ class OnlineDPOTrainer(BaseTrainer):
)
self.reward_processing_classes = []
for reward_processing_class_i, reward_func in zip(reward_processing_classes, reward_funcs):
for reward_processing_class_i, reward_func in zip(reward_processing_classes, reward_funcs, strict=True):
if isinstance(reward_func, PreTrainedModel):
if reward_processing_class_i is None:
reward_processing_class_i = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
@ -653,7 +654,7 @@ class OnlineDPOTrainer(BaseTrainer):
# Same as Trainer.get_eval_dataloader but skip the "remove_unused_columns".
@wraps(Trainer.get_eval_dataloader)
def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader:
def get_eval_dataloader(self, eval_dataset: str | Dataset | None = None) -> DataLoader:
if eval_dataset is None and self.eval_dataset is None:
raise ValueError("Trainer: evaluation requires an eval_dataset.")
@ -838,7 +839,7 @@ class OnlineDPOTrainer(BaseTrainer):
# Prepare vLLM inputs with images if available
if images is not None:
vllm_inputs = []
for prompt, image in zip(prompts_text, images):
for prompt, image in zip(prompts_text, images, strict=True):
if image is not None:
vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}})
else:
@ -968,7 +969,7 @@ class OnlineDPOTrainer(BaseTrainer):
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights([(name, param)])
def _fix_param_name_to_vllm(self, name, extra_prefixes: Optional[list[str]] = None):
def _fix_param_name_to_vllm(self, name, extra_prefixes: list[str] | None = None):
"""Clean parameter names for vLLM compatibility"""
extra_prefixes = extra_prefixes or []
prefixes = ["_checkpoint_wrapped_module."] + extra_prefixes
@ -977,7 +978,7 @@ class OnlineDPOTrainer(BaseTrainer):
return name
def process_vision_row(
self, features: dict[str, Union[list, torch.Tensor]], processing_class=None
self, features: dict[str, list | torch.Tensor], processing_class=None
) -> dict[str, list[int]]:
"""
Process a vision row for VLM models (adapted from DPO trainer)
@ -1165,15 +1166,15 @@ class OnlineDPOTrainer(BaseTrainer):
reward_kwargs["trainer_state"] = self.state
for i, (reward_func, reward_processing_class) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes)
zip(self.reward_funcs, self.reward_processing_classes, strict=True)
):
if isinstance(reward_func, nn.Module): # Model-based reward function
# Handle conversational vs text input
if is_conversational({"prompt": prompts[0]}):
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
messages = [{"messages": p + c} for p, c in zip(prompts, completions, strict=True)]
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
else:
texts = [p + c for p, c in zip(prompts, completions)]
texts = [p + c for p, c in zip(prompts, completions, strict=True)]
# Tokenize and get reward scores
reward_inputs = reward_processing_class(
@ -1237,7 +1238,7 @@ class OnlineDPOTrainer(BaseTrainer):
return logprobs
def training_step(
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch: int | None = None
) -> torch.Tensor:
model.train()
@ -1358,7 +1359,7 @@ class OnlineDPOTrainer(BaseTrainer):
completions = [template.render(messages=completion) for completion in completions]
ranks_of_first_completion = self.judge.judge(
prompts, list(zip(completions[:batch_size], completions[batch_size:]))
prompts, list(zip(completions[:batch_size], completions[batch_size:], strict=True))
)
# convert ranks to a True/False mask:

View File

@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import Any, Optional
from typing import Any
from transformers import TrainingArguments
@ -85,7 +85,7 @@ class ORPOConfig(TrainingArguments):
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)
bf16: Optional[bool] = field(
bf16: bool | None = field(
default=None,
metadata={
"help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA "
@ -94,18 +94,18 @@ class ORPOConfig(TrainingArguments):
},
)
max_length: Optional[int] = field(
max_length: int | None = field(
default=1024,
metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."},
)
max_prompt_length: Optional[int] = field(
max_prompt_length: int | None = field(
default=512,
metadata={
"help": "Maximum length of the prompt. This argument is required if you want to use the default data "
"collator and your model is an encoder-decoder."
},
)
max_completion_length: Optional[int] = field(
max_completion_length: int | None = field(
default=None,
metadata={
"help": "Maximum length of the completion. This argument is required if you want to use the default data "
@ -129,7 +129,7 @@ class ORPOConfig(TrainingArguments):
"help": "Label pad token id. This argument is required if you want to use the default data collator."
},
)
padding_value: Optional[int] = field(
padding_value: int | None = field(
default=None,
metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."},
)
@ -144,21 +144,21 @@ class ORPOConfig(TrainingArguments):
default=False,
metadata={"help": "If `True`, generates and logs completions from the model to W&B during evaluation."},
)
is_encoder_decoder: Optional[bool] = field(
is_encoder_decoder: bool | None = field(
default=None,
metadata={
"help": "When using the `model_init` argument (callable) to instantiate the model instead of the `model` "
"argument, you need to specify if the model returned by the callable is an encoder-decoder model."
},
)
model_init_kwargs: Optional[dict[str, Any]] = field(
model_init_kwargs: dict[str, Any] | None = field(
default=None,
metadata={
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model "
"from a string."
},
)
dataset_num_proc: Optional[int] = field(
dataset_num_proc: int | None = field(
default=None,
metadata={"help": "Number of processes to use for processing the dataset."},
)

View File

@ -16,9 +16,10 @@ import inspect
import random
import textwrap
from collections import defaultdict
from collections.abc import Callable
from contextlib import nullcontext
from pathlib import Path
from typing import Any, Callable, Literal, Optional, Union
from typing import Any, Literal
import numpy as np
import pandas as pd
@ -129,20 +130,22 @@ class ORPOTrainer(BaseTrainer):
def __init__(
self,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
args: Optional[ORPOConfig] = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
processing_class: Optional[
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
callbacks: Optional[list[TrainerCallback]] = None,
model: PreTrainedModel | nn.Module | str | None = None,
args: ORPOConfig | None = None,
data_collator: DataCollator | None = None,
train_dataset: Dataset | None = None,
eval_dataset: Dataset | dict[str, Dataset] | None = None,
processing_class: PreTrainedTokenizerBase
| BaseImageProcessor
| FeatureExtractionMixin
| ProcessorMixin
| None = None,
model_init: Callable[[], PreTrainedModel] | None = None,
callbacks: list[TrainerCallback] | None = None,
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
peft_config: Optional[dict] = None,
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
peft_config: dict | None = None,
compute_metrics: Callable[[EvalLoopOutput], dict] | None = None,
):
if args.model_init_kwargs is None:
model_init_kwargs = {}
@ -415,7 +418,7 @@ class ORPOTrainer(BaseTrainer):
attention_mask=answer_attention_mask,
)
def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
def tokenize_row(self, feature, model: PreTrainedModel | nn.Module | None = None) -> dict:
"""Tokenize a single row from a ORPO specific dataset.
At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt +
@ -463,7 +466,8 @@ class ORPOTrainer(BaseTrainer):
# Make sure prompts only have one different token at most an
# and length only differs by 1 at most
num_diff_tokens = sum(
a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])
a != b
for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"], strict=True)
)
num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
if num_diff_tokens > 1 or num_diff_len > 1:
@ -572,11 +576,11 @@ class ORPOTrainer(BaseTrainer):
@staticmethod
def concatenated_inputs(
batch: dict[str, Union[list, torch.LongTensor]],
batch: dict[str, list | torch.LongTensor],
is_encoder_decoder: bool = False,
label_pad_token_id: int = -100,
padding_value: int = 0,
device: Optional[torch.device] = None,
device: torch.device | None = None,
) -> dict[str, torch.LongTensor]:
"""Concatenate the chosen and rejected inputs into a single tensor.
@ -714,7 +718,7 @@ class ORPOTrainer(BaseTrainer):
return (per_token_logps * loss_mask).sum(-1)
def concatenated_forward(
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
self, model: nn.Module, batch: dict[str, list | torch.LongTensor]
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
@ -797,7 +801,7 @@ class ORPOTrainer(BaseTrainer):
def get_batch_loss_metrics(
self,
model,
batch: dict[str, Union[list, torch.LongTensor]],
batch: dict[str, list | torch.LongTensor],
train_eval: Literal["train", "eval"] = "train",
):
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
@ -851,11 +855,11 @@ class ORPOTrainer(BaseTrainer):
def compute_loss(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: dict[str, Union[torch.Tensor, Any]],
model: PreTrainedModel | nn.Module,
inputs: dict[str, torch.Tensor | Any],
return_outputs=False,
num_items_in_batch=None,
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
compute_loss_context_manager = (
autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
)
@ -898,10 +902,10 @@ class ORPOTrainer(BaseTrainer):
def prediction_step(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: dict[str, Union[torch.Tensor, Any]],
model: PreTrainedModel | nn.Module,
inputs: dict[str, torch.Tensor | Any],
prediction_loss_only: bool,
ignore_keys: Optional[list[str]] = None,
ignore_keys: list[str] | None = None,
):
if not self.use_dpo_data_collator:
logger.warning(
@ -946,8 +950,8 @@ class ORPOTrainer(BaseTrainer):
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[list[str]] = None,
prediction_loss_only: bool | None = None,
ignore_keys: list[str] | None = None,
metric_key_prefix: str = "eval",
) -> EvalLoopOutput:
"""
@ -973,7 +977,8 @@ class ORPOTrainer(BaseTrainer):
table = pd.DataFrame(
columns=["Prompt", "Policy"],
data=[
[prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
[prompt, pol[len(prompt) :]]
for prompt, pol in zip(random_batch["prompt"], policy_output_decoded, strict=True)
],
)
if "wandb" in self.args.report_to:
@ -992,7 +997,7 @@ class ORPOTrainer(BaseTrainer):
return initial_output
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.

View File

@ -14,7 +14,7 @@
import os
from dataclasses import dataclass, field
from typing import Literal, Optional
from typing import Literal
from ..trainer.utils import OnPolicyConfig
@ -76,11 +76,11 @@ class PPOConfig(OnPolicyConfig):
default="EleutherAI/pythia-160m",
metadata={"help": "Path to the reward model."},
)
model_adapter_name: Optional[str] = field(
model_adapter_name: str | None = field(
default=None,
metadata={"help": "Name of the train target PEFT adapter, when using LoRA with multiple adapters."},
)
ref_adapter_name: Optional[str] = field(
ref_adapter_name: str | None = field(
default=None,
metadata={"help": "Name of the reference PEFT adapter, when using LoRA with multiple adapters."},
)

View File

@ -20,7 +20,6 @@ import time
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from pathlib import Path
from typing import Optional, Union
import numpy as np
import pandas as pd
@ -146,18 +145,18 @@ class PPOTrainer(BaseTrainer):
def __init__(
self,
args: PPOConfig,
processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin],
processing_class: PreTrainedTokenizerBase | BaseImageProcessor | FeatureExtractionMixin | ProcessorMixin,
model: nn.Module,
ref_model: Optional[nn.Module],
ref_model: nn.Module | None,
reward_model: nn.Module,
train_dataset: Dataset,
value_model: nn.Module,
data_collator: Optional[DataCollatorWithPadding] = None,
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
data_collator: DataCollatorWithPadding | None = None,
eval_dataset: Dataset | dict[str, Dataset] | None = None,
# less commonly used
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
callbacks: Optional[list[TrainerCallback]] = None,
peft_config: Optional["PeftConfig"] = None,
callbacks: list[TrainerCallback] | None = None,
peft_config: "PeftConfig | None" = None,
) -> None:
if ref_model is model:
raise ValueError(
@ -371,7 +370,7 @@ class PPOTrainer(BaseTrainer):
if self.ref_adapter_name:
self.model.policy.set_adapter(self.model_adapter_name or "default")
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
def save_model(self, output_dir: str | None = None, _internal_call: bool = False):
backup_model = self.model
self.model = self.model.policy # save only the policy

View File

@ -13,7 +13,6 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
from transformers import TrainingArguments
@ -66,7 +65,7 @@ class PRMConfig(TrainingArguments):
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)
bf16: Optional[bool] = field(
bf16: bool | None = field(
default=None,
metadata={
"help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA "
@ -75,15 +74,15 @@ class PRMConfig(TrainingArguments):
},
)
max_length: Optional[int] = field(
max_length: int | None = field(
default=1024,
metadata={"help": "Maximum length of the sequences (prompt + completion) used for truncation."},
)
max_prompt_length: Optional[int] = field(
max_prompt_length: int | None = field(
default=512,
metadata={"help": "Maximum length of the prompt used for truncation."},
)
max_completion_length: Optional[int] = field(
max_completion_length: int | None = field(
default=None,
metadata={
"help": "Maximum length of the completion used for truncation. The completion is the concatenation of the "
@ -102,7 +101,7 @@ class PRMConfig(TrainingArguments):
default=False,
metadata={"help": "Whether to train only on the last step."},
)
dataset_num_proc: Optional[int] = field(
dataset_num_proc: int | None = field(
default=None,
metadata={"help": "Number of processes to use for processing the dataset."},
)

View File

@ -13,9 +13,9 @@
# limitations under the License.
import textwrap
from collections.abc import Callable
from itertools import chain
from pathlib import Path
from typing import Callable, Optional, Union
import torch
import torch.nn as nn
@ -99,23 +99,25 @@ class PRMTrainer(BaseTrainer):
def __init__(
self,
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
args: Optional[PRMConfig] = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
processing_class: Optional[
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
callbacks: Optional[list[TrainerCallback]] = None,
model: PreTrainedModel | nn.Module | None = None,
args: PRMConfig | None = None,
data_collator: DataCollator | None = None,
train_dataset: Dataset | None = None,
eval_dataset: Dataset | dict[str, Dataset] | None = None,
processing_class: PreTrainedTokenizerBase
| BaseImageProcessor
| FeatureExtractionMixin
| ProcessorMixin
| None = None,
model_init: Callable[[], PreTrainedModel] | None = None,
compute_metrics: Callable[[EvalPrediction], dict] | None = None,
callbacks: list[TrainerCallback] | None = None,
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
None,
None,
),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
peft_config: Optional[dict] = None,
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
peft_config: dict | None = None,
):
if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)):
model = prepare_peft_model(model, peft_config, args)
@ -263,7 +265,9 @@ class PRMTrainer(BaseTrainer):
completions_ids = [completion + separator_ids for completion in completions_ids]
# Create the label
labels = [[-100] * (len(completion) - 1) + [label] for completion, label in zip(completions_ids, labels)]
labels = [
[-100] * (len(completion) - 1) + [label] for completion, label in zip(completions_ids, labels, strict=True)
]
# Join the completions and labels steps
completion_ids = list(chain(*completions_ids))

View File

@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import dataclass, field
from typing import Any, Optional
from typing import Any
from transformers import TrainingArguments
@ -92,7 +92,7 @@ class RewardConfig(TrainingArguments):
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)
bf16: Optional[bool] = field(
bf16: bool | None = field(
default=None,
metadata={
"help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA "
@ -102,14 +102,14 @@ class RewardConfig(TrainingArguments):
)
# Parameters that control the model
model_init_kwargs: Optional[dict[str, Any]] = field(
model_init_kwargs: dict[str, Any] | None = field(
default=None,
metadata={
"help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of "
"the `RewardTrainer` is provided as a string."
},
)
chat_template_path: Optional[str] = field(
chat_template_path: str | None = field(
default=None,
metadata={
"help": "If specified, sets the model's chat template. This can either be the path to a tokenizer (local "
@ -124,37 +124,37 @@ class RewardConfig(TrainingArguments):
)
# Parameters that control the data preprocessing
dataset_num_proc: Optional[int] = field(
dataset_num_proc: int | None = field(
default=None,
metadata={"help": "Number of processes to use for processing the dataset."},
)
eos_token: Optional[str] = field(
eos_token: str | None = field(
default=None,
metadata={
"help": "Token used to indicate the end of a turn or sequence. If `None`, it defaults to `processing_class.eos_token`."
},
)
pad_token: Optional[str] = field(
pad_token: str | None = field(
default=None,
metadata={
"help": "Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that "
"is also `None`, it falls back to `processing_class.eos_token`."
},
)
max_length: Optional[int] = field(
max_length: int | None = field(
default=1024,
metadata={
"help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from"
"the right. If `None`, no truncation is applied."
},
)
pad_to_multiple_of: Optional[int] = field(
pad_to_multiple_of: int | None = field(
default=None,
metadata={"help": "If set, the sequences will be padded to a multiple of this value."},
)
# Parameters that control the training
center_rewards_coefficient: Optional[float] = field(
center_rewards_coefficient: float | None = field(
default=None,
metadata={
"help": "Coefficient to incentivize the reward model to output mean-zero rewards (proposed by "

Some files were not shown because too many files have changed in this diff Show More