Files
peft/tests/test_lorafa.py
Aaron Chung 0c2bdbb11a FEAT Add LoRA-FA to PEFT (#2468)
Adds LoRA with frozen A (LoRA-FA) to PEFT.

Paper: https://arxiv.org/abs/2308.03303
2025-04-10 10:53:19 +02:00

153 lines
4.7 KiB
Python

# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
import torch
from torch import nn
from peft import LoraConfig, get_peft_model
from peft.optimizers import create_lorafa_optimizer
from .testing_utils import torch_device
class SimpleNet(nn.Module):
def __init__(self, bias=True):
super().__init__()
self.embedding = nn.Embedding(100, 20)
self.layer_norm = nn.LayerNorm(20)
self.lin0 = nn.Linear(20, 20, bias=bias)
self.relu = nn.ReLU()
self.lin1 = nn.Linear(20, 16, bias=bias)
def forward(self, X):
X = self.lin0(self.layer_norm(self.embedding(X)))
X = self.relu(X)
X = self.lin1(X)
return X
def test_lorafa_init_default():
"""
Test if the optimizer is correctly created
"""
lora_rank = 16
lora_alpha = 32
lr = 7e-5
model = SimpleNet()
config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
target_modules=["lin0", "lin1"],
bias="none",
)
model = get_peft_model(model, config)
optimizer = create_lorafa_optimizer(model=model, r=lora_rank, lora_alpha=lora_alpha, lr=lr)
assert math.isclose(optimizer.param_groups[0]["scaling_factor"], lora_alpha / lora_rank, rel_tol=1e-9, abs_tol=0.0)
all_A_fixed = True
all_B_trainable = True
assert optimizer is not None
for name, param in model.named_parameters():
if "lora_A" in name:
all_A_fixed &= not param.requires_grad
elif "lora_B" in name:
all_B_trainable &= param.requires_grad
assert all_A_fixed and all_B_trainable
def test_lorafa_init_rslora():
"""
Test if the optimizer is correctly created when use_rslora = True
"""
lora_rank = 16
lora_alpha = 32
lr = 7e-5
model = SimpleNet()
config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
target_modules=["lin0", "lin1"],
bias="none",
)
model = get_peft_model(model, config)
optimizer = create_lorafa_optimizer(model=model, r=lora_rank, lora_alpha=lora_alpha, lr=lr, use_rslora=True)
assert math.isclose(
optimizer.param_groups[0]["scaling_factor"], lora_alpha / math.sqrt(lora_rank), rel_tol=1e-9, abs_tol=0.0
)
def test_LoraFAOptimizer_step():
"""
Test if the optimizer's step function runs without any exception and checks specific conditions on lora_A and
lora_B weights.
"""
lora_rank = 16
lora_alpha = 32
lr = 7e-5
num_steps = 5
model = SimpleNet()
config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
target_modules=["lin0", "lin1"],
bias="none",
)
model = get_peft_model(model, config).to(torch_device)
optimizer = create_lorafa_optimizer(model=model, r=16, lora_alpha=32, lr=7e-5)
loss = torch.nn.CrossEntropyLoss()
# Save initial weights of lora_A
initial_lora_A_weights = {name: param.clone() for name, param in model.named_parameters() if "lora_A" in name}
# Ensure lora_B is initialized to zero
for name, param in model.named_parameters():
if "lora_B" in name:
assert torch.all(param == 0), f"lora_B weights not initialized to zero for {name}"
for _ in range(num_steps): # Run the optimizer step multiple times
# Generate random input and label for each step
x = torch.randint(100, (2, 4, 10)).to(torch_device)
output = model(x).permute(0, 3, 1, 2)
label = torch.randint(16, (2, 4, 10)).to(torch_device)
# Calculate loss and perform backward pass
loss_value = loss(output, label)
loss_value.backward()
# Perform optimizer step
optimizer.step()
# Zero the gradients after each step to prevent accumulation
optimizer.zero_grad()
# Check if lora_A weights have not changed
for name, param in model.named_parameters():
if "lora_A" in name:
assert torch.equal(param, initial_lora_A_weights[name]), f"lora_A weights changed for {name}"
# Check if lora_B weights are non-zero
for name, param in model.named_parameters():
if "lora_B" in name:
assert torch.any(param != 0), f"lora_B weights are still zero for {name}"