Compare commits

...

5 Commits

Author SHA1 Message Date
01142bb7c4 Version 0.11.1 -> 0.11.2 2024-10-07 15:59:17 +00:00
d3fb4860e2 Fix RLOO checkpointing (#2114)
* Fix RLOO checkpointing for transformers>=4.45.0

* Add missing import

* Fix pre-commit issues

* Added test for RLOO checkpointing

* Ensure that tokenizer matches SFT and Reward model

* Pre-commit formatting

* processing class

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-10-07 15:47:38 +00:00
86ad7a7e85 Version 0.11.0-> 0.11.1 2024-09-24 15:47:35 +00:00
3aa1996986 [online-dpo] allow parse-args as list of floats (#2108)
* use a seperate argument for list of floats

* do super first

* fix docstrings

* typos

* use list of floats only

* check if it has len

* fix docstring

* fix suggestion

* fix default

* Update trl/trainer/online_dpo_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/xpo_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/nash_md_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/nash_md_config.py

* additional tests

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-09-24 15:43:39 +00:00
e4935e13fc Release v0.11.0 2024-09-19 09:49:54 +02:00
8 changed files with 159 additions and 20 deletions

View File

@ -73,7 +73,7 @@ import os
from setuptools import find_packages, setup
__version__ = "0.11.0.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
__version__ = "0.11.2" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
REQUIRED_PKGS = [
"torch>=1.4.0",

View File

@ -13,8 +13,14 @@
# limitations under the License.
import platform
import subprocess
import tempfile
import unittest
import torch
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from trl import RLOOConfig, RLOOTrainer
def test():
@ -26,6 +32,8 @@ python examples/scripts/rloo/rloo.py \
--gradient_accumulation_steps 1 \
--total_episodes 10 \
--model_name_or_path EleutherAI/pythia-14m \
--sft_model_path EleutherAI/pythia-14m \
--reward_model_path EleutherAI/pythia-14m \
--missing_eos_penalty 1.0 \
--save_strategy no \
--stop_token eos
@ -71,3 +79,42 @@ def test_rloo_reward():
baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1)
vec_advantages = rlhf_reward - baseline
torch.testing.assert_close(vec_advantages.flatten(), advantages)
class RLOOTrainerTester(unittest.TestCase):
def setUp(self):
self.sft_model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
self.reward_model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
self.policy_model = AutoModelForCausalLM.from_pretrained(self.sft_model_id)
self.reward_model = AutoModelForSequenceClassification.from_pretrained(self.reward_model_id)
self.policy_ref_model = AutoModelForCausalLM.from_pretrained(self.sft_model_id)
self.tokenizer = AutoTokenizer.from_pretrained(self.sft_model_id, padding_side="left")
self.tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}"
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
def test_rloo_checkpoint(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = RLOOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
total_episodes=1,
report_to="none",
)
dummy_text = {"content": "Hello World!", "role": "user"}
dummy_data = self.tokenizer.apply_chat_template(dummy_text)
dummy_dataset = Dataset.from_dict({"input_ids": dummy_data})
trainer = RLOOTrainer(
config=training_args,
policy=self.policy_model,
reward_model=self.reward_model,
ref_policy=self.policy_ref_model,
processing_class=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
)
trainer._save_checkpoint(trainer.model, trial=None)

View File

@ -16,6 +16,7 @@ import tempfile
import unittest
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from trl import (
@ -27,12 +28,18 @@ from trl import (
DPOTrainer,
KTOConfig,
KTOTrainer,
NashMDConfig,
NashMDTrainer,
OnlineDPOConfig,
OnlineDPOTrainer,
ORPOConfig,
ORPOTrainer,
RewardConfig,
RewardTrainer,
SFTConfig,
SFTTrainer,
XPOConfig,
XPOTrainer,
)
@ -219,8 +226,30 @@ class TrainerArgTester(unittest.TestCase):
self.assertEqual(trainer.args.ref_model_init_kwargs, {"trust_remote_code": True})
self.assertEqual(trainer.args.dataset_num_proc, 4)
def test_online_dpo(self):
tokenizer = AutoTokenizer.from_pretrained("gpt2")
@parameterized.expand([(False,), (True,)])
def test_nash_md(self, mixtures_coef_list):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = NashMDConfig(
tmp_dir,
mixture_coef=0.5 if not mixtures_coef_list else [0.5, 0.6],
)
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m")
ref_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m")
reward_model = AutoModelForSequenceClassification.from_pretrained("EleutherAI/pythia-14m", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m")
trainer = NashMDTrainer(
args=training_args,
tokenizer=tokenizer,
model=model,
ref_model=ref_model,
reward_model=reward_model,
train_dataset=dataset,
)
self.assertEqual(trainer.args.mixture_coef, 0.5 if not mixtures_coef_list else [0.5, 0.6])
@parameterized.expand([(False,), (True,)])
def test_online_dpo(self, beta_list):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
args = OnlineDPOConfig(
@ -228,13 +257,14 @@ class TrainerArgTester(unittest.TestCase):
max_new_tokens=42,
temperature=0.5,
missing_eos_penalty=0.33,
beta=0.6,
beta=0.6 if not beta_list else [0.6, 0.7],
loss_type="hinge",
dataset_num_proc=4,
)
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m")
ref_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m")
reward_model = AutoModelForSequenceClassification.from_pretrained("EleutherAI/pythia-14m", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m")
trainer = OnlineDPOTrainer(
args=args,
tokenizer=tokenizer,
@ -246,7 +276,7 @@ class TrainerArgTester(unittest.TestCase):
self.assertEqual(trainer.args.max_new_tokens, 42)
self.assertEqual(trainer.args.temperature, 0.5)
self.assertEqual(trainer.args.missing_eos_penalty, 0.33)
self.assertEqual(trainer.args.beta, 0.6)
self.assertEqual(trainer.args.beta, 0.6 if not beta_list else [0.6, 0.7])
self.assertEqual(trainer.args.loss_type, "hinge")
self.assertEqual(trainer.args.dataset_num_proc, 4)
@ -278,6 +308,27 @@ class TrainerArgTester(unittest.TestCase):
self.assertEqual(trainer.args.disable_dropout, False)
self.assertEqual(trainer.args.label_pad_token_id, -99)
def test_reward(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = RewardConfig(
tmp_dir,
max_length=256,
dataset_num_proc=4,
center_rewards_coefficient=0.1,
)
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m")
trainer = RewardTrainer(
model=model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
)
self.assertEqual(trainer.args.max_length, 256)
self.assertEqual(trainer.args.dataset_num_proc, 4)
self.assertEqual(trainer.args.center_rewards_coefficient, 0.1)
def test_sft(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
@ -308,3 +359,25 @@ class TrainerArgTester(unittest.TestCase):
self.assertEqual(trainer.args.eval_packing, True)
self.assertEqual(trainer.args.num_of_sequences, 32)
self.assertEqual(trainer.args.chars_per_token, 4.2)
@parameterized.expand([(False,), (True,)])
def test_xpo(self, alpha_list):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = XPOConfig(
tmp_dir,
alpha=0.5 if not alpha_list else [0.5, 0.6],
)
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m")
ref_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m")
reward_model = AutoModelForSequenceClassification.from_pretrained("EleutherAI/pythia-14m", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m")
trainer = XPOTrainer(
args=training_args,
tokenizer=tokenizer,
model=model,
ref_model=ref_model,
reward_model=reward_model,
train_dataset=dataset,
)
self.assertEqual(trainer.args.alpha, 0.5 if not alpha_list else [0.5, 0.6])

View File

@ -14,7 +14,7 @@
# flake8: noqa
__version__ = "0.11.0.dev0"
__version__ = "0.11.2"
from typing import TYPE_CHECKING
from .import_utils import _LazyModule, is_diffusers_available, OptionalDependencyNotAvailable

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Union
from dataclasses import dataclass, field
from typing import List
from trl.trainer.online_dpo_config import OnlineDPOConfig
@ -32,4 +32,9 @@ class NashMDConfig(OnlineDPOConfig):
epochs.
"""
mixture_coef: Union[float, List[float]] = 0.5
mixture_coef: List[float] = field(default_factory=lambda: [0.5])
def __post_init__(self):
super().__post_init__()
if hasattr(self.mixture_coef, "__len__") and len(self.mixture_coef) == 1:
self.mixture_coef = self.mixture_coef[0]

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Literal, Optional, Union
from dataclasses import dataclass, field
from typing import List, Literal, Optional
from transformers import TrainingArguments
@ -63,7 +63,12 @@ class OnlineDPOConfig(TrainingArguments):
max_new_tokens: int = 64
temperature: float = 0.9
missing_eos_penalty: Optional[float] = None
beta: Union[float, List[float]] = 0.1
beta: List[float] = field(default_factory=lambda: [0.1])
loss_type: Literal["sigmoid", "ipo"] = "sigmoid"
dataset_num_proc: Optional[int] = None
disable_dropout: bool = True
def __post_init__(self):
super().__post_init__()
if hasattr(self.beta, "__len__") and len(self.beta) == 1:
self.beta = self.beta[0]

View File

@ -39,7 +39,7 @@ from transformers import (
)
from transformers.integrations import get_reporting_integration_callbacks
from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK
from transformers.trainer_callback import CallbackHandler, PrinterCallback
from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback
from ..models.utils import unwrap_model_for_generation
from ..trainer.utils import (
@ -149,10 +149,6 @@ class RLOOTrainer(Trainer):
#########
### trainer specifics
#########
self.state = OnlineTrainerState(
is_local_process_zero=self.is_local_process_zero(),
is_world_process_zero=self.is_world_process_zero(),
)
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
self.callback_handler = CallbackHandler(
@ -160,6 +156,14 @@ class RLOOTrainer(Trainer):
)
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
self.control = TrainerControl()
self.state = OnlineTrainerState(
is_local_process_zero=self.is_local_process_zero(),
is_world_process_zero=self.is_world_process_zero(),
stateful_callbacks=[
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
],
)
self.current_flos = 0
self.hp_search_backend = None
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Union
from dataclasses import dataclass, field
from typing import List
from trl.trainer.online_dpo_config import OnlineDPOConfig
@ -30,4 +30,9 @@ class XPOConfig(OnlineDPOConfig):
Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each new epoch and the last alpha is used for the rest of the epochs.
"""
alpha: Union[float, List[float]] = 1e-5
alpha: List[float] = field(default_factory=lambda: [1e-5])
def __post_init__(self):
super().__post_init__()
if hasattr(self.alpha, "__len__") and len(self.alpha) == 1:
self.alpha = self.alpha[0]