Files
trl/examples/scripts/cpo.py
Haoran Xu d1df79f83c Add CPOTrainer (#1382)
* add CPOTrainer

* add docs

* fix formatting

* removed precompute_ref_log_probs arg

* remove precompute_ref_log_probs

* typos

* finish cpo trainer doc

* remove redundant lines

* typo

* formatting

* compute chosen nll loss also for enc-dec models

* fix gradient error of inplace operation for enc-dec models

* formatting

* use CPOConfig

* formatting

* use model_init_kwargs from CPOConfig

* comments in example

* fix doc string

* fix typo in docstring

* update year

* fixed typo

* use preference dataset

* fix learning rate

* move dataset_num_proc to configs

* Update cpo paper link from HF: cpo_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* update description for CPO: cpo_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* remove _prepare_deepspeed for cpo

Because CPO does not need init for reference model

* Add explanation to CPO loss

* format

* fix bug when lengths are given

* add CPOTrainer to README

* fix grammer

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2024-03-22 21:32:45 +01:00

122 lines
3.9 KiB
Python

# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Run the CPO training script with the following command with some example arguments.
In general, the optimal configuration for CPO will be similar to that of DPO:
# regular:
python examples/scripts/cpo.py \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 8e-6 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="gpt2-aligned-cpo" \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns
# peft:
python examples/scripts/cpo.py \
--model_name_or_path=gpt2 \
--per_device_train_batch_size 4 \
--max_steps 1000 \
--learning_rate 8e-5 \
--gradient_accumulation_steps 1 \
--logging_steps 10 \
--eval_steps 500 \
--output_dir="gpt2-lora-aligned-cpo" \
--optim rmsprop \
--warmup_steps 150 \
--report_to wandb \
--bf16 \
--logging_first_step \
--no_remove_unused_columns \
--use_peft \
--lora_r=16 \
--lora_alpha=16
"""
import multiprocessing
from dataclasses import dataclass, field
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
from trl import CPOConfig, CPOTrainer, ModelConfig, get_peft_config
@dataclass
class ScriptArguments:
dataset: str = field(
default="trl-internal-testing/hh-rlhf-trl-style", metadata={"help": "The name of the dataset to use."}
)
if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, CPOConfig, ModelConfig))
args, cpo_args, model_config = parser.parse_args_into_dataclasses()
################
# Model & Tokenizer
################
model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path)
peft_config = get_peft_config(model_config)
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
################
# Dataset
################
ds = load_dataset(args.dataset)
if cpo_args.debug:
for key in ds:
ds[key] = ds[key].select(range(50))
if tokenizer.chat_template is None:
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
def process(row):
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
return row
ds = ds.map(
process,
num_proc=1 if cpo_args.debug else multiprocessing.cpu_count(),
load_from_cache_file=False,
)
train_dataset = ds["train"]
eval_dataset = ds["test"]
################
# Training
################
trainer = CPOTrainer(
model,
args=cpo_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
peft_config=get_peft_config(model_config),
)
# train and save the model
trainer.train()
trainer.save_model(cpo_args.output_dir)