Files
trl/examples/research_projects/stack_llama_2/scripts/sft_llama2.py
Matthew Saad c115453fba Update sft_llama2.py (#678)
Add argument num_workers. Fixed error on line 103 if streaming set = False
2023-08-23 16:56:31 +02:00

217 lines
8.6 KiB
Python

# Fine-Tune Llama2-7b on SE paired dataset
import os
from dataclasses import dataclass, field
from typing import Optional
import torch
from datasets import load_dataset
from peft import AutoPeftModelForCausalLM, LoraConfig
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments
from trl import SFTTrainer
from trl.trainer import ConstantLengthDataset
@dataclass
class ScriptArguments:
model_name: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the model name"})
log_with: Optional[str] = field(default="wandb", metadata={"help": "use 'wandb' to log with wandb"})
dataset_name: Optional[str] = field(default="lvwerra/stack-exchange-paired", metadata={"help": "the dataset name"})
subset: Optional[str] = field(default="data/finetune", metadata={"help": "the subset to use"})
split: Optional[str] = field(default="train", metadata={"help": "the split to use"})
size_valid_set: Optional[int] = field(default=4000, metadata={"help": "the size of the validation set"})
streaming: Optional[bool] = field(default=True, metadata={"help": "whether to stream the dataset"})
shuffle_buffer: Optional[int] = field(default=5000, metadata={"help": "the shuffle buffer size"})
seq_length: Optional[int] = field(default=1024, metadata={"help": "the sequence length"})
num_workers: Optional[int] = field(default=4, metadata={"help": "the number of workers"})
max_steps: Optional[int] = field(default=500, metadata={"help": "the maximum number of sgd steps"})
logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"})
save_steps: Optional[int] = field(default=10, metadata={"help": "the saving frequency"})
per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "the per device train batch size"})
per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "the per device eval batch size"})
gradient_accumulation_steps: Optional[int] = field(default=2, metadata={"help": "the gradient accumulation steps"})
gradient_checkpointing: Optional[bool] = field(
default=True, metadata={"help": "whether to use gradient checkpointing"}
)
group_by_length: Optional[bool] = field(default=False, metadata={"help": "whether to group by length"})
packing: Optional[bool] = field(default=True, metadata={"help": "whether to use packing for SFTTrainer"})
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
learning_rate: Optional[float] = field(default=1e-4, metadata={"help": "the learning rate"})
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"})
num_warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"})
weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"})
optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"})
output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"})
log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"})
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
if script_args.group_by_length and script_args.packing:
raise ValueError("Cannot use both packing and group by length")
def chars_token_ratio(dataset, tokenizer, nb_examples=400):
"""
Estimate the average number of characters per token in the dataset.
"""
total_characters, total_tokens = 0, 0
for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
text = prepare_sample_text(example)
total_characters += len(text)
if tokenizer.is_fast:
total_tokens += len(tokenizer(text).tokens())
else:
total_tokens += len(tokenizer.tokenize(text))
return total_characters / total_tokens
def print_trainable_parameters(model):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
)
def prepare_sample_text(example):
"""Prepare the text from a sample of the dataset."""
text = f"Question: {example['question']}\n\nAnswer: {example['response_j']}"
return text
def create_datasets(tokenizer, args):
dataset = load_dataset(
args.dataset_name,
data_dir=args.subset,
split=args.split,
use_auth_token=True,
num_proc=args.num_workers if not args.streaming else None,
streaming=args.streaming,
)
if args.streaming:
print("Loading the dataset in streaming mode")
valid_data = dataset.take(args.size_valid_set)
train_data = dataset.skip(args.size_valid_set)
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=None)
else:
dataset = dataset.train_test_split(test_size=0.005, seed=None)
train_data = dataset["train"]
valid_data = dataset["test"]
print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}")
chars_per_token = chars_token_ratio(train_data, tokenizer)
print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}")
train_dataset = ConstantLengthDataset(
tokenizer,
train_data,
formatting_func=prepare_sample_text,
infinite=True,
seq_length=args.seq_length,
chars_per_token=chars_per_token,
)
valid_dataset = ConstantLengthDataset(
tokenizer,
valid_data,
formatting_func=prepare_sample_text,
infinite=False,
seq_length=args.seq_length,
chars_per_token=chars_per_token,
)
return train_dataset, valid_dataset
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
base_model = AutoModelForCausalLM.from_pretrained(
script_args.model_name,
quantization_config=bnb_config,
device_map={"": 0},
trust_remote_code=True,
use_auth_token=True,
)
base_model.config.use_cache = False
peft_config = LoraConfig(
r=script_args.lora_r,
lora_alpha=script_args.lora_alpha,
lora_dropout=script_args.lora_dropout,
target_modules=["q_proj", "v_proj"],
bias="none",
task_type="CAUSAL_LM",
)
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training
training_args = TrainingArguments(
output_dir=script_args.output_dir,
per_device_train_batch_size=script_args.per_device_train_batch_size,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
learning_rate=script_args.learning_rate,
logging_steps=script_args.logging_steps,
max_steps=script_args.max_steps,
report_to=script_args.log_with,
save_steps=script_args.save_steps,
group_by_length=script_args.group_by_length,
lr_scheduler_type=script_args.lr_scheduler_type,
warmup_steps=script_args.num_warmup_steps,
optim=script_args.optimizer_type,
bf16=True,
remove_unused_columns=False,
run_name="sft_llama2",
)
train_dataset, eval_dataset = create_datasets(tokenizer, script_args)
trainer = SFTTrainer(
model=base_model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
packing=script_args.packing,
max_seq_length=None,
tokenizer=tokenizer,
args=training_args,
)
trainer.train()
trainer.save_model(script_args.output_dir)
output_dir = os.path.join(script_args.output_dir, "final_checkpoint")
trainer.model.save_pretrained(output_dir)
# Free memory for merging weights
del base_model
torch.cuda.empty_cache()
model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map="auto", torch_dtype=torch.bfloat16)
model = model.merge_and_unload()
output_merged_dir = os.path.join(script_args.output_dir, "final_merged_checkpoint")
model.save_pretrained(output_merged_dir, safe_serialization=True)