mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Compare commits
21 Commits
7e9c6e45d5
...
py3.14
Author | SHA1 | Date | |
---|---|---|---|
a6263a5041 | |||
a5ca7d4ba7 | |||
cfcec4af86 | |||
d66ea247dc | |||
c97bb24098 | |||
88eee87e11 | |||
a33e642f16 | |||
cbb41f7366 | |||
68959ad9ea | |||
b07df79a92 | |||
b691f39bef | |||
b41bcbdeb7 | |||
dd56aaad40 | |||
0ccfe5df9b | |||
5004c95c12 | |||
d7fe889a3f | |||
4e239a6122 | |||
b991fd4a87 | |||
73107966ed | |||
f69c919b98 | |||
0d54019980 |
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@ -36,7 +36,7 @@ jobs:
|
||||
name: Tests
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
|
||||
python-version: ['3.10', '3.11', '3.12', '3.13']
|
||||
fail-fast: false
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
|
@ -1,6 +1,6 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.11.10
|
||||
rev: v0.13.3
|
||||
hooks:
|
||||
- id: ruff-check
|
||||
types_or: [ python, pyi ]
|
||||
|
@ -315,24 +315,6 @@ def replicate_str(string: str, n: int, sep: str = " ") -> str:
|
||||
* **Definite Articles:** Removed definite articles where possible to streamline language. (Eg: Changed "The string to replicate" to "String to replicate")
|
||||
* **Type Annotations:**
|
||||
* Always include type definitions, indicating if a parameter is optional and specifying the default value.
|
||||
* Note that `Optional` means that the value can be `None`, and `*optional*` means that it is not required for the user to pass a value.
|
||||
E.g., for arguments that can't be `None` and aren't required:
|
||||
|
||||
```python
|
||||
foo (`int`, *optional*, defaults to `4`):
|
||||
```
|
||||
|
||||
For arguments that can be `None` and are required:
|
||||
|
||||
```python
|
||||
foo (`Optional[int]`):
|
||||
```
|
||||
|
||||
for arguments that can be `None` and aren't required:
|
||||
|
||||
```python
|
||||
foo (`Optional[int]`, *optional*):
|
||||
```
|
||||
|
||||
* **String Defaults:**
|
||||
* Ensured that default string values are wrapped in double quotes:
|
||||
|
@ -143,7 +143,7 @@ For reinforcement learning, the blog uses a math reasoning task that we can repr
|
||||
```python
|
||||
def strip_reasoning_accuracy_reward(
|
||||
completions: list[list[dict[str, str]]], solution: list[str], **kwargs
|
||||
) -> list[Optional[float]]:
|
||||
) -> list[float | None]:
|
||||
"""Reward function that strips reasoning tags and checks mathematical accuracy.
|
||||
|
||||
This function:
|
||||
|
@ -14,7 +14,6 @@
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import ModelCard
|
||||
@ -42,7 +41,7 @@ class ScriptArguments:
|
||||
repo_id: str = field(
|
||||
default="trl-lib/hh-rlhf-helpful-base", metadata={"help": "Hugging Face repository ID to push the dataset to."}
|
||||
)
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
dataset_num_proc: int | None = field(
|
||||
default=None, metadata={"help": "Number of workers to use for dataset processing."}
|
||||
)
|
||||
|
||||
@ -50,7 +49,7 @@ class ScriptArguments:
|
||||
def common_start(str1: str, str2: str) -> str:
|
||||
# Zip the two strings and iterate over them together
|
||||
common_chars = []
|
||||
for c1, c2 in zip(str1, str2):
|
||||
for c1, c2 in zip(str1, str2, strict=True):
|
||||
if c1 == c2:
|
||||
common_chars.append(c1)
|
||||
else:
|
||||
|
@ -14,7 +14,6 @@
|
||||
|
||||
import ast
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import ModelCard
|
||||
@ -43,7 +42,7 @@ class ScriptArguments:
|
||||
default="trl-lib/llava-instruct-mix",
|
||||
metadata={"help": "Hugging Face repository ID to push the dataset to."},
|
||||
)
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
dataset_num_proc: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of workers to use for dataset processing."},
|
||||
)
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import ModelCard
|
||||
@ -42,7 +41,7 @@ class ScriptArguments:
|
||||
default="trl-lib/lm-human-preferences-descriptiveness",
|
||||
metadata={"help": "Hugging Face repository ID to push the dataset to."},
|
||||
)
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
dataset_num_proc: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of workers to use for dataset processing."},
|
||||
)
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import ModelCard
|
||||
@ -42,7 +41,7 @@ class ScriptArguments:
|
||||
default="trl-lib/lm-human-preferences-sentiment",
|
||||
metadata={"help": "Hugging Face repository ID to push the dataset to."},
|
||||
)
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
dataset_num_proc: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of workers to use for dataset processing."},
|
||||
)
|
||||
|
@ -15,7 +15,6 @@
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import chain
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import ModelCard
|
||||
@ -44,7 +43,7 @@ class ScriptArguments:
|
||||
default="trl-lib/math_shepherd",
|
||||
metadata={"help": "Hugging Face repository ID to push the dataset to."},
|
||||
)
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
dataset_num_proc: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of workers to use for dataset processing."},
|
||||
)
|
||||
@ -64,7 +63,7 @@ def process_example(example):
|
||||
labels = [example["label"][idx] == "+" for idx in indexes]
|
||||
|
||||
# Split the inputs into steps (caution, the first step is missing here, it is the prompt)
|
||||
steps = [inputs[i:j] for i, j in zip(chain([0], indexes), chain(indexes, [None]))]
|
||||
steps = [inputs[i:j] for i, j in zip(chain([0], indexes), chain(indexes, [None]), strict=True)]
|
||||
|
||||
# Remove the last step (single ⶻ)
|
||||
steps = steps[:-1]
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import ModelCard
|
||||
@ -42,7 +41,7 @@ class ScriptArguments:
|
||||
default="trl-lib/prm800k",
|
||||
metadata={"help": "Hugging Face repository ID to push the dataset to."},
|
||||
)
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
dataset_num_proc: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of workers to use for dataset processing."},
|
||||
)
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import features, load_dataset
|
||||
from huggingface_hub import ModelCard
|
||||
@ -42,7 +41,7 @@ class ScriptArguments:
|
||||
default="trl-lib/rlaif-v",
|
||||
metadata={"help": "Hugging Face repository ID to push the dataset to."},
|
||||
)
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
dataset_num_proc: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of workers to use for dataset processing."},
|
||||
)
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import ModelCard
|
||||
@ -42,7 +41,7 @@ class ScriptArguments:
|
||||
default="trl-lib/tldr",
|
||||
metadata={"help": "Hugging Face repository ID to push the dataset to."},
|
||||
)
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
dataset_num_proc: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of workers to use for dataset processing."},
|
||||
)
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import ModelCard
|
||||
@ -42,7 +41,7 @@ class ScriptArguments:
|
||||
default="trl-lib/tldr-preference",
|
||||
metadata={"help": "Hugging Face repository ID to push the dataset to."},
|
||||
)
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
dataset_num_proc: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of workers to use for dataset processing."},
|
||||
)
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import ModelCard
|
||||
@ -42,7 +41,7 @@ class ScriptArguments:
|
||||
default="trl-lib/ultrafeedback-prompt",
|
||||
metadata={"help": "Hugging Face repository ID to push the dataset to."},
|
||||
)
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
dataset_num_proc: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of workers to use for dataset processing."},
|
||||
)
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import ModelCard
|
||||
@ -79,7 +78,7 @@ class ScriptArguments:
|
||||
default="trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness",
|
||||
metadata={"help": "Hugging Face repository ID to push the dataset to."},
|
||||
)
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
dataset_num_proc: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of workers to use for dataset processing."},
|
||||
)
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from peft import PeftConfig, PeftModel
|
||||
@ -27,9 +26,9 @@ class ScriptArguments:
|
||||
merged model.
|
||||
"""
|
||||
|
||||
adapter_model_name: Optional[str] = field(default=None, metadata={"help": "the adapter name"})
|
||||
base_model_name: Optional[str] = field(default=None, metadata={"help": "the base model name"})
|
||||
output_name: Optional[str] = field(default=None, metadata={"help": "the merged model name"})
|
||||
adapter_model_name: str | None = field(default=None, metadata={"help": "the adapter name"})
|
||||
base_model_name: str | None = field(default=None, metadata={"help": "the base model name"})
|
||||
output_name: str | None = field(default=None, metadata={"help": "the merged model name"})
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import evaluate
|
||||
import numpy as np
|
||||
@ -41,70 +41,70 @@ class ScriptArguments:
|
||||
These arguments vary depending on how many GPUs you have, what their capacity and features are, and what size model you want to train.
|
||||
"""
|
||||
|
||||
local_rank: Optional[int] = field(default=-1, metadata={"help": "Used for multi-gpu"})
|
||||
resume_from_checkpoint: Optional[bool] = field(
|
||||
local_rank: int | None = field(default=-1, metadata={"help": "Used for multi-gpu"})
|
||||
resume_from_checkpoint: bool | None = field(
|
||||
default=False,
|
||||
metadata={"help": "If you want to resume training where it left off."},
|
||||
)
|
||||
deepspeed: Optional[str] = field(
|
||||
deepspeed: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Path to deepspeed config if using deepspeed. You may need this if the model that you want to train doesn't fit on a single GPU."
|
||||
},
|
||||
)
|
||||
per_device_train_batch_size: Optional[int] = field(default=4)
|
||||
per_device_eval_batch_size: Optional[int] = field(default=1)
|
||||
gradient_accumulation_steps: Optional[int] = field(default=1)
|
||||
learning_rate: Optional[float] = field(default=2e-5)
|
||||
weight_decay: Optional[float] = field(default=0.001)
|
||||
model_name: Optional[str] = field(
|
||||
per_device_train_batch_size: int | None = field(default=4)
|
||||
per_device_eval_batch_size: int | None = field(default=1)
|
||||
gradient_accumulation_steps: int | None = field(default=1)
|
||||
learning_rate: float | None = field(default=2e-5)
|
||||
weight_decay: float | None = field(default=0.001)
|
||||
model_name: str | None = field(
|
||||
default="gpt2",
|
||||
metadata={
|
||||
"help": "The model that you want to train from the Hugging Face hub. E.g. gpt2, gpt2-xl, bert, etc."
|
||||
},
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
tokenizer_name: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The tokenizer for your model, if left empty will use the default for your model",
|
||||
},
|
||||
)
|
||||
bf16: Optional[bool] = field(
|
||||
bf16: bool | None = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "This essentially cuts the training time in half if you want to sacrifice a little precision and have a supported GPU."
|
||||
},
|
||||
)
|
||||
num_train_epochs: Optional[int] = field(
|
||||
num_train_epochs: int | None = field(
|
||||
default=1,
|
||||
metadata={"help": "The number of training epochs for the reward model."},
|
||||
)
|
||||
train_subset: Optional[int] = field(
|
||||
train_subset: int | None = field(
|
||||
default=100000,
|
||||
metadata={"help": "The size of the subset of the training data to use"},
|
||||
)
|
||||
eval_subset: Optional[int] = field(
|
||||
eval_subset: int | None = field(
|
||||
default=50000,
|
||||
metadata={"help": "The size of the subset of the eval data to use"},
|
||||
)
|
||||
gradient_checkpointing: Optional[bool] = field(
|
||||
gradient_checkpointing: bool | None = field(
|
||||
default=False,
|
||||
metadata={"help": "Enables gradient checkpointing."},
|
||||
)
|
||||
optim: Optional[str] = field(
|
||||
optim: str | None = field(
|
||||
default="adamw_hf",
|
||||
metadata={"help": "The optimizer to use."},
|
||||
)
|
||||
lr_scheduler_type: Optional[str] = field(
|
||||
lr_scheduler_type: str | None = field(
|
||||
default="linear",
|
||||
metadata={"help": "The lr scheduler"},
|
||||
)
|
||||
max_length: Optional[int] = field(default=512)
|
||||
eval_first_step: Optional[bool] = field(
|
||||
max_length: int | None = field(default=512)
|
||||
eval_first_step: bool | None = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to run eval after the first step"},
|
||||
)
|
||||
seed: Optional[int] = field(
|
||||
seed: int | None = field(
|
||||
default=0, metadata={"help": "Random seed that will be set at the beginning of training."}
|
||||
)
|
||||
|
||||
@ -189,7 +189,9 @@ def preprocess_function(examples):
|
||||
"input_ids_k": [],
|
||||
"attention_mask_k": [],
|
||||
}
|
||||
for question, response_j, response_k in zip(examples["question"], examples["response_j"], examples["response_k"]):
|
||||
for question, response_j, response_k in zip(
|
||||
examples["question"], examples["response_j"], examples["response_k"], strict=True
|
||||
):
|
||||
tokenized_j = tokenizer("Question: " + question + "\n\nAnswer: " + response_j, truncation=True)
|
||||
tokenized_k = tokenizer("Question: " + question + "\n\nAnswer: " + response_k, truncation=True)
|
||||
|
||||
@ -229,8 +231,8 @@ eval_dataset = eval_dataset.filter(
|
||||
@dataclass
|
||||
class RewardDataCollatorWithPadding:
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
padding: bool | str | PaddingStrategy = True
|
||||
pad_to_multiple_of: int | None = None
|
||||
return_tensors: str = "pt"
|
||||
|
||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
@ -37,37 +36,37 @@ class ScriptArguments:
|
||||
|
||||
# NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode
|
||||
# models like gpt-neo* models are more suitable.
|
||||
model_name: Optional[str] = field(default="", metadata={"help": "the model name"})
|
||||
tokenizer_name: Optional[str] = field(default="", metadata={"help": "the tokenizer name"})
|
||||
reward_model_name: Optional[str] = field(default="", metadata={"help": "the reward model name"})
|
||||
log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
|
||||
learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"})
|
||||
output_max_length: Optional[int] = field(default=128, metadata={"help": "maximum length for generation"})
|
||||
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"})
|
||||
batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"})
|
||||
ppo_epochs: Optional[int] = field(default=4, metadata={"help": "the number of ppo epochs"})
|
||||
gradient_accumulation_steps: Optional[int] = field(
|
||||
model_name: str | None = field(default="", metadata={"help": "the model name"})
|
||||
tokenizer_name: str | None = field(default="", metadata={"help": "the tokenizer name"})
|
||||
reward_model_name: str | None = field(default="", metadata={"help": "the reward model name"})
|
||||
log_with: str | None = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
|
||||
learning_rate: float | None = field(default=1.41e-5, metadata={"help": "the learning rate"})
|
||||
output_max_length: int | None = field(default=128, metadata={"help": "maximum length for generation"})
|
||||
mini_batch_size: int | None = field(default=1, metadata={"help": "the PPO minibatch size"})
|
||||
batch_size: int | None = field(default=32, metadata={"help": "the batch size"})
|
||||
ppo_epochs: int | None = field(default=4, metadata={"help": "the number of ppo epochs"})
|
||||
gradient_accumulation_steps: int | None = field(
|
||||
default=4, metadata={"help": "the number of gradient accumulation steps"}
|
||||
)
|
||||
adafactor: Optional[bool] = field(default=False, metadata={"help": "whether to use the adafactor optimizer"})
|
||||
early_stopping: Optional[bool] = field(default=False, metadata={"help": "whether to early stop"})
|
||||
target_kl: Optional[float] = field(default=0.1, metadata={"help": "kl target for early stopping"})
|
||||
reward_baseline: Optional[float] = field(
|
||||
adafactor: bool | None = field(default=False, metadata={"help": "whether to use the adafactor optimizer"})
|
||||
early_stopping: bool | None = field(default=False, metadata={"help": "whether to early stop"})
|
||||
target_kl: float | None = field(default=0.1, metadata={"help": "kl target for early stopping"})
|
||||
reward_baseline: float | None = field(
|
||||
default=0.0,
|
||||
metadata={"help": "a baseline value that is subtracted from the reward"},
|
||||
)
|
||||
batched_gen: Optional[bool] = field(default=False, metadata={"help": "whether to use the batched text gen"})
|
||||
save_freq: Optional[int] = field(default=None, metadata={"help": "n steps to save the model"})
|
||||
output_dir: Optional[str] = field(default="runs/", metadata={"help": "n steps to save the model"})
|
||||
seed: Optional[int] = field(default=0, metadata={"help": "the seed"})
|
||||
steps: Optional[int] = field(default=20000, metadata={"help": "number of epochs"})
|
||||
init_kl_coef: Optional[float] = field(
|
||||
batched_gen: bool | None = field(default=False, metadata={"help": "whether to use the batched text gen"})
|
||||
save_freq: int | None = field(default=None, metadata={"help": "n steps to save the model"})
|
||||
output_dir: str | None = field(default="runs/", metadata={"help": "n steps to save the model"})
|
||||
seed: int | None = field(default=0, metadata={"help": "the seed"})
|
||||
steps: int | None = field(default=20000, metadata={"help": "number of epochs"})
|
||||
init_kl_coef: float | None = field(
|
||||
default=0.2,
|
||||
metadata={"help": "Initial KL penalty coefficient (used for adaptive and linear control)"},
|
||||
)
|
||||
|
||||
adap_kl_ctrl: Optional[bool] = field(default=True, metadata={"help": "Use adaptive KL control, otherwise linear"})
|
||||
load_in_8bit: Optional[bool] = field(default=True, metadata={"help": "whether to load the model in 8bit"})
|
||||
adap_kl_ctrl: bool | None = field(default=True, metadata={"help": "Use adaptive KL control, otherwise linear"})
|
||||
load_in_8bit: bool | None = field(default=True, metadata={"help": "whether to load the model in 8bit"})
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
@ -258,7 +257,7 @@ for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
|
||||
batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)
|
||||
|
||||
# Compute reward score (using the sentiment analysis pipeline)
|
||||
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
|
||||
texts = [q + r for q, r in zip(batch["query"], batch["response"], strict=True)]
|
||||
pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
|
||||
rewards = [torch.tensor(output[0]["score"] - script_args.reward_baseline) for output in pipe_outputs]
|
||||
|
||||
|
@ -70,7 +70,7 @@ def chars_token_ratio(dataset, tokenizer, nb_examples=400):
|
||||
Estimate the average number of characters per token in the dataset.
|
||||
"""
|
||||
total_characters, total_tokens = 0, 0
|
||||
for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
|
||||
for _, example in tqdm(zip(range(nb_examples), iter(dataset), strict=True), total=nb_examples):
|
||||
text = prepare_sample_text(example)
|
||||
total_characters += len(text)
|
||||
if tokenizer.is_fast:
|
||||
|
@ -15,7 +15,6 @@
|
||||
# 0. imports
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
@ -34,52 +33,52 @@ class ScriptArguments:
|
||||
"""
|
||||
|
||||
# data parameters
|
||||
beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
|
||||
beta: float | None = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
|
||||
|
||||
# training parameters
|
||||
model_name_or_path: Optional[str] = field(
|
||||
model_name_or_path: str | None = field(
|
||||
default="../sft/results/final_checkpoint",
|
||||
metadata={"help": "the location of the SFT model name or path"},
|
||||
)
|
||||
learning_rate: Optional[float] = field(default=5e-4, metadata={"help": "optimizer learning rate"})
|
||||
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"})
|
||||
warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"})
|
||||
weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"})
|
||||
optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"})
|
||||
learning_rate: float | None = field(default=5e-4, metadata={"help": "optimizer learning rate"})
|
||||
lr_scheduler_type: str | None = field(default="cosine", metadata={"help": "the lr scheduler type"})
|
||||
warmup_steps: int | None = field(default=100, metadata={"help": "the number of warmup steps"})
|
||||
weight_decay: float | None = field(default=0.05, metadata={"help": "the weight decay"})
|
||||
optimizer_type: str | None = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"})
|
||||
|
||||
per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "train batch size per device"})
|
||||
per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"})
|
||||
gradient_accumulation_steps: Optional[int] = field(
|
||||
per_device_train_batch_size: int | None = field(default=4, metadata={"help": "train batch size per device"})
|
||||
per_device_eval_batch_size: int | None = field(default=1, metadata={"help": "eval batch size per device"})
|
||||
gradient_accumulation_steps: int | None = field(
|
||||
default=4, metadata={"help": "the number of gradient accumulation steps"}
|
||||
)
|
||||
gradient_checkpointing: Optional[bool] = field(
|
||||
gradient_checkpointing: bool | None = field(
|
||||
default=True, metadata={"help": "whether to use gradient checkpointing"}
|
||||
)
|
||||
|
||||
gradient_checkpointing_use_reentrant: Optional[bool] = field(
|
||||
gradient_checkpointing_use_reentrant: bool | None = field(
|
||||
default=False, metadata={"help": "whether to use reentrant for gradient checkpointing"}
|
||||
)
|
||||
|
||||
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
|
||||
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
|
||||
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
|
||||
lora_alpha: float | None = field(default=16, metadata={"help": "the lora alpha parameter"})
|
||||
lora_dropout: float | None = field(default=0.05, metadata={"help": "the lora dropout parameter"})
|
||||
lora_r: int | None = field(default=8, metadata={"help": "the lora r parameter"})
|
||||
|
||||
max_prompt_length: Optional[int] = field(default=512, metadata={"help": "the maximum prompt length"})
|
||||
max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"})
|
||||
max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"})
|
||||
logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"})
|
||||
save_steps: Optional[int] = field(default=100, metadata={"help": "the saving frequency"})
|
||||
eval_steps: Optional[int] = field(default=100, metadata={"help": "the evaluation frequency"})
|
||||
max_prompt_length: int | None = field(default=512, metadata={"help": "the maximum prompt length"})
|
||||
max_length: int | None = field(default=1024, metadata={"help": "the maximum sequence length"})
|
||||
max_steps: int | None = field(default=1000, metadata={"help": "max number of training steps"})
|
||||
logging_steps: int | None = field(default=10, metadata={"help": "the logging frequency"})
|
||||
save_steps: int | None = field(default=100, metadata={"help": "the saving frequency"})
|
||||
eval_steps: int | None = field(default=100, metadata={"help": "the evaluation frequency"})
|
||||
|
||||
output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"})
|
||||
log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"})
|
||||
load_in_4bit: Optional[bool] = field(default=True, metadata={"help": "whether to load the model in 4bit"})
|
||||
model_dtype: Optional[str] = field(
|
||||
output_dir: str | None = field(default="./results", metadata={"help": "the output directory"})
|
||||
log_freq: int | None = field(default=1, metadata={"help": "the logging frequency"})
|
||||
load_in_4bit: bool | None = field(default=True, metadata={"help": "whether to load the model in 4bit"})
|
||||
model_dtype: str | None = field(
|
||||
default="float16", metadata={"help": "model_dtype[float16, bfloat16, float] for loading."}
|
||||
)
|
||||
|
||||
# instrumentation
|
||||
report_to: Optional[str] = field(
|
||||
report_to: str | None = field(
|
||||
default="wandb",
|
||||
metadata={
|
||||
"help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
|
||||
@ -88,21 +87,21 @@ class ScriptArguments:
|
||||
},
|
||||
)
|
||||
# debug argument for distributed training
|
||||
ignore_bias_buffers: Optional[bool] = field(
|
||||
ignore_bias_buffers: bool | None = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
|
||||
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
|
||||
},
|
||||
)
|
||||
seed: Optional[int] = field(
|
||||
seed: int | None = field(
|
||||
default=0, metadata={"help": "Random seed that will be set at the beginning of training."}
|
||||
)
|
||||
|
||||
|
||||
def get_stack_exchange_paired(
|
||||
data_dir: str = "data/rl",
|
||||
cache_dir: Optional[str] = None,
|
||||
cache_dir: str | None = None,
|
||||
num_proc=24,
|
||||
) -> Dataset:
|
||||
"""Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format.
|
||||
|
@ -15,7 +15,6 @@
|
||||
# Fine-Tune Llama2-7b on SE paired dataset
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
@ -38,21 +37,21 @@ from trl.trainer import ConstantLengthDataset
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_name: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the model name"})
|
||||
dataset_name: Optional[str] = field(default="lvwerra/stack-exchange-paired", metadata={"help": "the dataset name"})
|
||||
subset: Optional[str] = field(default="data/finetune", metadata={"help": "the subset to use"})
|
||||
split: Optional[str] = field(default="train", metadata={"help": "the split to use"})
|
||||
size_valid_set: Optional[int] = field(default=4000, metadata={"help": "the size of the validation set"})
|
||||
streaming: Optional[bool] = field(default=True, metadata={"help": "whether to stream the dataset"})
|
||||
shuffle_buffer: Optional[int] = field(default=5000, metadata={"help": "the shuffle buffer size"})
|
||||
seq_length: Optional[int] = field(default=1024, metadata={"help": "the sequence length"})
|
||||
num_workers: Optional[int] = field(default=4, metadata={"help": "the number of workers"})
|
||||
use_bnb: Optional[bool] = field(default=True, metadata={"help": "whether to use BitsAndBytes"})
|
||||
model_name: str | None = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the model name"})
|
||||
dataset_name: str | None = field(default="lvwerra/stack-exchange-paired", metadata={"help": "the dataset name"})
|
||||
subset: str | None = field(default="data/finetune", metadata={"help": "the subset to use"})
|
||||
split: str | None = field(default="train", metadata={"help": "the split to use"})
|
||||
size_valid_set: int | None = field(default=4000, metadata={"help": "the size of the validation set"})
|
||||
streaming: bool | None = field(default=True, metadata={"help": "whether to stream the dataset"})
|
||||
shuffle_buffer: int | None = field(default=5000, metadata={"help": "the shuffle buffer size"})
|
||||
seq_length: int | None = field(default=1024, metadata={"help": "the sequence length"})
|
||||
num_workers: int | None = field(default=4, metadata={"help": "the number of workers"})
|
||||
use_bnb: bool | None = field(default=True, metadata={"help": "whether to use BitsAndBytes"})
|
||||
|
||||
# LoraConfig
|
||||
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
|
||||
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
|
||||
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
|
||||
lora_alpha: float | None = field(default=16, metadata={"help": "the lora alpha parameter"})
|
||||
lora_dropout: float | None = field(default=0.05, metadata={"help": "the lora dropout parameter"})
|
||||
lora_r: int | None = field(default=8, metadata={"help": "the lora r parameter"})
|
||||
|
||||
|
||||
parser = HfArgumentParser((ScriptArguments, SFTConfig))
|
||||
@ -82,7 +81,7 @@ def chars_token_ratio(dataset, tokenizer, nb_examples=400):
|
||||
Estimate the average number of characters per token in the dataset.
|
||||
"""
|
||||
total_characters, total_tokens = 0, 0
|
||||
for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
|
||||
for _, example in tqdm(zip(range(nb_examples), iter(dataset), strict=True), total=nb_examples):
|
||||
text = prepare_sample_text(example)
|
||||
total_characters += len(text)
|
||||
if tokenizer.is_fast:
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
@ -65,15 +64,15 @@ class ScriptArguments:
|
||||
|
||||
# NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode
|
||||
# models like gpt-neo* models are more suitable.
|
||||
model_name: Optional[str] = field(default="ybelkada/gpt-j-6b-sharded-bf16", metadata={"help": "the model name"})
|
||||
log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
|
||||
learning_rate: Optional[float] = field(default=(1.47e-5) * 2, metadata={"help": "the learning rate"})
|
||||
mini_batch_size: Optional[int] = field(default=4, metadata={"help": "the PPO minibatch size"})
|
||||
batch_size: Optional[int] = field(default=16, metadata={"help": "the batch size"})
|
||||
gradient_accumulation_steps: Optional[int] = field(
|
||||
model_name: str | None = field(default="ybelkada/gpt-j-6b-sharded-bf16", metadata={"help": "the model name"})
|
||||
log_with: str | None = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
|
||||
learning_rate: float | None = field(default=(1.47e-5) * 2, metadata={"help": "the learning rate"})
|
||||
mini_batch_size: int | None = field(default=4, metadata={"help": "the PPO minibatch size"})
|
||||
batch_size: int | None = field(default=16, metadata={"help": "the batch size"})
|
||||
gradient_accumulation_steps: int | None = field(
|
||||
default=1, metadata={"help": "the number of gradient accumulation steps"}
|
||||
)
|
||||
model_save_path: Optional[str] = field(
|
||||
model_save_path: str | None = field(
|
||||
default="./gpt-j-6B-detoxified-long-context-26-shl-1e4-final",
|
||||
metadata={"help": "the path to save the model"},
|
||||
)
|
||||
|
@ -19,7 +19,6 @@
|
||||
# ///
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import HfArgumentParser
|
||||
@ -74,7 +73,7 @@ class ScriptArguments:
|
||||
"'meta-llama/Meta-Llama-3-70B-Instruct'."
|
||||
},
|
||||
)
|
||||
num_examples: Optional[int] = field(default=None, metadata={"help": "Number of examples to evaluate."})
|
||||
num_examples: int | None = field(default=None, metadata={"help": "Number of examples to evaluate."})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -103,7 +102,7 @@ if __name__ == "__main__":
|
||||
else:
|
||||
judge = HfPairwiseJudge(script_args.judge_model)
|
||||
|
||||
completions = [[c0, c1] for c0, c1 in zip(reference_completions, model_completions)]
|
||||
completions = [[c0, c1] for c0, c1 in zip(reference_completions, model_completions, strict=True)]
|
||||
best_idxs = judge.judge(prompts, completions)
|
||||
model_win_rate = best_idxs.count(1) / len(best_idxs)
|
||||
print(f"Model win rate: {model_win_rate * 100:.2f}%")
|
||||
|
@ -159,7 +159,7 @@ if __name__ == "__main__":
|
||||
"""
|
||||
rewards = []
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
for content, sol in zip(contents, solution):
|
||||
for content, sol in zip(contents, solution, strict=True):
|
||||
try:
|
||||
gold_parsed = parse(sol, extraction_mode="first_match")
|
||||
except Exception:
|
||||
|
@ -130,7 +130,7 @@ if __name__ == "__main__":
|
||||
"""
|
||||
rewards = []
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
for content, sol in zip(contents, solution):
|
||||
for content, sol in zip(contents, solution, strict=True):
|
||||
try:
|
||||
gold_parsed = parse(sol, extraction_mode="first_match")
|
||||
except Exception:
|
||||
|
@ -146,7 +146,7 @@ if __name__ == "__main__":
|
||||
"""
|
||||
rewards = []
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
for content, sol in zip(contents, solution):
|
||||
for content, sol in zip(contents, solution, strict=True):
|
||||
try:
|
||||
gold_parsed = parse(sol, extraction_mode="first_match")
|
||||
except Exception:
|
||||
|
@ -202,7 +202,7 @@ if __name__ == "__main__":
|
||||
"""
|
||||
rewards = []
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
for content, sol in zip(contents, solution):
|
||||
for content, sol in zip(contents, solution, strict=True):
|
||||
try:
|
||||
gold_parsed = parse(sol, extraction_mode="first_match")
|
||||
except Exception:
|
||||
|
@ -75,7 +75,7 @@ def main():
|
||||
"""
|
||||
rewards = []
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
for content, sol in zip(contents, solution):
|
||||
for content, sol in zip(contents, solution, strict=True):
|
||||
try:
|
||||
gold_parsed = parse(sol, extraction_mode="first_match")
|
||||
except Exception:
|
||||
|
@ -159,7 +159,7 @@ if __name__ == "__main__":
|
||||
"""
|
||||
rewards = []
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
for content, sol in zip(contents, solution):
|
||||
for content, sol in zip(contents, solution, strict=True):
|
||||
try:
|
||||
gold_parsed = parse(sol, extraction_mode="first_match")
|
||||
except Exception:
|
||||
|
@ -21,13 +21,12 @@ classifiers = [
|
||||
"Natural Language :: English",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13"
|
||||
]
|
||||
requires-python = ">=3.9"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"accelerate>=1.4.0",
|
||||
"datasets>=3.0.0",
|
||||
@ -125,7 +124,7 @@ version = { file = "VERSION" }
|
||||
branch = true
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py39"
|
||||
target-version = "py310"
|
||||
line-length = 119
|
||||
src = ["trl"]
|
||||
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import AutoModelForCausalLM
|
||||
@ -103,7 +102,7 @@ class TestActivationOffloading(TrlTestCase):
|
||||
grads2 = [p.grad.clone() for p in model.parameters()]
|
||||
|
||||
# Gradients should match as NoOpManager should have prevented offloading
|
||||
for g1, g2 in zip(grads1, grads2):
|
||||
for g1, g2 in zip(grads1, grads2, strict=True):
|
||||
assert torch.allclose(g1, g2, rtol=1e-4, atol=1e-5)
|
||||
|
||||
@require_torch_accelerator
|
||||
@ -152,5 +151,5 @@ class TestActivationOffloading(TrlTestCase):
|
||||
|
||||
# Check outputs and gradients match
|
||||
assert torch.allclose(out1, out2, rtol=1e-5)
|
||||
for g1, g2 in zip(grads1, grads2):
|
||||
for g1, g2 in zip(grads1, grads2, strict=True):
|
||||
assert torch.allclose(g1, g2, rtol=1e-5)
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, GenerationConfig
|
||||
|
||||
|
@ -116,7 +116,7 @@ class TestWinRateCallback(TrlTestCase):
|
||||
trainer.add_callback(win_rate_callback)
|
||||
trainer.train()
|
||||
winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h]
|
||||
for history_row, expected_row in zip(winrate_history, self.expected_winrates):
|
||||
for history_row, expected_row in zip(winrate_history, self.expected_winrates, strict=True):
|
||||
assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row)
|
||||
|
||||
def test_without_ref_model(self):
|
||||
@ -142,7 +142,7 @@ class TestWinRateCallback(TrlTestCase):
|
||||
trainer.add_callback(win_rate_callback)
|
||||
trainer.train()
|
||||
winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h]
|
||||
for history_row, expected_row in zip(winrate_history, self.expected_winrates):
|
||||
for history_row, expected_row in zip(winrate_history, self.expected_winrates, strict=True):
|
||||
assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row)
|
||||
|
||||
def test_soft_judge(self):
|
||||
@ -185,7 +185,7 @@ class TestWinRateCallback(TrlTestCase):
|
||||
for h in trainer.state.log_history
|
||||
if "eval_avg_win_prob" in h
|
||||
]
|
||||
for history_row, expected_row in zip(winrate_history, expected_soft_winrates):
|
||||
for history_row, expected_row in zip(winrate_history, expected_soft_winrates, strict=True):
|
||||
assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row)
|
||||
|
||||
@require_peft
|
||||
@ -219,7 +219,7 @@ class TestWinRateCallback(TrlTestCase):
|
||||
trainer.add_callback(win_rate_callback)
|
||||
trainer.train()
|
||||
winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h]
|
||||
for history_row, expected_row in zip(winrate_history, self.expected_winrates):
|
||||
for history_row, expected_row in zip(winrate_history, self.expected_winrates, strict=True):
|
||||
assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row)
|
||||
|
||||
|
||||
|
@ -12,23 +12,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import os
|
||||
import sys
|
||||
from io import StringIO
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from .testing_utils import TrlTestCase
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 10),
|
||||
reason="Transformers' generation codebase uses a Python >3.10 syntax (`str | None`), which seems to cause the CLI tests "
|
||||
"to fail on Python <3.10.", # let's say it's a known issue, but not expected to be fixed, because too niche
|
||||
)
|
||||
class TestCLI(TrlTestCase):
|
||||
def test_dpo(self):
|
||||
from trl.cli import main
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from trl.trainer.dpo_trainer import DataCollatorForPreference
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from trl.core import masked_mean, masked_var, masked_whiten
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from parameterized import parameterized
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable
|
||||
from collections.abc import Callable
|
||||
|
||||
from datasets import Dataset, load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
@ -1300,7 +1299,6 @@ class TestDPOTrainer(TrlTestCase):
|
||||
]
|
||||
)
|
||||
@require_liger_kernel
|
||||
@pytest.mark.skipif(not (sys.version_info >= (3, 10)), reason="Liger kernel is not supported on Python 3.9")
|
||||
def test_dpo_trainer_with_liger(self, beta, loss_type):
|
||||
"""Test DPO trainer with Liger loss enabled across supported loss types.
|
||||
|
||||
|
@ -69,7 +69,7 @@ class TestGKDTrainerGenerateOnPolicy(TrlTestCase):
|
||||
generated_texts = self.tokenizer.batch_decode(new_input_ids, skip_special_tokens=True)
|
||||
|
||||
# Check if the generated texts start with the original prompts
|
||||
for prompt, generated_text in zip(prompts, generated_texts):
|
||||
for prompt, generated_text in zip(prompts, generated_texts, strict=True):
|
||||
assert generated_text.startswith(prompt), (
|
||||
f"Generated text '{generated_text}' does not start with prompt '{prompt}'"
|
||||
)
|
||||
|
@ -602,7 +602,9 @@ class TestGRPOTrainer(TrlTestCase):
|
||||
|
||||
def reward_func(completions, some_values, **kwargs):
|
||||
"""Reward function that rewards completions with lengths closer to the values in some_values."""
|
||||
return [float(abs(len(completion) - value)) for completion, value in zip(completions, some_values)]
|
||||
return [
|
||||
float(abs(len(completion) - value)) for completion, value in zip(completions, some_values, strict=True)
|
||||
]
|
||||
|
||||
training_args = GRPOConfig(
|
||||
output_dir=self.tmp_dir,
|
||||
@ -1915,7 +1917,7 @@ class TestUpdateWithReplayBuffer:
|
||||
(item[1]["prompt_ids"].tolist(), item[1]["completion_ids"].tolist())
|
||||
for item in self.trainer.replay_buffer.heap
|
||||
]
|
||||
buffered_prompt_ids, buffered_completion_ids = zip(*buffered_prompt_completion_ids)
|
||||
buffered_prompt_ids, buffered_completion_ids = zip(*buffered_prompt_completion_ids, strict=True)
|
||||
|
||||
# Check for new entry with seq len 3 in buffer
|
||||
assert [[3, 4, 5], [3, 4, 5]] in buffered_prompt_ids # excluded no-variance group
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, GenerationConfig
|
||||
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from datasets import load_dataset
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from parameterized import parameterized
|
||||
|
@ -151,7 +151,7 @@ class TestPeftModel(TrlTestCase):
|
||||
model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(self.tmp_dir)
|
||||
|
||||
# check all the weights are the same
|
||||
for p1, p2 in zip(model.named_parameters(), model_from_pretrained.named_parameters()):
|
||||
for p1, p2 in zip(model.named_parameters(), model_from_pretrained.named_parameters(), strict=True):
|
||||
assert torch.allclose(p1[1], p2[1]), f"{p1[0]} != {p2[0]}"
|
||||
|
||||
def test_load_pretrained_peft(self):
|
||||
@ -175,7 +175,7 @@ class TestPeftModel(TrlTestCase):
|
||||
)
|
||||
|
||||
# check all the weights are the same
|
||||
for p1, p2 in zip(model.named_parameters(), model_from_pretrained.named_parameters()):
|
||||
for p1, p2 in zip(model.named_parameters(), model_from_pretrained.named_parameters(), strict=True):
|
||||
if p1[0] not in ["v_head.summary.weight", "v_head.summary.bias"]:
|
||||
assert torch.allclose(p1[1], p2[1]), f"{p1[0]} != {p2[0]}"
|
||||
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from trl.rewards import get_soft_overlong_punishment, think_format_reward
|
||||
|
||||
from .testing_utils import TrlTestCase
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from datasets import Dataset
|
||||
|
@ -491,7 +491,9 @@ class TestRLOOTrainer(TrlTestCase):
|
||||
|
||||
def reward_func(completions, some_values, **kwargs):
|
||||
"""Reward function that rewards completions with lengths closer to the values in some_values."""
|
||||
return [float(abs(len(completion) - value)) for completion, value in zip(completions, some_values)]
|
||||
return [
|
||||
float(abs(len(completion) - value)) for completion, value in zip(completions, some_values, strict=True)
|
||||
]
|
||||
|
||||
training_args = RLOOConfig(
|
||||
output_dir=self.tmp_dir,
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from datasets import load_dataset
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from datasets import load_dataset
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
@ -12,8 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from accelerate import logging
|
||||
@ -22,7 +20,7 @@ from accelerate import logging
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor:
|
||||
def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: bool | None = None) -> torch.Tensor:
|
||||
"""Compute mean of tensor with a masked values."""
|
||||
if axis is not None:
|
||||
return (values * mask).sum(axis=axis) / mask.sum(axis=axis)
|
||||
|
@ -13,9 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from itertools import takewhile
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
@ -119,8 +119,8 @@ def is_conversational(example: dict[str, Any]) -> bool:
|
||||
|
||||
def apply_chat_template(
|
||||
example: dict[str, list[dict[str, str]]],
|
||||
tokenizer: Union[PreTrainedTokenizerBase, ProcessorMixin],
|
||||
tools: Optional[list[Union[dict, Callable]]] = None,
|
||||
tokenizer: PreTrainedTokenizerBase | ProcessorMixin,
|
||||
tools: list[dict | Callable] | None = None,
|
||||
**template_kwargs,
|
||||
) -> dict[str, str]:
|
||||
r"""
|
||||
@ -174,7 +174,7 @@ def apply_chat_template(
|
||||
# DeepSeek-R1 inserts a <tool_call> token when using `add_generation_prompt`, which can cause discrepancies
|
||||
# between the prompt alone and the combined prompt+completion. To ensure consistency, we extract the
|
||||
# common prefix between the two. In most cases, this is a no-op.
|
||||
prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_chosen)))
|
||||
prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_chosen, strict=False)))
|
||||
|
||||
chosen = prompt_chosen[len(prompt) :]
|
||||
if "rejected" in example and "prompt" in example: # explicit prompt
|
||||
@ -182,14 +182,18 @@ def apply_chat_template(
|
||||
example["prompt"] + example["rejected"], tools=tools, tokenize=False, **template_kwargs
|
||||
)
|
||||
# Handle DeepSeek-R1 <tool_call> token, see the above comment for details
|
||||
prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_rejected)))
|
||||
prompt = "".join(
|
||||
x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_rejected, strict=False))
|
||||
)
|
||||
rejected = prompt_rejected[len(prompt) :]
|
||||
if "completion" in example:
|
||||
prompt_completion = tokenizer.apply_chat_template(
|
||||
example["prompt"] + example["completion"], tools=tools, tokenize=False, **template_kwargs
|
||||
)
|
||||
# Handle DeepSeek-R1 <tool_call> token, see the above comment for details
|
||||
prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_completion)))
|
||||
prompt = "".join(
|
||||
x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_completion, strict=False))
|
||||
)
|
||||
completion = prompt_completion[len(prompt) :]
|
||||
else: # implicit prompt case
|
||||
if "chosen" in example:
|
||||
@ -220,7 +224,7 @@ def apply_chat_template(
|
||||
def maybe_apply_chat_template(
|
||||
example: dict[str, list[dict[str, str]]],
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
tools: Optional[list[Union[dict, Callable]]] = None,
|
||||
tools: list[dict | Callable] | None = None,
|
||||
**template_kwargs: Any,
|
||||
) -> dict[str, str]:
|
||||
r"""
|
||||
@ -242,7 +246,7 @@ def maybe_apply_chat_template(
|
||||
messages, where each message is a dictionary with keys `"role"` and `"content"`.
|
||||
tokenizer (`PreTrainedTokenizerBase`):
|
||||
Tokenizer to apply the chat template with.
|
||||
tools (`list[Union[dict, Callable]]`, *optional*):
|
||||
tools (`list[dict | Callable]`, *optional*):
|
||||
A list of tools (callable functions) that will be accessible to the model. If the template does not support
|
||||
function calling, this argument will have no effect.
|
||||
**template_kwargs (`Any`, *optional*):
|
||||
@ -291,7 +295,7 @@ def _unpair_row(examples: list[dict[str, list[dict[str, str]]]]) -> list[dict[st
|
||||
|
||||
|
||||
def unpair_preference_dataset(
|
||||
dataset: DatasetType, num_proc: Optional[int] = None, desc: Optional[str] = None
|
||||
dataset: DatasetType, num_proc: int | None = None, desc: str | None = None
|
||||
) -> DatasetType:
|
||||
r"""
|
||||
Unpair a preference dataset.
|
||||
@ -334,7 +338,7 @@ def unpair_preference_dataset(
|
||||
|
||||
|
||||
def maybe_unpair_preference_dataset(
|
||||
dataset: DatasetType, num_proc: Optional[int] = None, desc: Optional[str] = None
|
||||
dataset: DatasetType, num_proc: int | None = None, desc: str | None = None
|
||||
) -> DatasetType:
|
||||
r"""
|
||||
Unpair a preference dataset if it is paired.
|
||||
@ -565,7 +569,7 @@ def _pack_bfd(examples: pa.Table, seq_length: int) -> pa.Table:
|
||||
|
||||
# Bin is represented as a dict (of example ids and sum of their lengths) to allow in-place updates
|
||||
bins: list[dict] = []
|
||||
for length, idx in zip(lengths.field(0).to_numpy(), lengths.field(1).to_numpy()):
|
||||
for length, idx in zip(lengths.field(0).to_numpy(), lengths.field(1).to_numpy(), strict=True):
|
||||
space = segment_tree.search(length)
|
||||
|
||||
if space < seq_length:
|
||||
@ -627,7 +631,7 @@ def _pack_wrapped(examples: pa.Table, seq_length: int) -> pa.Table:
|
||||
|
||||
|
||||
def pack_dataset(
|
||||
dataset: DatasetType, seq_length: int, strategy: str = "bfd", map_kwargs: Optional[dict[str, Any]] = None
|
||||
dataset: DatasetType, seq_length: int, strategy: str = "bfd", map_kwargs: dict[str, Any] | None = None
|
||||
) -> DatasetType:
|
||||
r"""
|
||||
Pack sequences in a dataset into chunks of size `seq_length`.
|
||||
@ -682,9 +686,7 @@ def pack_dataset(
|
||||
return dataset
|
||||
|
||||
|
||||
def truncate_dataset(
|
||||
dataset: DatasetType, max_length: int, map_kwargs: Optional[dict[str, Any]] = None
|
||||
) -> DatasetType:
|
||||
def truncate_dataset(dataset: DatasetType, max_length: int, map_kwargs: dict[str, Any] | None = None) -> DatasetType:
|
||||
r"""
|
||||
Truncate sequences in a dataset to a specified `max_length`.
|
||||
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
@ -13,14 +13,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from ...trainer.grpo_config import GRPOConfig as _GRPOConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class GFPOConfig(_GRPOConfig):
|
||||
num_remains_in_group: Optional[int] = field(
|
||||
num_remains_in_group: int | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "number inputs remains after group filter function, `'num_remains_in_group'` must be >=2 if given."
|
||||
|
@ -13,7 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from accelerate.utils import gather_object
|
||||
@ -183,7 +184,7 @@ class GFPOTrainer(_GRPOTrainer):
|
||||
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
||||
if is_conversational(inputs[0]):
|
||||
completions = []
|
||||
for prompt, completion in zip(prompts, completions_text):
|
||||
for prompt, completion in zip(prompts, completions_text, strict=True):
|
||||
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
|
||||
completions.append([{"role": "assistant", "content": bootstrap + completion}])
|
||||
else:
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import heapq
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from accelerate.utils import gather_object
|
||||
@ -35,7 +35,7 @@ class ReplayBuffer:
|
||||
self.heap = [] # Min-heap of (score, data) tuples
|
||||
|
||||
def add(self, scores: list[float], data: list[dict]):
|
||||
for score, datum in zip(scores, data):
|
||||
for score, datum in zip(scores, data, strict=True):
|
||||
if len(self.heap) < self.max_size:
|
||||
heapq.heappush(self.heap, (score, datum))
|
||||
else:
|
||||
@ -58,13 +58,13 @@ class ReplayBuffer:
|
||||
|
||||
|
||||
class GRPOWithReplayBufferTrainer(GRPOTrainer):
|
||||
def __init__(self, args: Optional[GRPOWithReplayBufferConfig] = None, **kwargs):
|
||||
def __init__(self, args: GRPOWithReplayBufferConfig | None = None, **kwargs):
|
||||
super().__init__(args=args, **kwargs)
|
||||
self.replay_buffer = ReplayBuffer(args.replay_buffer_size) if args.replay_buffer_size > 0 else None
|
||||
|
||||
def _generate_and_score_completions(
|
||||
self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
|
||||
) -> dict[str, Union[torch.Tensor, Any]]:
|
||||
self, inputs: list[dict[str, torch.Tensor | Any]]
|
||||
) -> dict[str, torch.Tensor | Any]:
|
||||
device = self.accelerator.device
|
||||
mode = "train" if self.model.training else "eval"
|
||||
|
||||
@ -89,7 +89,9 @@ class GRPOWithReplayBufferTrainer(GRPOTrainer):
|
||||
|
||||
# Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need
|
||||
# to re-tokenize completions if the reward is computed from tokens.
|
||||
completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())]
|
||||
completion_ids_list = [
|
||||
row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool(), strict=True)
|
||||
]
|
||||
|
||||
# Concatenate prompt_mask with completion_mask for logit computation
|
||||
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
|
||||
@ -162,7 +164,7 @@ class GRPOWithReplayBufferTrainer(GRPOTrainer):
|
||||
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
||||
if is_conversational(inputs[0]):
|
||||
completions = []
|
||||
for prompt, completion in zip(prompts, completions_text):
|
||||
for prompt, completion in zip(prompts, completions_text, strict=True):
|
||||
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
|
||||
completions.append([{"role": "assistant", "content": bootstrap + completion}])
|
||||
else:
|
||||
@ -339,9 +341,9 @@ class GRPOWithReplayBufferTrainer(GRPOTrainer):
|
||||
completion_mask: torch.Tensor,
|
||||
forward_kwargs: dict,
|
||||
optional_vision_fields: list[str] = None,
|
||||
old_per_token_logps: Optional[torch.Tensor] = None,
|
||||
ref_per_token_logps: Optional[torch.Tensor] = None,
|
||||
importance_sampling_ratio: Optional[float] = None,
|
||||
old_per_token_logps: torch.Tensor | None = None,
|
||||
ref_per_token_logps: torch.Tensor | None = None,
|
||||
importance_sampling_ratio: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Update the replay buffer with groups that have reward variance (std > 0).
|
||||
@ -463,9 +465,9 @@ class GRPOWithReplayBufferTrainer(GRPOTrainer):
|
||||
completion_mask: torch.Tensor,
|
||||
forward_kwargs: dict,
|
||||
num_items_in_batch: int,
|
||||
old_per_token_logps: Optional[torch.Tensor] = None,
|
||||
ref_per_token_logps: Optional[torch.Tensor] = None,
|
||||
importance_sampling_ratio: Optional[float] = None,
|
||||
old_per_token_logps: torch.Tensor | None = None,
|
||||
ref_per_token_logps: torch.Tensor | None = None,
|
||||
importance_sampling_ratio: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Update current batch data with samples from replay buffer.
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from trl import GRPOTrainer as _GRPOTrainer
|
||||
|
@ -12,7 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast, set_seed
|
||||
@ -47,13 +48,13 @@ class BestOfNSampler:
|
||||
def __init__(
|
||||
self,
|
||||
model: PreTrainedModelWrapper,
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
|
||||
queries_to_scores: Callable[[list[str]], list[float]],
|
||||
length_sampler: Any,
|
||||
sample_size: int = 4,
|
||||
seed: Optional[int] = None,
|
||||
seed: int | None = None,
|
||||
n_candidates: int = 1,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
generation_config: GenerationConfig | None = None,
|
||||
) -> None:
|
||||
if seed is not None:
|
||||
set_seed(seed)
|
||||
@ -78,9 +79,9 @@ class BestOfNSampler:
|
||||
|
||||
def generate(
|
||||
self,
|
||||
tokenized_query: Union[list[int], torch.Tensor, list[torch.Tensor], list[list[int]]],
|
||||
tokenized_query: list[int] | torch.Tensor | list[torch.Tensor] | list[list[int]],
|
||||
skip_special_tokens: bool = True,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
device: str | torch.device | None = None,
|
||||
**generation_kwargs,
|
||||
) -> list[list[str]]:
|
||||
"""
|
||||
|
@ -13,7 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Callable, Literal, Optional
|
||||
from collections.abc import Callable
|
||||
from typing import Literal
|
||||
|
||||
import datasets
|
||||
from datasets import Dataset, Value
|
||||
@ -36,7 +37,7 @@ else:
|
||||
|
||||
|
||||
def conversations_formatting_function(
|
||||
tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"], tools: Optional[list] = None
|
||||
tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"], tools: list | None = None
|
||||
):
|
||||
r"""
|
||||
return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the
|
||||
@ -84,8 +85,8 @@ def instructions_formatting_function(tokenizer: AutoTokenizer):
|
||||
|
||||
|
||||
def get_formatting_func_from_dataset(
|
||||
dataset: Dataset, tokenizer: AutoTokenizer, tools: Optional[list] = None
|
||||
) -> Optional[Callable]:
|
||||
dataset: Dataset, tokenizer: AutoTokenizer, tools: list | None = None
|
||||
) -> Callable | None:
|
||||
r"""
|
||||
Finds the correct formatting function based on the dataset structure. Currently supported datasets are:
|
||||
- `ChatML` with [{"role": str, "content": str}]
|
||||
|
@ -18,7 +18,6 @@ import logging
|
||||
import socket
|
||||
import time
|
||||
from io import BytesIO
|
||||
from typing import Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import torch
|
||||
@ -105,7 +104,7 @@ class VLLMClient:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: Optional[str] = None,
|
||||
base_url: str | None = None,
|
||||
host: str = "0.0.0.0",
|
||||
server_port: int = 8000,
|
||||
group_port: int = 51216,
|
||||
@ -170,7 +169,7 @@ class VLLMClient:
|
||||
def generate(
|
||||
self,
|
||||
prompts: list[str],
|
||||
images: Optional[list] = None,
|
||||
images: list | None = None,
|
||||
n: int = 1,
|
||||
repetition_penalty: float = 1.0,
|
||||
temperature: float = 1.0,
|
||||
@ -178,8 +177,8 @@ class VLLMClient:
|
||||
top_k: int = -1,
|
||||
min_p: float = 0.0,
|
||||
max_tokens: int = 16,
|
||||
guided_decoding_regex: Optional[str] = None,
|
||||
generation_kwargs: Optional[dict] = None,
|
||||
guided_decoding_regex: str | None = None,
|
||||
generation_kwargs: dict | None = None,
|
||||
) -> list[list[int]]:
|
||||
"""
|
||||
Generates model completions for the provided prompts.
|
||||
@ -250,7 +249,7 @@ class VLLMClient:
|
||||
else:
|
||||
raise Exception(f"Request failed: {response.status_code}, {response.text}")
|
||||
|
||||
def init_communicator(self, device: Union[torch.device, str, int] = 0):
|
||||
def init_communicator(self, device: torch.device | str | int = 0):
|
||||
"""
|
||||
Initializes the weight update group in a distributed setup for model synchronization.
|
||||
|
||||
|
@ -16,7 +16,6 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -392,7 +391,7 @@ class PreTrainedModelWrapper(nn.Module):
|
||||
object to handle corner cases when running scripts in distributed environments.
|
||||
|
||||
Returns:
|
||||
current_device (`Union[int, str]`):
|
||||
current_device (`int | str`):
|
||||
The current device.
|
||||
"""
|
||||
state = PartialState()
|
||||
@ -590,7 +589,7 @@ class PreTrainedModelWrapper(nn.Module):
|
||||
|
||||
|
||||
def create_reference_model(
|
||||
model: PreTrainedModelWrapper, num_shared_layers: Optional[int] = None, pattern: Optional[str] = None
|
||||
model: PreTrainedModelWrapper, num_shared_layers: int | None = None, pattern: str | None = None
|
||||
) -> PreTrainedModelWrapper:
|
||||
"""
|
||||
Creates a static reference copy of a model. Note that model will be in `.eval()` mode.
|
||||
|
@ -18,7 +18,7 @@ from collections.abc import Callable
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -87,8 +87,8 @@ FORMAT_MAPPING = {"chatml": ChatMlSpecialTokens}
|
||||
def setup_chat_format(
|
||||
model: PreTrainedModel,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
format: Optional[Literal["chatml"]] = "chatml",
|
||||
resize_to_multiple_of: Optional[int] = None,
|
||||
format: Literal["chatml"] | None = "chatml",
|
||||
resize_to_multiple_of: int | None = None,
|
||||
) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
|
||||
# docstyle-ignore
|
||||
"""
|
||||
@ -104,7 +104,7 @@ def setup_chat_format(
|
||||
Args:
|
||||
model (`~transformers.PreTrainedModel`): The model to be modified.
|
||||
tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified.
|
||||
format (`Optional[Literal["chatml"]]`): The format to be set. Defaults to "chatml".
|
||||
format (`Literal["chatml"] | None`): The format to be set. Defaults to "chatml".
|
||||
resize_to_multiple_of (`int` or `None`): Number to resize the embedding layer to. Defaults to None.
|
||||
|
||||
Returns:
|
||||
@ -164,7 +164,7 @@ def clone_chat_template(
|
||||
model: PreTrainedModel,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
source_tokenizer_path: str,
|
||||
resize_to_multiple_of: Optional[int] = 64,
|
||||
resize_to_multiple_of: int | None = 64,
|
||||
) -> tuple[PreTrainedModel, PreTrainedTokenizer, list[int]]:
|
||||
"""
|
||||
Clones a chat template from a source tokenizer to the target tokenizer and updates the model accordingly.
|
||||
@ -306,7 +306,7 @@ def add_hooks(model: "DeepSpeedEngine") -> None:
|
||||
|
||||
@contextmanager
|
||||
def unwrap_model_for_generation(
|
||||
model: Union["DistributedDataParallel", "DeepSpeedEngine"],
|
||||
model: "DistributedDataParallel | DeepSpeedEngine",
|
||||
accelerator: "Accelerator",
|
||||
gather_deepspeed3_params: bool = True,
|
||||
):
|
||||
@ -314,7 +314,7 @@ def unwrap_model_for_generation(
|
||||
Context manager to unwrap distributed or accelerated models for generation tasks.
|
||||
|
||||
Args:
|
||||
model (`Union[DistributedDataParallel, DeepSpeedEngine]`):
|
||||
model (`DistributedDataParallel | DeepSpeedEngine`):
|
||||
Model to be unwrapped.
|
||||
accelerator (`~accelerate.Accelerator`):
|
||||
Accelerator instance managing the model.
|
||||
@ -472,7 +472,7 @@ class _ForwardRedirection:
|
||||
|
||||
|
||||
def enable_gradient_checkpointing(
|
||||
model: PreTrainedModel, gradient_checkpointing_kwargs: Optional[dict]
|
||||
model: PreTrainedModel, gradient_checkpointing_kwargs: dict | None
|
||||
) -> PreTrainedModel:
|
||||
"""Enables gradient checkpointing for the model."""
|
||||
# Enable gradient checkpointing on the base model for PEFT
|
||||
@ -511,7 +511,7 @@ def peft_module_casting_to_bf16(model):
|
||||
|
||||
|
||||
def prepare_peft_model(
|
||||
model: PreTrainedModel, peft_config: Optional["PeftConfig"], args: TrainingArguments
|
||||
model: PreTrainedModel, peft_config: "PeftConfig | None", args: TrainingArguments
|
||||
) -> PreTrainedModel:
|
||||
"""Prepares a model for PEFT training."""
|
||||
if not is_peft_available():
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
|
||||
|
@ -62,7 +62,6 @@ python trl/scripts/dpo.py \
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from accelerate import logging
|
||||
@ -167,7 +166,7 @@ def main(script_args, training_args, model_args, dataset_args):
|
||||
trainer.accelerator.print(f"🤗 Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.")
|
||||
|
||||
|
||||
def make_parser(subparsers: Optional[argparse._SubParsersAction] = None):
|
||||
def make_parser(subparsers: argparse._SubParsersAction | None = None):
|
||||
dataclass_types = (ScriptArguments, DPOConfig, ModelConfig, DatasetMixtureConfig)
|
||||
if subparsers is not None:
|
||||
parser = subparsers.add_parser("dpo", help="Run the DPO training script", dataclass_types=dataclass_types)
|
||||
|
@ -26,7 +26,6 @@ import importlib
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from accelerate import logging
|
||||
from datasets import load_dataset
|
||||
@ -73,14 +72,14 @@ class GRPOScriptArguments(ScriptArguments):
|
||||
- any dotted import path " (e.g., `'my_lib.rewards.custom_reward'`).
|
||||
"""
|
||||
|
||||
reward_model_name_or_path: Optional[str] = field(
|
||||
reward_model_name_or_path: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Reward model id of a pretrained model hosted inside a model repo on huggingface.co or "
|
||||
"local path to a directory containing model weights saved using `PreTrainedModel.save_pretrained`."
|
||||
},
|
||||
)
|
||||
reward_funcs: Optional[list[str]] = field(
|
||||
reward_funcs: list[str] | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Reward functions to use. Supported values are: `think_format_reward`, "
|
||||
@ -153,7 +152,7 @@ def main(script_args, training_args, model_args, dataset_args):
|
||||
trainer.accelerator.print(f"🤗 Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.")
|
||||
|
||||
|
||||
def make_parser(subparsers: Optional[argparse._SubParsersAction] = None):
|
||||
def make_parser(subparsers: argparse._SubParsersAction | None = None):
|
||||
dataclass_types = (GRPOScriptArguments, GRPOConfig, ModelConfig, DatasetMixtureConfig)
|
||||
if subparsers is not None:
|
||||
parser = subparsers.add_parser("grpo", help="Run the GRPO training script", dataclass_types=dataclass_types)
|
||||
|
@ -66,7 +66,6 @@ python trl/scripts/kto.py \
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from accelerate import logging
|
||||
from datasets import load_dataset
|
||||
@ -147,7 +146,7 @@ def main(script_args, training_args, model_args, dataset_args):
|
||||
trainer.accelerator.print(f"🤗 Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.")
|
||||
|
||||
|
||||
def make_parser(subparsers: Optional[argparse._SubParsersAction] = None):
|
||||
def make_parser(subparsers: argparse._SubParsersAction | None = None):
|
||||
dataclass_types = (ScriptArguments, KTOConfig, ModelConfig, DatasetMixtureConfig)
|
||||
if subparsers is not None:
|
||||
parser = subparsers.add_parser("kto", help="Run the KTO training script", dataclass_types=dataclass_types)
|
||||
|
@ -23,7 +23,6 @@
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from accelerate import logging
|
||||
from datasets import load_dataset
|
||||
@ -87,7 +86,7 @@ def main(script_args, training_args, model_args, dataset_args):
|
||||
trainer.accelerator.print(f"🤗 Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.")
|
||||
|
||||
|
||||
def make_parser(subparsers: Optional[argparse._SubParsersAction] = None):
|
||||
def make_parser(subparsers: argparse._SubParsersAction | None = None):
|
||||
dataclass_types = (ScriptArguments, RewardConfig, ModelConfig, DatasetMixtureConfig)
|
||||
if subparsers is not None:
|
||||
parser = subparsers.add_parser(
|
||||
|
@ -26,7 +26,6 @@ import importlib
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from accelerate import logging
|
||||
from datasets import load_dataset
|
||||
@ -73,14 +72,14 @@ class RLOOScriptArguments(ScriptArguments):
|
||||
- any dotted import path " (e.g., `'my_lib.rewards.custom_reward'`).
|
||||
"""
|
||||
|
||||
reward_model_name_or_path: Optional[str] = field(
|
||||
reward_model_name_or_path: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Reward model id of a pretrained model hosted inside a model repo on huggingface.co or "
|
||||
"local path to a directory containing model weights saved using `PreTrainedModel.save_pretrained`."
|
||||
},
|
||||
)
|
||||
reward_funcs: Optional[list[str]] = field(
|
||||
reward_funcs: list[str] | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Reward functions to use. Supported values are: `think_format_reward`, "
|
||||
@ -153,7 +152,7 @@ def main(script_args, training_args, model_args, dataset_args):
|
||||
trainer.accelerator.print(f"🤗 Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.")
|
||||
|
||||
|
||||
def make_parser(subparsers: Optional[argparse._SubParsersAction] = None):
|
||||
def make_parser(subparsers: argparse._SubParsersAction | None = None):
|
||||
dataclass_types = (RLOOScriptArguments, RLOOConfig, ModelConfig, DatasetMixtureConfig)
|
||||
if subparsers is not None:
|
||||
parser = subparsers.add_parser("rloo", help="Run the RLOO training script", dataclass_types=dataclass_types)
|
||||
|
@ -64,7 +64,6 @@ python trl/scripts/sft.py \
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from accelerate import logging
|
||||
from datasets import load_dataset
|
||||
@ -158,7 +157,7 @@ def main(script_args, training_args, model_args, dataset_args):
|
||||
trainer.accelerator.print(f"🤗 Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.")
|
||||
|
||||
|
||||
def make_parser(subparsers: Optional[argparse._SubParsersAction] = None):
|
||||
def make_parser(subparsers: argparse._SubParsersAction | None = None):
|
||||
dataclass_types = (ScriptArguments, SFTConfig, ModelConfig, DatasetMixtureConfig)
|
||||
if subparsers is not None:
|
||||
parser = subparsers.add_parser("sft", help="Run the SFT training script", dataclass_types=dataclass_types)
|
||||
|
@ -21,7 +21,6 @@ import subprocess
|
||||
import sys
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Union
|
||||
|
||||
import datasets
|
||||
import yaml
|
||||
@ -80,11 +79,11 @@ class DatasetConfig:
|
||||
"""
|
||||
|
||||
path: str
|
||||
name: Optional[str] = None
|
||||
data_dir: Optional[str] = None
|
||||
data_files: Optional[Union[str, list[str], dict[str, str]]] = None
|
||||
name: str | None = None
|
||||
data_dir: str | None = None
|
||||
data_files: str | list[str] | dict[str, str] | None = None
|
||||
split: str = "train"
|
||||
columns: Optional[list[str]] = None
|
||||
columns: list[str] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -135,7 +134,7 @@ class DatasetMixtureConfig:
|
||||
default=False,
|
||||
metadata={"help": "Whether to stream the datasets. If True, the datasets will be loaded in streaming mode."},
|
||||
)
|
||||
test_split_size: Optional[float] = field(
|
||||
test_split_size: float | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Size of the test split. Refer to the `test_size` parameter in the `datasets.train_test_split` "
|
||||
@ -177,11 +176,11 @@ class ScriptArguments:
|
||||
https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992.
|
||||
"""
|
||||
|
||||
dataset_name: Optional[str] = field(
|
||||
dataset_name: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path or name of the dataset to load. If `datasets` is provided, this will be ignored."},
|
||||
)
|
||||
dataset_config: Optional[str] = field(
|
||||
dataset_config: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Dataset configuration name. Corresponds to the `name` argument of the `datasets.load_dataset` "
|
||||
@ -250,7 +249,7 @@ class TrlParser(HfArgumentParser):
|
||||
configurations, while also supporting configuration file loading and environment variable management.
|
||||
|
||||
Args:
|
||||
dataclass_types (`Union[DataClassType, Iterable[DataClassType]]`, *optional*):
|
||||
dataclass_types (`DataClassType | Iterable[DataClassType]`, *optional*):
|
||||
Dataclass types to use for argument parsing.
|
||||
**kwargs:
|
||||
Additional keyword arguments passed to the [`transformers.HfArgumentParser`] constructor.
|
||||
@ -294,7 +293,7 @@ class TrlParser(HfArgumentParser):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataclass_types: Optional[Union[DataClassType, Iterable[DataClassType]]] = None,
|
||||
dataclass_types: DataClassType | Iterable[DataClassType] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
# Make sure dataclass_types is an iterable
|
||||
@ -315,7 +314,7 @@ class TrlParser(HfArgumentParser):
|
||||
|
||||
def parse_args_and_config(
|
||||
self,
|
||||
args: Optional[Iterable[str]] = None,
|
||||
args: Iterable[str] | None = None,
|
||||
return_remaining_strings: bool = False,
|
||||
fail_with_unknown_args: bool = True,
|
||||
) -> tuple[DataClass, ...]:
|
||||
|
@ -23,7 +23,6 @@ from io import BytesIO
|
||||
from itertools import chain
|
||||
from multiprocessing import Pipe, Process
|
||||
from multiprocessing.connection import Connection
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed.distributed_c10d as c10d
|
||||
@ -239,7 +238,7 @@ class ScriptArguments:
|
||||
model: str = field(
|
||||
metadata={"help": "Model name or path to load the model from."},
|
||||
)
|
||||
revision: Optional[str] = field(
|
||||
revision: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Revision to use for the model. If not specified, the default branch will be used."},
|
||||
)
|
||||
@ -275,7 +274,7 @@ class ScriptArguments:
|
||||
"determined based on the model configuration. Find the supported values in the vLLM documentation."
|
||||
},
|
||||
)
|
||||
max_model_len: Optional[int] = field(
|
||||
max_model_len: int | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced "
|
||||
@ -283,14 +282,14 @@ class ScriptArguments:
|
||||
"context size, which might be much larger than the KV cache, leading to inefficiencies."
|
||||
},
|
||||
)
|
||||
enable_prefix_caching: Optional[bool] = field(
|
||||
enable_prefix_caching: bool | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the "
|
||||
"hardware support this feature."
|
||||
},
|
||||
)
|
||||
enforce_eager: Optional[bool] = field(
|
||||
enforce_eager: bool | None = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always "
|
||||
@ -487,7 +486,7 @@ def main(script_args: ScriptArguments):
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
prompts: list[str]
|
||||
images: Optional[list[str]] = None
|
||||
images: list[str] | None = None
|
||||
n: int = 1
|
||||
repetition_penalty: float = 1.0
|
||||
temperature: float = 1.0
|
||||
@ -495,7 +494,7 @@ def main(script_args: ScriptArguments):
|
||||
top_k: int = -1
|
||||
min_p: float = 0.0
|
||||
max_tokens: int = 16
|
||||
guided_decoding_regex: Optional[str] = None
|
||||
guided_decoding_regex: str | None = None
|
||||
generation_kwargs: dict = field(default_factory=dict)
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
@ -549,7 +548,7 @@ def main(script_args: ScriptArguments):
|
||||
request.images = request.images or [None] * len(request.prompts)
|
||||
|
||||
prompts = []
|
||||
for prompt, image in zip(request.prompts, request.images):
|
||||
for prompt, image in zip(request.prompts, request.images, strict=True):
|
||||
row = {"prompt": prompt}
|
||||
if image is not None:
|
||||
row["multi_modal_data"] = {"image": Image.open(BytesIO(base64.b64decode(image)))}
|
||||
@ -579,7 +578,7 @@ def main(script_args: ScriptArguments):
|
||||
chunked_prompts = chunk_list(prompts, script_args.data_parallel_size)
|
||||
|
||||
# Send the prompts to each worker
|
||||
for connection, prompts in zip(connections, chunked_prompts):
|
||||
for connection, prompts in zip(connections, chunked_prompts, strict=True):
|
||||
# When the number of prompts is less than data_parallel_size, some workers will receive empty prompts.
|
||||
# However, vLLM requires that we always send at least one prompt. So we send a placeholder prompt to comply
|
||||
# with vLLM's requirement, and we later ignore the result.
|
||||
@ -592,7 +591,7 @@ def main(script_args: ScriptArguments):
|
||||
all_outputs = [connection.recv() for connection in connections]
|
||||
|
||||
# Handle empty prompts (see above)
|
||||
all_outputs = [output for output, prompts in zip(all_outputs, chunked_prompts) if prompts]
|
||||
all_outputs = [output for output, prompts in zip(all_outputs, chunked_prompts, strict=True) if prompts]
|
||||
|
||||
# Flatten and combine all results
|
||||
all_outputs = list(chain.from_iterable(all_outputs)) # from list of list to single list
|
||||
@ -691,7 +690,7 @@ def main(script_args: ScriptArguments):
|
||||
uvicorn.run(app, host=script_args.host, port=script_args.port, log_level=script_args.log_level)
|
||||
|
||||
|
||||
def make_parser(subparsers: Optional[argparse._SubParsersAction] = None):
|
||||
def make_parser(subparsers: argparse._SubParsersAction | None = None):
|
||||
if subparsers is not None:
|
||||
parser = subparsers.add_parser("vllm-serve", help="Run the vLLM serve script", dataclass_types=ScriptArguments)
|
||||
else:
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from transformers import Trainer, is_wandb_available
|
||||
|
||||
@ -32,9 +31,9 @@ class BaseTrainer(Trainer):
|
||||
|
||||
def create_model_card(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
tags: Optional[Union[str, list[str]]] = None,
|
||||
model_name: str | None = None,
|
||||
dataset_name: str | None = None,
|
||||
tags: str | list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Creates a draft of a model card using the information available to the `Trainer`.
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from transformers import TrainingArguments
|
||||
|
||||
@ -93,7 +93,7 @@ class BCOConfig(TrainingArguments):
|
||||
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
|
||||
},
|
||||
)
|
||||
bf16: Optional[bool] = field(
|
||||
bf16: bool | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA "
|
||||
@ -102,21 +102,21 @@ class BCOConfig(TrainingArguments):
|
||||
},
|
||||
)
|
||||
|
||||
max_length: Optional[int] = field(
|
||||
max_length: int | None = field(
|
||||
default=1024,
|
||||
metadata={
|
||||
"help": "Maximum length of the sequences (prompt + completion) in the batch. "
|
||||
"This argument is required if you want to use the default data collator."
|
||||
},
|
||||
)
|
||||
max_prompt_length: Optional[int] = field(
|
||||
max_prompt_length: int | None = field(
|
||||
default=512,
|
||||
metadata={
|
||||
"help": "Maximum length of the prompt. "
|
||||
"This argument is required if you want to use the default data collator."
|
||||
},
|
||||
)
|
||||
max_completion_length: Optional[int] = field(
|
||||
max_completion_length: int | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Maximum length of the completion. This argument is required if you want to use the "
|
||||
@ -136,7 +136,7 @@ class BCOConfig(TrainingArguments):
|
||||
"help": "Label pad token id. This argument is required if you want to use the default data collator."
|
||||
},
|
||||
)
|
||||
padding_value: Optional[int] = field(
|
||||
padding_value: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."},
|
||||
)
|
||||
@ -159,7 +159,7 @@ class BCOConfig(TrainingArguments):
|
||||
"to W&B during evaluation."
|
||||
},
|
||||
)
|
||||
is_encoder_decoder: Optional[bool] = field(
|
||||
is_encoder_decoder: bool | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "When using the `model_init` argument (callable) to instantiate the model instead of the "
|
||||
@ -175,21 +175,21 @@ class BCOConfig(TrainingArguments):
|
||||
"needed."
|
||||
},
|
||||
)
|
||||
model_init_kwargs: Optional[dict[str, Any]] = field(
|
||||
model_init_kwargs: dict[str, Any] | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the "
|
||||
"model from a string."
|
||||
},
|
||||
)
|
||||
ref_model_init_kwargs: Optional[dict[str, Any]] = field(
|
||||
ref_model_init_kwargs: dict[str, Any] | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the "
|
||||
"reference model from a string."
|
||||
},
|
||||
)
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
dataset_num_proc: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of processes to use for processing the dataset."},
|
||||
)
|
||||
|
@ -17,10 +17,11 @@ import os
|
||||
import random
|
||||
import textwrap
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from operator import itemgetter
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -89,25 +90,27 @@ CLF_NAME = "clf.pkl"
|
||||
def _tokenize(
|
||||
batch: dict[str, list[Any]],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
embedding_tokenizer: Optional["PreTrainedTokenizer"] = None,
|
||||
embedding_tokenizer: "PreTrainedTokenizer | None" = None,
|
||||
) -> dict[str, list[Any]]:
|
||||
"""Tokenize a batch from a BCO specific dataset."""
|
||||
prompt_tokenized = tokenizer(batch["prompt"], add_special_tokens=False)
|
||||
prompt_input_ids = prompt_tokenized["input_ids"]
|
||||
prompt_attention_mask = prompt_tokenized["attention_mask"]
|
||||
prompt_and_completion = [prompt + completion for prompt, completion in zip(batch["prompt"], batch["completion"])]
|
||||
prompt_and_completion = [
|
||||
prompt + completion for prompt, completion in zip(batch["prompt"], batch["completion"], strict=True)
|
||||
]
|
||||
full_tokenized = tokenizer(prompt_and_completion, add_special_tokens=False)
|
||||
full_input_ids = full_tokenized["input_ids"]
|
||||
full_attention_mask = full_tokenized["attention_mask"]
|
||||
|
||||
answer_input_ids = [f[len(p) :] for f, p in zip(full_input_ids, prompt_input_ids)]
|
||||
answer_attention_mask = [f[len(p) :] for f, p in zip(full_attention_mask, prompt_attention_mask)]
|
||||
answer_input_ids = [f[len(p) :] for f, p in zip(full_input_ids, prompt_input_ids, strict=True)]
|
||||
answer_attention_mask = [f[len(p) :] for f, p in zip(full_attention_mask, prompt_attention_mask, strict=True)]
|
||||
|
||||
# Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
|
||||
full_concat_input_ids = [np.concatenate([p, a]) for p, a in zip(prompt_input_ids, answer_input_ids)]
|
||||
full_concat_input_ids = [np.concatenate([p, a]) for p, a in zip(prompt_input_ids, answer_input_ids, strict=True)]
|
||||
# Prepare input tokens for token by token comparison
|
||||
full_input_ids = [np.array(f) for f in full_input_ids]
|
||||
for full, concat in zip(full_input_ids, full_concat_input_ids):
|
||||
for full, concat in zip(full_input_ids, full_concat_input_ids, strict=True):
|
||||
if len(full) != len(concat):
|
||||
raise ValueError(
|
||||
"The elements in 'full_input_ids' and 'full_concat_input_ids' must have the same pairwise length."
|
||||
@ -121,19 +124,19 @@ def _tokenize(
|
||||
|
||||
# If tokenized prompt is different than both prompt+answer, then it means the
|
||||
# last token has changed due to merging.
|
||||
for idx, (p, f, r) in enumerate(zip(prompt_input_ids, full_input_ids, response_token_ids_start_idx)):
|
||||
for idx, (p, f, r) in enumerate(zip(prompt_input_ids, full_input_ids, response_token_ids_start_idx, strict=True)):
|
||||
if not np.array_equal(p, f[:r]):
|
||||
response_token_ids_start_idx[idx] -= 1
|
||||
|
||||
prompt_input_ids = [f[:r] for f, r in zip(full_input_ids, response_token_ids_start_idx)]
|
||||
prompt_attention_mask = [f[:r] for f, r in zip(full_attention_mask, response_token_ids_start_idx)]
|
||||
prompt_input_ids = [f[:r] for f, r in zip(full_input_ids, response_token_ids_start_idx, strict=True)]
|
||||
prompt_attention_mask = [f[:r] for f, r in zip(full_attention_mask, response_token_ids_start_idx, strict=True)]
|
||||
|
||||
for p, m in zip(prompt_input_ids, prompt_attention_mask):
|
||||
for p, m in zip(prompt_input_ids, prompt_attention_mask, strict=True):
|
||||
if len(p) != len(m):
|
||||
raise ValueError("Prompt input ids and attention mask should have the same length.")
|
||||
|
||||
answer_input_ids = [f[r:] for f, r in zip(full_input_ids, response_token_ids_start_idx)]
|
||||
answer_attention_mask = [f[r:] for f, r in zip(full_attention_mask, response_token_ids_start_idx)]
|
||||
answer_input_ids = [f[r:] for f, r in zip(full_input_ids, response_token_ids_start_idx, strict=True)]
|
||||
answer_attention_mask = [f[r:] for f, r in zip(full_attention_mask, response_token_ids_start_idx, strict=True)]
|
||||
|
||||
output = dict(
|
||||
prompt_input_ids=prompt_input_ids,
|
||||
@ -340,25 +343,27 @@ class BCOTrainer(BaseTrainer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module, str] = None,
|
||||
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
||||
model: PreTrainedModel | nn.Module | str = None,
|
||||
ref_model: PreTrainedModel | nn.Module | str | None = None,
|
||||
args: BCOConfig = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
||||
processing_class: Optional[
|
||||
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
||||
] = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
||||
callbacks: Optional[list[TrainerCallback]] = None,
|
||||
train_dataset: Dataset | None = None,
|
||||
eval_dataset: Dataset | dict[str, Dataset] | None = None,
|
||||
processing_class: PreTrainedTokenizerBase
|
||||
| BaseImageProcessor
|
||||
| FeatureExtractionMixin
|
||||
| ProcessorMixin
|
||||
| None = None,
|
||||
data_collator: DataCollator | None = None,
|
||||
model_init: Callable[[], PreTrainedModel] | None = None,
|
||||
callbacks: list[TrainerCallback] | None = None,
|
||||
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
peft_config: Optional[dict] = None,
|
||||
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
||||
model_adapter_name: Optional[str] = None,
|
||||
ref_adapter_name: Optional[str] = None,
|
||||
embedding_func: Optional[Callable] = None,
|
||||
embedding_tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
||||
peft_config: dict | None = None,
|
||||
compute_metrics: Callable[[EvalLoopOutput], dict] | None = None,
|
||||
model_adapter_name: str | None = None,
|
||||
ref_adapter_name: str | None = None,
|
||||
embedding_func: Callable | None = None,
|
||||
embedding_tokenizer: PreTrainedTokenizerBase | None = None,
|
||||
):
|
||||
if embedding_func is not None and not (is_sklearn_available() and is_joblib_available()):
|
||||
raise ImportError(
|
||||
@ -810,7 +815,7 @@ class BCOTrainer(BaseTrainer):
|
||||
return embeddings
|
||||
|
||||
def _get_prompt_embeddings(
|
||||
self, batch: dict[str, Union[list, torch.LongTensor]]
|
||||
self, batch: dict[str, list | torch.LongTensor]
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
"""Extract embeddings from frozen embedding model"""
|
||||
|
||||
@ -939,7 +944,7 @@ class BCOTrainer(BaseTrainer):
|
||||
|
||||
return super().get_train_dataloader()
|
||||
|
||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||
def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
|
||||
"""
|
||||
Returns the evaluation [`~torch.utils.data.DataLoader`].
|
||||
|
||||
@ -1080,7 +1085,7 @@ class BCOTrainer(BaseTrainer):
|
||||
return (per_token_logps * loss_mask).sum(-1)
|
||||
|
||||
def forward(
|
||||
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
||||
self, model: nn.Module, batch: dict[str, list | torch.LongTensor]
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
model_kwargs = (
|
||||
{
|
||||
@ -1143,8 +1148,8 @@ class BCOTrainer(BaseTrainer):
|
||||
policy_rejected_logps: torch.FloatTensor,
|
||||
reference_chosen_logps: torch.FloatTensor,
|
||||
reference_rejected_logps: torch.FloatTensor,
|
||||
chosen_embeddings: Optional[torch.FloatTensor],
|
||||
rejected_embeddings: Optional[torch.FloatTensor],
|
||||
chosen_embeddings: torch.FloatTensor | None,
|
||||
rejected_embeddings: torch.FloatTensor | None,
|
||||
do_train: bool = True,
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
"""Compute the BCO loss for a batch of policy and reference model log probabilities.
|
||||
@ -1196,7 +1201,7 @@ class BCOTrainer(BaseTrainer):
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model,
|
||||
batch: dict[str, Union[list, torch.LongTensor]],
|
||||
batch: dict[str, list | torch.LongTensor],
|
||||
do_train: bool = True,
|
||||
):
|
||||
"""Compute the BCO loss and other metrics for the given batch of inputs for train or test."""
|
||||
@ -1289,11 +1294,11 @@ class BCOTrainer(BaseTrainer):
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module],
|
||||
inputs: dict[str, Union[torch.Tensor, Any]],
|
||||
model: PreTrainedModel | nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
return_outputs=False,
|
||||
num_items_in_batch=None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
||||
) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
||||
compute_loss_context_manager = (
|
||||
autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
|
||||
)
|
||||
@ -1315,7 +1320,7 @@ class BCOTrainer(BaseTrainer):
|
||||
for key, value in metrics.items():
|
||||
self._stored_metrics[train_eval][key].append(value)
|
||||
|
||||
def _get_train_sampler(self, dataset: Optional[Dataset] = None) -> Optional[torch.utils.data.Sampler]:
|
||||
def _get_train_sampler(self, dataset: Dataset | None = None) -> torch.utils.data.Sampler | None:
|
||||
if dataset is None:
|
||||
dataset = self.train_dataset
|
||||
if dataset is None or not has_length(dataset):
|
||||
@ -1371,10 +1376,10 @@ class BCOTrainer(BaseTrainer):
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module],
|
||||
inputs: dict[str, Union[torch.Tensor, Any]],
|
||||
model: PreTrainedModel | nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[list[str]] = None,
|
||||
ignore_keys: list[str] | None = None,
|
||||
):
|
||||
if ignore_keys is None:
|
||||
if hasattr(model, "config"):
|
||||
@ -1411,8 +1416,8 @@ class BCOTrainer(BaseTrainer):
|
||||
self,
|
||||
dataloader: DataLoader,
|
||||
description: str,
|
||||
prediction_loss_only: Optional[bool] = None,
|
||||
ignore_keys: Optional[list[str]] = None,
|
||||
prediction_loss_only: bool | None = None,
|
||||
ignore_keys: list[str] | None = None,
|
||||
metric_key_prefix: str = "eval",
|
||||
) -> EvalLoopOutput:
|
||||
"""
|
||||
@ -1446,7 +1451,9 @@ class BCOTrainer(BaseTrainer):
|
||||
columns=["Prompt", "Policy", "Ref Model"],
|
||||
data=[
|
||||
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
|
||||
for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
|
||||
for prompt, pol, ref in zip(
|
||||
target_batch["prompt"], policy_output_decoded, ref_output_decoded, strict=True
|
||||
)
|
||||
],
|
||||
)
|
||||
if "wandb" in self.args.report_to:
|
||||
@ -1465,7 +1472,7 @@ class BCOTrainer(BaseTrainer):
|
||||
|
||||
return initial_output
|
||||
|
||||
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
||||
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
|
||||
"""
|
||||
Log `logs` on the various objects watching training, including stored metrics.
|
||||
|
||||
|
@ -14,7 +14,6 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
@ -66,7 +65,7 @@ def _generate_completions(
|
||||
model: PreTrainedModel,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
accelerator: Accelerator,
|
||||
generation_config: Optional[GenerationConfig],
|
||||
generation_config: GenerationConfig | None,
|
||||
batch_size: int = 1,
|
||||
) -> list[str]:
|
||||
"""
|
||||
@ -92,7 +91,7 @@ def _generate_completions(
|
||||
**tokenized_batch,
|
||||
generation_config=generation_config,
|
||||
)
|
||||
for prompt, generation in zip(tokenized_batch.input_ids, generations):
|
||||
for prompt, generation in zip(tokenized_batch.input_ids, generations, strict=True):
|
||||
# Remove prompt from generation
|
||||
generation = generation[len(prompt) :]
|
||||
completion = tokenizer.decode(generation, skip_special_tokens=True)
|
||||
@ -107,15 +106,15 @@ class SyncRefModelCallback(TrainerCallback):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ref_model: Union[PreTrainedModel, torch.nn.Module],
|
||||
accelerator: Optional[Accelerator],
|
||||
ref_model: PreTrainedModel | torch.nn.Module,
|
||||
accelerator: Accelerator | None,
|
||||
):
|
||||
self.accelerator = accelerator
|
||||
self.ref_model = ref_model
|
||||
|
||||
@staticmethod
|
||||
def _sync_target_model(model, target_model, alpha):
|
||||
for target_param, copy_param in zip(target_model.parameters(), model.parameters()):
|
||||
for target_param, copy_param in zip(target_model.parameters(), model.parameters(), strict=True):
|
||||
target_param.data.mul_(1.0 - alpha).add_(copy_param.data, alpha=alpha)
|
||||
|
||||
@staticmethod
|
||||
@ -225,7 +224,7 @@ def _win_rate_completions_df(
|
||||
state: TrainerState, prompts: list[str], completions: list[str], winner_indices: list[str]
|
||||
) -> pd.DataFrame:
|
||||
global_step = [str(state.global_step)] * len(prompts)
|
||||
data = list(zip(global_step, prompts, completions, winner_indices))
|
||||
data = list(zip(global_step, prompts, completions, winner_indices, strict=True))
|
||||
# Split completions from reference model and policy
|
||||
split_data = [(item[0], item[1], item[2][0], item[2][1], item[3]) for item in data]
|
||||
return pd.DataFrame(split_data, columns=["step", "prompt", "reference_model", "policy", "winner_index"])
|
||||
@ -273,8 +272,8 @@ class WinRateCallback(TrainerCallback):
|
||||
self,
|
||||
judge: BasePairwiseJudge,
|
||||
trainer: Trainer,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
num_prompts: Optional[int] = None,
|
||||
generation_config: GenerationConfig | None = None,
|
||||
num_prompts: int | None = None,
|
||||
shuffle_order: bool = True,
|
||||
use_soft_judge: bool = False,
|
||||
):
|
||||
@ -319,7 +318,7 @@ class WinRateCallback(TrainerCallback):
|
||||
batch_size=args.per_device_eval_batch_size,
|
||||
)
|
||||
# Compute initial win rate as a reference point
|
||||
completions = list(zip(self.ref_completions, self.ref_completions))
|
||||
completions = list(zip(self.ref_completions, self.ref_completions, strict=True))
|
||||
if self.use_soft_judge:
|
||||
ref_win_probs = self.judge.judge(prompts, completions, self.shuffle_order, return_scores=True)
|
||||
winner_indices = [0 if score > 0.5 else 1 for score in ref_win_probs]
|
||||
@ -379,7 +378,7 @@ class WinRateCallback(TrainerCallback):
|
||||
batch_size=args.per_device_eval_batch_size,
|
||||
)
|
||||
|
||||
completions = list(zip(self.ref_completions, completions))
|
||||
completions = list(zip(self.ref_completions, completions, strict=True))
|
||||
|
||||
if self.use_soft_judge:
|
||||
ref_win_probs = self.judge.judge(prompts, completions, self.shuffle_order, return_scores=True)
|
||||
@ -450,9 +449,9 @@ class LogCompletionsCallback(TrainerCallback):
|
||||
def __init__(
|
||||
self,
|
||||
trainer: Trainer,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
num_prompts: Optional[int] = None,
|
||||
freq: Optional[int] = None,
|
||||
generation_config: GenerationConfig | None = None,
|
||||
num_prompts: int | None = None,
|
||||
freq: int | None = None,
|
||||
):
|
||||
self.trainer = trainer
|
||||
self.generation_config = generation_config
|
||||
@ -498,7 +497,7 @@ class LogCompletionsCallback(TrainerCallback):
|
||||
# Build the data to log
|
||||
if self.trainer.accelerator.is_main_process:
|
||||
global_step = [str(state.global_step)] * len(prompts)
|
||||
data = list(zip(global_step, prompts, completions))
|
||||
data = list(zip(global_step, prompts, completions, strict=True))
|
||||
self.table.extend(data)
|
||||
table = pd.DataFrame(columns=["step", "prompt", "completion"], data=self.table)
|
||||
|
||||
@ -567,7 +566,7 @@ class WeaveCallback(TrainerCallback):
|
||||
scorers (`dict[str, Callable]`, *optional*):
|
||||
Dictionary mapping scorer names to scorer functions. If `None`, operates in tracing mode (predictions
|
||||
only). If provided, operates in evaluation mode (predictions + scores + summary). Scorer functions should
|
||||
have signature: `scorer(prompt: str, completion: str) -> Union[float, int]`
|
||||
have signature: `scorer(prompt: str, completion: str) -> float | int`
|
||||
generation_config (`GenerationConfig`, *optional*):
|
||||
Generation config to use for generating completions.
|
||||
num_prompts (`int` or `None`, *optional*):
|
||||
@ -582,12 +581,12 @@ class WeaveCallback(TrainerCallback):
|
||||
def __init__(
|
||||
self,
|
||||
trainer: Trainer,
|
||||
project_name: Optional[str] = None,
|
||||
scorers: Optional[dict[str, callable]] = None,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
num_prompts: Optional[int] = None,
|
||||
project_name: str | None = None,
|
||||
scorers: dict[str, callable] | None = None,
|
||||
generation_config: GenerationConfig | None = None,
|
||||
num_prompts: int | None = None,
|
||||
dataset_name: str = "eval_dataset",
|
||||
model_name: Optional[str] = None,
|
||||
model_name: str | None = None,
|
||||
):
|
||||
self.trainer = trainer
|
||||
self.project_name = project_name
|
||||
@ -696,7 +695,7 @@ class WeaveCallback(TrainerCallback):
|
||||
successful_predictions = 0
|
||||
total_score_values = {} # For summary statistics
|
||||
|
||||
for prompt, completion in zip(all_prompts, all_completions):
|
||||
for prompt, completion in zip(all_prompts, all_completions, strict=True):
|
||||
try:
|
||||
pred_logger = eval_logger.log_prediction(inputs={"prompt": prompt}, output=completion)
|
||||
|
||||
@ -771,7 +770,7 @@ class MergeModelCallback(TrainerCallback):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
merge_config: Optional["MergeConfig"] = None,
|
||||
merge_config: "MergeConfig | None" = None,
|
||||
merge_at_every_checkpoint: bool = False,
|
||||
push_to_hub: bool = False,
|
||||
):
|
||||
@ -954,7 +953,7 @@ class BEMACallback(TrainerCallback):
|
||||
|
||||
# Compute EMA + BEMA in-place and write directly to running_model
|
||||
for thetat, theta0, ema, run_param in zip(
|
||||
self.thetat_params, self.theta0_params, self.ema_params, self.running_model.parameters()
|
||||
self.thetat_params, self.theta0_params, self.ema_params, self.running_model.parameters(), strict=True
|
||||
):
|
||||
thetat = thetat.detach().to(self.device)
|
||||
ema.mul_(1 - beta).add_(thetat, alpha=beta) # EMA update: ema = (1 - beta) * ema + beta * θₜ
|
||||
@ -972,7 +971,9 @@ class BEMACallback(TrainerCallback):
|
||||
|
||||
# Snapshot θ₀ and EMA at first update
|
||||
if step == self.update_after:
|
||||
for thetat_param, theta0_param, ema_param in zip(self.thetat_params, self.theta0_params, self.ema_params):
|
||||
for thetat_param, theta0_param, ema_param in zip(
|
||||
self.thetat_params, self.theta0_params, self.ema_params, strict=True
|
||||
):
|
||||
theta0_param.copy_(thetat_param)
|
||||
ema_param.copy_(thetat_param)
|
||||
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from transformers import TrainingArguments
|
||||
|
||||
@ -107,7 +107,7 @@ class CPOConfig(TrainingArguments):
|
||||
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
|
||||
},
|
||||
)
|
||||
bf16: Optional[bool] = field(
|
||||
bf16: bool | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA "
|
||||
@ -116,18 +116,18 @@ class CPOConfig(TrainingArguments):
|
||||
},
|
||||
)
|
||||
|
||||
max_length: Optional[int] = field(
|
||||
max_length: int | None = field(
|
||||
default=1024,
|
||||
metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."},
|
||||
)
|
||||
max_prompt_length: Optional[int] = field(
|
||||
max_prompt_length: int | None = field(
|
||||
default=512,
|
||||
metadata={
|
||||
"help": "Maximum length of the prompt. This argument is required if you want to use the default data "
|
||||
"collator and your model is an encoder-decoder."
|
||||
},
|
||||
)
|
||||
max_completion_length: Optional[int] = field(
|
||||
max_completion_length: int | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Maximum length of the completion. This argument is required if you want to use the default data "
|
||||
@ -176,7 +176,7 @@ class CPOConfig(TrainingArguments):
|
||||
default=-100,
|
||||
metadata={"help": "Label pad token id."},
|
||||
)
|
||||
padding_value: Optional[int] = field(
|
||||
padding_value: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."},
|
||||
)
|
||||
@ -191,18 +191,18 @@ class CPOConfig(TrainingArguments):
|
||||
default=False,
|
||||
metadata={"help": "If `True`, generates and logs completions from the model to W&B during evaluation."},
|
||||
)
|
||||
is_encoder_decoder: Optional[bool] = field(
|
||||
is_encoder_decoder: bool | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Whether the model is an encoder-decoder model."},
|
||||
)
|
||||
model_init_kwargs: Optional[dict[str, Any]] = field(
|
||||
model_init_kwargs: dict[str, Any] | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model "
|
||||
"from a string."
|
||||
},
|
||||
)
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
dataset_num_proc: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of processes to use for processing the dataset."},
|
||||
)
|
||||
|
@ -16,9 +16,10 @@ import inspect
|
||||
import random
|
||||
import textwrap
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Literal, Optional, Union
|
||||
from typing import Any, Literal
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -127,20 +128,22 @@ class CPOTrainer(BaseTrainer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
||||
args: Optional[CPOConfig] = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
||||
processing_class: Optional[
|
||||
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
||||
] = None,
|
||||
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
||||
callbacks: Optional[list[TrainerCallback]] = None,
|
||||
model: PreTrainedModel | nn.Module | str | None = None,
|
||||
args: CPOConfig | None = None,
|
||||
data_collator: DataCollator | None = None,
|
||||
train_dataset: Dataset | None = None,
|
||||
eval_dataset: Dataset | dict[str, Dataset] | None = None,
|
||||
processing_class: PreTrainedTokenizerBase
|
||||
| BaseImageProcessor
|
||||
| FeatureExtractionMixin
|
||||
| ProcessorMixin
|
||||
| None = None,
|
||||
model_init: Callable[[], PreTrainedModel] | None = None,
|
||||
callbacks: list[TrainerCallback] | None = None,
|
||||
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
peft_config: Optional[dict] = None,
|
||||
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
||||
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
||||
peft_config: dict | None = None,
|
||||
compute_metrics: Callable[[EvalLoopOutput], dict] | None = None,
|
||||
):
|
||||
if args.model_init_kwargs is None:
|
||||
model_init_kwargs = {}
|
||||
@ -439,7 +442,7 @@ class CPOTrainer(BaseTrainer):
|
||||
attention_mask=answer_attention_mask,
|
||||
)
|
||||
|
||||
def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
|
||||
def tokenize_row(self, feature, model: PreTrainedModel | nn.Module | None = None) -> dict:
|
||||
"""Tokenize a single row from a CPO specific dataset.
|
||||
|
||||
At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt +
|
||||
@ -487,7 +490,8 @@ class CPOTrainer(BaseTrainer):
|
||||
# Make sure prompts only have one different token at most an
|
||||
# and length only differs by 1 at most
|
||||
num_diff_tokens = sum(
|
||||
a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])
|
||||
a != b
|
||||
for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"], strict=True)
|
||||
)
|
||||
num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
|
||||
if num_diff_tokens > 1 or num_diff_len > 1:
|
||||
@ -586,11 +590,11 @@ class CPOTrainer(BaseTrainer):
|
||||
|
||||
@staticmethod
|
||||
def concatenated_inputs(
|
||||
batch: dict[str, Union[list, torch.LongTensor]],
|
||||
batch: dict[str, list | torch.LongTensor],
|
||||
is_encoder_decoder: bool = False,
|
||||
label_pad_token_id: int = -100,
|
||||
padding_value: int = 0,
|
||||
device: Optional[torch.device] = None,
|
||||
device: torch.device | None = None,
|
||||
) -> dict[str, torch.LongTensor]:
|
||||
"""Concatenate the chosen and rejected inputs into a single tensor.
|
||||
|
||||
@ -769,7 +773,7 @@ class CPOTrainer(BaseTrainer):
|
||||
return (per_token_logps * loss_mask).sum(-1)
|
||||
|
||||
def concatenated_forward(
|
||||
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
||||
self, model: nn.Module, batch: dict[str, list | torch.LongTensor]
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
||||
|
||||
@ -846,7 +850,7 @@ class CPOTrainer(BaseTrainer):
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model,
|
||||
batch: dict[str, Union[list, torch.LongTensor]],
|
||||
batch: dict[str, list | torch.LongTensor],
|
||||
train_eval: Literal["train", "eval"] = "train",
|
||||
):
|
||||
"""Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
|
||||
@ -899,11 +903,11 @@ class CPOTrainer(BaseTrainer):
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module],
|
||||
inputs: dict[str, Union[torch.Tensor, Any]],
|
||||
model: PreTrainedModel | nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
return_outputs=False,
|
||||
num_items_in_batch=None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
||||
) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
||||
compute_loss_context_manager = (
|
||||
autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
|
||||
)
|
||||
@ -943,10 +947,10 @@ class CPOTrainer(BaseTrainer):
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module],
|
||||
inputs: dict[str, Union[torch.Tensor, Any]],
|
||||
model: PreTrainedModel | nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[list[str]] = None,
|
||||
ignore_keys: list[str] | None = None,
|
||||
):
|
||||
if ignore_keys is None:
|
||||
if hasattr(model, "config"):
|
||||
@ -986,8 +990,8 @@ class CPOTrainer(BaseTrainer):
|
||||
self,
|
||||
dataloader: DataLoader,
|
||||
description: str,
|
||||
prediction_loss_only: Optional[bool] = None,
|
||||
ignore_keys: Optional[list[str]] = None,
|
||||
prediction_loss_only: bool | None = None,
|
||||
ignore_keys: list[str] | None = None,
|
||||
metric_key_prefix: str = "eval",
|
||||
) -> EvalLoopOutput:
|
||||
"""
|
||||
@ -1013,7 +1017,8 @@ class CPOTrainer(BaseTrainer):
|
||||
table = pd.DataFrame(
|
||||
columns=["Prompt", "Policy"],
|
||||
data=[
|
||||
[prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
|
||||
[prompt, pol[len(prompt) :]]
|
||||
for prompt, pol in zip(random_batch["prompt"], policy_output_decoded, strict=True)
|
||||
],
|
||||
)
|
||||
if "wandb" in self.args.report_to:
|
||||
@ -1032,7 +1037,7 @@ class CPOTrainer(BaseTrainer):
|
||||
|
||||
return initial_output
|
||||
|
||||
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
||||
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
|
||||
"""
|
||||
Log `logs` on the various objects watching training, including stored metrics.
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from transformers import TrainingArguments
|
||||
|
||||
@ -123,7 +123,7 @@ class DPOConfig(TrainingArguments):
|
||||
Batch size to use when precomputing reference model log probabilities. This can be set higher than the
|
||||
training batch size to speed up preprocessing. If `None`, defaults to `per_device_train_batch_size` for
|
||||
training and `per_device_eval_batch_size` for evaluation.
|
||||
tools (`Optional[list[Union[dict, Callable]]]`, *optional*):
|
||||
tools (`list[dict] | None`, *optional*):
|
||||
List of tools (callable functions) that will be accessible to the model. If the template does not support
|
||||
function calling, this argument will have no effect.
|
||||
|
||||
@ -244,7 +244,7 @@ class DPOConfig(TrainingArguments):
|
||||
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
|
||||
},
|
||||
)
|
||||
bf16: Optional[bool] = field(
|
||||
bf16: bool | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA "
|
||||
@ -254,25 +254,25 @@ class DPOConfig(TrainingArguments):
|
||||
)
|
||||
|
||||
# Parameters that control the model and reference model
|
||||
model_init_kwargs: Optional[dict[str, Any]] = field(
|
||||
model_init_kwargs: dict[str, Any] | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of "
|
||||
"the `DPOTrainer` is provided as a string."
|
||||
},
|
||||
)
|
||||
ref_model_init_kwargs: Optional[dict[str, Any]] = field(
|
||||
ref_model_init_kwargs: dict[str, Any] | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `ref_model` argument "
|
||||
"of the `DPOTrainer` is provided as a string."
|
||||
},
|
||||
)
|
||||
model_adapter_name: Optional[str] = field(
|
||||
model_adapter_name: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Name of the train target PEFT adapter, when using LoRA with multiple adapters."},
|
||||
)
|
||||
ref_adapter_name: Optional[str] = field(
|
||||
ref_adapter_name: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Name of the reference PEFT adapter, when using LoRA with multiple adapters."},
|
||||
)
|
||||
@ -297,11 +297,11 @@ class DPOConfig(TrainingArguments):
|
||||
)
|
||||
|
||||
# Parameters that control the data preprocessing
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
dataset_num_proc: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of processes to use for processing the dataset."},
|
||||
)
|
||||
pad_token: Optional[str] = field(
|
||||
pad_token: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that "
|
||||
@ -312,15 +312,15 @@ class DPOConfig(TrainingArguments):
|
||||
default=-100,
|
||||
metadata={"help": "Padding value to use for labels."},
|
||||
)
|
||||
max_prompt_length: Optional[int] = field(
|
||||
max_prompt_length: int | None = field(
|
||||
default=512,
|
||||
metadata={"help": "Maximum length of the prompt."},
|
||||
)
|
||||
max_completion_length: Optional[int] = field(
|
||||
max_completion_length: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Maximum length of the completion."},
|
||||
)
|
||||
max_length: Optional[int] = field(
|
||||
max_length: int | None = field(
|
||||
default=1024,
|
||||
metadata={"help": "Maximum length of the full sequence (prompt + completion)."},
|
||||
)
|
||||
@ -350,7 +350,7 @@ class DPOConfig(TrainingArguments):
|
||||
"probabilities on-the-fly."
|
||||
},
|
||||
)
|
||||
precompute_ref_batch_size: Optional[int] = field(
|
||||
precompute_ref_batch_size: int | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Batch size to use when precomputing reference model log probabilities. This can be set higher "
|
||||
@ -358,7 +358,7 @@ class DPOConfig(TrainingArguments):
|
||||
"`per_device_train_batch_size` for training and `per_device_eval_batch_size` for evaluation."
|
||||
},
|
||||
)
|
||||
tools: Optional[list[Union[dict, Callable]]] = field(
|
||||
tools: list[dict] | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "List of tools (callable functions) that will be accessible to the model. If the template does "
|
||||
@ -396,7 +396,7 @@ class DPOConfig(TrainingArguments):
|
||||
"Higher β means less deviation from the reference model."
|
||||
},
|
||||
)
|
||||
f_divergence_type: Union[FDivergenceType, str] = field(
|
||||
f_divergence_type: FDivergenceType | str = field(
|
||||
default=FDivergenceType.REVERSE_KL,
|
||||
metadata={
|
||||
"help": "Type of f-divergence regularization function to compute divergence between policy and reference "
|
||||
@ -425,7 +425,7 @@ class DPOConfig(TrainingArguments):
|
||||
default=False,
|
||||
metadata={"help": "Whether to weight the loss as done in the WPO paper."},
|
||||
)
|
||||
rpo_alpha: Optional[float] = field(
|
||||
rpo_alpha: float | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "α parameter from the RPO paper (v3), which controls the weighting of the NLL term in the loss. "
|
||||
@ -433,7 +433,7 @@ class DPOConfig(TrainingArguments):
|
||||
"`rpo_alpha=1.0`."
|
||||
},
|
||||
)
|
||||
ld_alpha: Optional[float] = field(
|
||||
ld_alpha: float | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "α parameter from the LD-DPO paper, which controls the weighting of the verbose token "
|
||||
@ -448,7 +448,7 @@ class DPOConfig(TrainingArguments):
|
||||
"loss. The paper recommends the default value `discopop_tau=0.05`."
|
||||
},
|
||||
)
|
||||
loss_weights: Optional[list[float]] = field(
|
||||
loss_weights: list[float] | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "List of loss weights for multi-loss combinations. Used when combining multiple loss types. "
|
||||
@ -489,7 +489,7 @@ class DPOConfig(TrainingArguments):
|
||||
)
|
||||
|
||||
# Deprecated arguments
|
||||
padding_value: Optional[int] = field(
|
||||
padding_value: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Deprecated, use `pad_token` (str) instead."},
|
||||
)
|
||||
|
@ -17,10 +17,11 @@ import random
|
||||
import textwrap
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Literal, Optional, Union
|
||||
from typing import Any, Literal
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
@ -144,7 +145,7 @@ class DataCollatorForPreference(DataCollatorMixin):
|
||||
pad_token_id: int
|
||||
return_tensors: str = "pt"
|
||||
|
||||
def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]:
|
||||
def torch_call(self, examples: list[list[int] | Any | dict[str, Any]]) -> dict[str, Any]:
|
||||
# Convert to tensor
|
||||
prompt_input_ids = [torch.tensor(example["prompt_input_ids"]) for example in examples]
|
||||
prompt_attention_mask = [torch.ones_like(input_ids) for input_ids in prompt_input_ids]
|
||||
@ -188,7 +189,7 @@ class DPOTrainer(BaseTrainer):
|
||||
This class is a wrapper around the [`transformers.Trainer`] class and inherits all of its attributes and methods.
|
||||
|
||||
Args:
|
||||
model (`Union[str, PreTrainedModel]`):
|
||||
model (`str | PreTrainedModel`):
|
||||
Model to be trained. Can be either:
|
||||
|
||||
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
|
||||
@ -213,7 +214,7 @@ class DPOTrainer(BaseTrainer):
|
||||
- [Standard](dataset_formats#standard): Each sample contains plain text.
|
||||
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
|
||||
and content).
|
||||
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
|
||||
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Dataset | IterableDataset]`):
|
||||
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
|
||||
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
|
||||
Processing class used to process the data. If `None`, the processing class is loaded from the model's name
|
||||
@ -265,21 +266,23 @@ class DPOTrainer(BaseTrainer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[str, nn.Module, PreTrainedModel],
|
||||
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
||||
args: Optional[DPOConfig] = None,
|
||||
data_collator: Optional[DataCollator] = None, # type: ignore
|
||||
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
||||
eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
|
||||
processing_class: Optional[
|
||||
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
||||
] = None,
|
||||
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
||||
callbacks: Optional[list[TrainerCallback]] = None,
|
||||
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
|
||||
optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None,
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
peft_config: Optional["PeftConfig"] = None,
|
||||
model: str | nn.Module | PreTrainedModel,
|
||||
ref_model: PreTrainedModel | nn.Module | str | None = None,
|
||||
args: DPOConfig | None = None,
|
||||
data_collator: DataCollator | None = None, # type: ignore
|
||||
train_dataset: Dataset | IterableDataset | None = None,
|
||||
eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None,
|
||||
processing_class: PreTrainedTokenizerBase
|
||||
| BaseImageProcessor
|
||||
| FeatureExtractionMixin
|
||||
| ProcessorMixin
|
||||
| None = None,
|
||||
compute_metrics: Callable[[EvalLoopOutput], dict] | None = None,
|
||||
callbacks: list[TrainerCallback] | None = None,
|
||||
optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None),
|
||||
optimizer_cls_and_kwargs: tuple[type[torch.optim.Optimizer], dict[str, Any]] | None = None,
|
||||
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
||||
peft_config: "PeftConfig | None" = None,
|
||||
):
|
||||
# Args
|
||||
if args is None:
|
||||
@ -632,11 +635,11 @@ class DPOTrainer(BaseTrainer):
|
||||
|
||||
def _prepare_dataset(
|
||||
self,
|
||||
dataset: Union[Dataset, IterableDataset],
|
||||
processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin],
|
||||
dataset: Dataset | IterableDataset,
|
||||
processing_class: PreTrainedTokenizerBase | BaseImageProcessor | FeatureExtractionMixin | ProcessorMixin,
|
||||
args: DPOConfig,
|
||||
dataset_name: str,
|
||||
) -> Union[Dataset, IterableDataset]:
|
||||
) -> Dataset | IterableDataset:
|
||||
# Build the kwargs for the `map` function
|
||||
map_kwargs = {}
|
||||
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc nor writer_batch_size
|
||||
@ -679,8 +682,8 @@ class DPOTrainer(BaseTrainer):
|
||||
def tokenize_row(
|
||||
features: dict[str, str],
|
||||
processing_class: PreTrainedTokenizerBase,
|
||||
max_prompt_length: Optional[int] = None,
|
||||
max_completion_length: Optional[int] = None,
|
||||
max_prompt_length: int | None = None,
|
||||
max_completion_length: int | None = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> dict[str, list[int]]:
|
||||
"""
|
||||
@ -748,8 +751,8 @@ class DPOTrainer(BaseTrainer):
|
||||
def process_row(
|
||||
features: dict[str, str],
|
||||
processing_class: PreTrainedTokenizerBase,
|
||||
max_prompt_length: Optional[int] = None,
|
||||
max_completion_length: Optional[int] = None,
|
||||
max_prompt_length: int | None = None,
|
||||
max_completion_length: int | None = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> dict[str, list[int]]:
|
||||
"""
|
||||
@ -854,7 +857,7 @@ class DPOTrainer(BaseTrainer):
|
||||
|
||||
return super().get_train_dataloader()
|
||||
|
||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||
def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
|
||||
"""
|
||||
Returns the evaluation [`~torch.utils.data.DataLoader`].
|
||||
|
||||
@ -934,14 +937,14 @@ class DPOTrainer(BaseTrainer):
|
||||
|
||||
@staticmethod
|
||||
def concatenated_inputs(
|
||||
batch: dict[str, Union[list, torch.LongTensor]], padding_value: int
|
||||
batch: dict[str, list | torch.LongTensor], padding_value: int
|
||||
) -> dict[str, torch.LongTensor]:
|
||||
"""
|
||||
Concatenate the `chosen` and `rejected` inputs from the batch into a single tensor for both the prompt and
|
||||
completion sequences.
|
||||
|
||||
Args:
|
||||
batch (`dict[str, Union[list, torch.LongTensor]]`):
|
||||
batch (`dict[str, list | torch.LongTensor]`):
|
||||
A batch of input data. The batch must contain the following keys:
|
||||
|
||||
- `"prompt_input_ids"`: Tensor of shape `(batch_size, prompt_length)` representing the prompt input
|
||||
@ -1233,7 +1236,7 @@ class DPOTrainer(BaseTrainer):
|
||||
return losses, chosen_rewards, rejected_rewards
|
||||
|
||||
def _compute_loss_liger(
|
||||
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
||||
self, model: nn.Module, batch: dict[str, list | torch.LongTensor]
|
||||
) -> dict[str, torch.Tensor]:
|
||||
unwrapped_model = self.accelerator.unwrap_model(model)
|
||||
concatenated_batch = self.concatenated_inputs(batch, padding_value=self.pad_token_id)
|
||||
@ -1465,7 +1468,7 @@ class DPOTrainer(BaseTrainer):
|
||||
return output
|
||||
|
||||
def concatenated_forward(
|
||||
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]], is_ref_model: bool = False
|
||||
self, model: nn.Module, batch: dict[str, list | torch.LongTensor], is_ref_model: bool = False
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Runs the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
||||
@ -1687,8 +1690,8 @@ class DPOTrainer(BaseTrainer):
|
||||
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module],
|
||||
batch: dict[str, Union[list, torch.LongTensor]],
|
||||
model: PreTrainedModel | nn.Module,
|
||||
batch: dict[str, list | torch.LongTensor],
|
||||
train_eval: Literal["train", "eval"] = "train",
|
||||
) -> tuple[torch.Tensor, dict[str, float]]:
|
||||
"""Compute the DPO loss and other metrics for the given batch of inputs for train or test."""
|
||||
@ -1775,11 +1778,11 @@ class DPOTrainer(BaseTrainer):
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module],
|
||||
inputs: dict[str, Union[torch.Tensor, Any]],
|
||||
model: PreTrainedModel | nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
return_outputs=False,
|
||||
num_items_in_batch=None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, float]]]:
|
||||
) -> torch.Tensor | tuple[torch.Tensor, dict[str, float]]:
|
||||
compute_loss_context_manager = (
|
||||
autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
|
||||
)
|
||||
@ -1846,11 +1849,11 @@ class DPOTrainer(BaseTrainer):
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module],
|
||||
inputs: dict[str, Union[torch.Tensor, Any]],
|
||||
model: PreTrainedModel | nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[list[str]] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
ignore_keys: list[str] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
if ignore_keys is None:
|
||||
if hasattr(model, "config"):
|
||||
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
||||
@ -1889,8 +1892,8 @@ class DPOTrainer(BaseTrainer):
|
||||
self,
|
||||
dataloader: DataLoader,
|
||||
description: str,
|
||||
prediction_loss_only: Optional[bool] = None,
|
||||
ignore_keys: Optional[list[str]] = None,
|
||||
prediction_loss_only: bool | None = None,
|
||||
ignore_keys: list[str] | None = None,
|
||||
metric_key_prefix: str = "eval",
|
||||
) -> EvalLoopOutput:
|
||||
"""
|
||||
@ -1918,7 +1921,7 @@ class DPOTrainer(BaseTrainer):
|
||||
data=[
|
||||
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
|
||||
for prompt, pol, ref in zip(
|
||||
random_batch_dataset["prompt"], policy_output_decoded, ref_output_decoded
|
||||
random_batch_dataset["prompt"], policy_output_decoded, ref_output_decoded, strict=True
|
||||
)
|
||||
],
|
||||
)
|
||||
@ -1941,7 +1944,7 @@ class DPOTrainer(BaseTrainer):
|
||||
|
||||
return initial_output
|
||||
|
||||
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
||||
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
|
||||
"""
|
||||
Log `logs` on the various objects watching training, including stored metrics.
|
||||
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from transformers import TrainingArguments
|
||||
|
||||
@ -77,14 +77,14 @@ class GKDConfig(SFTConfig):
|
||||
default=128,
|
||||
metadata={"help": "Maximum number of tokens to generate per completion."},
|
||||
)
|
||||
teacher_model_name_or_path: Optional[str] = field(
|
||||
teacher_model_name_or_path: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Model name or path of the teacher model. If `None`, the teacher model will be the same as the "
|
||||
"model being trained."
|
||||
},
|
||||
)
|
||||
teacher_model_init_kwargs: Optional[dict[str, Any]] = field(
|
||||
teacher_model_init_kwargs: dict[str, Any] | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the "
|
||||
|
@ -14,7 +14,8 @@
|
||||
|
||||
import random
|
||||
import textwrap
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -111,21 +112,23 @@ class GKDTrainer(SFTTrainer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
||||
teacher_model: Union[PreTrainedModel, nn.Module, str] = None,
|
||||
args: Optional[GKDConfig] = None,
|
||||
data_collator: Optional[DataCollator] = None, # type: ignore
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
||||
processing_class: Optional[
|
||||
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
||||
] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
||||
callbacks: Optional[list[TrainerCallback]] = None,
|
||||
model: PreTrainedModel | nn.Module | str | None = None,
|
||||
teacher_model: PreTrainedModel | nn.Module | str = None,
|
||||
args: GKDConfig | None = None,
|
||||
data_collator: DataCollator | None = None, # type: ignore
|
||||
train_dataset: Dataset | None = None,
|
||||
eval_dataset: Dataset | dict[str, Dataset] | None = None,
|
||||
processing_class: PreTrainedTokenizerBase
|
||||
| BaseImageProcessor
|
||||
| FeatureExtractionMixin
|
||||
| ProcessorMixin
|
||||
| None = None,
|
||||
compute_metrics: Callable[[EvalPrediction], dict] | None = None,
|
||||
callbacks: list[TrainerCallback] | None = None,
|
||||
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
peft_config: Optional["PeftConfig"] = None,
|
||||
formatting_func: Optional[Callable] = None,
|
||||
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
||||
peft_config: "PeftConfig | None" = None,
|
||||
formatting_func: Callable | None = None,
|
||||
):
|
||||
# Ensure Trainer does not drop non-signature columns used by the collator (e.g., "prompts")
|
||||
args.remove_unused_columns = False
|
||||
@ -404,7 +407,7 @@ class GKDTrainer(SFTTrainer):
|
||||
return generated_tokens, new_attention_mask, new_labels
|
||||
|
||||
def training_step(
|
||||
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
|
||||
self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch: int | None = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform a training step for the Generalized Knowledge Distillation (GKD) model.
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Union
|
||||
|
||||
from transformers import TrainingArguments
|
||||
|
||||
@ -262,7 +261,7 @@ class GRPOConfig(TrainingArguments):
|
||||
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
|
||||
},
|
||||
)
|
||||
bf16: Optional[bool] = field(
|
||||
bf16: bool | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA "
|
||||
@ -272,7 +271,7 @@ class GRPOConfig(TrainingArguments):
|
||||
)
|
||||
|
||||
# Parameters that control the model and reference model
|
||||
model_init_kwargs: Optional[Union[dict, str]] = field(
|
||||
model_init_kwargs: dict | str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` "
|
||||
@ -290,27 +289,27 @@ class GRPOConfig(TrainingArguments):
|
||||
# Parameters that control the data preprocessing
|
||||
# The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on
|
||||
# additional columns to compute the reward
|
||||
remove_unused_columns: Optional[bool] = field(
|
||||
remove_unused_columns: bool | None = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to only keep the column 'prompt' in the dataset. If you use a custom reward function "
|
||||
"that requires any column other than 'prompts' and 'completions', you should keep this to `False`."
|
||||
},
|
||||
)
|
||||
max_prompt_length: Optional[int] = field(
|
||||
max_prompt_length: int | None = field(
|
||||
default=512,
|
||||
metadata={
|
||||
"help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left."
|
||||
},
|
||||
)
|
||||
num_generations: Optional[int] = field(
|
||||
num_generations: int | None = field(
|
||||
default=8,
|
||||
metadata={
|
||||
"help": "Number of generations to sample. The effective batch size (num_processes * per_device_batch_size "
|
||||
"* gradient_accumulation_steps) must be evenly divisible by this value."
|
||||
},
|
||||
)
|
||||
max_completion_length: Optional[int] = field(
|
||||
max_completion_length: int | None = field(
|
||||
default=256,
|
||||
metadata={"help": "Maximum length of the generated completion."},
|
||||
)
|
||||
@ -323,20 +322,20 @@ class GRPOConfig(TrainingArguments):
|
||||
"is not compatible with vLLM generation."
|
||||
},
|
||||
)
|
||||
shuffle_dataset: Optional[bool] = field(
|
||||
shuffle_dataset: bool | None = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to shuffle the training dataset."},
|
||||
)
|
||||
|
||||
# Parameters that control generation
|
||||
generation_batch_size: Optional[int] = field(
|
||||
generation_batch_size: int | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Batch size to use for generation. If `None`, it defaults to the effective training batch size: "
|
||||
"`per_device_train_batch_size * num_processes * steps_per_generation`."
|
||||
},
|
||||
)
|
||||
steps_per_generation: Optional[int] = field(
|
||||
steps_per_generation: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of steps per generation. If `None`, it defaults to `gradient_accumulation_steps`."},
|
||||
)
|
||||
@ -351,21 +350,21 @@ class GRPOConfig(TrainingArguments):
|
||||
"Set to 1.0 to consider all tokens."
|
||||
},
|
||||
)
|
||||
top_k: Optional[int] = field(
|
||||
top_k: int | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, "
|
||||
"top-k-filtering is disabled and all tokens are considered."
|
||||
},
|
||||
)
|
||||
min_p: Optional[float] = field(
|
||||
min_p: float | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Minimum token probability, which will be scaled by the probability of the most likely token. It "
|
||||
"must be a value between 0.0 and 1.0. Typical values are in the 0.01-0.2 range."
|
||||
},
|
||||
)
|
||||
generation_kwargs: Optional[dict] = field(
|
||||
generation_kwargs: dict | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Additional keyword arguments to pass to `GenerationConfig` (if using transformers) or "
|
||||
@ -390,7 +389,7 @@ class GRPOConfig(TrainingArguments):
|
||||
"implementation. This parameter is only effective when `use_vllm` is set to `False`."
|
||||
},
|
||||
)
|
||||
cache_implementation: Optional[str] = field(
|
||||
cache_implementation: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Implementation of the cache method for faster generation when use_vllm is set to False."},
|
||||
)
|
||||
@ -428,13 +427,13 @@ class GRPOConfig(TrainingArguments):
|
||||
"and woken for weight sync and generation."
|
||||
},
|
||||
)
|
||||
vllm_guided_decoding_regex: Optional[str] = field(
|
||||
vllm_guided_decoding_regex: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
|
||||
)
|
||||
|
||||
# Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`)
|
||||
vllm_server_base_url: Optional[str] = field(
|
||||
vllm_server_base_url: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Base URL for the vLLM server (e.g., 'http://localhost:8000'). If provided, `vllm_server_host` "
|
||||
@ -491,7 +490,7 @@ class GRPOConfig(TrainingArguments):
|
||||
default=0.2,
|
||||
metadata={"help": "Epsilon value for clipping."},
|
||||
)
|
||||
delta: Optional[float] = field(
|
||||
delta: float | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Enables the upper clipping bound in two-sided GRPO loss when set to a float. If `None` "
|
||||
@ -499,7 +498,7 @@ class GRPOConfig(TrainingArguments):
|
||||
"method is introduced in the [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291)."
|
||||
},
|
||||
)
|
||||
epsilon_high: Optional[float] = field(
|
||||
epsilon_high: float | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the "
|
||||
@ -516,7 +515,7 @@ class GRPOConfig(TrainingArguments):
|
||||
"sequence-level rewards."
|
||||
},
|
||||
)
|
||||
reward_weights: Optional[list[float]] = field(
|
||||
reward_weights: list[float] | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Weights for each reward function. Must match the number of reward functions. If `None`, all "
|
||||
@ -624,11 +623,11 @@ class GRPOConfig(TrainingArguments):
|
||||
"installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`."
|
||||
},
|
||||
)
|
||||
num_completions_to_print: Optional[int] = field(
|
||||
num_completions_to_print: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of completions to print with `rich`. If `None`, all completions are logged."},
|
||||
)
|
||||
wandb_log_unique_prompts: Optional[bool] = field(
|
||||
wandb_log_unique_prompts: bool | None = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to log unique prompts in wandb. If `True`, only unique prompts are logged. If `False`, "
|
||||
|
@ -17,10 +17,11 @@ import os
|
||||
import re
|
||||
import textwrap
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Callable
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
@ -94,7 +95,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
|
||||
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
|
||||
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
|
||||
RewardFunc = str | PreTrainedModel | Callable[[list, list], list[float]]
|
||||
|
||||
|
||||
class GRPOTrainer(BaseTrainer):
|
||||
@ -127,7 +128,7 @@ class GRPOTrainer(BaseTrainer):
|
||||
```
|
||||
|
||||
Args:
|
||||
model (`Union[str, PreTrainedModel]`):
|
||||
model (`str | PreTrainedModel`):
|
||||
Model to be trained. Can be either:
|
||||
|
||||
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
|
||||
@ -136,7 +137,7 @@ class GRPOTrainer(BaseTrainer):
|
||||
using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keyword arguments in
|
||||
`args.model_init_kwargs`.
|
||||
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
|
||||
reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
|
||||
reward_funcs (`RewardFunc | list[RewardFunc]`):
|
||||
Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
|
||||
functions with the prompts and completions and sum the rewards. Can be either:
|
||||
|
||||
@ -169,14 +170,14 @@ class GRPOTrainer(BaseTrainer):
|
||||
- [Standard](dataset_formats#standard): Each sample contains plain text.
|
||||
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
|
||||
and content).
|
||||
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
|
||||
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Dataset | IterableDataset]`):
|
||||
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
|
||||
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.ProcessorMixin`], *optional*):
|
||||
Processing class used to process the data. The padding side must be set to "left". If `None`, the
|
||||
processing class is loaded from the model's name with [`~transformers.AutoProcessor.from_pretrained`]. A
|
||||
padding token, `tokenizer.pad_token`, must be set. If the processing class has not set a padding token,
|
||||
`tokenizer.eos_token` will be used as the default.
|
||||
reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*):
|
||||
reward_processing_classes (`PreTrainedTokenizerBase | list[PreTrainedTokenizerBase]`, *optional*):
|
||||
Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
|
||||
|
||||
- A single processing class: Used when `reward_funcs` contains only one reward function.
|
||||
@ -217,16 +218,16 @@ class GRPOTrainer(BaseTrainer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[str, PreTrainedModel],
|
||||
reward_funcs: Union[RewardFunc, list[RewardFunc]],
|
||||
args: Optional[GRPOConfig] = None,
|
||||
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
||||
eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
|
||||
processing_class: Optional[Union[PreTrainedTokenizerBase, ProcessorMixin]] = None,
|
||||
reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
|
||||
callbacks: Optional[list[TrainerCallback]] = None,
|
||||
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
|
||||
peft_config: Optional["PeftConfig"] = None,
|
||||
model: str | PreTrainedModel,
|
||||
reward_funcs: RewardFunc | list[RewardFunc],
|
||||
args: GRPOConfig | None = None,
|
||||
train_dataset: Dataset | IterableDataset | None = None,
|
||||
eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None,
|
||||
processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None,
|
||||
reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None,
|
||||
callbacks: list[TrainerCallback] | None = None,
|
||||
optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None),
|
||||
peft_config: "PeftConfig | None" = None,
|
||||
):
|
||||
# Args
|
||||
if args is None:
|
||||
@ -333,7 +334,9 @@ class GRPOTrainer(BaseTrainer):
|
||||
f"reward functions ({len(reward_funcs)})."
|
||||
)
|
||||
|
||||
for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
|
||||
for i, (reward_processing_class, reward_func) in enumerate(
|
||||
zip(reward_processing_classes, reward_funcs, strict=True)
|
||||
):
|
||||
if isinstance(reward_func, PreTrainedModel):
|
||||
if reward_processing_class is None:
|
||||
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
|
||||
@ -661,7 +664,7 @@ class GRPOTrainer(BaseTrainer):
|
||||
|
||||
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
|
||||
|
||||
def _get_train_sampler(self, dataset: Optional[Dataset] = None) -> Sampler:
|
||||
def _get_train_sampler(self, dataset: Dataset | None = None) -> Sampler:
|
||||
# Returns a sampler that
|
||||
# 1. ensures each prompt is repeated across multiple processes. This guarantees that identical prompts are
|
||||
# distributed to different GPUs, allowing rewards to be computed and normalized correctly within each prompt
|
||||
@ -803,7 +806,7 @@ class GRPOTrainer(BaseTrainer):
|
||||
pixel_attention_mask=None,
|
||||
image_sizes=None,
|
||||
token_type_ids=None,
|
||||
) -> dict[str, Optional[torch.Tensor]]:
|
||||
) -> dict[str, torch.Tensor | None]:
|
||||
"""Compute log-probs and (optionally) entropies for each token."""
|
||||
batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak
|
||||
all_logps = []
|
||||
@ -862,7 +865,7 @@ class GRPOTrainer(BaseTrainer):
|
||||
entropies = torch.cat(all_entropies, dim=0) if compute_entropy else None
|
||||
return logps, entropies
|
||||
|
||||
def _fix_param_name_to_vllm(self, name, extra_prefixes: Optional[list[str]] = None):
|
||||
def _fix_param_name_to_vllm(self, name, extra_prefixes: list[str] | None = None):
|
||||
extra_prefixes = extra_prefixes or []
|
||||
prefixes = ["_checkpoint_wrapped_module."] + extra_prefixes
|
||||
for prefix in prefixes:
|
||||
@ -986,9 +989,7 @@ class GRPOTrainer(BaseTrainer):
|
||||
self.llm.reset_prefix_cache()
|
||||
|
||||
@profiling_decorator
|
||||
def _prepare_inputs(
|
||||
self, generation_batch: dict[str, Union[torch.Tensor, Any]]
|
||||
) -> dict[str, Union[torch.Tensor, Any]]:
|
||||
def _prepare_inputs(self, generation_batch: dict[str, torch.Tensor | Any]) -> dict[str, torch.Tensor | Any]:
|
||||
# Prepares inputs for model training/evaluation by managing completion generation and batch handling.
|
||||
# During training:
|
||||
# - Receives the local generation batch (Per-GPU batch size × steps per generation)
|
||||
@ -1033,15 +1034,15 @@ class GRPOTrainer(BaseTrainer):
|
||||
reward_kwargs["trainer_state"] = self.state
|
||||
|
||||
for i, (reward_func, reward_processing_class, reward_func_name) in enumerate(
|
||||
zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names)
|
||||
zip(self.reward_funcs, self.reward_processing_classes, self.reward_func_names, strict=True)
|
||||
):
|
||||
with profiling_context(self, reward_func_name):
|
||||
if isinstance(reward_func, nn.Module): # Module (no PretrainedModel) for compat with compiled models
|
||||
if is_conversational(inputs[0]):
|
||||
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
|
||||
messages = [{"messages": p + c} for p, c in zip(prompts, completions, strict=True)]
|
||||
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
|
||||
else:
|
||||
texts = [p + c for p, c in zip(prompts, completions)]
|
||||
texts = [p + c for p, c in zip(prompts, completions, strict=True)]
|
||||
reward_inputs = reward_processing_class(
|
||||
text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
|
||||
)
|
||||
@ -1075,7 +1076,7 @@ class GRPOTrainer(BaseTrainer):
|
||||
rewards_per_func = gather(rewards_per_func)
|
||||
return rewards_per_func
|
||||
|
||||
def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
|
||||
def _generate_single_turn(self, prompts: list[str], images: list | None):
|
||||
device = self.accelerator.device
|
||||
|
||||
# If the prompts are conversational and the inputs contain images, we need to convert the prompts from
|
||||
@ -1084,7 +1085,7 @@ class GRPOTrainer(BaseTrainer):
|
||||
kwargs = {}
|
||||
if images is not None:
|
||||
kwargs = {"images": images}
|
||||
for prompt, image_list in zip(prompts, images):
|
||||
for prompt, image_list in zip(prompts, images, strict=True):
|
||||
if isinstance(prompt, list): # i.e., when using conversational data
|
||||
prepare_multimodal_messages(prompt, num_images=len(image_list))
|
||||
|
||||
@ -1103,7 +1104,7 @@ class GRPOTrainer(BaseTrainer):
|
||||
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
||||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
||||
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
||||
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
|
||||
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool(), strict=True)]
|
||||
|
||||
if self.max_prompt_length is not None:
|
||||
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
|
||||
@ -1245,7 +1246,7 @@ class GRPOTrainer(BaseTrainer):
|
||||
|
||||
if images is not None and all_images:
|
||||
vllm_inputs = []
|
||||
for prompt, image_list in zip(all_prompts_text, all_images):
|
||||
for prompt, image_list in zip(all_prompts_text, all_images, strict=True):
|
||||
vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}})
|
||||
|
||||
else:
|
||||
@ -1342,13 +1343,13 @@ class GRPOTrainer(BaseTrainer):
|
||||
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
||||
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
|
||||
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
||||
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
|
||||
completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())]
|
||||
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool(), strict=True)]
|
||||
completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool(), strict=True)]
|
||||
logprobs = None # not used in this case
|
||||
|
||||
return prompt_ids, completion_ids, logprobs, forward_kwargs
|
||||
|
||||
def _generate(self, prompts: list[str], images: Optional[list]):
|
||||
def _generate(self, prompts: list[str], images: list | None):
|
||||
device = self.accelerator.device
|
||||
mode = "train" if self.model.training else "eval"
|
||||
|
||||
@ -1387,8 +1388,8 @@ class GRPOTrainer(BaseTrainer):
|
||||
return prompt_ids, completion_ids, total_completion_tokens, logprobs, forward_kwargs
|
||||
|
||||
def _generate_and_score_completions(
|
||||
self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
|
||||
) -> dict[str, Union[torch.Tensor, Any]]:
|
||||
self, inputs: list[dict[str, torch.Tensor | Any]]
|
||||
) -> dict[str, torch.Tensor | Any]:
|
||||
device = self.accelerator.device
|
||||
mode = "train" if self.model.training else "eval"
|
||||
|
||||
@ -1507,7 +1508,7 @@ class GRPOTrainer(BaseTrainer):
|
||||
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
||||
if is_conversational(inputs[0]):
|
||||
completions = []
|
||||
for prompt, completion in zip(prompts, completions_text):
|
||||
for prompt, completion in zip(prompts, completions_text, strict=True):
|
||||
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
|
||||
completions.append([{"role": "assistant", "content": bootstrap + completion}])
|
||||
else:
|
||||
@ -1815,7 +1816,7 @@ class GRPOTrainer(BaseTrainer):
|
||||
self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())
|
||||
return loss
|
||||
|
||||
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
|
||||
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: list[str] | None = None):
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
with torch.no_grad():
|
||||
with self.compute_loss_context_manager():
|
||||
@ -1823,7 +1824,7 @@ class GRPOTrainer(BaseTrainer):
|
||||
loss = loss.mean().detach()
|
||||
return loss, None, None
|
||||
|
||||
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
||||
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
|
||||
mode = "train" if self.model.training else "eval"
|
||||
metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics
|
||||
|
||||
|
@ -15,7 +15,6 @@
|
||||
import concurrent.futures
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from accelerate import Accelerator
|
||||
@ -154,7 +153,7 @@ class BaseBinaryJudge(BaseJudge):
|
||||
self,
|
||||
prompts: list[str],
|
||||
completions: list[str],
|
||||
gold_completions: Optional[list[str]] = None,
|
||||
gold_completions: list[str] | None = None,
|
||||
shuffle_order: bool = True,
|
||||
) -> list[int]:
|
||||
"""
|
||||
@ -224,7 +223,7 @@ class PairRMJudge(BasePairwiseJudge):
|
||||
shuffle_order: bool = True,
|
||||
return_scores: bool = False,
|
||||
temperature: float = 1.0,
|
||||
) -> list[Union[int, float]]:
|
||||
) -> list[int | float]:
|
||||
"""
|
||||
Judge the completion pairs for the given prompts using the PairRM model.
|
||||
|
||||
@ -241,7 +240,7 @@ class PairRMJudge(BasePairwiseJudge):
|
||||
Temperature for scaling logits if `return_scores` is True.
|
||||
|
||||
Returns:
|
||||
`Union[list[int, float]]`:
|
||||
`list[int | float]`:
|
||||
If `return_scores` is `False`, returns a list of ranks (`0` or `1`) for each prompt, indicating which
|
||||
completion is preferred. If `return_scores` is `True`, returns softmax probabilities for the first
|
||||
completion.
|
||||
@ -260,7 +259,7 @@ class PairRMJudge(BasePairwiseJudge):
|
||||
# Shuffle the order of the completions to avoid positional bias
|
||||
if shuffle_order:
|
||||
flip_mask = np.random.choice([True, False], size=len(prompts))
|
||||
completions = [pair[::-1] if flip else pair for flip, pair in zip(flip_mask, completions)]
|
||||
completions = [pair[::-1] if flip else pair for flip, pair in zip(flip_mask, completions, strict=True)]
|
||||
|
||||
# Rank the completions
|
||||
ranks = self.blender.rank(prompts, completions, return_scores=return_scores, disable_tqdm=True)
|
||||
@ -305,8 +304,8 @@ class HfPairwiseJudge(BasePairwiseJudge):
|
||||
def __init__(
|
||||
self,
|
||||
model="meta-llama/Meta-Llama-3-70B-Instruct",
|
||||
token: Optional[str] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
token: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
):
|
||||
self.client = InferenceClient(model=model, token=token)
|
||||
self.system_prompt = system_prompt or DEFAULT_PAIRWISE_SYSTEM_PROMPT
|
||||
@ -315,7 +314,7 @@ class HfPairwiseJudge(BasePairwiseJudge):
|
||||
# Shuffle the order of the completions to avoid positional bias
|
||||
if shuffle_order:
|
||||
flip_mask = np.random.choice([True, False], size=len(prompts))
|
||||
completions = [pair[::-1] if flip else pair for flip, pair in zip(flip_mask, completions)]
|
||||
completions = [pair[::-1] if flip else pair for flip, pair in zip(flip_mask, completions, strict=True)]
|
||||
|
||||
# Define a function to get the rank for a single prompt, will be called concurrently
|
||||
def get_rank(prompt, candidates):
|
||||
@ -359,7 +358,7 @@ class OpenAIPairwiseJudge(BasePairwiseJudge):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, model="gpt-4-turbo-preview", system_prompt: Optional[str] = None, max_requests: Union[int, None] = 1_000
|
||||
self, model="gpt-4-turbo-preview", system_prompt: str | None = None, max_requests: int | None = 1_000
|
||||
):
|
||||
if not is_openai_available():
|
||||
raise ValueError("OpenAI client is not installed. Please install it with 'pip install openai'.")
|
||||
@ -384,7 +383,7 @@ class OpenAIPairwiseJudge(BasePairwiseJudge):
|
||||
# Shuffle the order of the completions to avoid positional bias
|
||||
if shuffle_order:
|
||||
flip_mask = np.random.choice([True, False], size=len(prompts))
|
||||
completions = [pair[::-1] if flip else pair for flip, pair in zip(flip_mask, completions)]
|
||||
completions = [pair[::-1] if flip else pair for flip, pair in zip(flip_mask, completions, strict=True)]
|
||||
|
||||
# Define a function to get the rank for a single prompt, will be called concurrently
|
||||
def get_rank(prompt, candidates):
|
||||
@ -433,14 +432,14 @@ class AllTrueJudge(BaseBinaryJudge):
|
||||
self,
|
||||
prompts: list[str],
|
||||
completions: list[str],
|
||||
gold_completions: Optional[list[str]] = None,
|
||||
gold_completions: list[str] | None = None,
|
||||
shuffle_order: bool = True,
|
||||
) -> list[int]:
|
||||
all_binary_judgments = [
|
||||
judge.judge(prompts, completions, gold_completions, shuffle_order) for judge in self.judges
|
||||
]
|
||||
output = []
|
||||
for binary_judgments in zip(*all_binary_judgments):
|
||||
for binary_judgments in zip(*all_binary_judgments, strict=True):
|
||||
# Check that all values are in {0, 1, -1}
|
||||
if any(binary_judgment not in {0, 1, -1} for binary_judgment in binary_judgments):
|
||||
raise ValueError(
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from transformers import TrainingArguments
|
||||
|
||||
@ -107,7 +107,7 @@ class KTOConfig(TrainingArguments):
|
||||
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
|
||||
},
|
||||
)
|
||||
bf16: Optional[bool] = field(
|
||||
bf16: bool | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA "
|
||||
@ -116,18 +116,18 @@ class KTOConfig(TrainingArguments):
|
||||
},
|
||||
)
|
||||
|
||||
max_length: Optional[int] = field(
|
||||
max_length: int | None = field(
|
||||
default=1024,
|
||||
metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."},
|
||||
)
|
||||
max_prompt_length: Optional[int] = field(
|
||||
max_prompt_length: int | None = field(
|
||||
default=512,
|
||||
metadata={
|
||||
"help": "Maximum length of the prompt. This argument is required if you want to use the default data "
|
||||
"collator and your model is an encoder-decoder."
|
||||
},
|
||||
)
|
||||
max_completion_length: Optional[int] = field(
|
||||
max_completion_length: int | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Maximum length of the completion. This argument is required if you want to use the default data "
|
||||
@ -168,7 +168,7 @@ class KTOConfig(TrainingArguments):
|
||||
"help": "Label pad token id. This argument is required if you want to use the default data collator."
|
||||
},
|
||||
)
|
||||
padding_value: Optional[int] = field(
|
||||
padding_value: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."},
|
||||
)
|
||||
@ -186,7 +186,7 @@ class KTOConfig(TrainingArguments):
|
||||
"during evaluation."
|
||||
},
|
||||
)
|
||||
is_encoder_decoder: Optional[bool] = field(
|
||||
is_encoder_decoder: bool | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "When using the `model_init` argument (callable) to instantiate the model instead of the `model` "
|
||||
@ -204,21 +204,21 @@ class KTOConfig(TrainingArguments):
|
||||
"This is useful when training without the reference model to reduce the total GPU memory needed."
|
||||
},
|
||||
)
|
||||
model_init_kwargs: Optional[dict[str, Any]] = field(
|
||||
model_init_kwargs: dict[str, Any] | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model "
|
||||
"from a string."
|
||||
},
|
||||
)
|
||||
ref_model_init_kwargs: Optional[dict[str, Any]] = field(
|
||||
ref_model_init_kwargs: dict[str, Any] | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the "
|
||||
"reference model from a string."
|
||||
},
|
||||
)
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
dataset_num_proc: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of processes to use for processing the dataset."},
|
||||
)
|
||||
|
@ -16,10 +16,11 @@ import inspect
|
||||
import random
|
||||
import textwrap
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from operator import itemgetter
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -100,19 +101,21 @@ def _tokenize(
|
||||
prompt_tokenized = tokenizer(batch["prompt"], add_special_tokens=False)
|
||||
prompt_input_ids = prompt_tokenized["input_ids"]
|
||||
prompt_attention_mask = prompt_tokenized["attention_mask"]
|
||||
prompt_and_completion = [prompt + completion for prompt, completion in zip(batch["prompt"], batch["completion"])]
|
||||
prompt_and_completion = [
|
||||
prompt + completion for prompt, completion in zip(batch["prompt"], batch["completion"], strict=True)
|
||||
]
|
||||
full_tokenized = tokenizer(prompt_and_completion, add_special_tokens=False)
|
||||
full_input_ids = full_tokenized["input_ids"]
|
||||
full_attention_mask = full_tokenized["attention_mask"]
|
||||
|
||||
answer_input_ids = [f[len(p) :] for f, p in zip(full_input_ids, prompt_input_ids)]
|
||||
answer_attention_mask = [f[len(p) :] for f, p in zip(full_attention_mask, prompt_attention_mask)]
|
||||
answer_input_ids = [f[len(p) :] for f, p in zip(full_input_ids, prompt_input_ids, strict=True)]
|
||||
answer_attention_mask = [f[len(p) :] for f, p in zip(full_attention_mask, prompt_attention_mask, strict=True)]
|
||||
|
||||
# Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
|
||||
full_concat_input_ids = [np.concatenate([p, a]) for p, a in zip(prompt_input_ids, answer_input_ids)]
|
||||
full_concat_input_ids = [np.concatenate([p, a]) for p, a in zip(prompt_input_ids, answer_input_ids, strict=True)]
|
||||
# Prepare input tokens for token by token comparison
|
||||
full_input_ids = [np.array(f) for f in full_input_ids]
|
||||
for full, concat in zip(full_input_ids, full_concat_input_ids):
|
||||
for full, concat in zip(full_input_ids, full_concat_input_ids, strict=True):
|
||||
if len(full) != len(concat):
|
||||
raise ValueError(
|
||||
"The elements in 'full_input_ids' and 'full_concat_input_ids' must have the same pairwise length."
|
||||
@ -126,19 +129,19 @@ def _tokenize(
|
||||
|
||||
# If tokenized prompt is different than both prompt+answer, then it means the
|
||||
# last token has changed due to merging.
|
||||
for idx, (p, f, r) in enumerate(zip(prompt_input_ids, full_input_ids, response_token_ids_start_idx)):
|
||||
for idx, (p, f, r) in enumerate(zip(prompt_input_ids, full_input_ids, response_token_ids_start_idx, strict=True)):
|
||||
if not np.array_equal(p, f[:r]):
|
||||
response_token_ids_start_idx[idx] -= 1
|
||||
|
||||
prompt_input_ids = [f[:r] for f, r in zip(full_input_ids, response_token_ids_start_idx)]
|
||||
prompt_attention_mask = [f[:r] for f, r in zip(full_attention_mask, response_token_ids_start_idx)]
|
||||
prompt_input_ids = [f[:r] for f, r in zip(full_input_ids, response_token_ids_start_idx, strict=True)]
|
||||
prompt_attention_mask = [f[:r] for f, r in zip(full_attention_mask, response_token_ids_start_idx, strict=True)]
|
||||
|
||||
for p, m in zip(prompt_input_ids, prompt_attention_mask):
|
||||
for p, m in zip(prompt_input_ids, prompt_attention_mask, strict=True):
|
||||
if len(p) != len(m):
|
||||
raise ValueError("Prompt input ids and attention mask should have the same length.")
|
||||
|
||||
answer_input_ids = [f[r:] for f, r in zip(full_input_ids, response_token_ids_start_idx)]
|
||||
answer_attention_mask = [f[r:] for f, r in zip(full_attention_mask, response_token_ids_start_idx)]
|
||||
answer_input_ids = [f[r:] for f, r in zip(full_input_ids, response_token_ids_start_idx, strict=True)]
|
||||
answer_attention_mask = [f[r:] for f, r in zip(full_attention_mask, response_token_ids_start_idx, strict=True)]
|
||||
|
||||
output = dict(
|
||||
prompt_input_ids=prompt_input_ids,
|
||||
@ -335,23 +338,25 @@ class KTOTrainer(BaseTrainer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module, str] = None,
|
||||
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
||||
model: PreTrainedModel | nn.Module | str = None,
|
||||
ref_model: PreTrainedModel | nn.Module | str | None = None,
|
||||
args: KTOConfig = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
||||
processing_class: Optional[
|
||||
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
||||
] = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
||||
callbacks: Optional[list[TrainerCallback]] = None,
|
||||
train_dataset: Dataset | None = None,
|
||||
eval_dataset: Dataset | dict[str, Dataset] | None = None,
|
||||
processing_class: PreTrainedTokenizerBase
|
||||
| BaseImageProcessor
|
||||
| FeatureExtractionMixin
|
||||
| ProcessorMixin
|
||||
| None = None,
|
||||
data_collator: DataCollator | None = None,
|
||||
model_init: Callable[[], PreTrainedModel] | None = None,
|
||||
callbacks: list[TrainerCallback] | None = None,
|
||||
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
peft_config: Optional[dict] = None,
|
||||
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
||||
model_adapter_name: Optional[str] = None,
|
||||
ref_adapter_name: Optional[str] = None,
|
||||
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
||||
peft_config: dict | None = None,
|
||||
compute_metrics: Callable[[EvalLoopOutput], dict] | None = None,
|
||||
model_adapter_name: str | None = None,
|
||||
ref_adapter_name: str | None = None,
|
||||
):
|
||||
if type(args) is TrainingArguments:
|
||||
raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
|
||||
@ -870,7 +875,7 @@ class KTOTrainer(BaseTrainer):
|
||||
|
||||
return super().get_train_dataloader()
|
||||
|
||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||
def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
|
||||
"""
|
||||
Returns the evaluation [`~torch.utils.data.DataLoader`].
|
||||
|
||||
@ -1057,7 +1062,7 @@ class KTOTrainer(BaseTrainer):
|
||||
return (per_token_logps * loss_mask).sum(-1)
|
||||
|
||||
def forward(
|
||||
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
||||
self, model: nn.Module, batch: dict[str, list | torch.LongTensor]
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
KL_logps = self._compute_kl_logps(model, batch)
|
||||
|
||||
@ -1356,7 +1361,7 @@ class KTOTrainer(BaseTrainer):
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model,
|
||||
batch: dict[str, Union[list, torch.LongTensor]],
|
||||
batch: dict[str, list | torch.LongTensor],
|
||||
):
|
||||
"""Compute the KTO loss and other metrics for the given batch of inputs for train or test."""
|
||||
metrics = {}
|
||||
@ -1467,11 +1472,11 @@ class KTOTrainer(BaseTrainer):
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module],
|
||||
inputs: dict[str, Union[torch.Tensor, Any]],
|
||||
model: PreTrainedModel | nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
return_outputs=False,
|
||||
num_items_in_batch=None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
||||
) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
||||
compute_loss_context_manager = (
|
||||
autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
|
||||
)
|
||||
@ -1493,7 +1498,7 @@ class KTOTrainer(BaseTrainer):
|
||||
for key, value in metrics.items():
|
||||
self._stored_metrics[train_eval][key].append(value)
|
||||
|
||||
def _get_train_sampler(self, dataset: Optional[Dataset] = None) -> Optional[torch.utils.data.Sampler]:
|
||||
def _get_train_sampler(self, dataset: Dataset | None = None) -> torch.utils.data.Sampler | None:
|
||||
if dataset is None:
|
||||
dataset = self.train_dataset
|
||||
if dataset is None or not has_length(dataset):
|
||||
@ -1550,10 +1555,10 @@ class KTOTrainer(BaseTrainer):
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module],
|
||||
inputs: dict[str, Union[torch.Tensor, Any]],
|
||||
model: PreTrainedModel | nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[list[str]] = None,
|
||||
ignore_keys: list[str] | None = None,
|
||||
):
|
||||
if ignore_keys is None:
|
||||
if hasattr(model, "config"):
|
||||
@ -1590,8 +1595,8 @@ class KTOTrainer(BaseTrainer):
|
||||
self,
|
||||
dataloader: DataLoader,
|
||||
description: str,
|
||||
prediction_loss_only: Optional[bool] = None,
|
||||
ignore_keys: Optional[list[str]] = None,
|
||||
prediction_loss_only: bool | None = None,
|
||||
ignore_keys: list[str] | None = None,
|
||||
metric_key_prefix: str = "eval",
|
||||
) -> EvalLoopOutput:
|
||||
"""
|
||||
@ -1625,7 +1630,9 @@ class KTOTrainer(BaseTrainer):
|
||||
columns=["Prompt", "Policy", "Ref Model"],
|
||||
data=[
|
||||
[prompt, pol[len(prompt) :], ref[len(prompt) :]]
|
||||
for prompt, pol, ref in zip(target_batch["prompt"], policy_output_decoded, ref_output_decoded)
|
||||
for prompt, pol, ref in zip(
|
||||
target_batch["prompt"], policy_output_decoded, ref_output_decoded, strict=True
|
||||
)
|
||||
],
|
||||
)
|
||||
if "wandb" in self.args.report_to:
|
||||
@ -1644,7 +1651,7 @@ class KTOTrainer(BaseTrainer):
|
||||
|
||||
return initial_output
|
||||
|
||||
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
||||
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
|
||||
"""
|
||||
Log `logs` on the various objects watching training, including stored metrics.
|
||||
|
||||
|
@ -14,7 +14,6 @@
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -54,9 +53,9 @@ class ModelConfig:
|
||||
LoRA alpha.
|
||||
lora_dropout (`float`, *optional*, defaults to `0.05`):
|
||||
LoRA dropout.
|
||||
lora_target_modules (`Union[str, list[str]]`, *optional*):
|
||||
lora_target_modules (`str | list[str]`, *optional*):
|
||||
LoRA target modules.
|
||||
lora_target_parameters (`Union[str, list[str]]`, *optional*):
|
||||
lora_target_parameters (`str | list[str]`, *optional*):
|
||||
List of target parameters for LoRA.
|
||||
lora_modules_to_save (`list[str]`, *optional*):
|
||||
Model layers to unfreeze & train.
|
||||
@ -82,7 +81,7 @@ class ModelConfig:
|
||||
Whether to use nested quantization.
|
||||
"""
|
||||
|
||||
model_name_or_path: Optional[str] = field(
|
||||
model_name_or_path: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Model checkpoint for weights initialization."},
|
||||
)
|
||||
@ -90,7 +89,7 @@ class ModelConfig:
|
||||
default="main",
|
||||
metadata={"help": "Specific model version to use. It can be a branch name, a tag name, or a commit id."},
|
||||
)
|
||||
dtype: Optional[str] = field(
|
||||
dtype: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Override the default `torch.dtype` and load the model under this dtype.",
|
||||
@ -105,7 +104,7 @@ class ModelConfig:
|
||||
"execute code present on the Hub on your local machine."
|
||||
},
|
||||
)
|
||||
attn_implementation: Optional[str] = field(
|
||||
attn_implementation: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Which attention implementation to use. You can run `--attn_implementation=flash_attention_2`, in "
|
||||
@ -128,15 +127,15 @@ class ModelConfig:
|
||||
default=0.05,
|
||||
metadata={"help": "LoRA dropout."},
|
||||
)
|
||||
lora_target_modules: Optional[list[str]] = field(
|
||||
lora_target_modules: list[str] | None = field(
|
||||
default=None,
|
||||
metadata={"help": "LoRA target modules."},
|
||||
)
|
||||
lora_target_parameters: Optional[list[str]] = field(
|
||||
lora_target_parameters: list[str] | None = field(
|
||||
default=None,
|
||||
metadata={"help": "List of target parameters for LoRA."},
|
||||
)
|
||||
lora_modules_to_save: Optional[list[str]] = field(
|
||||
lora_modules_to_save: list[str] | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Model layers to unfreeze & train."},
|
||||
)
|
||||
@ -178,7 +177,7 @@ class ModelConfig:
|
||||
metadata={"help": "Whether to use nested quantization."},
|
||||
)
|
||||
# Deprecated params
|
||||
torch_dtype: Optional[str] = field(
|
||||
torch_dtype: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Override the default `torch.dtype` and load the model under this dtype.",
|
||||
|
@ -13,7 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import textwrap
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import jinja2
|
||||
import torch
|
||||
@ -122,24 +123,26 @@ class NashMDTrainer(OnlineDPOTrainer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module] = None,
|
||||
ref_model: Union[PreTrainedModel, nn.Module] = None,
|
||||
reward_funcs: Union[PreTrainedModel, nn.Module, None] = None,
|
||||
judge: Optional[BasePairwiseJudge] = None,
|
||||
args: Optional[NashMDConfig] = None,
|
||||
data_collator: Optional[Callable] = None,
|
||||
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
||||
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
||||
processing_class: Optional[
|
||||
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
||||
] = None,
|
||||
peft_config: Optional[dict] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
||||
callbacks: Optional[list[TrainerCallback]] = None,
|
||||
model: PreTrainedModel | nn.Module = None,
|
||||
ref_model: PreTrainedModel | nn.Module = None,
|
||||
reward_funcs: PreTrainedModel | nn.Module | None = None,
|
||||
judge: BasePairwiseJudge | None = None,
|
||||
args: NashMDConfig | None = None,
|
||||
data_collator: Callable | None = None,
|
||||
train_dataset: Dataset | IterableDataset | None = None,
|
||||
eval_dataset: Dataset | dict[str, Dataset] | None = None,
|
||||
processing_class: PreTrainedTokenizerBase
|
||||
| BaseImageProcessor
|
||||
| FeatureExtractionMixin
|
||||
| ProcessorMixin
|
||||
| None = None,
|
||||
peft_config: dict | None = None,
|
||||
compute_metrics: Callable[[EvalPrediction], dict] | None = None,
|
||||
callbacks: list[TrainerCallback] | None = None,
|
||||
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
||||
# Deprecated parameters
|
||||
reward_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
|
||||
reward_model: PreTrainedModel | nn.Module | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
model=model,
|
||||
@ -316,7 +319,7 @@ class NashMDTrainer(OnlineDPOTrainer):
|
||||
|
||||
probability = self.judge.judge(
|
||||
prompts,
|
||||
list(zip(model_data_completions, mixture_data_completions)),
|
||||
list(zip(model_data_completions, mixture_data_completions, strict=True)),
|
||||
return_scores=True,
|
||||
)
|
||||
return torch.tensor(probability, device=model_data["input_ids"].device)
|
||||
@ -426,7 +429,7 @@ class NashMDTrainer(OnlineDPOTrainer):
|
||||
self.stats["mixture_coef"].append(self.mixture_coef)
|
||||
|
||||
def training_step(
|
||||
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
|
||||
self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch: int | None = None
|
||||
) -> torch.Tensor:
|
||||
model.train()
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from transformers import TrainingArguments
|
||||
|
||||
@ -174,7 +174,7 @@ class OnlineDPOConfig(TrainingArguments):
|
||||
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
|
||||
},
|
||||
)
|
||||
bf16: Optional[bool] = field(
|
||||
bf16: bool | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA "
|
||||
@ -183,13 +183,13 @@ class OnlineDPOConfig(TrainingArguments):
|
||||
},
|
||||
)
|
||||
|
||||
reward_model_path: Optional[str] = field(
|
||||
reward_model_path: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Path to the reward model. Either `judge` or `reward_model_path` must be set, but not both."
|
||||
},
|
||||
)
|
||||
judge: Optional[str] = field(
|
||||
judge: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Name of the judge to use. Either `judge` or `reward_model_path` must be set, but not both."
|
||||
@ -218,14 +218,14 @@ class OnlineDPOConfig(TrainingArguments):
|
||||
"Set to 1.0 to consider all tokens."
|
||||
},
|
||||
)
|
||||
top_k: Optional[int] = field(
|
||||
top_k: int | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, "
|
||||
"top-k-filtering is disabled and all tokens are considered."
|
||||
},
|
||||
)
|
||||
min_p: Optional[float] = field(
|
||||
min_p: float | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Minimum token probability, which will be scaled by the probability of the most likely token. It "
|
||||
@ -240,7 +240,7 @@ class OnlineDPOConfig(TrainingArguments):
|
||||
"to repeat tokens."
|
||||
},
|
||||
)
|
||||
generation_kwargs: Optional[dict] = field(
|
||||
generation_kwargs: dict | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Additional keyword arguments to pass to `GenerationConfig` (if using transformers) or "
|
||||
@ -257,11 +257,11 @@ class OnlineDPOConfig(TrainingArguments):
|
||||
"implementation. This parameter is only effective when `use_vllm` is set to `False`."
|
||||
},
|
||||
)
|
||||
cache_implementation: Optional[str] = field(
|
||||
cache_implementation: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Implementation of the cache method for faster generation when use_vllm is set to False."},
|
||||
)
|
||||
missing_eos_penalty: Optional[float] = field(
|
||||
missing_eos_penalty: float | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Penalty applied to the score when the model fails to generate an EOS token. This is useful to "
|
||||
@ -304,11 +304,11 @@ class OnlineDPOConfig(TrainingArguments):
|
||||
"model implementation."
|
||||
},
|
||||
)
|
||||
vllm_guided_decoding_regex: Optional[str] = field(
|
||||
vllm_guided_decoding_regex: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
|
||||
)
|
||||
vllm_gpu_memory_utilization: Optional[float] = field(
|
||||
vllm_gpu_memory_utilization: float | None = field(
|
||||
default=0.55,
|
||||
metadata={
|
||||
"help": "Control the GPU memory utilization for vLLM. This setting only applies when `vllm_mode` is set "
|
||||
@ -326,7 +326,7 @@ class OnlineDPOConfig(TrainingArguments):
|
||||
"contention with training.",
|
||||
},
|
||||
)
|
||||
vllm_server_base_url: Optional[str] = field(
|
||||
vllm_server_base_url: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Base URL for the vLLM server (e.g., 'http://localhost:8000'). If provided, `vllm_server_host` "
|
||||
@ -365,14 +365,14 @@ class OnlineDPOConfig(TrainingArguments):
|
||||
"is not compatible with vLLM generation."
|
||||
},
|
||||
)
|
||||
model_init_kwargs: Optional[dict[str, Any]] = field(
|
||||
model_init_kwargs: dict[str, Any] | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model "
|
||||
"from a string."
|
||||
},
|
||||
)
|
||||
reward_weights: Optional[list[float]] = field(
|
||||
reward_weights: list[float] | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Weights for combining multiple reward functions. Must match the number of reward functions. "
|
||||
@ -381,11 +381,11 @@ class OnlineDPOConfig(TrainingArguments):
|
||||
)
|
||||
|
||||
# Deprecated parameters
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
dataset_num_proc: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of processes to use for processing the dataset."},
|
||||
)
|
||||
gpu_memory_utilization: Optional[float] = field(
|
||||
gpu_memory_utilization: float | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "This parameter is deprecated and will be removed in version 0.25.0. Please use "
|
||||
|
@ -16,10 +16,11 @@ import os
|
||||
import re
|
||||
import textwrap
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from contextlib import nullcontext
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import jinja2
|
||||
import torch
|
||||
@ -96,7 +97,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
|
||||
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
|
||||
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
|
||||
RewardFunc = str | PreTrainedModel | Callable[[list, list], list[float]]
|
||||
|
||||
|
||||
class OnlineDPOTrainer(BaseTrainer):
|
||||
@ -104,7 +105,7 @@ class OnlineDPOTrainer(BaseTrainer):
|
||||
Initialize OnlineDPOTrainer.
|
||||
|
||||
Args:
|
||||
model (`Union[str, nn.Module, PreTrainedModel]`):
|
||||
model (`str | nn.Module | PreTrainedModel`):
|
||||
Model to be trained. Can be either:
|
||||
|
||||
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
|
||||
@ -118,7 +119,7 @@ class OnlineDPOTrainer(BaseTrainer):
|
||||
model.
|
||||
judge (`BasePairwiseJudge`):
|
||||
The judge to use for pairwise comparison of model completions.
|
||||
reward_funcs (`Union[RewardFunc, list[RewardFunc]]`, *optional*):
|
||||
reward_funcs (`RewardFunc | list[RewardFunc]`, *optional*):
|
||||
Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
|
||||
functions with the prompts and completions and sum the rewards. Can be either:
|
||||
|
||||
@ -135,13 +136,13 @@ class OnlineDPOTrainer(BaseTrainer):
|
||||
sequences in the batch, given a dataset of paired sequences.
|
||||
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
|
||||
The dataset to use for training.
|
||||
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
|
||||
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Dataset | IterableDataset]`):
|
||||
The dataset to use for evaluation.
|
||||
processing_class ([`~transformers.PreTrainedTokenizerBase`] or [`~transformers.ProcessorMixin`], *optional*):
|
||||
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
||||
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
||||
reuse the fine-tuned model.
|
||||
reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*):
|
||||
reward_processing_classes (`PreTrainedTokenizerBase | list[PreTrainedTokenizerBase]`, *optional*):
|
||||
Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
|
||||
|
||||
- A single processing class: Used when `reward_funcs` contains only one reward function.
|
||||
@ -187,24 +188,24 @@ class OnlineDPOTrainer(BaseTrainer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module, str],
|
||||
ref_model: Union[PreTrainedModel, nn.Module, None] = None,
|
||||
reward_funcs: Optional[Union[RewardFunc, list[RewardFunc]]] = None,
|
||||
judge: Optional[BasePairwiseJudge] = None,
|
||||
args: Optional[OnlineDPOConfig] = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
||||
eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
|
||||
processing_class: Optional[Union[PreTrainedTokenizerBase, ProcessorMixin]] = None,
|
||||
reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
|
||||
peft_config: Optional["PeftConfig"] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
||||
callbacks: Optional[list[TrainerCallback]] = None,
|
||||
model: PreTrainedModel | nn.Module | str,
|
||||
ref_model: PreTrainedModel | nn.Module | None = None,
|
||||
reward_funcs: RewardFunc | list[RewardFunc] | None = None,
|
||||
judge: BasePairwiseJudge | None = None,
|
||||
args: OnlineDPOConfig | None = None,
|
||||
data_collator: DataCollator | None = None,
|
||||
train_dataset: Dataset | IterableDataset | None = None,
|
||||
eval_dataset: Dataset | IterableDataset | dict[str, Dataset | IterableDataset] | None = None,
|
||||
processing_class: PreTrainedTokenizerBase | ProcessorMixin | None = None,
|
||||
reward_processing_classes: PreTrainedTokenizerBase | list[PreTrainedTokenizerBase] | None = None,
|
||||
peft_config: "PeftConfig | None" = None,
|
||||
compute_metrics: Callable[[EvalPrediction], dict] | None = None,
|
||||
callbacks: list[TrainerCallback] | None = None,
|
||||
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
||||
# Deprecated parameters
|
||||
reward_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
|
||||
reward_processing_class: Optional[PreTrainedTokenizerBase] = None,
|
||||
reward_model: PreTrainedModel | nn.Module | None = None,
|
||||
reward_processing_class: PreTrainedTokenizerBase | None = None,
|
||||
) -> None:
|
||||
if ref_model is model:
|
||||
raise ValueError(
|
||||
@ -289,7 +290,7 @@ class OnlineDPOTrainer(BaseTrainer):
|
||||
)
|
||||
|
||||
self.reward_processing_classes = []
|
||||
for reward_processing_class_i, reward_func in zip(reward_processing_classes, reward_funcs):
|
||||
for reward_processing_class_i, reward_func in zip(reward_processing_classes, reward_funcs, strict=True):
|
||||
if isinstance(reward_func, PreTrainedModel):
|
||||
if reward_processing_class_i is None:
|
||||
reward_processing_class_i = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
|
||||
@ -653,7 +654,7 @@ class OnlineDPOTrainer(BaseTrainer):
|
||||
|
||||
# Same as Trainer.get_eval_dataloader but skip the "remove_unused_columns".
|
||||
@wraps(Trainer.get_eval_dataloader)
|
||||
def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None) -> DataLoader:
|
||||
def get_eval_dataloader(self, eval_dataset: str | Dataset | None = None) -> DataLoader:
|
||||
if eval_dataset is None and self.eval_dataset is None:
|
||||
raise ValueError("Trainer: evaluation requires an eval_dataset.")
|
||||
|
||||
@ -838,7 +839,7 @@ class OnlineDPOTrainer(BaseTrainer):
|
||||
# Prepare vLLM inputs with images if available
|
||||
if images is not None:
|
||||
vllm_inputs = []
|
||||
for prompt, image in zip(prompts_text, images):
|
||||
for prompt, image in zip(prompts_text, images, strict=True):
|
||||
if image is not None:
|
||||
vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}})
|
||||
else:
|
||||
@ -968,7 +969,7 @@ class OnlineDPOTrainer(BaseTrainer):
|
||||
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||
llm_model.load_weights([(name, param)])
|
||||
|
||||
def _fix_param_name_to_vllm(self, name, extra_prefixes: Optional[list[str]] = None):
|
||||
def _fix_param_name_to_vllm(self, name, extra_prefixes: list[str] | None = None):
|
||||
"""Clean parameter names for vLLM compatibility"""
|
||||
extra_prefixes = extra_prefixes or []
|
||||
prefixes = ["_checkpoint_wrapped_module."] + extra_prefixes
|
||||
@ -977,7 +978,7 @@ class OnlineDPOTrainer(BaseTrainer):
|
||||
return name
|
||||
|
||||
def process_vision_row(
|
||||
self, features: dict[str, Union[list, torch.Tensor]], processing_class=None
|
||||
self, features: dict[str, list | torch.Tensor], processing_class=None
|
||||
) -> dict[str, list[int]]:
|
||||
"""
|
||||
Process a vision row for VLM models (adapted from DPO trainer)
|
||||
@ -1165,15 +1166,15 @@ class OnlineDPOTrainer(BaseTrainer):
|
||||
reward_kwargs["trainer_state"] = self.state
|
||||
|
||||
for i, (reward_func, reward_processing_class) in enumerate(
|
||||
zip(self.reward_funcs, self.reward_processing_classes)
|
||||
zip(self.reward_funcs, self.reward_processing_classes, strict=True)
|
||||
):
|
||||
if isinstance(reward_func, nn.Module): # Model-based reward function
|
||||
# Handle conversational vs text input
|
||||
if is_conversational({"prompt": prompts[0]}):
|
||||
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
|
||||
messages = [{"messages": p + c} for p, c in zip(prompts, completions, strict=True)]
|
||||
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
|
||||
else:
|
||||
texts = [p + c for p, c in zip(prompts, completions)]
|
||||
texts = [p + c for p, c in zip(prompts, completions, strict=True)]
|
||||
|
||||
# Tokenize and get reward scores
|
||||
reward_inputs = reward_processing_class(
|
||||
@ -1237,7 +1238,7 @@ class OnlineDPOTrainer(BaseTrainer):
|
||||
return logprobs
|
||||
|
||||
def training_step(
|
||||
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
|
||||
self, model: nn.Module, inputs: dict[str, torch.Tensor | Any], num_items_in_batch: int | None = None
|
||||
) -> torch.Tensor:
|
||||
model.train()
|
||||
|
||||
@ -1358,7 +1359,7 @@ class OnlineDPOTrainer(BaseTrainer):
|
||||
completions = [template.render(messages=completion) for completion in completions]
|
||||
|
||||
ranks_of_first_completion = self.judge.judge(
|
||||
prompts, list(zip(completions[:batch_size], completions[batch_size:]))
|
||||
prompts, list(zip(completions[:batch_size], completions[batch_size:], strict=True))
|
||||
)
|
||||
|
||||
# convert ranks to a True/False mask:
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from transformers import TrainingArguments
|
||||
|
||||
@ -85,7 +85,7 @@ class ORPOConfig(TrainingArguments):
|
||||
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
|
||||
},
|
||||
)
|
||||
bf16: Optional[bool] = field(
|
||||
bf16: bool | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA "
|
||||
@ -94,18 +94,18 @@ class ORPOConfig(TrainingArguments):
|
||||
},
|
||||
)
|
||||
|
||||
max_length: Optional[int] = field(
|
||||
max_length: int | None = field(
|
||||
default=1024,
|
||||
metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."},
|
||||
)
|
||||
max_prompt_length: Optional[int] = field(
|
||||
max_prompt_length: int | None = field(
|
||||
default=512,
|
||||
metadata={
|
||||
"help": "Maximum length of the prompt. This argument is required if you want to use the default data "
|
||||
"collator and your model is an encoder-decoder."
|
||||
},
|
||||
)
|
||||
max_completion_length: Optional[int] = field(
|
||||
max_completion_length: int | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Maximum length of the completion. This argument is required if you want to use the default data "
|
||||
@ -129,7 +129,7 @@ class ORPOConfig(TrainingArguments):
|
||||
"help": "Label pad token id. This argument is required if you want to use the default data collator."
|
||||
},
|
||||
)
|
||||
padding_value: Optional[int] = field(
|
||||
padding_value: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."},
|
||||
)
|
||||
@ -144,21 +144,21 @@ class ORPOConfig(TrainingArguments):
|
||||
default=False,
|
||||
metadata={"help": "If `True`, generates and logs completions from the model to W&B during evaluation."},
|
||||
)
|
||||
is_encoder_decoder: Optional[bool] = field(
|
||||
is_encoder_decoder: bool | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "When using the `model_init` argument (callable) to instantiate the model instead of the `model` "
|
||||
"argument, you need to specify if the model returned by the callable is an encoder-decoder model."
|
||||
},
|
||||
)
|
||||
model_init_kwargs: Optional[dict[str, Any]] = field(
|
||||
model_init_kwargs: dict[str, Any] | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model "
|
||||
"from a string."
|
||||
},
|
||||
)
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
dataset_num_proc: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of processes to use for processing the dataset."},
|
||||
)
|
||||
|
@ -16,9 +16,10 @@ import inspect
|
||||
import random
|
||||
import textwrap
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Literal, Optional, Union
|
||||
from typing import Any, Literal
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -129,20 +130,22 @@ class ORPOTrainer(BaseTrainer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
||||
args: Optional[ORPOConfig] = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
||||
processing_class: Optional[
|
||||
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
||||
] = None,
|
||||
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
||||
callbacks: Optional[list[TrainerCallback]] = None,
|
||||
model: PreTrainedModel | nn.Module | str | None = None,
|
||||
args: ORPOConfig | None = None,
|
||||
data_collator: DataCollator | None = None,
|
||||
train_dataset: Dataset | None = None,
|
||||
eval_dataset: Dataset | dict[str, Dataset] | None = None,
|
||||
processing_class: PreTrainedTokenizerBase
|
||||
| BaseImageProcessor
|
||||
| FeatureExtractionMixin
|
||||
| ProcessorMixin
|
||||
| None = None,
|
||||
model_init: Callable[[], PreTrainedModel] | None = None,
|
||||
callbacks: list[TrainerCallback] | None = None,
|
||||
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
peft_config: Optional[dict] = None,
|
||||
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
||||
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
||||
peft_config: dict | None = None,
|
||||
compute_metrics: Callable[[EvalLoopOutput], dict] | None = None,
|
||||
):
|
||||
if args.model_init_kwargs is None:
|
||||
model_init_kwargs = {}
|
||||
@ -415,7 +418,7 @@ class ORPOTrainer(BaseTrainer):
|
||||
attention_mask=answer_attention_mask,
|
||||
)
|
||||
|
||||
def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> dict:
|
||||
def tokenize_row(self, feature, model: PreTrainedModel | nn.Module | None = None) -> dict:
|
||||
"""Tokenize a single row from a ORPO specific dataset.
|
||||
|
||||
At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation in case the prompt +
|
||||
@ -463,7 +466,8 @@ class ORPOTrainer(BaseTrainer):
|
||||
# Make sure prompts only have one different token at most an
|
||||
# and length only differs by 1 at most
|
||||
num_diff_tokens = sum(
|
||||
a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])
|
||||
a != b
|
||||
for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"], strict=True)
|
||||
)
|
||||
num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
|
||||
if num_diff_tokens > 1 or num_diff_len > 1:
|
||||
@ -572,11 +576,11 @@ class ORPOTrainer(BaseTrainer):
|
||||
|
||||
@staticmethod
|
||||
def concatenated_inputs(
|
||||
batch: dict[str, Union[list, torch.LongTensor]],
|
||||
batch: dict[str, list | torch.LongTensor],
|
||||
is_encoder_decoder: bool = False,
|
||||
label_pad_token_id: int = -100,
|
||||
padding_value: int = 0,
|
||||
device: Optional[torch.device] = None,
|
||||
device: torch.device | None = None,
|
||||
) -> dict[str, torch.LongTensor]:
|
||||
"""Concatenate the chosen and rejected inputs into a single tensor.
|
||||
|
||||
@ -714,7 +718,7 @@ class ORPOTrainer(BaseTrainer):
|
||||
return (per_token_logps * loss_mask).sum(-1)
|
||||
|
||||
def concatenated_forward(
|
||||
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]
|
||||
self, model: nn.Module, batch: dict[str, list | torch.LongTensor]
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
||||
|
||||
@ -797,7 +801,7 @@ class ORPOTrainer(BaseTrainer):
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model,
|
||||
batch: dict[str, Union[list, torch.LongTensor]],
|
||||
batch: dict[str, list | torch.LongTensor],
|
||||
train_eval: Literal["train", "eval"] = "train",
|
||||
):
|
||||
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
|
||||
@ -851,11 +855,11 @@ class ORPOTrainer(BaseTrainer):
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module],
|
||||
inputs: dict[str, Union[torch.Tensor, Any]],
|
||||
model: PreTrainedModel | nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
return_outputs=False,
|
||||
num_items_in_batch=None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
||||
) -> torch.Tensor | tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
||||
compute_loss_context_manager = (
|
||||
autocast(self.accelerator.device.type) if self._peft_has_been_casted_to_bf16 else nullcontext()
|
||||
)
|
||||
@ -898,10 +902,10 @@ class ORPOTrainer(BaseTrainer):
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module],
|
||||
inputs: dict[str, Union[torch.Tensor, Any]],
|
||||
model: PreTrainedModel | nn.Module,
|
||||
inputs: dict[str, torch.Tensor | Any],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[list[str]] = None,
|
||||
ignore_keys: list[str] | None = None,
|
||||
):
|
||||
if not self.use_dpo_data_collator:
|
||||
logger.warning(
|
||||
@ -946,8 +950,8 @@ class ORPOTrainer(BaseTrainer):
|
||||
self,
|
||||
dataloader: DataLoader,
|
||||
description: str,
|
||||
prediction_loss_only: Optional[bool] = None,
|
||||
ignore_keys: Optional[list[str]] = None,
|
||||
prediction_loss_only: bool | None = None,
|
||||
ignore_keys: list[str] | None = None,
|
||||
metric_key_prefix: str = "eval",
|
||||
) -> EvalLoopOutput:
|
||||
"""
|
||||
@ -973,7 +977,8 @@ class ORPOTrainer(BaseTrainer):
|
||||
table = pd.DataFrame(
|
||||
columns=["Prompt", "Policy"],
|
||||
data=[
|
||||
[prompt, pol[len(prompt) :]] for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
|
||||
[prompt, pol[len(prompt) :]]
|
||||
for prompt, pol in zip(random_batch["prompt"], policy_output_decoded, strict=True)
|
||||
],
|
||||
)
|
||||
if "wandb" in self.args.report_to:
|
||||
@ -992,7 +997,7 @@ class ORPOTrainer(BaseTrainer):
|
||||
|
||||
return initial_output
|
||||
|
||||
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
||||
def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
|
||||
"""
|
||||
Log `logs` on the various objects watching training, including stored metrics.
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
from ..trainer.utils import OnPolicyConfig
|
||||
|
||||
@ -76,11 +76,11 @@ class PPOConfig(OnPolicyConfig):
|
||||
default="EleutherAI/pythia-160m",
|
||||
metadata={"help": "Path to the reward model."},
|
||||
)
|
||||
model_adapter_name: Optional[str] = field(
|
||||
model_adapter_name: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Name of the train target PEFT adapter, when using LoRA with multiple adapters."},
|
||||
)
|
||||
ref_adapter_name: Optional[str] = field(
|
||||
ref_adapter_name: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Name of the reference PEFT adapter, when using LoRA with multiple adapters."},
|
||||
)
|
||||
|
@ -20,7 +20,6 @@ import time
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@ -146,18 +145,18 @@ class PPOTrainer(BaseTrainer):
|
||||
def __init__(
|
||||
self,
|
||||
args: PPOConfig,
|
||||
processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin],
|
||||
processing_class: PreTrainedTokenizerBase | BaseImageProcessor | FeatureExtractionMixin | ProcessorMixin,
|
||||
model: nn.Module,
|
||||
ref_model: Optional[nn.Module],
|
||||
ref_model: nn.Module | None,
|
||||
reward_model: nn.Module,
|
||||
train_dataset: Dataset,
|
||||
value_model: nn.Module,
|
||||
data_collator: Optional[DataCollatorWithPadding] = None,
|
||||
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
||||
data_collator: DataCollatorWithPadding | None = None,
|
||||
eval_dataset: Dataset | dict[str, Dataset] | None = None,
|
||||
# less commonly used
|
||||
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
callbacks: Optional[list[TrainerCallback]] = None,
|
||||
peft_config: Optional["PeftConfig"] = None,
|
||||
callbacks: list[TrainerCallback] | None = None,
|
||||
peft_config: "PeftConfig | None" = None,
|
||||
) -> None:
|
||||
if ref_model is model:
|
||||
raise ValueError(
|
||||
@ -371,7 +370,7 @@ class PPOTrainer(BaseTrainer):
|
||||
if self.ref_adapter_name:
|
||||
self.model.policy.set_adapter(self.model_adapter_name or "default")
|
||||
|
||||
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
|
||||
def save_model(self, output_dir: str | None = None, _internal_call: bool = False):
|
||||
backup_model = self.model
|
||||
self.model = self.model.policy # save only the policy
|
||||
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from transformers import TrainingArguments
|
||||
|
||||
@ -66,7 +65,7 @@ class PRMConfig(TrainingArguments):
|
||||
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
|
||||
},
|
||||
)
|
||||
bf16: Optional[bool] = field(
|
||||
bf16: bool | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA "
|
||||
@ -75,15 +74,15 @@ class PRMConfig(TrainingArguments):
|
||||
},
|
||||
)
|
||||
|
||||
max_length: Optional[int] = field(
|
||||
max_length: int | None = field(
|
||||
default=1024,
|
||||
metadata={"help": "Maximum length of the sequences (prompt + completion) used for truncation."},
|
||||
)
|
||||
max_prompt_length: Optional[int] = field(
|
||||
max_prompt_length: int | None = field(
|
||||
default=512,
|
||||
metadata={"help": "Maximum length of the prompt used for truncation."},
|
||||
)
|
||||
max_completion_length: Optional[int] = field(
|
||||
max_completion_length: int | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Maximum length of the completion used for truncation. The completion is the concatenation of the "
|
||||
@ -102,7 +101,7 @@ class PRMConfig(TrainingArguments):
|
||||
default=False,
|
||||
metadata={"help": "Whether to train only on the last step."},
|
||||
)
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
dataset_num_proc: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of processes to use for processing the dataset."},
|
||||
)
|
||||
|
@ -13,9 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
import textwrap
|
||||
from collections.abc import Callable
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -99,23 +99,25 @@ class PRMTrainer(BaseTrainer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
|
||||
args: Optional[PRMConfig] = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
||||
processing_class: Optional[
|
||||
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
||||
] = None,
|
||||
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
||||
callbacks: Optional[list[TrainerCallback]] = None,
|
||||
model: PreTrainedModel | nn.Module | None = None,
|
||||
args: PRMConfig | None = None,
|
||||
data_collator: DataCollator | None = None,
|
||||
train_dataset: Dataset | None = None,
|
||||
eval_dataset: Dataset | dict[str, Dataset] | None = None,
|
||||
processing_class: PreTrainedTokenizerBase
|
||||
| BaseImageProcessor
|
||||
| FeatureExtractionMixin
|
||||
| ProcessorMixin
|
||||
| None = None,
|
||||
model_init: Callable[[], PreTrainedModel] | None = None,
|
||||
compute_metrics: Callable[[EvalPrediction], dict] | None = None,
|
||||
callbacks: list[TrainerCallback] | None = None,
|
||||
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
|
||||
None,
|
||||
None,
|
||||
),
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
peft_config: Optional[dict] = None,
|
||||
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
|
||||
peft_config: dict | None = None,
|
||||
):
|
||||
if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)):
|
||||
model = prepare_peft_model(model, peft_config, args)
|
||||
@ -263,7 +265,9 @@ class PRMTrainer(BaseTrainer):
|
||||
completions_ids = [completion + separator_ids for completion in completions_ids]
|
||||
|
||||
# Create the label
|
||||
labels = [[-100] * (len(completion) - 1) + [label] for completion, label in zip(completions_ids, labels)]
|
||||
labels = [
|
||||
[-100] * (len(completion) - 1) + [label] for completion, label in zip(completions_ids, labels, strict=True)
|
||||
]
|
||||
|
||||
# Join the completions and labels steps
|
||||
completion_ids = list(chain(*completions_ids))
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from transformers import TrainingArguments
|
||||
|
||||
@ -92,7 +92,7 @@ class RewardConfig(TrainingArguments):
|
||||
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
|
||||
},
|
||||
)
|
||||
bf16: Optional[bool] = field(
|
||||
bf16: bool | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA "
|
||||
@ -102,14 +102,14 @@ class RewardConfig(TrainingArguments):
|
||||
)
|
||||
|
||||
# Parameters that control the model
|
||||
model_init_kwargs: Optional[dict[str, Any]] = field(
|
||||
model_init_kwargs: dict[str, Any] | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of "
|
||||
"the `RewardTrainer` is provided as a string."
|
||||
},
|
||||
)
|
||||
chat_template_path: Optional[str] = field(
|
||||
chat_template_path: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If specified, sets the model's chat template. This can either be the path to a tokenizer (local "
|
||||
@ -124,37 +124,37 @@ class RewardConfig(TrainingArguments):
|
||||
)
|
||||
|
||||
# Parameters that control the data preprocessing
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
dataset_num_proc: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of processes to use for processing the dataset."},
|
||||
)
|
||||
eos_token: Optional[str] = field(
|
||||
eos_token: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Token used to indicate the end of a turn or sequence. If `None`, it defaults to `processing_class.eos_token`."
|
||||
},
|
||||
)
|
||||
pad_token: Optional[str] = field(
|
||||
pad_token: str | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that "
|
||||
"is also `None`, it falls back to `processing_class.eos_token`."
|
||||
},
|
||||
)
|
||||
max_length: Optional[int] = field(
|
||||
max_length: int | None = field(
|
||||
default=1024,
|
||||
metadata={
|
||||
"help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from"
|
||||
"the right. If `None`, no truncation is applied."
|
||||
},
|
||||
)
|
||||
pad_to_multiple_of: Optional[int] = field(
|
||||
pad_to_multiple_of: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "If set, the sequences will be padded to a multiple of this value."},
|
||||
)
|
||||
|
||||
# Parameters that control the training
|
||||
center_rewards_coefficient: Optional[float] = field(
|
||||
center_rewards_coefficient: float | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Coefficient to incentivize the reward model to output mean-zero rewards (proposed by "
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user