mirror of
https://github.com/huggingface/trl.git
synced 2025-10-21 11:33:51 +08:00
* add CPOTrainer * add docs * fix formatting * removed precompute_ref_log_probs arg * remove precompute_ref_log_probs * typos * finish cpo trainer doc * remove redundant lines * typo * formatting * compute chosen nll loss also for enc-dec models * fix gradient error of inplace operation for enc-dec models * formatting * use CPOConfig * formatting * use model_init_kwargs from CPOConfig * comments in example * fix doc string * fix typo in docstring * update year * fixed typo * use preference dataset * fix learning rate * move dataset_num_proc to configs * Update cpo paper link from HF: cpo_trainer.mdx Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * update description for CPO: cpo_trainer.mdx Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * remove _prepare_deepspeed for cpo Because CPO does not need init for reference model * Add explanation to CPO loss * format * fix bug when lengths are given * add CPOTrainer to README * fix grammer --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
122 lines
3.9 KiB
Python
122 lines
3.9 KiB
Python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Run the CPO training script with the following command with some example arguments.
|
|
In general, the optimal configuration for CPO will be similar to that of DPO:
|
|
|
|
# regular:
|
|
python examples/scripts/cpo.py \
|
|
--model_name_or_path=gpt2 \
|
|
--per_device_train_batch_size 4 \
|
|
--max_steps 1000 \
|
|
--learning_rate 8e-6 \
|
|
--gradient_accumulation_steps 1 \
|
|
--logging_steps 10 \
|
|
--eval_steps 500 \
|
|
--output_dir="gpt2-aligned-cpo" \
|
|
--warmup_steps 150 \
|
|
--report_to wandb \
|
|
--bf16 \
|
|
--logging_first_step \
|
|
--no_remove_unused_columns
|
|
|
|
# peft:
|
|
python examples/scripts/cpo.py \
|
|
--model_name_or_path=gpt2 \
|
|
--per_device_train_batch_size 4 \
|
|
--max_steps 1000 \
|
|
--learning_rate 8e-5 \
|
|
--gradient_accumulation_steps 1 \
|
|
--logging_steps 10 \
|
|
--eval_steps 500 \
|
|
--output_dir="gpt2-lora-aligned-cpo" \
|
|
--optim rmsprop \
|
|
--warmup_steps 150 \
|
|
--report_to wandb \
|
|
--bf16 \
|
|
--logging_first_step \
|
|
--no_remove_unused_columns \
|
|
--use_peft \
|
|
--lora_r=16 \
|
|
--lora_alpha=16
|
|
"""
|
|
|
|
import multiprocessing
|
|
from dataclasses import dataclass, field
|
|
|
|
from datasets import load_dataset
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
|
|
|
|
from trl import CPOConfig, CPOTrainer, ModelConfig, get_peft_config
|
|
|
|
|
|
@dataclass
|
|
class ScriptArguments:
|
|
dataset: str = field(
|
|
default="trl-internal-testing/hh-rlhf-trl-style", metadata={"help": "The name of the dataset to use."}
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = HfArgumentParser((ScriptArguments, CPOConfig, ModelConfig))
|
|
args, cpo_args, model_config = parser.parse_args_into_dataclasses()
|
|
|
|
################
|
|
# Model & Tokenizer
|
|
################
|
|
model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path)
|
|
peft_config = get_peft_config(model_config)
|
|
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
################
|
|
# Dataset
|
|
################
|
|
ds = load_dataset(args.dataset)
|
|
if cpo_args.debug:
|
|
for key in ds:
|
|
ds[key] = ds[key].select(range(50))
|
|
if tokenizer.chat_template is None:
|
|
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
|
|
|
def process(row):
|
|
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
|
|
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
|
|
return row
|
|
|
|
ds = ds.map(
|
|
process,
|
|
num_proc=1 if cpo_args.debug else multiprocessing.cpu_count(),
|
|
load_from_cache_file=False,
|
|
)
|
|
train_dataset = ds["train"]
|
|
eval_dataset = ds["test"]
|
|
|
|
################
|
|
# Training
|
|
################
|
|
trainer = CPOTrainer(
|
|
model,
|
|
args=cpo_args,
|
|
train_dataset=train_dataset,
|
|
eval_dataset=eval_dataset,
|
|
tokenizer=tokenizer,
|
|
peft_config=get_peft_config(model_config),
|
|
)
|
|
|
|
# train and save the model
|
|
trainer.train()
|
|
trainer.save_model(cpo_args.output_dir)
|