Compare commits

...

3 Commits

Author SHA1 Message Date
86ad7a7e85 Version 0.11.0-> 0.11.1 2024-09-24 15:47:35 +00:00
3aa1996986 [online-dpo] allow parse-args as list of floats (#2108)
* use a seperate argument for list of floats

* do super first

* fix docstrings

* typos

* use list of floats only

* check if it has len

* fix docstring

* fix suggestion

* fix default

* Update trl/trainer/online_dpo_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/xpo_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/nash_md_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/nash_md_config.py

* additional tests

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-09-24 15:43:39 +00:00
e4935e13fc Release v0.11.0 2024-09-19 09:49:54 +02:00
6 changed files with 103 additions and 15 deletions

View File

@ -73,7 +73,7 @@ import os
from setuptools import find_packages, setup
__version__ = "0.11.0.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
__version__ = "0.11.1" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
REQUIRED_PKGS = [
"torch>=1.4.0",

View File

@ -16,6 +16,7 @@ import tempfile
import unittest
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from trl import (
@ -27,12 +28,18 @@ from trl import (
DPOTrainer,
KTOConfig,
KTOTrainer,
NashMDConfig,
NashMDTrainer,
OnlineDPOConfig,
OnlineDPOTrainer,
ORPOConfig,
ORPOTrainer,
RewardConfig,
RewardTrainer,
SFTConfig,
SFTTrainer,
XPOConfig,
XPOTrainer,
)
@ -219,8 +226,30 @@ class TrainerArgTester(unittest.TestCase):
self.assertEqual(trainer.args.ref_model_init_kwargs, {"trust_remote_code": True})
self.assertEqual(trainer.args.dataset_num_proc, 4)
def test_online_dpo(self):
tokenizer = AutoTokenizer.from_pretrained("gpt2")
@parameterized.expand([(False,), (True,)])
def test_nash_md(self, mixtures_coef_list):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = NashMDConfig(
tmp_dir,
mixture_coef=0.5 if not mixtures_coef_list else [0.5, 0.6],
)
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m")
ref_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m")
reward_model = AutoModelForSequenceClassification.from_pretrained("EleutherAI/pythia-14m", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m")
trainer = NashMDTrainer(
args=training_args,
tokenizer=tokenizer,
model=model,
ref_model=ref_model,
reward_model=reward_model,
train_dataset=dataset,
)
self.assertEqual(trainer.args.mixture_coef, 0.5 if not mixtures_coef_list else [0.5, 0.6])
@parameterized.expand([(False,), (True,)])
def test_online_dpo(self, beta_list):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
args = OnlineDPOConfig(
@ -228,13 +257,14 @@ class TrainerArgTester(unittest.TestCase):
max_new_tokens=42,
temperature=0.5,
missing_eos_penalty=0.33,
beta=0.6,
beta=0.6 if not beta_list else [0.6, 0.7],
loss_type="hinge",
dataset_num_proc=4,
)
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m")
ref_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m")
reward_model = AutoModelForSequenceClassification.from_pretrained("EleutherAI/pythia-14m", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m")
trainer = OnlineDPOTrainer(
args=args,
tokenizer=tokenizer,
@ -246,7 +276,7 @@ class TrainerArgTester(unittest.TestCase):
self.assertEqual(trainer.args.max_new_tokens, 42)
self.assertEqual(trainer.args.temperature, 0.5)
self.assertEqual(trainer.args.missing_eos_penalty, 0.33)
self.assertEqual(trainer.args.beta, 0.6)
self.assertEqual(trainer.args.beta, 0.6 if not beta_list else [0.6, 0.7])
self.assertEqual(trainer.args.loss_type, "hinge")
self.assertEqual(trainer.args.dataset_num_proc, 4)
@ -278,6 +308,27 @@ class TrainerArgTester(unittest.TestCase):
self.assertEqual(trainer.args.disable_dropout, False)
self.assertEqual(trainer.args.label_pad_token_id, -99)
def test_reward(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = RewardConfig(
tmp_dir,
max_length=256,
dataset_num_proc=4,
center_rewards_coefficient=0.1,
)
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m")
trainer = RewardTrainer(
model=model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
)
self.assertEqual(trainer.args.max_length, 256)
self.assertEqual(trainer.args.dataset_num_proc, 4)
self.assertEqual(trainer.args.center_rewards_coefficient, 0.1)
def test_sft(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
@ -308,3 +359,25 @@ class TrainerArgTester(unittest.TestCase):
self.assertEqual(trainer.args.eval_packing, True)
self.assertEqual(trainer.args.num_of_sequences, 32)
self.assertEqual(trainer.args.chars_per_token, 4.2)
@parameterized.expand([(False,), (True,)])
def test_xpo(self, alpha_list):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = XPOConfig(
tmp_dir,
alpha=0.5 if not alpha_list else [0.5, 0.6],
)
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m")
ref_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m")
reward_model = AutoModelForSequenceClassification.from_pretrained("EleutherAI/pythia-14m", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m")
trainer = XPOTrainer(
args=training_args,
tokenizer=tokenizer,
model=model,
ref_model=ref_model,
reward_model=reward_model,
train_dataset=dataset,
)
self.assertEqual(trainer.args.alpha, 0.5 if not alpha_list else [0.5, 0.6])

View File

@ -14,7 +14,7 @@
# flake8: noqa
__version__ = "0.11.0.dev0"
__version__ = "0.11.1"
from typing import TYPE_CHECKING
from .import_utils import _LazyModule, is_diffusers_available, OptionalDependencyNotAvailable

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Union
from dataclasses import dataclass, field
from typing import List
from trl.trainer.online_dpo_config import OnlineDPOConfig
@ -32,4 +32,9 @@ class NashMDConfig(OnlineDPOConfig):
epochs.
"""
mixture_coef: Union[float, List[float]] = 0.5
mixture_coef: List[float] = field(default_factory=lambda: [0.5])
def __post_init__(self):
super().__post_init__()
if hasattr(self.mixture_coef, "__len__") and len(self.mixture_coef) == 1:
self.mixture_coef = self.mixture_coef[0]

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Literal, Optional, Union
from dataclasses import dataclass, field
from typing import List, Literal, Optional
from transformers import TrainingArguments
@ -63,7 +63,12 @@ class OnlineDPOConfig(TrainingArguments):
max_new_tokens: int = 64
temperature: float = 0.9
missing_eos_penalty: Optional[float] = None
beta: Union[float, List[float]] = 0.1
beta: List[float] = field(default_factory=lambda: [0.1])
loss_type: Literal["sigmoid", "ipo"] = "sigmoid"
dataset_num_proc: Optional[int] = None
disable_dropout: bool = True
def __post_init__(self):
super().__post_init__()
if hasattr(self.beta, "__len__") and len(self.beta) == 1:
self.beta = self.beta[0]

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Union
from dataclasses import dataclass, field
from typing import List
from trl.trainer.online_dpo_config import OnlineDPOConfig
@ -30,4 +30,9 @@ class XPOConfig(OnlineDPOConfig):
Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each new epoch and the last alpha is used for the rest of the epochs.
"""
alpha: Union[float, List[float]] = 1e-5
alpha: List[float] = field(default_factory=lambda: [1e-5])
def __post_init__(self):
super().__post_init__()
if hasattr(self.alpha, "__len__") and len(self.alpha) == 1:
self.alpha = self.alpha[0]