mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
223 lines
7.5 KiB
Python
223 lines
7.5 KiB
Python
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import argparse
|
|
import os
|
|
|
|
from accelerate import Accelerator
|
|
from datasets import load_dataset
|
|
from peft import LoraConfig
|
|
from tqdm import tqdm
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, logging, set_seed
|
|
|
|
from trl import SFTTrainer
|
|
from trl.trainer import ConstantLengthDataset
|
|
|
|
|
|
"""
|
|
Fine-Tune Llama-7b on SE paired dataset
|
|
"""
|
|
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--model_path", type=str, default="")
|
|
parser.add_argument("--dataset_name", type=str, default="lvwerra/stack-exchange-paired")
|
|
parser.add_argument("--subset", type=str, default="data/finetune")
|
|
parser.add_argument("--split", type=str, default="train")
|
|
parser.add_argument("--size_valid_set", type=int, default=4000)
|
|
parser.add_argument("--streaming", action="store_true")
|
|
parser.add_argument("--shuffle_buffer", type=int, default=5000)
|
|
|
|
parser.add_argument("--seq_length", type=int, default=1024)
|
|
parser.add_argument("--max_steps", type=int, default=10000)
|
|
parser.add_argument("--batch_size", type=int, default=4)
|
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
|
parser.add_argument("--eos_token_id", type=int, default=49152)
|
|
|
|
parser.add_argument("--learning_rate", type=float, default=1e-4)
|
|
parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
|
|
parser.add_argument("--num_warmup_steps", type=int, default=100)
|
|
parser.add_argument("--weight_decay", type=float, default=0.05)
|
|
|
|
parser.add_argument("--local_rank", type=int, default=0)
|
|
parser.add_argument("--fp16", action="store_true", default=False)
|
|
parser.add_argument("--bf16", action="store_true", default=False)
|
|
parser.add_argument("--gradient_checkpointing", action="store_true", default=False)
|
|
parser.add_argument("--seed", type=int, default=0)
|
|
parser.add_argument("--num_workers", type=int, default=None)
|
|
parser.add_argument("--output_dir", type=str, default="./checkpoints")
|
|
parser.add_argument("--log_freq", default=1, type=int)
|
|
parser.add_argument("--eval_freq", default=1000, type=int)
|
|
parser.add_argument("--save_freq", default=1000, type=int)
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def chars_token_ratio(dataset, tokenizer, nb_examples=400):
|
|
"""
|
|
Estimate the average number of characters per token in the dataset.
|
|
"""
|
|
total_characters, total_tokens = 0, 0
|
|
for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
|
|
text = prepare_sample_text(example)
|
|
total_characters += len(text)
|
|
if tokenizer.is_fast:
|
|
total_tokens += len(tokenizer(text).tokens())
|
|
else:
|
|
total_tokens += len(tokenizer.tokenize(text))
|
|
|
|
return total_characters / total_tokens
|
|
|
|
|
|
def print_trainable_parameters(model):
|
|
"""
|
|
Prints the number of trainable parameters in the model.
|
|
"""
|
|
trainable_params = 0
|
|
all_param = 0
|
|
for _, param in model.named_parameters():
|
|
all_param += param.numel()
|
|
if param.requires_grad:
|
|
trainable_params += param.numel()
|
|
print(
|
|
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
|
|
)
|
|
|
|
|
|
def prepare_sample_text(example):
|
|
"""Prepare the text from a sample of the dataset."""
|
|
text = f"Question: {example['question']}\n\nAnswer: {example['response_j']}"
|
|
return text
|
|
|
|
|
|
def create_datasets(tokenizer, args):
|
|
dataset = load_dataset(
|
|
args.dataset_name,
|
|
data_dir=args.subset,
|
|
split=args.split,
|
|
use_auth_token=True,
|
|
num_proc=args.num_workers if not args.streaming else None,
|
|
streaming=args.streaming,
|
|
)
|
|
if args.streaming:
|
|
print("Loading the dataset in streaming mode")
|
|
valid_data = dataset.take(args.size_valid_set)
|
|
train_data = dataset.skip(args.size_valid_set)
|
|
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed)
|
|
else:
|
|
dataset = dataset.train_test_split(test_size=0.005, seed=args.seed)
|
|
train_data = dataset["train"]
|
|
valid_data = dataset["test"]
|
|
print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}")
|
|
|
|
chars_per_token = chars_token_ratio(train_data, tokenizer)
|
|
print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}")
|
|
|
|
train_dataset = ConstantLengthDataset(
|
|
tokenizer,
|
|
train_data,
|
|
formatting_func=prepare_sample_text,
|
|
infinite=True,
|
|
seq_length=args.seq_length,
|
|
chars_per_token=chars_per_token,
|
|
)
|
|
valid_dataset = ConstantLengthDataset(
|
|
tokenizer,
|
|
valid_data,
|
|
formatting_func=prepare_sample_text,
|
|
infinite=False,
|
|
seq_length=args.seq_length,
|
|
chars_per_token=chars_per_token,
|
|
)
|
|
return train_dataset, valid_dataset
|
|
|
|
|
|
def run_training(args, train_data, val_data):
|
|
print("Loading the model")
|
|
|
|
lora_config = LoraConfig(
|
|
r=16,
|
|
lora_alpha=32,
|
|
lora_dropout=0.05,
|
|
bias="none",
|
|
task_type="CAUSAL_LM",
|
|
)
|
|
|
|
train_data.start_iteration = 0
|
|
|
|
print("Starting main loop")
|
|
|
|
training_args = TrainingArguments(
|
|
output_dir=args.output_dir,
|
|
dataloader_drop_last=True,
|
|
eval_strategy="steps",
|
|
max_steps=args.max_steps,
|
|
eval_steps=args.eval_freq,
|
|
save_steps=args.save_freq,
|
|
logging_steps=args.log_freq,
|
|
per_device_train_batch_size=args.batch_size,
|
|
per_device_eval_batch_size=args.batch_size,
|
|
learning_rate=args.learning_rate,
|
|
lr_scheduler_type=args.lr_scheduler_type,
|
|
warmup_steps=args.num_warmup_steps,
|
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
|
gradient_checkpointing=args.gradient_checkpointing,
|
|
fp16=args.fp16,
|
|
bf16=args.bf16,
|
|
weight_decay=args.weight_decay,
|
|
run_name="llama-7b-finetuned",
|
|
report_to="wandb",
|
|
ddp_find_unused_parameters=False,
|
|
)
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
args.model_path, load_in_8bit=True, device_map={"": Accelerator().process_index}
|
|
)
|
|
|
|
trainer = SFTTrainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=train_data,
|
|
eval_dataset=val_data,
|
|
peft_config=lora_config,
|
|
packing=True,
|
|
)
|
|
|
|
print_trainable_parameters(trainer.model)
|
|
|
|
print("Training...")
|
|
trainer.train()
|
|
|
|
print("Saving last checkpoint of the model")
|
|
trainer.model.save_pretrained(os.path.join(args.output_dir, "final_checkpoint/"))
|
|
|
|
|
|
def main(args):
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
|
|
train_dataset, eval_dataset = create_datasets(tokenizer, args)
|
|
run_training(args, train_dataset, eval_dataset)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = get_args()
|
|
assert args.model_path != "", "Please provide the llama model path"
|
|
|
|
set_seed(args.seed)
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
|
logging.set_verbosity_error()
|
|
|
|
main(args)
|