mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
253 lines
9.7 KiB
Python
253 lines
9.7 KiB
Python
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# 0. imports
|
|
import os
|
|
from dataclasses import dataclass, field
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from accelerate import Accelerator
|
|
from datasets import Dataset, load_dataset
|
|
from peft import LoraConfig
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, set_seed
|
|
|
|
from trl import DPOConfig, DPOTrainer
|
|
|
|
|
|
# Define and parse arguments.
|
|
@dataclass
|
|
class ScriptArguments:
|
|
"""
|
|
The arguments for the DPO training script.
|
|
"""
|
|
|
|
# data parameters
|
|
beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
|
|
|
|
# training parameters
|
|
model_name_or_path: Optional[str] = field(
|
|
default="../sft/results/final_checkpoint",
|
|
metadata={"help": "the location of the SFT model name or path"},
|
|
)
|
|
learning_rate: Optional[float] = field(default=5e-4, metadata={"help": "optimizer learning rate"})
|
|
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"})
|
|
warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"})
|
|
weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"})
|
|
optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"})
|
|
|
|
per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "train batch size per device"})
|
|
per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"})
|
|
gradient_accumulation_steps: Optional[int] = field(
|
|
default=4, metadata={"help": "the number of gradient accumulation steps"}
|
|
)
|
|
gradient_checkpointing: Optional[bool] = field(
|
|
default=True, metadata={"help": "whether to use gradient checkpointing"}
|
|
)
|
|
|
|
gradient_checkpointing_use_reentrant: Optional[bool] = field(
|
|
default=False, metadata={"help": "whether to use reentrant for gradient checkpointing"}
|
|
)
|
|
|
|
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
|
|
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
|
|
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
|
|
|
|
max_prompt_length: Optional[int] = field(default=512, metadata={"help": "the maximum prompt length"})
|
|
max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"})
|
|
max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"})
|
|
logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"})
|
|
save_steps: Optional[int] = field(default=100, metadata={"help": "the saving frequency"})
|
|
eval_steps: Optional[int] = field(default=100, metadata={"help": "the evaluation frequency"})
|
|
|
|
output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"})
|
|
log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"})
|
|
load_in_4bit: Optional[bool] = field(default=True, metadata={"help": "whether to load the model in 4bit"})
|
|
model_dtype: Optional[str] = field(
|
|
default="float16", metadata={"help": "model_dtype[float16, bfloat16, float] for loading."}
|
|
)
|
|
|
|
# instrumentation
|
|
report_to: Optional[str] = field(
|
|
default="wandb",
|
|
metadata={
|
|
"help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
|
|
'`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. '
|
|
'Use `"all"` to report to all integrations installed, `"none"` for no integrations.'
|
|
},
|
|
)
|
|
# debug argument for distributed training
|
|
ignore_bias_buffers: Optional[bool] = field(
|
|
default=False,
|
|
metadata={
|
|
"help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
|
|
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
|
|
},
|
|
)
|
|
seed: Optional[int] = field(
|
|
default=0, metadata={"help": "Random seed that will be set at the beginning of training."}
|
|
)
|
|
|
|
|
|
def get_stack_exchange_paired(
|
|
data_dir: str = "data/rl",
|
|
cache_dir: Optional[str] = None,
|
|
num_proc=24,
|
|
) -> Dataset:
|
|
"""Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format.
|
|
|
|
The dataset is converted to a dictionary with the following structure:
|
|
{
|
|
'prompt': list[str],
|
|
'chosen': list[str],
|
|
'rejected': list[str],
|
|
}
|
|
|
|
Prompts are structured as follows:
|
|
"Question: " + <prompt> + "\n\nAnswer: "
|
|
"""
|
|
dataset = load_dataset(
|
|
"lvwerra/stack-exchange-paired",
|
|
split="train",
|
|
cache_dir=cache_dir,
|
|
data_dir=data_dir,
|
|
verification_mode="no_checks",
|
|
)
|
|
original_columns = dataset.column_names
|
|
|
|
def return_prompt_and_responses(samples) -> dict[str, str]:
|
|
return {
|
|
"prompt": ["Question: " + question + "\n\nAnswer: " for question in samples["question"]],
|
|
"chosen": samples["response_j"],
|
|
"rejected": samples["response_k"],
|
|
}
|
|
|
|
return dataset.map(
|
|
return_prompt_and_responses,
|
|
batched=True,
|
|
num_proc=num_proc,
|
|
remove_columns=original_columns,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = HfArgumentParser(ScriptArguments)
|
|
script_args = parser.parse_args_into_dataclasses()[0]
|
|
|
|
set_seed(script_args.seed)
|
|
|
|
# 1. load a pretrained model
|
|
torch_dtype = torch.float
|
|
if script_args.model_dtype == "float16":
|
|
torch_dtype = torch.float16
|
|
elif script_args.model_dtype == "bfloat16":
|
|
torch_dtype = torch.bfloat16
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
script_args.model_name_or_path,
|
|
low_cpu_mem_usage=True,
|
|
torch_dtype=torch_dtype,
|
|
load_in_4bit=script_args.load_in_4bit,
|
|
device_map={"": Accelerator().local_process_index},
|
|
)
|
|
model.config.use_cache = False
|
|
|
|
if script_args.ignore_bias_buffers:
|
|
# torch distributed hack
|
|
model._ddp_params_and_buffers_to_ignore = [
|
|
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
|
|
]
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
# 2. Load the Stack-exchange paired dataset
|
|
train_dataset = get_stack_exchange_paired(data_dir="data/rl")
|
|
train_dataset = train_dataset.filter(
|
|
lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
|
|
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length,
|
|
num_proc=script_args.num_proc,
|
|
)
|
|
|
|
# 3. Load evaluation dataset
|
|
eval_dataset = get_stack_exchange_paired(data_dir="data/evaluation")
|
|
eval_dataset = eval_dataset.filter(
|
|
lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
|
|
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length,
|
|
num_proc=script_args.num_proc,
|
|
)
|
|
|
|
# 4. initialize training arguments:
|
|
training_args = DPOConfig(
|
|
per_device_train_batch_size=script_args.per_device_train_batch_size,
|
|
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
|
|
max_steps=script_args.max_steps,
|
|
logging_steps=script_args.logging_steps,
|
|
save_steps=script_args.save_steps,
|
|
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
|
gradient_checkpointing=script_args.gradient_checkpointing,
|
|
learning_rate=script_args.learning_rate,
|
|
eval_strategy="steps",
|
|
eval_steps=script_args.eval_steps,
|
|
output_dir=script_args.output_dir,
|
|
report_to=script_args.report_to,
|
|
lr_scheduler_type=script_args.lr_scheduler_type,
|
|
warmup_steps=script_args.warmup_steps,
|
|
optim=script_args.optimizer_type,
|
|
bf16=True,
|
|
remove_unused_columns=False,
|
|
run_name="dpo_llama2",
|
|
gradient_checkpointing_kwargs=dict(use_reentrant=script_args.gradient_checkpointing_use_reentrant),
|
|
seed=script_args.seed,
|
|
)
|
|
|
|
peft_config = LoraConfig(
|
|
r=script_args.lora_r,
|
|
lora_alpha=script_args.lora_alpha,
|
|
lora_dropout=script_args.lora_dropout,
|
|
target_modules=[
|
|
"q_proj",
|
|
"v_proj",
|
|
"k_proj",
|
|
"out_proj",
|
|
"fc_in",
|
|
"fc_out",
|
|
"wte",
|
|
],
|
|
bias="none",
|
|
task_type="CAUSAL_LM",
|
|
)
|
|
|
|
# 5. initialize the DPO trainer
|
|
dpo_trainer = DPOTrainer(
|
|
model,
|
|
ref_model=None,
|
|
args=training_args,
|
|
beta=script_args.beta,
|
|
train_dataset=train_dataset,
|
|
eval_dataset=eval_dataset,
|
|
processing_class=tokenizer,
|
|
peft_config=peft_config,
|
|
max_prompt_length=script_args.max_prompt_length,
|
|
max_length=script_args.max_length,
|
|
)
|
|
|
|
# 6. train
|
|
dpo_trainer.train()
|
|
dpo_trainer.save_model(script_args.output_dir)
|
|
|
|
# 7. save
|
|
output_dir = os.path.join(script_args.output_dir, "final_checkpoint")
|
|
dpo_trainer.model.save_pretrained(output_dir)
|