mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 15:33:48 +08:00
FEAT Add LoRA-FA to PEFT (#2468)
Adds LoRA with frozen A (LoRA-FA) to PEFT. Paper: https://arxiv.org/abs/2308.03303
This commit is contained in:
@ -271,7 +271,40 @@ The same logic applies to `alpha_pattern`. If you're in doubt, don't try to get
|
||||
|
||||
## Optimizers
|
||||
|
||||
LoRA training can optionally include special purpose optimizers. Currently the only such optimizer is LoRA+.
|
||||
LoRA training can optionally include special purpose optimizers. Currently PEFT supports LoRA-FA and LoRA+.
|
||||
|
||||
### LoRA-FA Optimizer
|
||||
|
||||
LoRA training can be more effective and efficient using LoRA-FA, as described in [LoRA-FA](https://arxiv.org/abs/2308.03303). LoRA-FA reduces activation memory consumption by fixing the matrix A and only tuning the matrix B. During training, the gradient of B is optimized to approximate the full parameter fine-tuning gradient. Moreover, the memory consumption of LoRA-FA is not sensitive to the rank (since it erases the activation of $A$), therefore it can improve performance by enlarging lora rank without increasing memory consumption.
|
||||
|
||||
```py
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from peft.optimizers import create_lorafa_optimizer
|
||||
from transformers import Trainer, get_cosine_schedule_with_warmup
|
||||
|
||||
base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
|
||||
|
||||
config = LoraConfig(...)
|
||||
model = get_peft_model(base_model, config)
|
||||
|
||||
optimizer = create_lorafa_optimizer(
|
||||
model=model,
|
||||
r=128,
|
||||
lora_alpha=32,
|
||||
lr=7e-5,
|
||||
)
|
||||
|
||||
scheduler = get_cosine_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=100,
|
||||
num_training_steps=1000,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
...,
|
||||
optimizers=(optimizer, scheduler),
|
||||
)
|
||||
```
|
||||
|
||||
### LoRA+ optimized LoRA
|
||||
|
||||
|
115
examples/lorafa_finetune/README.md
Normal file
115
examples/lorafa_finetune/README.md
Normal file
@ -0,0 +1,115 @@
|
||||
# LoRA-FA: Memory-efficient Low-rank Adaptation for Large Language Models Fine-tuning
|
||||
|
||||
## Introduction
|
||||
|
||||
[LoRA-FA](https://arxiv.org/abs/2308.03303) is a noval Parameter-efficient Fine-tuning method, which freezes the projection down layer (matrix A) during LoRA training process and thus lead to less GPU memory consumption by eliminating the need for storing the activations of input tensors (X). Furthermore, LoRA-FA narrows the gap between the update amount of pre-trained weights when using the low-rank fine-tuning method and the full fine-tuning method. In conclusion, LoRA-FA reduces the memory consumption and leads to superior performance compared to vanilla LoRA.
|
||||
|
||||
## Quick start
|
||||
|
||||
```python
|
||||
import torch
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from peft.optimizers import create_lorafa_optimizer
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer
|
||||
from datasets import load_dataset
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
|
||||
dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")
|
||||
|
||||
lora_rank = 16
|
||||
lora_alpha = 32
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
bias="none",
|
||||
)
|
||||
peft_model = get_peft_model(model, lora_config)
|
||||
optimizer = create_lorafa_optimizer(
|
||||
model=peft_model,
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
lr=7e-5,
|
||||
)
|
||||
# you can also use scheduler, we recommend get_cosine_schedule_with_warmup from transformers
|
||||
# for better model performance
|
||||
scheduler = None
|
||||
|
||||
trainer = transformers.Trainer(
|
||||
model=peft_model,
|
||||
train_dataset=dataset,
|
||||
dataset_text_field="text",
|
||||
max_seq_length=2048,
|
||||
tokenizer=tokenizer,
|
||||
optimizers=(optimizer, None),
|
||||
)
|
||||
trainer.train()
|
||||
peft_model.save_pretrained("lorafa-llama-3-8b-inst")
|
||||
```
|
||||
|
||||
The only change in your code is to pass the LoRA-FA optimizer to the trainer (if training with trainer). Do not forget `from peft.optimizers import create_lorafa_optimizer`!
|
||||
|
||||
## Example
|
||||
|
||||
In this dir, we also provide you a simple example for fine-tuning with LoRA-FA optimizer.
|
||||
|
||||
### Run on CPU, single-GPU or multi-GPU
|
||||
|
||||
This 👇 by default will load the model in peft set up with LoRA config, and train the model with LoRA-FA optimizer.
|
||||
|
||||
0. CPU
|
||||
|
||||
You can simply run LoRA-FA as below:
|
||||
|
||||
```bash
|
||||
python lorafa_finetuning.py --base_model_name_or_path meta-llama/Meta-Llama-3-8B --dataset_name_or_path meta-math/MetaMathQA-40K --output_dir path/to/output --lorafa
|
||||
```
|
||||
|
||||
1. Single-GPU
|
||||
|
||||
Run the finetuning script on 1 GPU:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python lorafa_finetuning.py --base_model_name_or_path meta-llama/Meta-Llama-3-8B --dataset_name_or_path meta-math/MetaMathQA-40K --output_dir path/to/output --lorafa
|
||||
```
|
||||
|
||||
2. Multi-GPU
|
||||
|
||||
LoRA-FA can also be run on multi-GPU, with 🤗 Accelerate:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 accelerate launch lorafa_finetuning.py --base_model_name_or_path meta-llama/Meta-Llama-3-8B --dataset_name_or_path meta-math/MetaMathQA-40K --output_dir path/to/output --lorafa
|
||||
```
|
||||
|
||||
The `accelerate launch` will automatically configure multi-GPU for you. You can also utilize `accelerate launch` in single-GPU scenario.
|
||||
|
||||
### Use the model from 🤗
|
||||
You can load and use the model as any other 🤗 models.
|
||||
```python
|
||||
from transformers import AutoModel
|
||||
model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
|
||||
```
|
||||
|
||||
## Best practice in fine-tuning Llama using LoRA-FA: the hyper-params
|
||||
|
||||
Sometimes, achieving optimal LoRA fine-tuning can be challenging due to the larger number of hyperparameters to consider compared to full fine-tuning. For instance, not only do we need to adjust the commonly used learning rate, but the ideal LoRA rank may also vary depending on the specific model and task. Additionally, there are other factors to consider, such as LoRA alpha and sequence length. To assist with this, we have created a repository of reproducible best practices in the [LoRA-FA examples](https://github.com/AaronZLT/lorafa) for reference. This resource showcases the optimal LoRA-FA fine-tuning hyperparameters for different models across various datasets. By doing so, we significantly reduce the time and effort spent on hyperparameter tuning, and it may also provide insights for tuning other training hyperparameters. We encourage you to experiment and fine-tune on your own downstream tasks as well.
|
||||
|
||||
## LoRA-FA's advantages and limitations
|
||||
|
||||
By eliminating the activation of adapter A, LoRA-FA uses less memory for fine-tuning compared to LoRA. For instance, when fine-tuning Llama-2-7b-chat-hf with a batch size of 8 and a sequence length of 1024, LoRA-FA requires 36GB of memory to store activations. This allows it to run successfully on an 80GB GPU. In contrast, LoRA requires at least 60GB of memory for activations, leading to an Out of Memory (OOM) error. Additionally, the memory consumption of LoRA-FA is not sensitive to the rank, allowing for performance improvements by increasing the LoRA rank without additional memory usage. LoRA-FA further narrows the performance gap with Full-FT by minimizing the discrepancy between the low-rank gradient and the full gradient, enabling it to achieve performance that is on par with or even superior to vanilla LoRA.
|
||||
|
||||
Despite its advantages, LoRA-FA is inherently limited by its low-rank approximation nature and potential issues with catastrophic forgetting. The gradient approximation can impact training throughput. Addressing these limitations, especially in terms of approximation accuracy and forgetting phenomena, presents a promising direction for future research.
|
||||
|
||||
## Citation
|
||||
```
|
||||
@misc{zhang2023lorafamemoryefficientlowrankadaptation,
|
||||
title={LoRA-FA: Memory-efficient Low-rank Adaptation for Large Language Models Fine-tuning},
|
||||
author={Longteng Zhang and Lin Zhang and Shaohuai Shi and Xiaowen Chu and Bo Li},
|
||||
year={2023},
|
||||
eprint={2308.03303},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL},
|
||||
url={https://arxiv.org/abs/2308.03303},
|
||||
}
|
||||
```
|
214
examples/lorafa_finetune/lorafa_finetuning.py
Normal file
214
examples/lorafa_finetune/lorafa_finetuning.py
Normal file
@ -0,0 +1,214 @@
|
||||
# Copyright 2025-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
DataCollatorForLanguageModeling,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
||||
from peft.optimizers import create_lorafa_optimizer
|
||||
|
||||
|
||||
def train_model(
|
||||
base_model_name_or_path: str,
|
||||
dataset_name_or_path: str,
|
||||
output_dir: str,
|
||||
batch_size: int,
|
||||
num_epochs: int,
|
||||
lr: float,
|
||||
cutoff_len: int,
|
||||
quantize: bool,
|
||||
eval_step: int,
|
||||
save_step: int,
|
||||
lora_rank: int,
|
||||
lora_alpha: int,
|
||||
lora_dropout: float,
|
||||
lora_target_modules: Optional[str],
|
||||
lorafa: bool,
|
||||
):
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
compute_dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
|
||||
device_map = "cuda" if torch.cuda.is_available() else None
|
||||
|
||||
# load tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path)
|
||||
|
||||
# load model
|
||||
if quantize:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model_name_or_path,
|
||||
quantization_config=BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=compute_dtype,
|
||||
bnb_4bit_use_double_quant=False,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
),
|
||||
torch_dtype=compute_dtype,
|
||||
device_map=device_map,
|
||||
)
|
||||
# setup for quantized training
|
||||
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model_name_or_path, torch_dtype=compute_dtype, device_map=device_map
|
||||
)
|
||||
|
||||
# LoRA config for the PEFT model
|
||||
if lora_target_modules is not None:
|
||||
if lora_target_modules == "all-linear":
|
||||
target_modules = "all-linear"
|
||||
else:
|
||||
target_modules = lora_target_modules.split(",")
|
||||
else:
|
||||
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=target_modules,
|
||||
lora_dropout=lora_dropout,
|
||||
bias="none",
|
||||
)
|
||||
|
||||
# get the peft model with LoRA config
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# Load the dataset
|
||||
dataset = load_dataset(dataset_name_or_path)
|
||||
|
||||
def tokenize_function(examples):
|
||||
inputs = tokenizer(examples["query"], padding="max_length", truncation=True, max_length=cutoff_len)
|
||||
outputs = tokenizer(examples["response"], padding="max_length", truncation=True, max_length=cutoff_len)
|
||||
inputs["labels"] = outputs["input_ids"].copy()
|
||||
return inputs
|
||||
|
||||
# Tokenize the dataset and prepare for training
|
||||
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)
|
||||
dataset = tokenized_datasets["train"].train_test_split(test_size=0.1, shuffle=True, seed=42)
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"]
|
||||
|
||||
# Data collator to dynamically pad the batched examples
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
|
||||
|
||||
# Define training arguments
|
||||
training_args = TrainingArguments(
|
||||
output_dir=output_dir,
|
||||
num_train_epochs=num_epochs,
|
||||
per_device_train_batch_size=batch_size,
|
||||
per_device_eval_batch_size=batch_size,
|
||||
warmup_steps=100,
|
||||
weight_decay=0.01,
|
||||
logging_dir="./logs",
|
||||
logging_steps=eval_step,
|
||||
save_steps=save_step,
|
||||
save_total_limit=2,
|
||||
gradient_accumulation_steps=1,
|
||||
bf16=True if compute_dtype == torch.bfloat16 else False,
|
||||
fp16=True if compute_dtype == torch.float16 else False,
|
||||
learning_rate=lr,
|
||||
)
|
||||
|
||||
# Here we initialize the LoRA-FA Optimizer
|
||||
# After this, all adapter A will be fixed, only adapter B will be trainable
|
||||
if lorafa:
|
||||
optimizer = create_lorafa_optimizer(
|
||||
model=model, r=lora_rank, lora_alpha=lora_alpha, lr=lr, weight_decay=training_args.weight_decay
|
||||
)
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
data_collator=data_collator,
|
||||
optimizers=(optimizer, None),
|
||||
)
|
||||
else:
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
data_collator=data_collator,
|
||||
)
|
||||
|
||||
# Start model training
|
||||
trainer.train()
|
||||
|
||||
# Save the model and tokenizer locally
|
||||
model.save_pretrained(output_dir)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Fine-tune Meta-Llama-3-8B-Instruct with LoRA-FA and PEFT")
|
||||
parser.add_argument(
|
||||
"--base_model_name_or_path",
|
||||
type=str,
|
||||
default="meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
help="Base model name or path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_name_or_path", type=str, default="meta-math/MetaMathQA-40K", help="Dataset name or path"
|
||||
)
|
||||
parser.add_argument("--output_dir", type=str, help="Output directory for the fine-tuned model")
|
||||
parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
|
||||
parser.add_argument("--num_epochs", type=int, default=3, help="Number of training epochs")
|
||||
parser.add_argument("--lr", type=float, default=7e-5, help="Learning rate")
|
||||
parser.add_argument("--cutoff_len", type=int, default=1024, help="Cutoff length for tokenization")
|
||||
parser.add_argument("--quantize", action="store_true", help="Use quantization")
|
||||
parser.add_argument("--eval_step", type=int, default=10, help="Evaluation step interval")
|
||||
parser.add_argument("--save_step", type=int, default=100, help="Save step interval")
|
||||
parser.add_argument("--lora_rank", type=int, default=16, help="LoRA rank")
|
||||
parser.add_argument("--lora_alpha", type=int, default=32, help="LoRA alpha")
|
||||
parser.add_argument("--lora_dropout", type=float, default=0.05, help="LoRA dropout rate")
|
||||
parser.add_argument(
|
||||
"--lora_target_modules", type=str, default=None, help="Comma-separated list of target modules for LoRA"
|
||||
)
|
||||
parser.add_argument("--lorafa", action="store_true", help="Use LoRA-FA Optimizer")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
train_model(
|
||||
base_model_name_or_path=args.base_model_name_or_path,
|
||||
dataset_name_or_path=args.dataset_name_or_path,
|
||||
output_dir=args.output_dir,
|
||||
batch_size=args.batch_size,
|
||||
num_epochs=args.num_epochs,
|
||||
lr=args.lr,
|
||||
cutoff_len=args.cutoff_len,
|
||||
quantize=args.quantize,
|
||||
eval_step=args.eval_step,
|
||||
save_step=args.save_step,
|
||||
lora_rank=args.lora_rank,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout,
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
lorafa=args.lorafa,
|
||||
)
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2024-present the HuggingFace Inc. team.
|
||||
# Copyright 2025-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,7 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .lorafa import create_lorafa_optimizer
|
||||
from .loraplus import create_loraplus_optimizer
|
||||
|
||||
|
||||
__all__ = ["create_loraplus_optimizer"]
|
||||
__all__ = ["create_lorafa_optimizer", "create_loraplus_optimizer"]
|
||||
|
253
src/peft/optimizers/lorafa.py
Normal file
253
src/peft/optimizers/lorafa.py
Normal file
@ -0,0 +1,253 @@
|
||||
# Copyright 2025-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This module contains the implementation of the LoRA-FA optimizer.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from ..peft_model import PeftModel
|
||||
|
||||
|
||||
class LoraFAOptimizer(Optimizer):
|
||||
"""
|
||||
Implements the LoRA-FA optimizer designed specifically for training Low-Rank Adaptation (LoRA) parameters
|
||||
efficiently. Note that LoraFAOptimizer is based on adamw-hf in transformers, with only LoRA part modified. Without
|
||||
LoRA it will fall back to adamw-hf.
|
||||
|
||||
Args:
|
||||
params (Iterable[nn.parameter.Parameter]): Parameters to optimize.
|
||||
lr (float, optional): Learning rate (default: 1e-3).
|
||||
betas (Tuple[float, float], optional):
|
||||
Coefficients for computing running averages of gradient and squared gradient (default: (0.9, 0.999)).
|
||||
eps (float, optional): Term added to denominator to improve numerical stability (default: 1e-6).
|
||||
weight_decay (float, optional): Weight decay (L2 penalty) (default: 0.0).
|
||||
correct_bias (bool, optional): Whether to apply bias correction as in original Adam (default: True).
|
||||
|
||||
Args in sub-function step:
|
||||
closure (Callable, optional): A closure that reevaluates the model and returns the loss.
|
||||
|
||||
Reference:
|
||||
- LoRA-FA: https://arxiv.org/abs/2308.03303
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params: Iterable[nn.parameter.Parameter],
|
||||
lr: float = 1e-3,
|
||||
betas: tuple[float, float] = (0.9, 0.999),
|
||||
eps: float = 1e-6,
|
||||
weight_decay: float = 0.0,
|
||||
correct_bias: bool = True,
|
||||
):
|
||||
if lr < 0.0:
|
||||
raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0")
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)")
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)")
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0")
|
||||
defaults = {
|
||||
"lr": lr,
|
||||
"betas": betas,
|
||||
"eps": eps,
|
||||
"weight_decay": weight_decay,
|
||||
"correct_bias": correct_bias,
|
||||
}
|
||||
super().__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure: Callable = None):
|
||||
"""
|
||||
Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
scaling_factor = group["scaling_factor"]
|
||||
param_list = []
|
||||
name_list = []
|
||||
for p, n in zip(group["params"], group["names"]):
|
||||
# Skip non-lora no-grad module, since we need lora_A which is no-grad.
|
||||
if "lora" not in n and p.grad is None:
|
||||
continue
|
||||
grad = p.grad
|
||||
|
||||
if "lora" in n:
|
||||
param_list.append(p)
|
||||
name_list.append(n)
|
||||
if len(param_list) == 2:
|
||||
name = n[: n.find("lora")] + "lora"
|
||||
elif len(param_list) == 1:
|
||||
continue
|
||||
else:
|
||||
name = n
|
||||
# param_list contains a pair of A and B adapters
|
||||
# i.e., param_list -> [A,B]
|
||||
|
||||
state = self.state[name]
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
if len(param_list) == 2:
|
||||
state["step"] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg_B"] = torch.zeros_like(param_list[1])
|
||||
# Exponential moving average of squared gradient values
|
||||
state["exp_avg_sq_B"] = torch.zeros_like(param_list[1])
|
||||
else:
|
||||
state["step"] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(p)
|
||||
# Exponential moving average of squared gradient values
|
||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||
|
||||
# Below is the LoRA-FA part
|
||||
# 1. In this part, we optimize the gradient of B as:
|
||||
# g^B = \left(\frac{r}{\alpha}\right)^2 (A^\top A)^{-1} g_{\text{LoRA-FA}}^B
|
||||
# to min the func as described below:
|
||||
# \min_{g^B} \|\hat{g}_\text{LoRA-FA} - g\|_F^2
|
||||
# 2. After the gradient of B is ready, update the optimizer state
|
||||
if len(param_list) == 2:
|
||||
A = param_list[0]
|
||||
B = param_list[1]
|
||||
grad_B_orin = B.grad
|
||||
|
||||
# projection
|
||||
delta = 1e-8
|
||||
|
||||
# computing the inverse matrix
|
||||
AA_T = A @ A.T
|
||||
AA_T_inv = torch.linalg.pinv(AA_T + delta * torch.eye(A.shape[0]).to(A.device))
|
||||
|
||||
if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
|
||||
with autocast(dtype=torch.bfloat16):
|
||||
grad_B = (1 / scaling_factor**2) * (grad_B_orin @ AA_T_inv)
|
||||
else:
|
||||
grad_B = (1 / scaling_factor**2) * (grad_B_orin @ AA_T_inv)
|
||||
|
||||
if grad_B.dtype != B.grad.dtype:
|
||||
grad_B = grad_B.to(B.grad.dtype)
|
||||
|
||||
exp_avg_B, exp_avg_sq_B = state["exp_avg_B"], state["exp_avg_sq_B"]
|
||||
beta1, beta2 = group["betas"]
|
||||
state["step"] += 1
|
||||
exp_avg_B.mul_(beta1).add_(grad_B, alpha=(1.0 - beta1))
|
||||
exp_avg_sq_B.mul_(beta2).addcmul_(grad_B, grad_B, value=1.0 - beta2)
|
||||
|
||||
denom_B = exp_avg_sq_B.sqrt().add_(group["eps"])
|
||||
step_size = group["lr"]
|
||||
if group["correct_bias"]: # No bias correction for Bert
|
||||
bias_correction1 = 1.0 - beta1 ** state["step"]
|
||||
bias_correction2 = 1.0 - beta2 ** state["step"]
|
||||
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
|
||||
B.addcdiv_(exp_avg_B, denom_B, value=-step_size)
|
||||
if group["weight_decay"] > 0.0:
|
||||
B.add_(B, alpha=(-group["lr"] * group["weight_decay"]))
|
||||
param_list = []
|
||||
name_list = []
|
||||
|
||||
# Below is the original AdamW
|
||||
else:
|
||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
state["step"] += 1
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
# In-place operations to update the averages at the same time
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
|
||||
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
||||
|
||||
step_size = group["lr"]
|
||||
if group["correct_bias"]: # No bias correction for Bert
|
||||
bias_correction1 = 1.0 - beta1 ** state["step"]
|
||||
bias_correction2 = 1.0 - beta2 ** state["step"]
|
||||
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
|
||||
|
||||
p.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
# Just adding the square of the weights to the loss function is *not*
|
||||
# the correct way of using L2 regularization/weight decay with Adam,
|
||||
# since that will interact with the m and v parameters in strange ways.
|
||||
#
|
||||
# Instead we want to decay the weights in a manner that doesn't interact
|
||||
# with the m/v parameters. This is equivalent to adding the square
|
||||
# of the weights to the loss with plain (non-momentum) SGD.
|
||||
# Add weight decay at the end (fixed version)
|
||||
if group["weight_decay"] > 0.0:
|
||||
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def create_lorafa_optimizer(
|
||||
model: PeftModel, r: int, lora_alpha: int, lr: float, weight_decay: float = 0.0, use_rslora: bool = False
|
||||
) -> Optimizer:
|
||||
"""
|
||||
Helper function to instantiate a lorafa optimizer specifically configured for a given model using the LoRA method.
|
||||
|
||||
This function will:
|
||||
- Disable gradient updates for the "lora_A" parameters (these are typically frozen during LoRA training).
|
||||
- Compute the scaling factor based on provided `lora_alpha` and rank `r` for proper gradient projection.
|
||||
- Create and configure parameter groups for the optimizer including specified learning rate, weight decay, and
|
||||
additional optimizer options.
|
||||
|
||||
For hyper-params, LoRA-FA uses the same hyper-params as AdamW, except for the LoRA hyper-params (r, lora_alpha,
|
||||
use_rslora). One can always use the same hyper-params such as lr and weight_decay, as AdamW in LoRA tuning.
|
||||
|
||||
Args:
|
||||
model (PeftModel): The model containing LoRA-adapted parameters.
|
||||
r (int): Rank of the LoRA decomposition.
|
||||
lora_alpha (int): Scaling factor for LoRA parameterization.
|
||||
lr (float): Learning rate for optimizer updates.
|
||||
weight_decay (float): Weight decay for AdamW.
|
||||
use_rslora (bool):
|
||||
whether to use rslora. In rslora, the lora scaling factor becomes to lora_alpha / math.sqrt(r) instead of
|
||||
lora_alpha / r.
|
||||
|
||||
Returns:
|
||||
Optimizer: Configured lorafa optimizer instance ready for training.
|
||||
"""
|
||||
for name, param in model.named_parameters():
|
||||
if "lora_A" in name:
|
||||
param.requires_grad_(False)
|
||||
lora_scaling = lora_alpha / math.sqrt(r) if use_rslora else lora_alpha / r
|
||||
param_groups = [
|
||||
{
|
||||
"params": model.parameters(),
|
||||
"lr": lr,
|
||||
"names": [name for name, _ in model.named_parameters()],
|
||||
"scaling_factor": lora_scaling,
|
||||
"betas": (0.9, 0.999),
|
||||
"weight_decay": weight_decay,
|
||||
}
|
||||
]
|
||||
return LoraFAOptimizer(param_groups)
|
152
tests/test_lorafa.py
Normal file
152
tests/test_lorafa.py
Normal file
@ -0,0 +1,152 @@
|
||||
# Copyright 2025-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from peft.optimizers import create_lorafa_optimizer
|
||||
|
||||
from .testing_utils import torch_device
|
||||
|
||||
|
||||
class SimpleNet(nn.Module):
|
||||
def __init__(self, bias=True):
|
||||
super().__init__()
|
||||
self.embedding = nn.Embedding(100, 20)
|
||||
self.layer_norm = nn.LayerNorm(20)
|
||||
self.lin0 = nn.Linear(20, 20, bias=bias)
|
||||
self.relu = nn.ReLU()
|
||||
self.lin1 = nn.Linear(20, 16, bias=bias)
|
||||
|
||||
def forward(self, X):
|
||||
X = self.lin0(self.layer_norm(self.embedding(X)))
|
||||
X = self.relu(X)
|
||||
X = self.lin1(X)
|
||||
return X
|
||||
|
||||
|
||||
def test_lorafa_init_default():
|
||||
"""
|
||||
Test if the optimizer is correctly created
|
||||
"""
|
||||
lora_rank = 16
|
||||
lora_alpha = 32
|
||||
lr = 7e-5
|
||||
|
||||
model = SimpleNet()
|
||||
config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=["lin0", "lin1"],
|
||||
bias="none",
|
||||
)
|
||||
model = get_peft_model(model, config)
|
||||
optimizer = create_lorafa_optimizer(model=model, r=lora_rank, lora_alpha=lora_alpha, lr=lr)
|
||||
|
||||
assert math.isclose(optimizer.param_groups[0]["scaling_factor"], lora_alpha / lora_rank, rel_tol=1e-9, abs_tol=0.0)
|
||||
|
||||
all_A_fixed = True
|
||||
all_B_trainable = True
|
||||
|
||||
assert optimizer is not None
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if "lora_A" in name:
|
||||
all_A_fixed &= not param.requires_grad
|
||||
elif "lora_B" in name:
|
||||
all_B_trainable &= param.requires_grad
|
||||
|
||||
assert all_A_fixed and all_B_trainable
|
||||
|
||||
|
||||
def test_lorafa_init_rslora():
|
||||
"""
|
||||
Test if the optimizer is correctly created when use_rslora = True
|
||||
"""
|
||||
lora_rank = 16
|
||||
lora_alpha = 32
|
||||
lr = 7e-5
|
||||
|
||||
model = SimpleNet()
|
||||
config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=["lin0", "lin1"],
|
||||
bias="none",
|
||||
)
|
||||
model = get_peft_model(model, config)
|
||||
optimizer = create_lorafa_optimizer(model=model, r=lora_rank, lora_alpha=lora_alpha, lr=lr, use_rslora=True)
|
||||
assert math.isclose(
|
||||
optimizer.param_groups[0]["scaling_factor"], lora_alpha / math.sqrt(lora_rank), rel_tol=1e-9, abs_tol=0.0
|
||||
)
|
||||
|
||||
|
||||
def test_LoraFAOptimizer_step():
|
||||
"""
|
||||
Test if the optimizer's step function runs without any exception and checks specific conditions on lora_A and
|
||||
lora_B weights.
|
||||
"""
|
||||
lora_rank = 16
|
||||
lora_alpha = 32
|
||||
lr = 7e-5
|
||||
num_steps = 5
|
||||
|
||||
model = SimpleNet()
|
||||
config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=["lin0", "lin1"],
|
||||
bias="none",
|
||||
)
|
||||
model = get_peft_model(model, config).to(torch_device)
|
||||
optimizer = create_lorafa_optimizer(model=model, r=16, lora_alpha=32, lr=7e-5)
|
||||
loss = torch.nn.CrossEntropyLoss()
|
||||
|
||||
# Save initial weights of lora_A
|
||||
initial_lora_A_weights = {name: param.clone() for name, param in model.named_parameters() if "lora_A" in name}
|
||||
# Ensure lora_B is initialized to zero
|
||||
for name, param in model.named_parameters():
|
||||
if "lora_B" in name:
|
||||
assert torch.all(param == 0), f"lora_B weights not initialized to zero for {name}"
|
||||
|
||||
for _ in range(num_steps): # Run the optimizer step multiple times
|
||||
# Generate random input and label for each step
|
||||
x = torch.randint(100, (2, 4, 10)).to(torch_device)
|
||||
output = model(x).permute(0, 3, 1, 2)
|
||||
label = torch.randint(16, (2, 4, 10)).to(torch_device)
|
||||
|
||||
# Calculate loss and perform backward pass
|
||||
loss_value = loss(output, label)
|
||||
loss_value.backward()
|
||||
|
||||
# Perform optimizer step
|
||||
optimizer.step()
|
||||
|
||||
# Zero the gradients after each step to prevent accumulation
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Check if lora_A weights have not changed
|
||||
for name, param in model.named_parameters():
|
||||
if "lora_A" in name:
|
||||
assert torch.equal(param, initial_lora_A_weights[name]), f"lora_A weights changed for {name}"
|
||||
|
||||
# Check if lora_B weights are non-zero
|
||||
for name, param in model.named_parameters():
|
||||
if "lora_B" in name:
|
||||
assert torch.any(param != 0), f"lora_B weights are still zero for {name}"
|
Reference in New Issue
Block a user