FEAT Add DeLoRA (#2780)

Implements DeLoRA: "Decoupling Angles and Strength in Low-rank
Adaptation" (https://huggingface.co/papers/2503.18225).

Similar to DoRA, DeLoRA decouples the angular learning from the
adaptation strength, but it also allows to limit the norm of the change.
This way, DeLoRA promises to reduce the risk of catastrophic forgetting
and to be more robust to hyper-parameter settings such as the learning
rate.
This commit is contained in:
Massimo Bini
2025-10-17 15:24:46 +01:00
committed by GitHub
parent 8d8aa0b716
commit 2813b9c4bf
23 changed files with 1059 additions and 2 deletions

View File

@ -136,6 +136,8 @@
title: RoAd title: RoAd
- local: package_reference/waveft - local: package_reference/waveft
title: WaveFT title: WaveFT
- local: package_reference/delora
title: DeLoRA
title: Adapters title: Adapters
- sections: - sections:

View File

@ -0,0 +1,35 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# DeLoRA: Decoupled Low-rank Adaptation
[DeLoRA](https://huggingface.co/papers/2503.18225) is a parameter-efficient fine-tuning technique that implicitly maintains a Frobenius boundary with respect to the pretrained weights by normalizing and scaling learnable low-rank matrices. This effectively decouples the learning of directions (BA term) and magnitude (boundary term) of the weight updates, avoiding catastrophic shifts in the adapted weights and enhancing robustness to hyperparameter choices.
Note:
- use 10-100x larger learning rate than standard LoRA variants (typical values from 1e-3/1e-2/..)
- do not set a too small initial boundary parameter lambda (typical values are around 10/15/..)
- setting different lambdas to different layers is possible
The abstract from the paper is:
> Parameter-Efficient FineTuning (PEFT) methods have recently gained significant popularity thanks to the widespread availability of large-scale pretrained models. These methods allow for quick adaptation to downstream tasks with minimal computational cost. However, popular finetuning methods such as LoRA exhibit limited robustness when it comes to hyperparameter choices or extended training regimes, preventing optimal out-of-the-box performance. In contrast, bounded approaches, such as ETHER, provide greater robustness but are limited to extremely low-rank adaptations and fixed-strength transformations, reducing their adaptation expressive power. In this work, we propose Decoupled Low-rank Adaptation (DeLoRA), a novel finetuning method that normalizes and scales learnable low-rank matrices. By bounding the distance of the transformation, DeLoRA effectively decouples the angular learning from the adaptation strength, enhancing robustness without compromising performance. Through evaluations on subject-driven image generation, natural language understanding, and instruction tuning, we show that DeLoRA matches or surpasses performance of competing PEFT methods, while exhibiting stronger robustness.
## DeloraConfig
[[autodoc]] tuners.delora.config.DeloraConfig
## DeloraModel
[[autodoc]] tuners.delora.model.DeloraModel

View File

@ -0,0 +1,102 @@
# DeLoRA: Decoupled Low-Rank Adaptation
## Introduction
[DeLoRA](https://huggingface.co/papers/2503.18225) tackles finetuning in a Frobenius-norm bounded setup: this allows to prevent divergence from the pretrained model, effectively decoupling the learning of angles and magnitudes.
This is done by (i) normalization of the BA low-rank matrices, which bound the updates' Frobenius norm, (ii) learnable scaling lambda, which controls the update's boundary/magnitude, (iii) layer-wise scaling of ||W||, to adapt each update's norm to the original weights' norm.
## Quick start
With respect to your standard PEFT training procedure with LoRA, simply swap your `LoraConfig` for a `DeloraConfig`. Note however that `lora_alpha` parameter is replaced by `delora_lambda` parameter which sets an upper bound to the Frobenius norm of the weight change.
```python
import torch
from peft import DeloraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B", dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
tokenizer.pad_token_id = tokenizer.eos_token_id
delora_config = DeloraConfig(r=32, delora_lambda=15)
peft_model = get_peft_model(model, delora_config)
peft_model.print_trainable_parameters()
dataset = load_dataset("imdb", split="train[:1%]")
training_args = SFTConfig(dataset_text_field="text", max_seq_length=128)
trainer = SFTTrainer(
model=peft_model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer,
)
trainer.train()
peft_model.save_pretrained("delora-llama-3-8b")
```
To utilize the fine-tuned DeLoRA modules, simply run the following command:
```python
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-8B", dtype=torch.bfloat16, device_map="auto"
)
peft_model = PeftModel.from_pretrained(model, "delora-llama-3-8b")
```
## Advanced Usage
In this script the default DeLoRA layers are the query and value layers of the Llama model. Adding adapters on more layers will increase memory usage. If you wish to choose a different set of layers for DeLoRA to be applied on, you can simply define it using:
```bash
python examples/delora_finetuning/delora_finetuning.py --base_model meta-llama/Meta-Llama-3-8B --delora_target_modules "q_proj,k_proj,v_proj,o_proj"
```
Using different lambdas for different layers is also possible by setting `lambda_pattern`.
### Fine-tune
```bash
python delora_finetuning.py \
--base_model "PATH_TO_MODEL" \
--data_path "PATH_TO_DATASET" \
--output_dir "PATH_TO_OUTPUT_DIR" \
--batch_size 1 \
--num_epochs 3 \
--learning_rate 3e-3 \
--cutoff_len 512 \
--val_set_size 500 \
--eval_step 10 \
--save_step 100 \
--device "auto" \
--rank 32 \
--delora_lambda 15 \
--module_dropout 0.1 \
--delora_target_modules "q_proj,v_proj" \
--hub_model_id "YOUR_HF_REPO" \
--push_to_hub
```
## Additional Notes
### Best practices
- use 10-100x larger learning rate than standard LoRA variants (typical values from 1e-3/1e-2/..)
- do not set a too small initial boundary parameter lambda (typical values are around 10/15/..)
### DeLoRA vs DoRA
DeLoRA might feel quite similar to DoRA (given the similar target of decoupling angular from magnitude learning), however it presents key differences: (i) DoRA applies normalization and scaling operations on the fully finetuned weights ($W + \Delta W$), (ii) DoRA's normalization operation is performed on the column space of the weight matrices.
Conversely DeLoRA (i) introduces the normalization and scaling operations directly on the weight updates $\Delta W$, better preventing divergence from the pretrained model, and (ii) normalizes the inner low-dimensional space, which enforces a Frobenius-norm boundary to the weight updates.
## Citation
```
@inproceedings{bini2025decouplinganglesstrengthlowrank,
title={Decoupling Angles and Strength in Low-rank Adaptation},
author={Massimo Bini and Leander Girrbach and Zeynep Akata},
year={2025},
booktitle={International Conference on Learning Representations (ICLR)},
}
```

View File

@ -0,0 +1,189 @@
# This script is based on examples/randlora_finetuning/randlora_finetuning.py
import os
import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)
from peft import DeloraConfig, get_peft_model
def train_model(
base_model: str,
data_path: str,
output_dir: str,
batch_size: int,
num_epochs: int,
learning_rate: float,
cutoff_len: int,
val_set_size: int,
eval_step: int,
save_step: int,
device: str,
rank: int,
delora_lambda: int,
module_dropout: float,
target_modules: str,
hub_model_id: str,
push_to_hub: bool,
):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
hf_token = os.getenv("HF_TOKEN")
# Setup device
device = torch.device(device)
print(f"Using device: {device}")
# load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model, token=hf_token)
# Compute type
device_type = device.type
device_module = getattr(torch, device_type, torch.cuda)
bf16_supported = device_module.is_available() and device_module.is_bf16_supported()
dtype = torch.bfloat16 if bf16_supported else torch.float32
# Load the base model
model = AutoModelForCausalLM.from_pretrained(
base_model,
dtype=dtype,
)
# DeLoRA config for the PEFT model
peft_config = DeloraConfig(
r=rank,
delora_lambda=delora_lambda,
target_modules=(target_modules.split(",") if target_modules else None),
module_dropout=module_dropout,
bias="none",
)
# get the peft model with DeLoRA config
model = get_peft_model(model, peft_config)
model.to(device) # MODEL TO ACCELERATOR
tokenizer.pad_token = tokenizer.eos_token
# Load the dataset
dataset = load_dataset(data_path)
def tokenize_function(examples):
inputs = tokenizer(examples["text"], padding="max_length", truncation=True, max_length=cutoff_len)
inputs["labels"] = inputs["input_ids"].copy() # setting labels for a language modeling task
return inputs
# Tokenize the dataset and prepare for training
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)
# Data collator to dynamically pad the batched examples
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
# Compute the total amount of training step for warmup
max_steps = int((len(dataset) // batch_size) * num_epochs)
# Define training arguments
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
warmup_steps=int(max_steps * 0.1), # 10% of total trainig steps
weight_decay=0.0,
logging_steps=eval_step,
save_steps=save_step,
save_total_limit=2,
push_to_hub=push_to_hub,
hub_model_id=hub_model_id,
gradient_accumulation_steps=16,
learning_rate=learning_rate,
hub_token=hf_token,
label_names=["labels"],
)
# Clear accelerator cache to free memory
device_module.empty_cache()
# Initialize the Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["test"],
data_collator=data_collator,
)
# Start model training
trainer.train()
# Save and push the trained model and tokenizer
if push_to_hub:
# Push the main model to the hub
trainer.push_to_hub(commit_message="Fine-tuned model")
# Save the model and tokenizer locally
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Fine-tune LLaMA with DeLoRA")
parser.add_argument("--base_model", type=str, default="huggyllama/llama-7b", help="Base model path or name")
parser.add_argument(
"--data_path", type=str, default="timdettmers/openassistant-guanaco", help="Dataset path or name"
)
parser.add_argument(
"--output_dir", type=str, default="path/to/output", help="Output directory for the fine-tuned model"
)
parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
parser.add_argument("--learning_rate", type=float, default=3e-3, help="Learning rate")
parser.add_argument("--cutoff_len", type=int, default=512, help="Cutoff length for tokenization")
parser.add_argument("--val_set_size", type=int, default=500, help="Validation set size")
parser.add_argument("--eval_step", type=int, default=10, help="Evaluation step interval")
parser.add_argument("--save_step", type=int, default=100, help="Save step interval")
parser.add_argument("--device", type=str, default="auto", help="Device to use for training")
parser.add_argument("--rank", type=int, default=32, help="DeLoRA basis rank")
parser.add_argument("--delora_lambda", type=int, default=640, help="DeLoRA alpha")
parser.add_argument("--module_dropout", type=float, default=0.05, help="DeLoRA dropout rate")
parser.add_argument(
"--target_modules", type=str, default=None, help="Comma-separated list of target modules for DeLoRA"
)
parser.add_argument(
"--hub_model_id",
type=str,
default="path/to/repo",
help="Repository name to push the model on the Hugging Face Hub",
)
parser.add_argument("--push_to_hub", action="store_true", help="Whether to push the model to Hugging Face Hub")
args = parser.parse_args()
if args.device == "auto":
args.device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
train_model(
base_model=args.base_model,
data_path=args.data_path,
output_dir=args.output_dir,
batch_size=args.batch_size,
num_epochs=args.num_epochs,
learning_rate=args.learning_rate,
cutoff_len=args.cutoff_len,
val_set_size=args.val_set_size,
eval_step=args.eval_step,
save_step=args.save_step,
device=args.device,
rank=args.rank,
delora_lambda=args.delora_lambda,
module_dropout=args.module_dropout,
target_modules=args.target_modules,
hub_model_id=args.hub_model_id,
push_to_hub=args.push_to_hub,
)

View File

@ -0,0 +1,20 @@
{
"lambda_pattern": {},
"auto_mapping": null,
"base_model_name_or_path": null,
"bias": "none",
"exclude_modules": null,
"inference_mode": false,
"init_weights": true,
"layers_pattern": null,
"layers_to_transform": null,
"delora_lambda": 15,
"module_dropout": 0.0,
"modules_to_save": null,
"peft_type": "DELORA",
"r": 32,
"rank_pattern": {},
"revision": null,
"target_modules": null,
"task_type": "CAUSAL_LM"
}

View File

@ -0,0 +1,6 @@
{
"optimizer_kwargs": {
"lr": 1e-3
}
}

View File

@ -59,6 +59,8 @@ from .tuners import (
C3AModel, C3AModel,
CPTConfig, CPTConfig,
CPTEmbedding, CPTEmbedding,
DeloraConfig,
DeloraModel,
EvaConfig, EvaConfig,
FourierFTConfig, FourierFTConfig,
FourierFTModel, FourierFTModel,
@ -154,6 +156,8 @@ __all__ = [
"C3AModel", "C3AModel",
"CPTConfig", "CPTConfig",
"CPTEmbedding", "CPTEmbedding",
"DeloraConfig",
"DeloraModel",
"EvaConfig", "EvaConfig",
"FourierFTConfig", "FourierFTConfig",
"FourierFTModel", "FourierFTModel",

View File

@ -18,6 +18,7 @@ from .boft import BOFTConfig, BOFTModel
from .bone import BoneConfig, BoneModel from .bone import BoneConfig, BoneModel
from .c3a import C3AConfig, C3AModel from .c3a import C3AConfig, C3AModel
from .cpt import CPTConfig, CPTEmbedding from .cpt import CPTConfig, CPTEmbedding
from .delora import DeloraConfig, DeloraModel
from .fourierft import FourierFTConfig, FourierFTModel from .fourierft import FourierFTConfig, FourierFTModel
from .hra import HRAConfig, HRAModel from .hra import HRAConfig, HRAModel
from .ia3 import IA3Config, IA3Model from .ia3 import IA3Config, IA3Model
@ -67,6 +68,8 @@ __all__ = [
"C3AModel", "C3AModel",
"CPTConfig", "CPTConfig",
"CPTEmbedding", "CPTEmbedding",
"DeloraConfig",
"DeloraModel",
"EvaConfig", "EvaConfig",
"FourierFTConfig", "FourierFTConfig",
"FourierFTModel", "FourierFTModel",

View File

@ -0,0 +1,23 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from peft.utils import register_peft_method
from .config import DeloraConfig
from .layer import DeloraLayer, DeloraLinear
from .model import DeloraModel
__all__ = ["DeloraConfig", "DeloraLayer", "DeloraLinear", "DeloraModel"]
register_peft_method(name="delora", model_cls=DeloraModel, config_cls=DeloraConfig)

View File

@ -0,0 +1,154 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Optional, Union
from peft.config import PeftConfig
from peft.utils import PeftType
@dataclass
class DeloraConfig(PeftConfig):
"""
This is the configuration class to store the configuration of a [`DeloraModel`].
Args:
r (`int`):
The rank of the DeLoRA adapter.
delora_lambda (`int`):
The initial value of the boundary of the DeLoRA adapter. This variable sets an upper bound to the Frobenius
norm of the weight change, avoiding the finetuned model to deviate too much from the original model.
module_dropout (`float`):
The dropout probability for disabling DeLoRA modules during training.
target_modules (`Optional[Union[List[str], str]]`):
The names of the modules to apply the adapter to. If this is specified, only the modules with the specified
names will be replaced. When passing a string, a regex match will be performed. When passing a list of
strings, either an exact match will be performed or it is checked if the name of the module ends with any
of the passed strings. If this is specified as 'all-linear', then all linear/Conv1D modules are chosen,
excluding the output layer. If this is not specified, modules will be chosen according to the model
architecture. If the architecture is not known, an error will be raised -- in this case, you should specify
the target modules manually.
exclude_modules (`Optional[Union[List[str], str]]`):
The names of the modules to not apply the adapter. When passing a string, a regex match will be performed.
When passing a list of strings, either an exact match will be performed or it is checked if the name of the
module ends with any of the passed strings.
bias (`str`):
Bias type for DeLoRA. Can be 'none', 'all' or 'delora_only'. If 'all' or 'delora_only', the corresponding
biases will be updated during training. Be aware that this means that, even when disabling the adapters,
the model will not produce the same output as the base model would have without adaptation.
init_weights (`bool`):
Whether to perform initialization of adapter weights. If `True` (default): A is initialized with kaiming
uniform initialization, while B is initialized with zeros. If `False`: A and B are both initialized with
kaiming uniform, immediately contributing a non-zero delta. This is generally discouraged for normal use.
layers_to_transform (`Union[List[int], int]`):
The layer indices to transform. If a list of ints is passed, it will apply the adapter to the layer indices
that are specified in this list. If a single integer is passed, it will apply the transformations on the
layer at this index.
layers_pattern (`Optional[Union[List[str], str]]`):
The layer pattern name, used only if `layers_to_transform` is different from `None`. This should target the
`nn.ModuleList` of the model, which is often called `'layers'` or `'h'`.
rank_pattern (`dict`):
The mapping from layer names or regexp expression to ranks which are different from the default rank
specified by `r`. For example, `{'^model.decoder.layers.0.encoder_attn.k_proj': 16}`.
lambda_pattern (`dict`):
The mapping from layer names or regexp expression to lambdas which are different from the default lambda
specified by `delora_lambda`. For example, `{'^model.decoder.layers.0.encoder_attn.k_proj': 16}`.
modules_to_save (`Optional[List[str]]`):
List of modules apart from adapter layers to be set as trainable and saved in the final checkpoint.
"""
r: int = field(default=8, metadata={"help": "DeLoRA rank"})
delora_lambda: int = field(
default=15,
metadata={
"help": "The initial value of the boundary of the DeLoRA adapter. This variable sets an upper bound to the "
"Frobenius norm of the weight change, avoiding the finetuned model to deviate too much from the original model."
},
)
module_dropout: float = field(
default=0.0, metadata={"help": "The dropout probability for disabling DeLoRA modules during training"}
)
target_modules: Optional[Union[list[str], str]] = field(
default=None,
metadata={
"help": "List of module names or regex expression of the module names to replace with DeLoRA."
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' "
"This can also be a wildcard 'all-linear' which matches all linear layers except the output layer."
},
)
exclude_modules: Optional[Union[list[str], str]] = field(
default=None,
metadata={"help": "List of module names or regex expression of the module names to exclude from DeLoRA."},
)
bias: str = field(default="none", metadata={"help": "Bias type for DeLoRA. Can be 'none' or 'all'"})
init_weights: bool = field(
default=True,
metadata={
"help": "Whether to perform initialization of adapter weights. If `True` (default): A is initialized with kaiming uniform "
"initialization, while B is initialized with zeros. If `False`: A and B are both initialized with kaiming uniform, "
"immediately contributing a non-zero delta. This is generally discouraged for normal use."
},
)
layers_to_transform: Optional[Union[list[int], int]] = field(
default=None,
metadata={
"help": "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that "
"are specified inside this list. If a single integer is passed, PEFT will transform only the layer at this index."
},
)
layers_pattern: Optional[Union[list[str], str]] = field(
default=None,
metadata={
"help": "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the "
"common layers pattern. This should target the `nn.ModuleList` of the model, which is often called `'layers'` or `'h'`."
},
)
rank_pattern: Optional[dict] = field(
default_factory=dict,
metadata={
"help": "The mapping from layer names or regexp expression to ranks which are different from the default rank specified "
"by `r`. For example, `{'^model.decoder.layers.0.encoder_attn.k_proj': 16}`."
},
)
lambda_pattern: Optional[dict] = field(
default_factory=dict,
metadata={
"help": "The mapping from layer names or regexp expression to lambdas which are different from the default lambda specified by `delora_lambda`."
},
)
modules_to_save: Optional[list[str]] = field(
default=None,
metadata={
"help": "List of modules apart from DeLoRA layers to be set as trainable and saved in the final checkpoint. "
"For example, in Sequence Classification or Token Classification tasks, the final layer `classifier/score` "
"are randomly initialized and as such need to be trainable and saved."
},
)
def __post_init__(self):
super().__post_init__()
# PeftType enum members are uppercase; use DELORA
self.peft_type = PeftType.DELORA
self.target_modules = (
set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules
)
# if target_modules is a regex expression, then layers_to_transform should be None
if isinstance(self.target_modules, str) and self.layers_to_transform is not None:
raise ValueError("`layers_to_transform` cannot be used when `target_modules` is a str.")
# check for layers_to_transform and layers_pattern
if self.layers_pattern and not self.layers_to_transform:
raise ValueError("When `layers_pattern` is specified, `layers_to_transform` must also be specified. ")

View File

@ -0,0 +1,269 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
import warnings
from typing import Any, Optional
import torch
import torch.nn as nn
from peft.tuners._buffer_dict import BufferDict
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
class DeloraLayer(BaseTunerLayer):
# All names of layers that may contain (trainable) adapter weights
adapter_layer_names = (
"delora_A",
"delora_B",
"delora_lambda",
)
# All names of other parameters that may contain adapter-related parameters
other_param_names = (
"r",
"module_dropout",
"delora_w_norm",
)
def __init__(self, base_layer: nn.Module, **kwargs) -> None:
self.base_layer = base_layer
self.r = {}
self.module_dropout = nn.ModuleDict({})
self.delora_A = nn.ParameterDict({})
self.delora_B = nn.ParameterDict({})
self.delora_lambda = nn.ParameterDict({})
# Use persistent buffers so they are included in state_dict and saved.
self.delora_w_norm = BufferDict({}, persistent=True)
# Mark the weight as unmerged
self._disable_adapters = False
self.merged_adapters = []
self.kwargs = kwargs
base_layer_mod = self.get_base_layer()
if isinstance(base_layer_mod, nn.Linear):
self.in_features, self.out_features = base_layer_mod.in_features, base_layer_mod.out_features
else:
raise ValueError(f"Unsupported layer type {type(base_layer_mod)}")
@staticmethod
def _compute_delta(
A: torch.Tensor, B: torch.Tensor, delora_lambda: torch.Tensor, r: int, w_norm: torch.Tensor
) -> torch.Tensor:
"""Compute delta = B @ diag(delora_lambda/r / (||A_i||*||B^j||)) @ A, scaled by provided w_norm (per-input channel)"""
An = torch.clamp(A.norm(dim=1), min=1e-4)
Bn = torch.clamp(B.norm(dim=0), min=1e-4)
diag = torch.diag_embed(delora_lambda / r / (An * Bn))
delta = B @ diag @ A
delta = delta * w_norm.unsqueeze(0)
return delta
def get_delta_weight(self, adapter: str) -> torch.Tensor:
if adapter not in self.delora_A or adapter not in self.delora_B:
raise ValueError(f"Adapter {adapter} not found.")
delta = self._compute_delta(
self.delora_A[adapter],
self.delora_B[adapter],
self.delora_lambda[adapter],
self.r[adapter],
self.delora_w_norm[adapter],
)
return delta
def update_layer(
self,
adapter_name: str,
r: int,
delora_lambda: float,
module_dropout: float,
init_weights: bool = True,
inference_mode: bool = False,
**kwargs: Any,
) -> None:
"""Internal function to create delora adapter
Args:
adapter_name (`str`): Name for the adapter to add.
r (`int`): Rank for the added adapter.
delora_lambda (`float`): Boundary for the adapter's norm.
module_dropout (`float`): The dropout probability for disabling adapter during training.
init_weights (`bool`): Whether to initialize weights.
"""
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
self.r[adapter_name] = r
self.delora_A[adapter_name] = nn.Parameter(torch.empty(r, self.in_features))
self.delora_B[adapter_name] = nn.Parameter(torch.empty(self.out_features, r))
self.delora_lambda[adapter_name] = nn.Parameter(torch.empty(1))
if module_dropout > 0.0:
module_dropout_layer = nn.Dropout(p=module_dropout)
else:
module_dropout_layer = nn.Identity()
self.module_dropout.update(nn.ModuleDict({adapter_name: module_dropout_layer}))
# Initialize weights
self.reset_delora_parameters(adapter_name, init_weights, delora_lambda)
# Move new weights to device
self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters, inference_mode=inference_mode)
def reset_delora_parameters(
self,
adapter_name: str,
init_weights: bool = True,
delora_lambda: float = 15.0,
) -> None:
if adapter_name not in self.delora_A.keys():
return
if init_weights is True:
nn.init.kaiming_uniform_(self.delora_A[adapter_name], a=math.sqrt(5))
nn.init.zeros_(self.delora_B[adapter_name])
else:
nn.init.kaiming_uniform_(self.delora_A[adapter_name], a=math.sqrt(5))
nn.init.kaiming_uniform_(self.delora_B[adapter_name], a=math.sqrt(5))
self.delora_lambda[adapter_name].data.fill_(float(delora_lambda))
# capture a fixed norm for this adapter to use for future delta computations
with torch.no_grad():
w = self.get_base_layer().weight
if w.device.type != "meta":
w_norm = torch.norm(w.data, dim=0).detach()
else:
# For meta tensors, we can't compute the norm, so use a default value
w_norm = torch.ones(w.shape[1], device=w.device)
self.delora_w_norm[adapter_name] = w_norm
class DeloraLinear(nn.Module, DeloraLayer):
# DeLoRA implemented in a dense layer
def __init__(
self,
base_layer,
adapter_name: str,
r: int,
delora_lambda: float,
module_dropout: float,
init_weights: bool = True,
**kwargs,
) -> None:
super().__init__()
DeloraLayer.__init__(self, base_layer, **kwargs)
self._active_adapter = adapter_name
self.update_layer(adapter_name, r, delora_lambda, module_dropout, init_weights)
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Args:
safe_merge (`bool`, *optional*):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`list[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
"""
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
return
for active_adapter in adapter_names:
if active_adapter in self.delora_A.keys():
base_layer = self.get_base_layer()
delta_weight = (
self.get_delta_weight(active_adapter)
.detach()
.to(dtype=base_layer.weight.dtype, device=base_layer.weight.device)
)
with torch.no_grad():
if safe_merge:
orig_weights = base_layer.weight.data.clone()
orig_weights = orig_weights + delta_weight
if not torch.isfinite(orig_weights).all():
raise ValueError(
f"NaNs detected in merged weights for adapter {active_adapter}; aborting merge"
)
base_layer.weight.data = orig_weights
else:
base_layer.weight.data.add_(delta_weight)
self.merged_adapters.append(active_adapter)
def unmerge(self) -> None:
"""
Unmerge all merged adapter layers from the base weights.
"""
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter in self.delora_A.keys():
self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter)
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
previous_dtype = x.dtype
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
if not self.active_adapters:
return self.base_layer(x, *args, **kwargs).to(previous_dtype)
base_out = self.base_layer(x, *args, **kwargs)
add_out = torch.zeros_like(base_out)
for adapter in self.active_adapters:
if adapter not in self.delora_A:
continue
x_d = self.module_dropout[adapter](x)
# Decomposed delta calculation
# 1. (x * w_norm) @ A.T
h = nn.functional.linear(x_d * self.delora_w_norm[adapter], self.delora_A[adapter])
# 2. h @ diag
An = torch.clamp(self.delora_A[adapter].norm(dim=1), min=1e-4)
Bn = torch.clamp(self.delora_B[adapter].norm(dim=0), min=1e-4)
scaling = (self.delora_lambda[adapter] / self.r[adapter]) / (An * Bn)
h = h * scaling
# 3. h @ B.T
h = nn.functional.linear(h, self.delora_B[adapter])
add_out += h
result = base_out + add_out.to(base_out.dtype)
result = result.to(previous_dtype)
return result
def __repr__(self) -> str:
rep = super().__repr__()
return "delora." + rep

View File

@ -0,0 +1,105 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer
from peft.utils import (
TRANSFORMERS_MODELS_TO_DELORA_TARGET_MODULES_MAPPING,
)
from peft.utils.other import get_pattern_key
from .config import DeloraConfig
from .layer import DeloraLayer, DeloraLinear
class DeloraModel(BaseTuner):
"""
Creates DeLoRA model from a pretrained transformers model.
The method is described in detail in [TODO].
Args:
model ([`torch.nn.Module`]): The model to be adapted.
config ([`DeloraConfig`]): The configuration of the DeLoRA model.
adapter_name (`str`): The name of the adapter, defaults to `"default"`.
Returns:
`torch.nn.Module`: The DeLoRA model.
**Attributes**:
- **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted.
- **peft_config** ([`DeloraConfig`]): The configuration of the DeLoRA model.
"""
prefix: str = "delora_"
tuner_layer_cls = DeloraLayer
target_module_mapping = TRANSFORMERS_MODELS_TO_DELORA_TARGET_MODULES_MAPPING
def _check_new_adapter_config(self, config: DeloraConfig) -> None:
"""
A helper method to check the config when a new adapter is being added.
Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters.
"""
super()._check_new_adapter_config(config)
def _create_and_replace(
self,
delora_config,
adapter_name,
target,
target_name,
parent,
current_key,
**optional_kwargs,
):
if current_key is None:
raise ValueError("Current Key shouldn't be `None`")
# Regexp matching - Find key which matches current target_name in patterns provided
r_key = get_pattern_key(delora_config.rank_pattern.keys(), current_key)
lambda_key = get_pattern_key(delora_config.lambda_pattern.keys(), current_key)
r = delora_config.rank_pattern.get(r_key, delora_config.r)
delora_lambda = delora_config.lambda_pattern.get(lambda_key, delora_config.delora_lambda)
kwargs = {
"r": r,
"delora_lambda": delora_lambda,
"module_dropout": delora_config.module_dropout,
"init_weights": delora_config.init_weights,
}
if isinstance(target, DeloraLinear):
target.update_layer(adapter_name, **kwargs)
else:
new_module = self._create_new_module(delora_config, adapter_name, target, **kwargs)
if adapter_name != self.active_adapter:
# adding an additional adapter: it is not automatically trainable
new_module.requires_grad_(False)
self._replace_module(parent, target_name, new_module, target)
@staticmethod
def _create_new_module(delora_config, adapter_name, target, **kwargs):
if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target
if isinstance(target_base_layer, torch.nn.Linear):
new_module = DeloraLinear(target, adapter_name, **kwargs)
return new_module

View File

@ -22,6 +22,7 @@ from .other import (
TRANSFORMERS_MODELS_TO_BOFT_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_BOFT_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_BONE_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_BONE_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_C3A_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_C3A_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_DELORA_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_HRA_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_HRA_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING,
@ -77,6 +78,7 @@ __all__ = [
"TRANSFORMERS_MODELS_TO_BOFT_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_BOFT_TARGET_MODULES_MAPPING",
"TRANSFORMERS_MODELS_TO_BONE_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_BONE_TARGET_MODULES_MAPPING",
"TRANSFORMERS_MODELS_TO_C3A_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_C3A_TARGET_MODULES_MAPPING",
"TRANSFORMERS_MODELS_TO_DELORA_TARGET_MODULES_MAPPING",
"TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING",
"TRANSFORMERS_MODELS_TO_HRA_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_HRA_TARGET_MODULES_MAPPING",
"TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING",

View File

@ -108,6 +108,7 @@ TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING = {
TRANSFORMERS_MODELS_TO_BOFT_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy() TRANSFORMERS_MODELS_TO_BOFT_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()
TRANSFORMERS_MODELS_TO_BONE_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy() TRANSFORMERS_MODELS_TO_BONE_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()
TRANSFORMERS_MODELS_TO_C3A_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy() TRANSFORMERS_MODELS_TO_C3A_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()
TRANSFORMERS_MODELS_TO_DELORA_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()
TRANSFORMERS_MODELS_TO_HRA_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy() TRANSFORMERS_MODELS_TO_HRA_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()
TRANSFORMERS_MODELS_TO_LOHA_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy() TRANSFORMERS_MODELS_TO_LOHA_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()
TRANSFORMERS_MODELS_TO_LOKR_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy() TRANSFORMERS_MODELS_TO_LOKR_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()

View File

@ -46,6 +46,7 @@ from .constants import (
TRANSFORMERS_MODELS_TO_BOFT_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_BOFT_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_BONE_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_BONE_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_C3A_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_C3A_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_DELORA_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_HRA_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_HRA_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING,
@ -86,6 +87,7 @@ __all__ = [
"TRANSFORMERS_MODELS_TO_BOFT_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_BOFT_TARGET_MODULES_MAPPING",
"TRANSFORMERS_MODELS_TO_BONE_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_BONE_TARGET_MODULES_MAPPING",
"TRANSFORMERS_MODELS_TO_C3A_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_C3A_TARGET_MODULES_MAPPING",
"TRANSFORMERS_MODELS_TO_DELORA_TARGET_MODULES_MAPPING",
"TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING",
"TRANSFORMERS_MODELS_TO_HRA_TARGET_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_HRA_TARGET_MODULES_MAPPING",
"TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING", "TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING",

View File

@ -46,6 +46,7 @@ class PeftType(str, enum.Enum):
- C3A - C3A
- ROAD - ROAD
- WAVEFT - WAVEFT
- DELORA
""" """
PROMPT_TUNING = "PROMPT_TUNING" PROMPT_TUNING = "PROMPT_TUNING"
@ -76,6 +77,7 @@ class PeftType(str, enum.Enum):
SHIRA = "SHIRA" SHIRA = "SHIRA"
C3A = "C3A" C3A = "C3A"
WAVEFT = "WAVEFT" WAVEFT = "WAVEFT"
DELORA = "DELORA"
class TaskType(str, enum.Enum): class TaskType(str, enum.Enum):

View File

@ -36,6 +36,7 @@ from peft import (
BOFTConfig, BOFTConfig,
BoneConfig, BoneConfig,
C3AConfig, C3AConfig,
DeloraConfig,
FourierFTConfig, FourierFTConfig,
HRAConfig, HRAConfig,
IA3Config, IA3Config,
@ -848,6 +849,19 @@ TEST_CASES = [
WaveFTConfig, WaveFTConfig,
{"target_modules": "lin0", "n_frequency": 16, "wavelet_family": "db1", "proportional_parameters": True}, {"target_modules": "lin0", "n_frequency": 16, "wavelet_family": "db1", "proportional_parameters": True},
), ),
##########
# DeLoRA #
##########
("Vanilla MLP 1 DeLoRA", "MLP", DeloraConfig, {"target_modules": "lin0"}),
("Vanilla MLP 2 DeLoRA", "MLP", DeloraConfig, {"target_modules": ["lin0"]}),
("Vanilla MLP 3 DeLoRA", "MLP", DeloraConfig, {"target_modules": ["lin1"]}),
("Vanilla MLP 4 DeLoRA", "MLP", DeloraConfig, {"target_modules": ["lin0", "lin1"]}),
(
"Vanilla MLP 5 DeLoRA",
"MLP",
DeloraConfig,
{"target_modules": ["lin0"], "module_dropout": 0.1},
),
] ]
ALL_PEFT_CONFIG_CLASSES = sorted({row[2] for row in TEST_CASES}, key=lambda cls: cls.__name__) ALL_PEFT_CONFIG_CLASSES = sorted({row[2] for row in TEST_CASES}, key=lambda cls: cls.__name__)
@ -1118,6 +1132,20 @@ MULTIPLE_ACTIVE_ADAPTERS_TEST_CASES = [
{"target_modules": ["lin0"], "init_weights": False, "n_frequency": 8}, {"target_modules": ["lin0"], "init_weights": False, "n_frequency": 8},
{"target_modules": ["lin1"], "init_weights": False, "n_frequency": 8}, {"target_modules": ["lin1"], "init_weights": False, "n_frequency": 8},
), ),
(
"DeLoRA Same",
"delora",
DeloraConfig,
{"target_modules": ["lin0"], "init_weights": False},
{"target_modules": ["lin0"], "init_weights": False},
),
(
"DeLoRA Different",
"delora",
DeloraConfig,
{"target_modules": ["lin0"], "init_weights": False},
{"target_modules": ["lin1"], "init_weights": False},
),
] ]
PREFIXES = { PREFIXES = {
@ -1138,6 +1166,7 @@ PREFIXES = {
BoneConfig: "bone_", BoneConfig: "bone_",
RoadConfig: "road_", RoadConfig: "road_",
MissConfig: "miss_", MissConfig: "miss_",
DeloraConfig: "delora_",
TrainableTokensConfig: "trainable_tokens_", TrainableTokensConfig: "trainable_tokens_",
WaveFTConfig: "waveft_", WaveFTConfig: "waveft_",
} }

View File

@ -32,6 +32,7 @@ from peft import (
BoneConfig, BoneConfig,
C3AConfig, C3AConfig,
CPTConfig, CPTConfig,
DeloraConfig,
FourierFTConfig, FourierFTConfig,
HRAConfig, HRAConfig,
IA3Config, IA3Config,
@ -119,6 +120,14 @@ ALL_CONFIGS = [
"cpt_tokens_type_mask": [1, 2, 2, 2, 3, 3, 4, 4], "cpt_tokens_type_mask": [1, 2, 2, 2, 3, 3, 4, 4],
}, },
), ),
(
DeloraConfig,
{
"task_type": "CAUSAL_LM",
"target_modules": None,
"r": 2,
},
),
( (
FourierFTConfig, FourierFTConfig,
{ {
@ -290,8 +299,9 @@ def _skip_if_not_conv1d_supported(model_id, config_cls):
ShiraConfig, ShiraConfig,
C3AConfig, C3AConfig,
MissConfig, MissConfig,
DeloraConfig,
]: ]:
pytest.skip("Skipping BOFT/HRA/OFT/Bone/Road/SHiRA/C3A/MiSS for GPT2LMHeadModel") pytest.skip("Skipping BOFT/HRA/OFT/Bone/Road/SHiRA/C3A/MiSS/DeLoRA for GPT2LMHeadModel")
def _skip_adalora_oft_hra_bone_for_gpt2(model_id, config_cls): def _skip_adalora_oft_hra_bone_for_gpt2(model_id, config_cls):
@ -304,8 +314,9 @@ def _skip_adalora_oft_hra_bone_for_gpt2(model_id, config_cls):
C3AConfig, C3AConfig,
RoadConfig, RoadConfig,
MissConfig, MissConfig,
DeloraConfig,
]: ]:
pytest.skip("Skipping AdaLora/BOFT/HRA/OFT/Bone/MiSS for GPT2LMHeadModel") pytest.skip("Skipping AdaLora/BOFT/HRA/OFT/Bone/MiSS/DeLoRA for GPT2LMHeadModel")
def _skip_alora_no_activation(config_cls, config_kwargs): def _skip_alora_no_activation(config_cls, config_kwargs):

View File

@ -22,6 +22,7 @@ from peft import (
BOFTConfig, BOFTConfig,
BoneConfig, BoneConfig,
C3AConfig, C3AConfig,
DeloraConfig,
FourierFTConfig, FourierFTConfig,
HRAConfig, HRAConfig,
IA3Config, IA3Config,
@ -82,6 +83,14 @@ ALL_CONFIGS = [
"task_type": "SEQ_2_SEQ_LM", "task_type": "SEQ_2_SEQ_LM",
}, },
), ),
(
DeloraConfig,
{
"task_type": "SEQ_2_SEQ_LM",
"target_modules": None,
"r": 2,
},
),
( (
FourierFTConfig, FourierFTConfig,
{ {

View File

@ -20,6 +20,7 @@ from peft import (
BOFTConfig, BOFTConfig,
BoneConfig, BoneConfig,
C3AConfig, C3AConfig,
DeloraConfig,
FourierFTConfig, FourierFTConfig,
HRAConfig, HRAConfig,
IA3Config, IA3Config,
@ -81,6 +82,14 @@ ALL_CONFIGS = [
"r": 2, "r": 2,
}, },
), ),
(
DeloraConfig,
{
"task_type": "FEATURE_EXTRACTION",
"target_modules": None,
"r": 2,
},
),
( (
FourierFTConfig, FourierFTConfig,
{ {

View File

@ -36,6 +36,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import ( from peft import (
AdaLoraConfig, AdaLoraConfig,
C3AConfig, C3AConfig,
DeloraConfig,
EvaConfig, EvaConfig,
IA3Config, IA3Config,
LoftQConfig, LoftQConfig,
@ -2093,6 +2094,68 @@ class TestRoadInitialization:
get_peft_model(model, config) get_peft_model(model, config)
class TestDeLoRAInitialization:
"""Basic sanity tests for the DeLoRA tuner."""
torch_device = infer_device()
def get_model(self, bias=True):
class MLP(nn.Module):
def __init__(self, bias=True):
super().__init__()
self.lin0 = nn.Linear(10, 30, bias=bias)
self.lin1 = nn.Linear(30, 2, bias=bias)
def forward(self, X):
X = self.lin0(X)
X = self.lin1(X)
return X
return MLP(bias=bias).to(self.torch_device).eval()
@pytest.fixture
def data(self):
torch.manual_seed(0)
return torch.randn(4, 10, device=self.torch_device)
def test_delora_injection_keeps_output_default(self, data):
# With init_weights=True (default), initial forward should match base model
torch.manual_seed(0)
base = self.get_model()
y_base = base(data)
cfg = DeloraConfig(target_modules=["lin0"], r=8, delora_lambda=15, init_weights=True)
model = get_peft_model(base, cfg)
y_peft = model(data)
assert torch.allclose(y_base, y_peft, atol=1e-6, rtol=1e-6)
def test_delora_param_shapes(self):
base = self.get_model()
in_f, out_f = base.lin0.in_features, base.lin0.out_features
r = 4
cfg = DeloraConfig(target_modules=["lin0"], r=r, delora_lambda=15, init_weights=True)
model = get_peft_model(base, cfg)
layer = model.lin0 # DeloraLinear wrapper
assert hasattr(layer, "delora_A") and hasattr(layer, "delora_B") and hasattr(layer, "delora_lambda")
A = layer.delora_A["default"]
B = layer.delora_B["default"]
delora_lambda = layer.delora_lambda["default"]
assert tuple(A.shape) == (r, in_f)
assert tuple(B.shape) == (out_f, r)
assert tuple(delora_lambda.shape) == (1,)
def test_init_weights_false_shifts_output(self, data):
# With init_weights=False, there should be an initial delta to the base model output
base = self.get_model()
y_base = base(data)
cfg = DeloraConfig(target_modules=["lin0"], r=8, delora_lambda=15, init_weights=False)
model = get_peft_model(base, cfg)
y_peft = model(data)
assert not torch.allclose(y_base, y_peft, atol=1e-6, rtol=1e-6)
class TestNoInfiniteRecursionDeepspeed: class TestNoInfiniteRecursionDeepspeed:
# see #1892 for details # see #1892 for details
classes = [ classes = [

View File

@ -20,6 +20,7 @@ from peft import (
BOFTConfig, BOFTConfig,
BoneConfig, BoneConfig,
C3AConfig, C3AConfig,
DeloraConfig,
FourierFTConfig, FourierFTConfig,
HRAConfig, HRAConfig,
IA3Config, IA3Config,
@ -82,6 +83,14 @@ ALL_CONFIGS = [
"r": 2, "r": 2,
}, },
), ),
(
DeloraConfig,
{
"task_type": "SEQ_CLS",
"target_modules": None,
"r": 2,
},
),
( (
FourierFTConfig, FourierFTConfig,
{ {

View File

@ -35,6 +35,7 @@ from peft import (
BOFTConfig, BOFTConfig,
BoneConfig, BoneConfig,
CPTConfig, CPTConfig,
DeloraConfig,
FourierFTConfig, FourierFTConfig,
HRAConfig, HRAConfig,
IA3Config, IA3Config,
@ -168,6 +169,12 @@ CONFIG_TESTING_KWARGS = (
"cpt_mask": [1, 1, 1, 1, 1, 1, 1, 1], "cpt_mask": [1, 1, 1, 1, 1, 1, 1, 1],
"cpt_tokens_type_mask": [1, 2, 2, 2, 3, 3, 4, 4], "cpt_tokens_type_mask": [1, 2, 2, 2, 3, 3, 4, 4],
}, },
# DeLoRA
{
"r": 8,
"target_modules": None,
"bias": "none",
},
) )
CLASSES_MAPPING = { CLASSES_MAPPING = {
@ -187,6 +194,7 @@ CLASSES_MAPPING = {
"miss": (MissConfig, CONFIG_TESTING_KWARGS[12]), "miss": (MissConfig, CONFIG_TESTING_KWARGS[12]),
"lora+trainable_tokens": (LoraConfig, CONFIG_TESTING_KWARGS[13]), "lora+trainable_tokens": (LoraConfig, CONFIG_TESTING_KWARGS[13]),
"randlora": (RandLoraConfig, CONFIG_TESTING_KWARGS[14]), "randlora": (RandLoraConfig, CONFIG_TESTING_KWARGS[14]),
"delora": (DeloraConfig, CONFIG_TESTING_KWARGS[17]),
} }
DECODER_MODELS_EXTRA = {"cpt": (CPTConfig, CONFIG_TESTING_KWARGS[15])} DECODER_MODELS_EXTRA = {"cpt": (CPTConfig, CONFIG_TESTING_KWARGS[15])}