Compare commits

...

16 Commits

Author SHA1 Message Date
3595eb00e0 Release: v0.8.5 (#1555) 2024-04-18 13:56:36 +02:00
9afd901d0f enable multiple eos tokens (#1553) 2024-04-18 12:19:18 +02:00
e04432d5e3 FIX: make the train / test fields modulable (#1551)
* make the train / test fields modulable

* format

* fix --output_dir issue
2024-04-18 11:33:30 +02:00
75c1c47fcc set dev version (#1548) 2024-04-17 17:25:01 +02:00
a5788ac99b Release: v0.8.4 (#1547) 2024-04-17 17:19:28 +02:00
3bbe7e0407 Fixed ref model not used in PPO generation (#1534) 2024-04-17 07:22:56 -07:00
edf60e826b Update run_sft.sh (#1546) 2024-04-17 16:17:05 +02:00
5d1deb1445 CLI: Set dataset_text_field to None to allow ChatML automatic template (#1545)
* Update cli_utils.py

* Update test_cli.py
2024-04-17 14:45:14 +02:00
476c4b8dc0 [KTO] support to load the adapter twice (#1542)
Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de>
2024-04-16 17:43:40 +02:00
e823458a6a save_model -> save_pretrained in ppo_trainer.mdx (#1537) 2024-04-15 09:35:03 +02:00
1c0d8bca15 VSFT hotfix - adds gen prompt to template and processor to hub (#1532)
* adds gen prompt to template and processor to hub

* fixes hub model id, removes Path
2024-04-12 20:14:12 +02:00
363369a717 [CPO] fix memory leak due to retained value (#1531) 2024-04-12 15:32:01 +02:00
aba4df02c1 set dev version (#1529) 2024-04-12 12:37:34 +02:00
98226473e4 Release: v0.8.3 (#1528) 2024-04-12 12:22:05 +02:00
87f4c70e60 [CLI] fix imports (#1527) 2024-04-12 12:17:05 +02:00
995f1174da set dev version (#1523) 2024-04-11 15:51:57 +02:00
13 changed files with 143 additions and 20 deletions

View File

@ -41,6 +41,7 @@ accelerate launch $EXTRA_ACCELERATE_ARGS \
--dataset_name $DATASET_NAME \
--output_dir $OUTPUT_DIR \
--max_steps $MAX_STEPS \
--dataset_text_field 'text' \
--per_device_train_batch_size $BATCH_SIZE \
--max_seq_length $SEQ_LEN \
$EXTRA_TRAINING_ARGS
@ -56,4 +57,4 @@ echo "Starting program..."
echo "Operation Failed!"
exit 1
}
exit 0
exit 0

View File

@ -151,7 +151,7 @@ for epoch in tqdm(range(epochs), "epoch: "):
ppo_trainer.log_stats(stats, batch, rewards)
#### Save model
ppo_trainer.save_model("my_ppo_model")
ppo_trainer.save_pretrained("my_ppo_model")
```
## Logging

View File

@ -231,6 +231,26 @@ def load_model_and_tokenizer(args):
return model, tokenizer
def parse_eos_tokens(tokenizer, eos_tokens, eos_token_ids):
if tokenizer.pad_token_id is None:
pad_token_id = tokenizer.eos_token_id
else:
pad_token_id = tokenizer.pad_token_id
all_eos_token_ids = []
if eos_tokens is not None:
all_eos_token_ids.extend(tokenizer.convert_tokens_to_ids(eos_tokens.split(",")))
if eos_token_ids is not None:
all_eos_token_ids.extend([int(token_id) for token_id in eos_token_ids.split(",")])
if len(all_eos_token_ids) == 0:
all_eos_token_ids.append(tokenizer.eos_token_id)
return pad_token_id, all_eos_token_ids
def chat_cli():
parser = TrlParser(ChatArguments)
args = parser.parse_args_into_dataclasses()[0]
@ -252,6 +272,8 @@ def chat_cli():
model, tokenizer = load_model_and_tokenizer(args)
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
pad_token_id, eos_token_ids = parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids)
interface = RichInterface(model_name=args.model_name_or_path, user_name=user)
interface.clear()
chat = clear_chat_history(current_args.system_prompt)
@ -322,8 +344,8 @@ def chat_cli():
top_k=current_args.top_k,
top_p=current_args.top_p,
repetition_penalty=current_args.repetition_penalty,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=pad_token_id,
eos_token_id=eos_token_ids,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)

View File

@ -114,8 +114,9 @@ if __name__ == "__main__":
# Dataset
################
raw_datasets = load_dataset(args.dataset_name)
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]
train_dataset = raw_datasets[args.dataset_train_name]
eval_dataset = raw_datasets[args.dataset_test_name]
################
# Optional rich context managers

View File

@ -79,6 +79,7 @@ if TRL_USE_RICH:
from rich.logging import RichHandler
import torch
from accelerate import Accelerator
from datasets import load_dataset
from tqdm.rich import tqdm
@ -111,7 +112,7 @@ if __name__ == "__main__":
################
# Model, Tokenizer & Processor
################
LLAVA_CHAT_TEMPLATE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}"""
LLAVA_CHAT_TEMPLATE = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}"""
torch_dtype = (
model_config.torch_dtype
@ -205,3 +206,5 @@ if __name__ == "__main__":
with save_context:
trainer.save_model(training_args.output_dir)
trainer.push_to_hub()
if Accelerator().is_main_process:
processor.push_to_hub(training_args.hub_model_id)

View File

@ -58,7 +58,7 @@ import os
from setuptools import find_packages, setup
__version__ = "0.8.2" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
__version__ = "0.8.5" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
REQUIRED_PKGS = [
"torch>=1.4.0",

View File

@ -20,7 +20,7 @@ import unittest
def test_sft_cli():
try:
subprocess.run(
"trl sft --max_steps 1 --output_dir tmp-sft --model_name_or_path HuggingFaceM4/tiny-random-LlamaForCausalLM --dataset_name imdb --learning_rate 1e-4 --lr_scheduler_type cosine",
"trl sft --max_steps 1 --output_dir tmp-sft --model_name_or_path HuggingFaceM4/tiny-random-LlamaForCausalLM --dataset_name imdb --learning_rate 1e-4 --lr_scheduler_type cosine --dataset_text_field text",
shell=True,
check=True,
)

View File

@ -1142,6 +1142,55 @@ class PPOTrainerTester(unittest.TestCase):
assert generations_single == generations_batched
def test_generation_with_ref_model(self):
dummy_dataset = self._init_dummy_dataset()
model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# Negate the weights in the last layer of the ref model so it never
# outputs the same things as the primary model
ref_model = copy.deepcopy(model)
lm_head_weight = ref_model.pretrained_model.lm_head.weight
lm_head_weight.data = -lm_head_weight.data
ppo_trainer = PPOTrainer(
config=self.ppo_config,
model=model,
ref_model=ref_model,
tokenizer=tokenizer,
dataset=dummy_dataset,
)
input_texts = ["this is a test", "this is another, longer test"]
generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": tokenizer.eos_token_id}
tokenizer.pad_token = tokenizer.eos_token
model_inputs = [tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts]
generations_batched, ref_generations_batched = ppo_trainer.generate(
model_inputs, batch_size=2, generate_ref_response=True, **generation_kwargs
)
generations_batched = tokenizer.batch_decode(generations_batched)
ref_generations_batched = tokenizer.batch_decode(ref_generations_batched)
generations_single = []
ref_generations_single = []
for inputs in model_inputs:
generation, ref_generation = ppo_trainer.generate(inputs, generate_ref_response=True, **generation_kwargs)
generations_single.append(generation.squeeze())
ref_generations_single.append(ref_generation.squeeze())
generations_single = tokenizer.batch_decode(generations_single)
ref_generations_single = tokenizer.batch_decode(ref_generations_single)
assert generations_single == generations_batched
assert ref_generations_single == ref_generations_batched
assert generations_batched != ref_generations_batched
assert generations_single != ref_generations_single
def test_grad_accumulation(self):
dummy_dataset = self._init_dummy_dataset()

View File

@ -1,6 +1,6 @@
# flake8: noqa
__version__ = "0.8.2"
__version__ = "0.8.5"
from typing import TYPE_CHECKING
from .import_utils import _LazyModule, is_diffusers_available, OptionalDependencyNotAvailable
@ -51,7 +51,7 @@ _import_structure = {
"SFTTrainer",
],
"commands": [],
"commands.utils": ["SftArgumentParser", "init_zero_verbose", "TrlParser", "DpoArgumentParser"],
"commands.cli_utils": ["init_zero_verbose", "SftScriptArguments", "DpoScriptArguments", "TrlParser"],
"trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config", "RichProgressCallback"],
"multitask_prompt_tuning": [
"MultitaskPromptEmbedding",
@ -115,7 +115,7 @@ if TYPE_CHECKING:
SFTTrainer,
)
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config, RichProgressCallback
from .commands.utils import init_zero_verbose, SftScriptArguments, DpoScriptArguments, TrlParser
from .commands.cli_utils import init_zero_verbose, SftScriptArguments, DpoScriptArguments, TrlParser
try:
if not is_diffusers_available():

View File

@ -15,6 +15,7 @@
# limitations under the License.
import inspect
import os
import sys
from copy import deepcopy
from dataclasses import asdict, dataclass, field, fields
from typing import Any, List
@ -137,7 +138,9 @@ def init_zero_verbose():
@dataclass
class SftScriptArguments:
dataset_name: str = field(default="timdettmers/openassistant-guanaco", metadata={"help": "the dataset name"})
dataset_text_field: str = field(default="text", metadata={"help": "the text field of the dataset"})
dataset_text_field: str = field(default=None, metadata={"help": "the text field of the dataset"})
dataset_train_name: str = field(default="train", metadata={"help": "the name of the training set of the dataset"})
dataset_test_name: str = field(default="test", metadata={"help": "the name of the training set of the dataset"})
max_seq_length: int = field(default=512, metadata={"help": "The maximum sequence length for SFT Trainer"})
packing: bool = field(default=False, metadata={"help": "Whether to apply data packing or not during training"})
config: str = field(default=None, metadata={"help": "Path to the optional config file"})
@ -197,6 +200,14 @@ class ChatArguments:
top_k: int = field(default=50, metadata={"help": "Value of k for top-k sampling"})
top_p: float = field(default=1.0, metadata={"help": "Value of p for nucleus sampling"})
repetition_penalty: float = field(default=1.0, metadata={"help": "Repetition penalty"})
eos_tokens: str = field(
default=None,
metadata={"help": "EOS tokens to stop the generation. If multiple they should be comma separated"},
)
eos_token_ids: str = field(
default=None,
metadata={"help": "EOS token IDs to stop the generation. If multiple they should be comma separated"},
)
# model loading
model_revision: str = field(
default="main",
@ -269,6 +280,24 @@ class TrlParser(HfArgumentParser):
return dataclasses
def parse_args_and_config(self):
# Hack to force-replace the `output_dir` from the YAML file if one did not passed
# output_dir in the command line
if "--config" in sys.argv:
config_index = sys.argv.index("--config") + 1
config_path = sys.argv[config_index]
with open(config_path) as yaml_file:
yaml_config = yaml.safe_load(yaml_file)
output_dir = yaml_config.get("output_dir")
if output_dir is not None:
if "--output_dir" in sys.argv:
output_dir_index = sys.argv.index("--output_dir")
sys.argv.index[output_dir_index + 1] = output_dir
else:
sys.argv.extend(["--output_dir", output_dir])
dataclasses = self.parse_args_into_dataclasses(return_remaining_strings=True)
# Pop the last element which should be the remaining strings
dataclasses = self.update_dataclasses_with_config(dataclasses[:-1])

View File

@ -738,7 +738,7 @@ class CPOTrainer(Trainer):
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
metrics[f"{prefix}nll_loss"] = policy_nll_loss.cpu().mean()
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu()
return loss, metrics

View File

@ -16,7 +16,7 @@ import inspect
import random
import warnings
from collections import defaultdict
from contextlib import nullcontext
from contextlib import contextmanager, nullcontext
from copy import deepcopy
from functools import wraps
from operator import itemgetter
@ -257,6 +257,10 @@ class KTOTrainer(Trainer):
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
The function to use to compute the metrics. Must take a `EvalPrediction` and return
a dictionary string to metric values.
model_adapter_name (`str`, defaults to `None`):
Name of the train target PEFT adapter, when using LoRA with multiple adapters.
ref_adapter_name (`str`, defaults to `None`):
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
"""
_tag_names = ["trl", "kto"]
@ -276,6 +280,8 @@ class KTOTrainer(Trainer):
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
peft_config: Optional[Dict] = None,
compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
model_adapter_name: Optional[str] = None,
ref_adapter_name: Optional[str] = None,
):
if type(args) == TrainingArguments:
raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
@ -392,6 +398,8 @@ class KTOTrainer(Trainer):
self.is_encoder_decoder = args.is_encoder_decoder
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
self.model_adapter_name = model_adapter_name
self.ref_adapter_name = ref_adapter_name
if ref_model:
self.ref_model = ref_model
@ -677,6 +685,18 @@ class KTOTrainer(Trainer):
model.eval()
return model
@contextmanager
def null_ref_context(self):
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
with self.accelerator.unwrap_model(
self.model
).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext():
if self.ref_adapter_name:
self.model.set_adapter(self.ref_adapter_name)
yield
if self.ref_adapter_name:
self.model.set_adapter(self.model_adapter_name or "default")
def get_train_dataloader(self) -> DataLoader:
"""
Returns the training [`~torch.utils.data.DataLoader`].
@ -775,9 +795,7 @@ class KTOTrainer(Trainer):
"""Computes log probabilities of the reference model for a single padded batch of a KTO specific dataset."""
with torch.no_grad():
if self.ref_model is None:
with self.accelerator.unwrap_model(
self.model
).disable_adapter() if self.is_peft_model else nullcontext():
with self.null_ref_context():
if self.is_encoder_decoder:
completion_logits = self.model(
padded_batch["prompt_input_ids"],
@ -1029,7 +1047,7 @@ class KTOTrainer(Trainer):
else:
with torch.no_grad():
if self.ref_model is None:
with self.accelerator.unwrap_model(self.model).disable_adapter():
with self.null_ref_context():
(
reference_chosen_logps,
reference_rejected_logps,

View File

@ -498,7 +498,7 @@ class PPOTrainer(BaseTrainer):
if generate_ref_response:
with unwrap_model_for_generation(
self.model, self.accelerator, is_peft_model=self.is_peft_model
ref_model, self.accelerator, is_peft_model=self.is_peft_model
) as unwrapped_model:
ref_response = unwrapped_model.generate(
input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs