mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 23:43:47 +08:00
474 lines
18 KiB
Python
474 lines
18 KiB
Python
# Copyright 2025-present the HuggingFace Inc. team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
Main entry point to run the experiments. Contains general setup and the proper training code.
|
|
"""
|
|
|
|
import argparse
|
|
import datetime as dt
|
|
import gc
|
|
import json
|
|
import os
|
|
import random
|
|
import sys
|
|
import textwrap
|
|
import time
|
|
from contextlib import nullcontext
|
|
from functools import partial
|
|
from typing import Any, Callable, Literal, Optional
|
|
|
|
import torch
|
|
from data import get_train_valid_test_datasets
|
|
from torch import nn
|
|
from torch.amp import GradScaler, autocast
|
|
from tqdm import tqdm
|
|
from transformers import GenerationConfig, set_seed
|
|
from utils import (
|
|
FILE_NAME_TRAIN_PARAMS,
|
|
BucketIterator,
|
|
TrainResult,
|
|
TrainStatus,
|
|
get_accuracy,
|
|
get_base_model_info,
|
|
get_dataset_info,
|
|
get_file_size,
|
|
get_model,
|
|
get_optimizer_and_scheduler,
|
|
get_peft_branch,
|
|
get_tokenizer,
|
|
get_train_config,
|
|
init_accelerator,
|
|
log_results,
|
|
validate_experiment_path,
|
|
)
|
|
|
|
from peft import AdaLoraConfig, PeftConfig
|
|
from peft.utils import CONFIG_NAME, infer_device
|
|
|
|
|
|
# # suppress all warnings
|
|
# warnings.filterwarnings("ignore") # FIXME?
|
|
|
|
dtype_to_bytes_linear = {"float32": 4, "float16": 2, "bfloat16": 2, "int8": 1, "int4": 0.5}
|
|
# if lr scheduler with warmup is used, the ratio of warmup steps to total steps
|
|
BUCKET_FACTOR = 20 # number of batches per bucket, increasing this further has diminishing returns
|
|
|
|
|
|
def get_generation_config(*, seq_len, generate_kwargs) -> GenerationConfig:
|
|
# filter out None values so that we don't depend on setting correct defaults in the config
|
|
generation_kwargs = {k: v for k, v in generate_kwargs.items() if v is not None}
|
|
if ("max_length" in generation_kwargs) and ("max_new_tokens" in generation_kwargs):
|
|
# transformers does not support setting both max_length and max_new_tokens, but what we want in this case is to
|
|
# take the smaller of the two values
|
|
new_max_length = min(generation_kwargs["max_new_tokens"] + seq_len, generation_kwargs["max_length"])
|
|
del generation_kwargs["max_new_tokens"]
|
|
generation_kwargs["max_length"] = new_max_length
|
|
generation_config = GenerationConfig(**generate_kwargs)
|
|
return generation_config
|
|
|
|
|
|
def evaluate(model, tokenizer, ds, batch_size, generate_kwargs, use_tqdm: bool = False) -> tuple[list[str], list[str]]:
|
|
with torch.inference_mode():
|
|
predictions = []
|
|
responses = []
|
|
pbar = range(0, len(ds), batch_size)
|
|
if use_tqdm:
|
|
pbar = tqdm(pbar)
|
|
for j in pbar:
|
|
sliced = ds[j : j + batch_size]
|
|
responses += sliced.pop("response")
|
|
batch = tokenizer.pad(sliced, return_tensors="pt", padding_side="left").to(model.device)
|
|
seq_len = batch["input_ids"].shape[1]
|
|
generation_config = get_generation_config(seq_len=seq_len, generate_kwargs=generate_kwargs)
|
|
outputs = model.generate(**batch, generation_config=generation_config, pad_token_id=tokenizer.eos_token_id)
|
|
predictions += tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
|
return predictions, responses
|
|
|
|
|
|
class DummyGradScaler:
|
|
# if no mixed precision is being used
|
|
def scale(self, loss):
|
|
return loss
|
|
|
|
def unscale_(self, optimizer):
|
|
pass
|
|
|
|
def step(self, optimizer):
|
|
optimizer.step()
|
|
|
|
def update(self):
|
|
pass
|
|
|
|
|
|
def train(
|
|
*,
|
|
model: nn.Module,
|
|
max_steps: int,
|
|
batch_size: int,
|
|
batch_size_eval: int,
|
|
tokenizer: Any,
|
|
accelerator_memory_init: int,
|
|
eval_steps: int,
|
|
generation_kwargs: dict[str, Any],
|
|
grad_norm_clip: float,
|
|
optimizer_type: str,
|
|
optimizer_kwargs: dict[str, Any],
|
|
query_template: str,
|
|
lr_scheduler_arg: Optional[Literal["cosine"]],
|
|
use_amp: bool,
|
|
is_adalora: bool,
|
|
) -> TrainResult:
|
|
accelerator_memory_allocated_log = []
|
|
accelerator_memory_reserved_log = []
|
|
losses = []
|
|
durations = []
|
|
metrics = []
|
|
sample = 0 # keep count of the current sample
|
|
total_samples = 0 # total number of samples over all epochs
|
|
total_tokens = [] # total number of tokens over all epochs
|
|
|
|
device_type = infer_device()
|
|
torch_accelerator_module = getattr(torch, device_type, torch.cuda)
|
|
if use_amp:
|
|
grad_scaler: GradScaler | DummyGradScaler = GradScaler(device=device_type)
|
|
autocast_ctx: Callable[[], ContextManager[Any]] = partial(autocast, device_type=device_type)
|
|
else:
|
|
grad_scaler = DummyGradScaler()
|
|
autocast_ctx = nullcontext
|
|
|
|
optimizer, lr_scheduler = get_optimizer_and_scheduler(
|
|
model,
|
|
optimizer_type=optimizer_type,
|
|
max_steps=max_steps,
|
|
lr_scheduler_arg=lr_scheduler_arg,
|
|
**optimizer_kwargs,
|
|
)
|
|
# print this after getting the optimizer, in case it modifies requires_gard
|
|
if hasattr(model, "get_nb_trainable_parameters"):
|
|
num_trainable_params, num_params = model.get_nb_trainable_parameters()
|
|
else:
|
|
num_params = model.num_parameters()
|
|
num_trainable_params = num_params
|
|
print_verbose(
|
|
f"trainable params: {num_trainable_params:,d} || all params: {num_params:,d} || "
|
|
f"trainable: {100 * num_trainable_params / num_params:.4f}%"
|
|
)
|
|
|
|
status = TrainStatus.FAILED
|
|
tic_train = time.perf_counter()
|
|
eval_time = 0.0
|
|
error_msg = ""
|
|
|
|
ds_train, ds_valid, ds_test = get_train_valid_test_datasets(
|
|
tokenizer=tokenizer, query_template=query_template, print_fn=print_verbose
|
|
)
|
|
# note: bucketing by length is only really worth it for the train dataset, since it's length is big compared to the
|
|
# batch size
|
|
iterator_train = BucketIterator(
|
|
ds_train,
|
|
batch_size=batch_size,
|
|
bucket_factor=BUCKET_FACTOR,
|
|
delete_cols=["response"],
|
|
)
|
|
try:
|
|
pbar = tqdm(range(1, max_steps + 1))
|
|
for step, batch in zip(pbar, iterator_train):
|
|
tic = time.perf_counter()
|
|
|
|
# create the batch
|
|
tokens_per_sample = [len(i) for i in batch["input_ids"]]
|
|
total_tokens.append(sum(tokens_per_sample) + len(tokens_per_sample)) # add EOS token
|
|
batch = tokenizer.pad(batch, return_tensors="pt").to(model.device)
|
|
actual_batch_size = len(batch["input_ids"])
|
|
total_samples += actual_batch_size
|
|
sample += batch_size
|
|
if sample >= len(ds_train): # new epoch
|
|
sample = 0
|
|
|
|
# add labels, they are automatically shifted by transformers
|
|
labels = batch["input_ids"].clone()
|
|
# We want to ignore the padding tokens except for the first EOS token; if we don't ignore them, the loss
|
|
# will be dominated by padding tokens; if we ignore all, the model will not learn to predict the EOS token.
|
|
# TODO: Note that the longest sequence in the batch won't have any PAD/EOS token at the end, this is fine if
|
|
# the batch size is > 1 but should still be fixed eventually.
|
|
for i, num_tokens in enumerate(tokens_per_sample):
|
|
labels[i, num_tokens + 1 :] = -100
|
|
batch["labels"] = labels
|
|
num_items_in_batch = batch["attention_mask"].sum().item()
|
|
|
|
# train step
|
|
optimizer.zero_grad()
|
|
with autocast_ctx():
|
|
outputs = model(**batch, num_items_in_batch=num_items_in_batch)
|
|
loss = outputs.loss
|
|
grad_scaler.scale(loss).backward()
|
|
if grad_norm_clip:
|
|
grad_scaler.unscale_(optimizer)
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm_clip)
|
|
grad_scaler.step(optimizer)
|
|
grad_scaler.update()
|
|
lr_scheduler.step()
|
|
|
|
if is_adalora:
|
|
model.base_model.update_and_allocate(step)
|
|
|
|
losses.append(loss.item())
|
|
pbar.set_postfix({"loss": loss.item()})
|
|
accelerator_memory_allocated_log.append(
|
|
torch_accelerator_module.memory_allocated() - accelerator_memory_init
|
|
)
|
|
accelerator_memory_reserved_log.append(
|
|
torch_accelerator_module.memory_reserved() - accelerator_memory_init
|
|
)
|
|
toc = time.perf_counter()
|
|
durations.append(toc - tic)
|
|
|
|
# every couple of steps, evaluate; this can be slow due to generation
|
|
if step % eval_steps == 0:
|
|
tic_eval = time.perf_counter()
|
|
loss_avg = sum(losses[-eval_steps:]) / eval_steps
|
|
memory_allocated_avg = sum(accelerator_memory_allocated_log[-eval_steps:]) / eval_steps
|
|
memory_reserved_avg = sum(accelerator_memory_reserved_log[-eval_steps:]) / eval_steps
|
|
token_sum = sum(total_tokens[-eval_steps:])
|
|
dur_train = sum(durations[-eval_steps:])
|
|
tokens_per_sec = token_sum / dur_train
|
|
|
|
model.eval()
|
|
predictions, responses = evaluate(
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
ds=ds_valid,
|
|
batch_size=batch_size_eval,
|
|
generate_kwargs={**generation_kwargs},
|
|
)
|
|
model.train()
|
|
|
|
example = random.choice(predictions)
|
|
example = textwrap.shorten(example, width=750)
|
|
example = textwrap.indent(example, " ")
|
|
print_verbose(f"\nExample prediction:\n{example}\n")
|
|
accuracy = get_accuracy(predictions=predictions, responses=responses)
|
|
num_tokens_generated = sum(sum(mask) for mask in tokenizer(predictions)["attention_mask"])
|
|
|
|
toc_eval = time.perf_counter()
|
|
dur_eval = toc_eval - tic_eval
|
|
eval_time += toc_eval - tic_eval
|
|
elapsed = time.perf_counter() - tic_train
|
|
|
|
metrics.append(
|
|
{
|
|
"step": step,
|
|
"valid accuracy": accuracy,
|
|
"train loss": loss_avg,
|
|
"train samples": total_samples,
|
|
"train time": dur_train,
|
|
"eval time": dur_eval,
|
|
"tokens / sec": tokens_per_sec,
|
|
"mem allocated avg": memory_allocated_avg,
|
|
"mem reserved avg": memory_reserved_avg,
|
|
"elapsed time": elapsed,
|
|
}
|
|
)
|
|
|
|
log_dict = {
|
|
"step": f"{step:5d}",
|
|
"samples": f"{total_samples:7d}",
|
|
"lr": f"{lr_scheduler.get_last_lr()[0]:.2e}",
|
|
"loss avg": f"{loss_avg:.4f}",
|
|
"valid acc": f"{accuracy:.3f}",
|
|
"gen valid tokens": num_tokens_generated,
|
|
"train time": f"{dur_train:.1f}s",
|
|
"eval time": f"{dur_eval:.1f}s",
|
|
"train tokens / sec": f"{tokens_per_sec:.0f}",
|
|
"mem allocated": f"{memory_allocated_avg:.0f}",
|
|
"mem reserved": f"{memory_reserved_avg:.0f}",
|
|
"elapsed time": f"{elapsed // 60:.0f}min {elapsed % 60:.0f}s",
|
|
}
|
|
print_verbose(json.dumps(log_dict))
|
|
|
|
# # TODO is this needed?
|
|
torch_accelerator_module.empty_cache()
|
|
gc.collect()
|
|
|
|
print_verbose(f"Training finished after {max_steps} steps, evaluation on test set follows.")
|
|
# test set evaluation
|
|
model.eval()
|
|
predictions, responses = evaluate(
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
ds=ds_test,
|
|
batch_size=batch_size_eval,
|
|
generate_kwargs={**generation_kwargs, "pad_token_id": tokenizer.eos_token_id},
|
|
use_tqdm=len(ds_test) > 100,
|
|
)
|
|
accuracy = get_accuracy(predictions=predictions, responses=responses)
|
|
metrics.append(
|
|
{
|
|
"step": step,
|
|
"test accuracy": accuracy,
|
|
"train loss": sum(losses[-eval_steps:]) / eval_steps,
|
|
"train samples": total_samples,
|
|
"train total tokens": sum(total_tokens),
|
|
}
|
|
)
|
|
print_verbose(f"Test accuracy: {accuracy:.3f}")
|
|
|
|
except KeyboardInterrupt:
|
|
print_verbose("canceled training")
|
|
status = TrainStatus.CANCELED
|
|
error_msg = "manually canceled"
|
|
except torch.OutOfMemoryError as exc:
|
|
# ouch, still let's try to log some results
|
|
print_verbose("out of memory error encountered")
|
|
status = TrainStatus.CANCELED
|
|
error_msg = str(exc)
|
|
except Exception as exc:
|
|
print_verbose(f"encountered an error: {exc}")
|
|
status = TrainStatus.CANCELED
|
|
error_msg = str(exc)
|
|
|
|
toc_train = time.perf_counter()
|
|
train_time = toc_train - tic_train - eval_time
|
|
|
|
if status != TrainStatus.CANCELED:
|
|
status = TrainStatus.SUCCESS
|
|
train_result = TrainResult(
|
|
status=status,
|
|
train_time=train_time,
|
|
accelerator_memory_reserved_log=accelerator_memory_reserved_log,
|
|
losses=losses,
|
|
metrics=metrics,
|
|
error_msg=error_msg,
|
|
num_trainable_params=num_trainable_params,
|
|
num_total_params=num_params,
|
|
)
|
|
return train_result
|
|
|
|
|
|
def main(*, path_experiment: str, experiment_name: str, clean: bool) -> None:
|
|
tic_total = time.perf_counter()
|
|
start_date = dt.datetime.now(tz=dt.timezone.utc).replace(microsecond=0).isoformat()
|
|
|
|
peft_branch = get_peft_branch()
|
|
if peft_branch == "main":
|
|
print_verbose("===== This experiment is categorized as a MAIN run because the PEFT branch is 'main' ======")
|
|
else:
|
|
print_verbose(
|
|
f"===== This experiment is categorized as a TEST run because the PEFT branch is '{peft_branch}' ======"
|
|
)
|
|
|
|
# load configs
|
|
peft_config: Optional[PeftConfig] = None
|
|
if os.path.exists(os.path.join(path_experiment, CONFIG_NAME)):
|
|
peft_config = PeftConfig.from_pretrained(path_experiment)
|
|
else:
|
|
print_verbose(f"Could not find PEFT config at {path_experiment}, performing FULL FINETUNING")
|
|
path_train_config = os.path.join(path_experiment, FILE_NAME_TRAIN_PARAMS)
|
|
train_config = get_train_config(path_train_config)
|
|
set_seed(train_config.seed)
|
|
|
|
# initialize objects
|
|
accelerator_memory_init = init_accelerator()
|
|
tokenizer = get_tokenizer(model_id=train_config.model_id, max_seq_length=train_config.max_seq_length)
|
|
|
|
model_info = get_base_model_info(train_config.model_id)
|
|
metamath_info = get_dataset_info("meta-math/MetaMathQA")
|
|
gsm8k_info = get_dataset_info("openai/gsm8k")
|
|
model = get_model(
|
|
model_id=train_config.model_id,
|
|
dtype=train_config.dtype,
|
|
compile=train_config.compile,
|
|
attn_implementation=train_config.attn_implementation,
|
|
peft_config=peft_config,
|
|
autocast_adapter_dtype=train_config.autocast_adapter_dtype,
|
|
)
|
|
print_verbose(model)
|
|
|
|
# train model
|
|
train_result = train(
|
|
model=model,
|
|
max_steps=train_config.max_steps,
|
|
batch_size=train_config.batch_size,
|
|
batch_size_eval=train_config.batch_size_eval,
|
|
tokenizer=tokenizer,
|
|
accelerator_memory_init=accelerator_memory_init,
|
|
eval_steps=train_config.eval_steps,
|
|
generation_kwargs=train_config.generation_kwargs,
|
|
grad_norm_clip=train_config.grad_norm_clip,
|
|
optimizer_type=train_config.optimizer_type,
|
|
optimizer_kwargs=train_config.optimizer_kwargs,
|
|
query_template=train_config.query_template,
|
|
lr_scheduler_arg=train_config.lr_scheduler,
|
|
use_amp=train_config.use_amp,
|
|
is_adalora=isinstance(peft_config, AdaLoraConfig),
|
|
)
|
|
|
|
if train_result.status == TrainStatus.FAILED:
|
|
print_verbose("Training failed, not logging results")
|
|
sys.exit(1)
|
|
|
|
file_size = get_file_size(
|
|
model,
|
|
peft_config=peft_config,
|
|
clean=clean,
|
|
print_fn=print_verbose,
|
|
)
|
|
|
|
time_total = time.perf_counter() - tic_total
|
|
# log results: print and save to file
|
|
log_results(
|
|
experiment_name=experiment_name,
|
|
train_result=train_result,
|
|
accelerator_memory_init=accelerator_memory_init,
|
|
time_total=time_total,
|
|
file_size=file_size,
|
|
model_info=model_info,
|
|
datasets_info={"metamath": metamath_info, "gsm8k": gsm8k_info},
|
|
start_date=start_date,
|
|
train_config=train_config,
|
|
peft_config=peft_config,
|
|
print_fn=print_verbose,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose output")
|
|
parser.add_argument("path_experiment", type=str, help="Path to the experiment directory")
|
|
parser.add_argument(
|
|
"--clean",
|
|
action="store_true",
|
|
help="Delete training artifacts after run finishes (logs are still saved)",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
experiment_name = validate_experiment_path(args.path_experiment)
|
|
|
|
if args.verbose:
|
|
|
|
def print_verbose(*args, **kwargs) -> None:
|
|
kwargs["file"] = sys.stderr
|
|
print(*args, **kwargs)
|
|
else:
|
|
|
|
def print_verbose(*args, **kwargs) -> None:
|
|
pass
|
|
|
|
main(
|
|
path_experiment=args.path_experiment,
|
|
experiment_name=experiment_name,
|
|
clean=args.clean,
|
|
)
|