mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Compare commits
16 Commits
Author | SHA1 | Date | |
---|---|---|---|
3595eb00e0 | |||
9afd901d0f | |||
e04432d5e3 | |||
75c1c47fcc | |||
a5788ac99b | |||
3bbe7e0407 | |||
edf60e826b | |||
5d1deb1445 | |||
476c4b8dc0 | |||
e823458a6a | |||
1c0d8bca15 | |||
363369a717 | |||
aba4df02c1 | |||
98226473e4 | |||
87f4c70e60 | |||
995f1174da |
@ -41,6 +41,7 @@ accelerate launch $EXTRA_ACCELERATE_ARGS \
|
||||
--dataset_name $DATASET_NAME \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--max_steps $MAX_STEPS \
|
||||
--dataset_text_field 'text' \
|
||||
--per_device_train_batch_size $BATCH_SIZE \
|
||||
--max_seq_length $SEQ_LEN \
|
||||
$EXTRA_TRAINING_ARGS
|
||||
@ -56,4 +57,4 @@ echo "Starting program..."
|
||||
echo "Operation Failed!"
|
||||
exit 1
|
||||
}
|
||||
exit 0
|
||||
exit 0
|
||||
|
@ -151,7 +151,7 @@ for epoch in tqdm(range(epochs), "epoch: "):
|
||||
ppo_trainer.log_stats(stats, batch, rewards)
|
||||
|
||||
#### Save model
|
||||
ppo_trainer.save_model("my_ppo_model")
|
||||
ppo_trainer.save_pretrained("my_ppo_model")
|
||||
```
|
||||
|
||||
## Logging
|
||||
|
@ -231,6 +231,26 @@ def load_model_and_tokenizer(args):
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def parse_eos_tokens(tokenizer, eos_tokens, eos_token_ids):
|
||||
if tokenizer.pad_token_id is None:
|
||||
pad_token_id = tokenizer.eos_token_id
|
||||
else:
|
||||
pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
all_eos_token_ids = []
|
||||
|
||||
if eos_tokens is not None:
|
||||
all_eos_token_ids.extend(tokenizer.convert_tokens_to_ids(eos_tokens.split(",")))
|
||||
|
||||
if eos_token_ids is not None:
|
||||
all_eos_token_ids.extend([int(token_id) for token_id in eos_token_ids.split(",")])
|
||||
|
||||
if len(all_eos_token_ids) == 0:
|
||||
all_eos_token_ids.append(tokenizer.eos_token_id)
|
||||
|
||||
return pad_token_id, all_eos_token_ids
|
||||
|
||||
|
||||
def chat_cli():
|
||||
parser = TrlParser(ChatArguments)
|
||||
args = parser.parse_args_into_dataclasses()[0]
|
||||
@ -252,6 +272,8 @@ def chat_cli():
|
||||
model, tokenizer = load_model_and_tokenizer(args)
|
||||
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
|
||||
|
||||
pad_token_id, eos_token_ids = parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids)
|
||||
|
||||
interface = RichInterface(model_name=args.model_name_or_path, user_name=user)
|
||||
interface.clear()
|
||||
chat = clear_chat_history(current_args.system_prompt)
|
||||
@ -322,8 +344,8 @@ def chat_cli():
|
||||
top_k=current_args.top_k,
|
||||
top_p=current_args.top_p,
|
||||
repetition_penalty=current_args.repetition_penalty,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
eos_token_id=eos_token_ids,
|
||||
)
|
||||
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
|
@ -114,8 +114,9 @@ if __name__ == "__main__":
|
||||
# Dataset
|
||||
################
|
||||
raw_datasets = load_dataset(args.dataset_name)
|
||||
train_dataset = raw_datasets["train"]
|
||||
eval_dataset = raw_datasets["test"]
|
||||
|
||||
train_dataset = raw_datasets[args.dataset_train_name]
|
||||
eval_dataset = raw_datasets[args.dataset_test_name]
|
||||
|
||||
################
|
||||
# Optional rich context managers
|
||||
|
@ -79,6 +79,7 @@ if TRL_USE_RICH:
|
||||
from rich.logging import RichHandler
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from datasets import load_dataset
|
||||
|
||||
from tqdm.rich import tqdm
|
||||
@ -111,7 +112,7 @@ if __name__ == "__main__":
|
||||
################
|
||||
# Model, Tokenizer & Processor
|
||||
################
|
||||
LLAVA_CHAT_TEMPLATE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}"""
|
||||
LLAVA_CHAT_TEMPLATE = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}"""
|
||||
|
||||
torch_dtype = (
|
||||
model_config.torch_dtype
|
||||
@ -205,3 +206,5 @@ if __name__ == "__main__":
|
||||
with save_context:
|
||||
trainer.save_model(training_args.output_dir)
|
||||
trainer.push_to_hub()
|
||||
if Accelerator().is_main_process:
|
||||
processor.push_to_hub(training_args.hub_model_id)
|
||||
|
2
setup.py
2
setup.py
@ -58,7 +58,7 @@ import os
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
__version__ = "0.8.2" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
__version__ = "0.8.5" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
|
||||
REQUIRED_PKGS = [
|
||||
"torch>=1.4.0",
|
||||
|
@ -20,7 +20,7 @@ import unittest
|
||||
def test_sft_cli():
|
||||
try:
|
||||
subprocess.run(
|
||||
"trl sft --max_steps 1 --output_dir tmp-sft --model_name_or_path HuggingFaceM4/tiny-random-LlamaForCausalLM --dataset_name imdb --learning_rate 1e-4 --lr_scheduler_type cosine",
|
||||
"trl sft --max_steps 1 --output_dir tmp-sft --model_name_or_path HuggingFaceM4/tiny-random-LlamaForCausalLM --dataset_name imdb --learning_rate 1e-4 --lr_scheduler_type cosine --dataset_text_field text",
|
||||
shell=True,
|
||||
check=True,
|
||||
)
|
||||
|
@ -1142,6 +1142,55 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
|
||||
assert generations_single == generations_batched
|
||||
|
||||
def test_generation_with_ref_model(self):
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
# Negate the weights in the last layer of the ref model so it never
|
||||
# outputs the same things as the primary model
|
||||
ref_model = copy.deepcopy(model)
|
||||
lm_head_weight = ref_model.pretrained_model.lm_head.weight
|
||||
lm_head_weight.data = -lm_head_weight.data
|
||||
|
||||
ppo_trainer = PPOTrainer(
|
||||
config=self.ppo_config,
|
||||
model=model,
|
||||
ref_model=ref_model,
|
||||
tokenizer=tokenizer,
|
||||
dataset=dummy_dataset,
|
||||
)
|
||||
|
||||
input_texts = ["this is a test", "this is another, longer test"]
|
||||
|
||||
generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": tokenizer.eos_token_id}
|
||||
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model_inputs = [tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts]
|
||||
|
||||
generations_batched, ref_generations_batched = ppo_trainer.generate(
|
||||
model_inputs, batch_size=2, generate_ref_response=True, **generation_kwargs
|
||||
)
|
||||
generations_batched = tokenizer.batch_decode(generations_batched)
|
||||
ref_generations_batched = tokenizer.batch_decode(ref_generations_batched)
|
||||
|
||||
generations_single = []
|
||||
ref_generations_single = []
|
||||
for inputs in model_inputs:
|
||||
generation, ref_generation = ppo_trainer.generate(inputs, generate_ref_response=True, **generation_kwargs)
|
||||
generations_single.append(generation.squeeze())
|
||||
ref_generations_single.append(ref_generation.squeeze())
|
||||
|
||||
generations_single = tokenizer.batch_decode(generations_single)
|
||||
ref_generations_single = tokenizer.batch_decode(ref_generations_single)
|
||||
|
||||
assert generations_single == generations_batched
|
||||
assert ref_generations_single == ref_generations_batched
|
||||
|
||||
assert generations_batched != ref_generations_batched
|
||||
assert generations_single != ref_generations_single
|
||||
|
||||
def test_grad_accumulation(self):
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
# flake8: noqa
|
||||
|
||||
__version__ = "0.8.2"
|
||||
__version__ = "0.8.5"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from .import_utils import _LazyModule, is_diffusers_available, OptionalDependencyNotAvailable
|
||||
@ -51,7 +51,7 @@ _import_structure = {
|
||||
"SFTTrainer",
|
||||
],
|
||||
"commands": [],
|
||||
"commands.utils": ["SftArgumentParser", "init_zero_verbose", "TrlParser", "DpoArgumentParser"],
|
||||
"commands.cli_utils": ["init_zero_verbose", "SftScriptArguments", "DpoScriptArguments", "TrlParser"],
|
||||
"trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config", "RichProgressCallback"],
|
||||
"multitask_prompt_tuning": [
|
||||
"MultitaskPromptEmbedding",
|
||||
@ -115,7 +115,7 @@ if TYPE_CHECKING:
|
||||
SFTTrainer,
|
||||
)
|
||||
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config, RichProgressCallback
|
||||
from .commands.utils import init_zero_verbose, SftScriptArguments, DpoScriptArguments, TrlParser
|
||||
from .commands.cli_utils import init_zero_verbose, SftScriptArguments, DpoScriptArguments, TrlParser
|
||||
|
||||
try:
|
||||
if not is_diffusers_available():
|
||||
|
@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
from copy import deepcopy
|
||||
from dataclasses import asdict, dataclass, field, fields
|
||||
from typing import Any, List
|
||||
@ -137,7 +138,9 @@ def init_zero_verbose():
|
||||
@dataclass
|
||||
class SftScriptArguments:
|
||||
dataset_name: str = field(default="timdettmers/openassistant-guanaco", metadata={"help": "the dataset name"})
|
||||
dataset_text_field: str = field(default="text", metadata={"help": "the text field of the dataset"})
|
||||
dataset_text_field: str = field(default=None, metadata={"help": "the text field of the dataset"})
|
||||
dataset_train_name: str = field(default="train", metadata={"help": "the name of the training set of the dataset"})
|
||||
dataset_test_name: str = field(default="test", metadata={"help": "the name of the training set of the dataset"})
|
||||
max_seq_length: int = field(default=512, metadata={"help": "The maximum sequence length for SFT Trainer"})
|
||||
packing: bool = field(default=False, metadata={"help": "Whether to apply data packing or not during training"})
|
||||
config: str = field(default=None, metadata={"help": "Path to the optional config file"})
|
||||
@ -197,6 +200,14 @@ class ChatArguments:
|
||||
top_k: int = field(default=50, metadata={"help": "Value of k for top-k sampling"})
|
||||
top_p: float = field(default=1.0, metadata={"help": "Value of p for nucleus sampling"})
|
||||
repetition_penalty: float = field(default=1.0, metadata={"help": "Repetition penalty"})
|
||||
eos_tokens: str = field(
|
||||
default=None,
|
||||
metadata={"help": "EOS tokens to stop the generation. If multiple they should be comma separated"},
|
||||
)
|
||||
eos_token_ids: str = field(
|
||||
default=None,
|
||||
metadata={"help": "EOS token IDs to stop the generation. If multiple they should be comma separated"},
|
||||
)
|
||||
# model loading
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
@ -269,6 +280,24 @@ class TrlParser(HfArgumentParser):
|
||||
return dataclasses
|
||||
|
||||
def parse_args_and_config(self):
|
||||
# Hack to force-replace the `output_dir` from the YAML file if one did not passed
|
||||
# output_dir in the command line
|
||||
if "--config" in sys.argv:
|
||||
config_index = sys.argv.index("--config") + 1
|
||||
config_path = sys.argv[config_index]
|
||||
|
||||
with open(config_path) as yaml_file:
|
||||
yaml_config = yaml.safe_load(yaml_file)
|
||||
|
||||
output_dir = yaml_config.get("output_dir")
|
||||
|
||||
if output_dir is not None:
|
||||
if "--output_dir" in sys.argv:
|
||||
output_dir_index = sys.argv.index("--output_dir")
|
||||
sys.argv.index[output_dir_index + 1] = output_dir
|
||||
else:
|
||||
sys.argv.extend(["--output_dir", output_dir])
|
||||
|
||||
dataclasses = self.parse_args_into_dataclasses(return_remaining_strings=True)
|
||||
# Pop the last element which should be the remaining strings
|
||||
dataclasses = self.update_dataclasses_with_config(dataclasses[:-1])
|
||||
|
@ -738,7 +738,7 @@ class CPOTrainer(Trainer):
|
||||
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
|
||||
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
|
||||
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
|
||||
metrics[f"{prefix}nll_loss"] = policy_nll_loss.cpu().mean()
|
||||
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu()
|
||||
|
||||
return loss, metrics
|
||||
|
||||
|
@ -16,7 +16,7 @@ import inspect
|
||||
import random
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from copy import deepcopy
|
||||
from functools import wraps
|
||||
from operator import itemgetter
|
||||
@ -257,6 +257,10 @@ class KTOTrainer(Trainer):
|
||||
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
|
||||
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
||||
a dictionary string to metric values.
|
||||
model_adapter_name (`str`, defaults to `None`):
|
||||
Name of the train target PEFT adapter, when using LoRA with multiple adapters.
|
||||
ref_adapter_name (`str`, defaults to `None`):
|
||||
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
|
||||
"""
|
||||
|
||||
_tag_names = ["trl", "kto"]
|
||||
@ -276,6 +280,8 @@ class KTOTrainer(Trainer):
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
peft_config: Optional[Dict] = None,
|
||||
compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
|
||||
model_adapter_name: Optional[str] = None,
|
||||
ref_adapter_name: Optional[str] = None,
|
||||
):
|
||||
if type(args) == TrainingArguments:
|
||||
raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
|
||||
@ -392,6 +398,8 @@ class KTOTrainer(Trainer):
|
||||
self.is_encoder_decoder = args.is_encoder_decoder
|
||||
|
||||
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
|
||||
self.model_adapter_name = model_adapter_name
|
||||
self.ref_adapter_name = ref_adapter_name
|
||||
|
||||
if ref_model:
|
||||
self.ref_model = ref_model
|
||||
@ -677,6 +685,18 @@ class KTOTrainer(Trainer):
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
@contextmanager
|
||||
def null_ref_context(self):
|
||||
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
|
||||
with self.accelerator.unwrap_model(
|
||||
self.model
|
||||
).disable_adapter() if self.is_peft_model and not self.ref_adapter_name else nullcontext():
|
||||
if self.ref_adapter_name:
|
||||
self.model.set_adapter(self.ref_adapter_name)
|
||||
yield
|
||||
if self.ref_adapter_name:
|
||||
self.model.set_adapter(self.model_adapter_name or "default")
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
"""
|
||||
Returns the training [`~torch.utils.data.DataLoader`].
|
||||
@ -775,9 +795,7 @@ class KTOTrainer(Trainer):
|
||||
"""Computes log probabilities of the reference model for a single padded batch of a KTO specific dataset."""
|
||||
with torch.no_grad():
|
||||
if self.ref_model is None:
|
||||
with self.accelerator.unwrap_model(
|
||||
self.model
|
||||
).disable_adapter() if self.is_peft_model else nullcontext():
|
||||
with self.null_ref_context():
|
||||
if self.is_encoder_decoder:
|
||||
completion_logits = self.model(
|
||||
padded_batch["prompt_input_ids"],
|
||||
@ -1029,7 +1047,7 @@ class KTOTrainer(Trainer):
|
||||
else:
|
||||
with torch.no_grad():
|
||||
if self.ref_model is None:
|
||||
with self.accelerator.unwrap_model(self.model).disable_adapter():
|
||||
with self.null_ref_context():
|
||||
(
|
||||
reference_chosen_logps,
|
||||
reference_rejected_logps,
|
||||
|
@ -498,7 +498,7 @@ class PPOTrainer(BaseTrainer):
|
||||
|
||||
if generate_ref_response:
|
||||
with unwrap_model_for_generation(
|
||||
self.model, self.accelerator, is_peft_model=self.is_peft_model
|
||||
ref_model, self.accelerator, is_peft_model=self.is_peft_model
|
||||
) as unwrapped_model:
|
||||
ref_response = unwrapped_model.generate(
|
||||
input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs
|
||||
|
Reference in New Issue
Block a user