Files
peft/examples/waveft_finetuning/waveft_finetuning.py
Shantanu Gupta 1a1f97263d CHORE Replace deprecated torch_dtype with dtype (#2837)
Note: Diffusers is left as is for now, might need an update later.
2025-10-16 14:59:09 +02:00

190 lines
6.6 KiB
Python

# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Optional
import torch
import transformers
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from peft import (
WaveFTConfig,
get_peft_model,
)
def train(
base_model: str,
data_path: str = "yahma/alpaca-cleaned",
output_dir: str = "waveft",
batch_size: int = 16,
num_epochs: int = 1,
learning_rate: float = 3e-4,
cutoff_len: int = 256,
val_set_size: int = 16,
eval_step: int = 100,
save_step: int = 100,
device_map: str = "auto",
waveft_n_frequency: int = 2592,
waveft_target_modules: list[str] = None,
waveft_scaling: float = 25.0,
waveft_wavelet_family: str = "db1",
waveft_use_idwt: bool = True,
dtype: str = "float16",
seed: Optional[int] = None,
):
# Set device_map to the right place when enabling DDP.
world_size = int(os.environ.get("WORLD_SIZE", 0)) or int(os.environ.get("PMI_SIZE", 0))
if world_size > 1 and device_map != "cpu":
from accelerate import Accelerator
device_map = {"": Accelerator().process_index}
# Set seed
if seed is not None:
set_seed(seed)
model_kwargs = {"dtype": getattr(torch, dtype), "device_map": device_map}
model = AutoModelForCausalLM.from_pretrained(base_model, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
# For some tokenizer with no pad token like llama
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def tokenize(prompt, add_eos_token=True):
result = tokenizer(
prompt,
truncation=True,
max_length=cutoff_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != tokenizer.eos_token_id
and len(result["input_ids"]) < cutoff_len
and add_eos_token
):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
def generate_and_tokenize_prompt(example):
full_prompt = generate_prompt(example)
tokenized_full_prompt = tokenize(full_prompt)
return tokenized_full_prompt
config = WaveFTConfig(
n_frequency=waveft_n_frequency,
scaling=waveft_scaling,
wavelet_family=waveft_wavelet_family,
use_idwt=waveft_use_idwt,
target_modules=waveft_target_modules,
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
data = load_dataset(data_path)
train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=42)
train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)
trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=batch_size,
warmup_steps=100,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
logging_steps=100,
optim="adamw_torch",
eval_strategy="steps",
save_strategy="steps",
eval_steps=eval_step,
save_steps=save_step,
output_dir=output_dir,
save_total_limit=3,
load_best_model_at_end=True,
ddp_find_unused_parameters=False if world_size > 1 else None,
),
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
),
)
trainer.train()
model.save_pretrained(output_dir)
def generate_prompt(example):
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{example["instruction"]}
### Response:
{example["output"]}"""
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--base_model", type=str)
parser.add_argument("--data_path", type=str, default="yahma/alpaca-cleaned")
parser.add_argument("--output_dir", type=str, default="waveft")
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--num_epochs", type=int, default=1)
parser.add_argument("--learning_rate", type=float, default=3e-4)
parser.add_argument("--cutoff_len", type=int, default=256)
parser.add_argument("--val_set_size", type=int, default=16)
parser.add_argument("--eval_step", type=int, default=100)
parser.add_argument("--save_step", type=int, default=100)
parser.add_argument("--device_map", type=str, default="auto")
parser.add_argument("--waveft_n_frequency", type=int, default=2592)
parser.add_argument("--waveft_target_modules", type=str, default=None)
parser.add_argument("--waveft_scaling", type=float, default=25.0)
parser.add_argument("--waveft_wavelet_family", type=str, default="db1")
parser.add_argument("--waveft_use_idwt", action="store_true", default=True)
parser.add_argument("--dtype", type=str, default="float16")
parser.add_argument("--seed", type=int, default=None)
args = parser.parse_args()
train(
base_model=args.base_model,
data_path=args.data_path,
output_dir=args.output_dir,
batch_size=args.batch_size,
num_epochs=args.num_epochs,
learning_rate=args.learning_rate,
cutoff_len=args.cutoff_len,
val_set_size=args.val_set_size,
eval_step=args.eval_step,
save_step=args.save_step,
device_map=args.device_map,
waveft_n_frequency=args.waveft_n_frequency,
waveft_target_modules=args.waveft_target_modules,
waveft_scaling=args.waveft_scaling,
waveft_wavelet_family=args.waveft_wavelet_family,
waveft_use_idwt=args.waveft_use_idwt,
dtype=args.dtype,
seed=args.seed,
)