mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
213 lines
7.5 KiB
Python
213 lines
7.5 KiB
Python
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# Fine-Tune Llama2-7b on SE paired dataset
|
|
import os
|
|
from dataclasses import dataclass, field
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from accelerate import Accelerator
|
|
from datasets import load_dataset
|
|
from peft import AutoPeftModelForCausalLM, LoraConfig
|
|
from tqdm import tqdm
|
|
from transformers import (
|
|
AutoModelForCausalLM,
|
|
AutoTokenizer,
|
|
BitsAndBytesConfig,
|
|
HfArgumentParser,
|
|
is_torch_npu_available,
|
|
is_torch_xpu_available,
|
|
set_seed,
|
|
)
|
|
|
|
from trl import SFTConfig, SFTTrainer
|
|
from trl.trainer import ConstantLengthDataset
|
|
|
|
|
|
@dataclass
|
|
class ScriptArguments:
|
|
model_name: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the model name"})
|
|
dataset_name: Optional[str] = field(default="lvwerra/stack-exchange-paired", metadata={"help": "the dataset name"})
|
|
subset: Optional[str] = field(default="data/finetune", metadata={"help": "the subset to use"})
|
|
split: Optional[str] = field(default="train", metadata={"help": "the split to use"})
|
|
size_valid_set: Optional[int] = field(default=4000, metadata={"help": "the size of the validation set"})
|
|
streaming: Optional[bool] = field(default=True, metadata={"help": "whether to stream the dataset"})
|
|
shuffle_buffer: Optional[int] = field(default=5000, metadata={"help": "the shuffle buffer size"})
|
|
seq_length: Optional[int] = field(default=1024, metadata={"help": "the sequence length"})
|
|
num_workers: Optional[int] = field(default=4, metadata={"help": "the number of workers"})
|
|
use_bnb: Optional[bool] = field(default=True, metadata={"help": "whether to use BitsAndBytes"})
|
|
|
|
# LoraConfig
|
|
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
|
|
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
|
|
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
|
|
|
|
|
|
parser = HfArgumentParser((ScriptArguments, SFTConfig))
|
|
script_args, training_args = parser.parse_args_into_dataclasses()
|
|
peft_config = LoraConfig(
|
|
r=script_args.lora_r,
|
|
lora_alpha=script_args.lora_alpha,
|
|
lora_dropout=script_args.lora_dropout,
|
|
target_modules=["q_proj", "v_proj"],
|
|
bias="none",
|
|
task_type="CAUSAL_LM",
|
|
)
|
|
|
|
if training_args.group_by_length and training_args.packing:
|
|
raise ValueError("Cannot use both packing and group by length")
|
|
|
|
# `gradient_checkpointing` was True by default until `1f3314`, but it's actually not used.
|
|
# `gradient_checkpointing=True` will cause `Variable._execution_engine.run_backward`.
|
|
if training_args.gradient_checkpointing:
|
|
raise ValueError("gradient_checkpointing not supported")
|
|
|
|
set_seed(training_args.seed)
|
|
|
|
|
|
def chars_token_ratio(dataset, tokenizer, nb_examples=400):
|
|
"""
|
|
Estimate the average number of characters per token in the dataset.
|
|
"""
|
|
total_characters, total_tokens = 0, 0
|
|
for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
|
|
text = prepare_sample_text(example)
|
|
total_characters += len(text)
|
|
if tokenizer.is_fast:
|
|
total_tokens += len(tokenizer(text).tokens())
|
|
else:
|
|
total_tokens += len(tokenizer.tokenize(text))
|
|
|
|
return total_characters / total_tokens
|
|
|
|
|
|
def print_trainable_parameters(model):
|
|
"""
|
|
Prints the number of trainable parameters in the model.
|
|
"""
|
|
trainable_params = 0
|
|
all_param = 0
|
|
for _, param in model.named_parameters():
|
|
all_param += param.numel()
|
|
if param.requires_grad:
|
|
trainable_params += param.numel()
|
|
print(
|
|
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
|
|
)
|
|
|
|
|
|
def prepare_sample_text(example):
|
|
"""Prepare the text from a sample of the dataset."""
|
|
text = f"Question: {example['question']}\n\nAnswer: {example['response_j']}"
|
|
return text
|
|
|
|
|
|
def create_datasets(tokenizer, args, seed=None):
|
|
dataset = load_dataset(
|
|
args.dataset_name,
|
|
data_dir=args.subset,
|
|
split=args.split,
|
|
use_auth_token=True,
|
|
num_proc=args.num_workers if not args.streaming else None,
|
|
streaming=args.streaming,
|
|
)
|
|
if args.streaming:
|
|
print("Loading the dataset in streaming mode")
|
|
valid_data = dataset.take(args.size_valid_set)
|
|
train_data = dataset.skip(args.size_valid_set)
|
|
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=seed)
|
|
else:
|
|
dataset = dataset.train_test_split(test_size=0.005, seed=seed)
|
|
train_data = dataset["train"]
|
|
valid_data = dataset["test"]
|
|
print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}")
|
|
|
|
chars_per_token = chars_token_ratio(train_data, tokenizer)
|
|
print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}")
|
|
|
|
train_dataset = ConstantLengthDataset(
|
|
tokenizer,
|
|
train_data,
|
|
formatting_func=prepare_sample_text,
|
|
infinite=True,
|
|
seq_length=args.seq_length,
|
|
chars_per_token=chars_per_token,
|
|
)
|
|
valid_dataset = ConstantLengthDataset(
|
|
tokenizer,
|
|
valid_data,
|
|
formatting_func=prepare_sample_text,
|
|
infinite=False,
|
|
seq_length=args.seq_length,
|
|
chars_per_token=chars_per_token,
|
|
)
|
|
return train_dataset, valid_dataset
|
|
|
|
|
|
bnb_config = None
|
|
if script_args.use_bnb:
|
|
bnb_config = BitsAndBytesConfig(
|
|
load_in_4bit=True,
|
|
bnb_4bit_quant_type="nf4",
|
|
bnb_4bit_compute_dtype=torch.bfloat16,
|
|
)
|
|
|
|
base_model = AutoModelForCausalLM.from_pretrained(
|
|
script_args.model_name,
|
|
quantization_config=bnb_config,
|
|
device_map={"": Accelerator().local_process_index},
|
|
trust_remote_code=True,
|
|
use_auth_token=True,
|
|
)
|
|
base_model.config.use_cache = False
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, trust_remote_code=True)
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training
|
|
|
|
train_dataset, eval_dataset = create_datasets(tokenizer, script_args, seed=training_args.seed)
|
|
|
|
trainer = SFTTrainer(
|
|
model=base_model,
|
|
train_dataset=train_dataset,
|
|
eval_dataset=eval_dataset,
|
|
peft_config=peft_config,
|
|
max_length=None,
|
|
formatting_func=prepare_sample_text,
|
|
processing_class=tokenizer,
|
|
args=training_args,
|
|
)
|
|
trainer.train()
|
|
trainer.save_model(training_args.output_dir)
|
|
|
|
output_dir = os.path.join(training_args.output_dir, "final_checkpoint")
|
|
trainer.model.save_pretrained(output_dir)
|
|
|
|
# Free memory for merging weights
|
|
del base_model
|
|
if is_torch_xpu_available():
|
|
torch.xpu.empty_cache()
|
|
elif is_torch_npu_available():
|
|
torch.npu.empty_cache()
|
|
else:
|
|
torch.cuda.empty_cache()
|
|
|
|
model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map="auto", dtype=torch.bfloat16)
|
|
model = model.merge_and_unload()
|
|
|
|
output_merged_dir = os.path.join(training_args.output_dir, "final_merged_checkpoint")
|
|
model.save_pretrained(output_merged_dir, safe_serialization=True)
|