FEAT Add SHiRA Adapters (#2584)

Implements: Sparse High Rank Adapters

Paper: https://arxiv.org/abs/2406.13175
This commit is contained in:
kkb-code
2025-07-14 02:16:10 -07:00
committed by GitHub
parent 35000fda88
commit a4f9334f12
27 changed files with 1623 additions and 9 deletions

View File

@ -126,6 +126,8 @@
title: Trainable Tokens
- local: package_reference/randlora
title: RandLora
- local: package_reference/shira
title: SHiRA
- local: package_reference/c3a
title: C3A

View File

@ -0,0 +1,35 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Sparse High Rank Adapters
Sparse High Rank Adapters or [SHiRA](https://arxiv.org/abs/2406.13175) is an alternate type of adapter and has been found to have significant advantages over the low rank adapters. Specifically, SHiRA achieves better accuracy than LoRA for a variety of vision and language tasks. It also offers simpler and higher quality multi-adapter fusion by significantly reducing concept loss, a common problem faced by low rank adapters. SHiRA directly finetunes a small number of the base model's parameters to finetune the model on any adaptation task.
SHiRA currently has the following constraint:
- Only `nn.Linear` layers are supported.
The abstract from the paper is:
> Low Rank Adaptation (LoRA) has gained massive attention in the recent generative AI research. One of the main advantages of LoRA is its ability to be fused with pretrained models, adding no overhead during inference. However, from a mobile deployment standpoint, we can either avoid inference overhead in the fused mode but lose the ability to switch adapters rapidly, or suffer significant (up to 30% higher) inference latency while enabling rapid switching in the unfused mode. LoRA also exhibits concept-loss when multiple adapters are used concurrently. In this paper, we propose Sparse High Rank Adapters (SHiRA), a new paradigm which incurs no inference overhead, enables rapid switching, and significantly reduces concept-loss. Specifically, SHiRA can be trained by directly tuning only 1-2% of the base model weights while leaving others unchanged. This results in a highly sparse adapter which can be switched directly in the fused mode. We further provide theoretical and empirical insights on how high sparsity in SHiRA can aid multi-adapter fusion by reducing concept loss. Our extensive experiments on LVMs and LLMs demonstrate that finetuning only a small fraction of the parameters in the base model significantly outperforms LoRA while enabling both rapid switching and multi-adapter fusion. Finally, we provide a latency- and memory-efficient SHiRA implementation based on Parameter-Efficient Finetuning (PEFT) Library which trains at nearly the same speed as LoRA while consuming up to 16% lower peak GPU memory, thus making SHiRA easy to adopt for practical use cases. To demonstrate rapid switching benefits during inference, we show that loading SHiRA on a base model can be 5x-16x faster than LoRA fusion on a CPU.
## ShiraConfig
[[autodoc]] tuners.shira.config.ShiraConfig
## ShiraModel
[[autodoc]] tuners.shira.model.ShiraModel

View File

@ -0,0 +1,73 @@
# Sparse High Rank Adapters
## Introduction
Sparse High Rank Adapters or [SHiRA](https://arxiv.org/abs/2406.13175) is an alternate type of adapter and has been found to have significant advantages over the low rank adapters. Specifically, SHiRA achieves better accuracy than LoRA for a variety of vision and language tasks. It also offers simpler and higher quality multi-adapter fusion by significantly reducing concept loss, a common problem faced by low rank adapters. SHiRA directly finetunes a small number of the base model's parameters to finetune the model on any adaptation task.
## Quick start
```python
import torch
from peft import ShiraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
dataset = load_dataset("imdb", split="train[:1%]")
shira_config = ShiraConfig(
r=32,
)
peft_model = get_peft_model(model, shira_config)
training_args = SFTConfig(dataset_text_field="text", max_seq_length=128)
trainer = SFTTrainer(
model=peft_model,
train_dataset=dataset,
tokenizer=tokenizer,
)
trainer.train()
peft_model.save_pretrained("shira-opt-350m")
```
For more options and a more detailed example code, you can refer to shira finetuning script.
Run the script simply by running:
```bash
python3 examples/shira_finetuning/shira_finetuning.py --base_model facebook/opt-350m
```
If you want to run DDP by [accelerate](https://huggingface.co/docs/accelerate/en/index), please run `accelerate config` to set your ddp config, and run:
```bash
accelerate launch examples/shira_finetuning/shira_finetuning.py --base_model facebook/opt-350m
```
please add `--device_map cpu` if you want to run finetune on CPU.
If you want to train SHiRA with a custom sparse mask function which requires custom keyword arguments, please see the definition of `custom_random_mask_function_with_custom_kwargs` function provided in the `shira_fintuning.py` script. You can run this code using the `--use_custom_random_mask_function_with_custom_kwargs` argument. Without this argument, SHiRA defaults to a random sparse mask. Please run the code as follows. :
```bash
python3 examples/shira_finetuning/shira_finetuning.py --base_model facebook/opt-350m --use_custom_random_mask_function_with_custom_kwargs
```
## Use the model
You can load and use the model as any other 🤗 PEFT model
```python
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
shira_model = PeftModel.from_pretrained(model, "shira-opt-350m")
```
## Citation
```
@inproceedings{NEURIPS2024_18c0102c,
author = {Bhardwaj, Kartikeya and Pandey, Nilesh Prasad and Priyadarshi, Sweta and Ganapathy, Viswanath and Kadambi, Shreya and Esteves, Rafael and Borse, Shubhankar and Whatmough, Paul and Garrepalli, Risheek and Van Baalen, Mart and Teague, Harris and Nagel, Markus},
booktitle = {Advances in Neural Information Processing Systems},
editor = {A. Globerson and L. Mackey and D. Belgrave and A. Fan and U. Paquet and J. Tomczak and C. Zhang},
pages = {13685--13715},
publisher = {Curran Associates, Inc.},
title = {Sparse High Rank Adapters},
url = {https://proceedings.neurips.cc/paper_files/paper/2024/file/18c0102cb7f1a02c14f0929089b2e576-Paper-Conference.pdf},
volume = {37},
year = {2024}
}
```

View File

@ -0,0 +1,217 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Optional
import torch
import transformers
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from peft import (
PeftModel,
ShiraConfig,
get_peft_model,
)
def train(
base_model: str = "path/to/model",
data_path: str = "yahma/alpaca-cleaned",
output_dir: str = "shira",
batch_size: int = 16,
num_epochs: int = 1,
learning_rate: float = 3e-4,
cutoff_len: int = 256,
val_set_size: int = 16,
eval_step: int = 100,
save_step: int = 100,
device_map: str = "auto",
shira_r: int = 32,
shira_target_modules: list[str] = None,
torch_dtype: str = "float16",
seed: Optional[int] = None,
use_custom_random_mask_function_with_custom_kwargs: Optional[bool] = False,
):
# Set device_map to the right place when enabling DDP.
world_size = int(os.environ.get("WORLD_SIZE", 0)) or int(os.environ.get("PMI_SIZE", 0))
if world_size > 1 and device_map != "cpu":
from accelerate import Accelerator
device_map = {"": Accelerator().process_index}
# Set seed
if seed is not None:
set_seed(seed)
model_kwargs = {"torch_dtype": getattr(torch, torch_dtype), "device_map": device_map}
model = AutoModelForCausalLM.from_pretrained(base_model, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
# For some tokenizer with no pad token like llama
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def tokenize(prompt, add_eos_token=True):
result = tokenizer(
prompt,
truncation=True,
max_length=cutoff_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != tokenizer.eos_token_id
and len(result["input_ids"]) < cutoff_len
and add_eos_token
):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
def generate_and_tokenize_prompt(example):
full_prompt = generate_prompt(example)
tokenized_full_prompt = tokenize(full_prompt)
return tokenized_full_prompt
def custom_random_mask_function_with_custom_kwargs(custom_arg):
def mask_fn(base_layer, r):
"""
This mask function is similar to the random_mask provided in src/peft/tuners/shira/mask_functions.py except the seed is derived from custom_kwargs.
Please use this as an example to create your own custom sparse masks that may use custom_kwargs. Remember, for a pretrained weight with shape m, n,
mask_fn must return only one mask (shape: m, n) which must be binary 0 or 1 with num_shira_parameters = r(m+n) for linear layers. Device and dtype
of mask must be same as base layer's weight's device and dtype.
"""
new_seed = custom_arg
shape = base_layer.weight.shape
num_shira_weights = r * (shape[0] + shape[1])
random_generator = torch.Generator()
random_generator.manual_seed(new_seed)
idx = (torch.randperm(base_layer.weight.numel(), generator=random_generator)[:num_shira_weights]).to(
base_layer.weight.device
)
val = torch.ones_like(idx.type(base_layer.weight.dtype))
mask = torch.zeros_like(base_layer.weight.view(1, -1))
mask = mask.scatter_(1, idx.unsqueeze(0), val.unsqueeze(0)).view(shape)
return mask
return mask_fn
mask_type = "random" if not use_custom_random_mask_function_with_custom_kwargs else "custom"
config = ShiraConfig(
r=shira_r,
mask_type=mask_type,
target_modules=shira_target_modules,
task_type="CAUSAL_LM",
)
if use_custom_random_mask_function_with_custom_kwargs:
custom_arg = 120
custom_mask_fn = custom_random_mask_function_with_custom_kwargs(custom_arg)
config.mask_fn = custom_mask_fn
model = get_peft_model(model, config)
data = load_dataset(data_path)
train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=42)
train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)
trainer = transformers.Trainer(
model=model,
train_dataset=train_data,
eval_dataset=val_data,
args=transformers.TrainingArguments(
per_device_train_batch_size=batch_size,
warmup_steps=100,
num_train_epochs=num_epochs,
learning_rate=learning_rate,
logging_steps=100,
optim="adamw_torch",
eval_strategy="steps",
save_strategy="steps",
eval_steps=eval_step,
save_steps=save_step,
output_dir=output_dir,
save_total_limit=3,
load_best_model_at_end=True,
ddp_find_unused_parameters=False if world_size > 1 else None,
),
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
),
)
trainer.train()
model.save_pretrained(output_dir)
# Delete the model and load it again from the checkpoint.
del model
model = AutoModelForCausalLM.from_pretrained(base_model, **model_kwargs)
model = PeftModel.from_pretrained(model, output_dir)
def generate_prompt(example):
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{example["instruction"]}
### Response:
{example["output"]}"""
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--base_model", type=str, default="path/to/model")
parser.add_argument("--data_path", type=str, default="yahma/alpaca-cleaned")
parser.add_argument("--output_dir", type=str, default="shira")
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--num_epochs", type=int, default=1)
parser.add_argument("--learning_rate", type=float, default=3e-4)
parser.add_argument("--cutoff_len", type=int, default=256)
parser.add_argument("--val_set_size", type=int, default=16)
parser.add_argument("--eval_step", type=int, default=100)
parser.add_argument("--save_step", type=int, default=100)
parser.add_argument("--device_map", type=str, default="auto")
parser.add_argument("--shira_r", type=int, default=32)
parser.add_argument("--shira_target_modules", type=str, default=None)
parser.add_argument("--torch_dtype", type=str, default="float16")
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--use_custom_random_mask_function_with_custom_kwargs", action="store_true")
args = parser.parse_args()
train(
base_model=args.base_model,
data_path=args.data_path,
output_dir=args.output_dir,
batch_size=args.batch_size,
num_epochs=args.num_epochs,
learning_rate=args.learning_rate,
cutoff_len=args.cutoff_len,
val_set_size=args.val_set_size,
eval_step=args.eval_step,
save_step=args.save_step,
device_map=args.device_map,
shira_r=args.shira_r,
shira_target_modules=args.shira_target_modules,
torch_dtype=args.torch_dtype,
seed=args.seed,
use_custom_random_mask_function_with_custom_kwargs=args.use_custom_random_mask_function_with_custom_kwargs,
)

View File

@ -0,0 +1,15 @@
{
"auto_mapping": null,
"base_model_name_or_path": null,
"fan_in_fan_out": false,
"inference_mode": false,
"init_weights": true,
"mask_type": "random",
"modules_to_save": null,
"peft_type": "SHIRA",
"r": 32,
"random_seed": 42,
"revision": null,
"target_modules": null,
"task_type": null
}

View File

@ -0,0 +1,6 @@
{
"optimizer_kwargs": {
"lr": 3e-4
}
}

View File

@ -91,6 +91,8 @@ from .tuners import (
PromptTuningInit,
RandLoraConfig,
RandLoraModel,
ShiraConfig,
ShiraModel,
TrainableTokensConfig,
TrainableTokensModel,
VBLoRAConfig,
@ -186,6 +188,8 @@ __all__ = [
"PromptTuningInit",
"RandLoraConfig",
"RandLoraModel",
"ShiraConfig",
"ShiraModel",
"TaskType",
"TrainableTokensConfig",
"TrainableTokensModel",

View File

@ -41,6 +41,7 @@ from .poly import PolyConfig, PolyModel
from .prefix_tuning import PrefixEncoder, PrefixTuningConfig
from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit
from .randlora import RandLoraConfig, RandLoraModel
from .shira import ShiraConfig, ShiraModel
from .trainable_tokens import TrainableTokensConfig, TrainableTokensModel
from .vblora import VBLoRAConfig, VBLoRAModel
from .vera import VeraConfig, VeraModel
@ -95,6 +96,8 @@ __all__ = [
"PromptTuningInit",
"RandLoraConfig",
"RandLoraModel",
"ShiraConfig",
"ShiraModel",
"TrainableTokensConfig",
"TrainableTokensModel",
"VBLoRAConfig",

View File

@ -19,7 +19,7 @@ from typing import Any, Optional, Union
from torch import nn
from tqdm import tqdm
from peft.tuners import adalora, loha, lokr, lora, oft
from peft.tuners import adalora, loha, lokr, lora, oft, shira
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists
from peft.utils import (
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
@ -31,10 +31,25 @@ from peft.utils import (
# Collection of constants used for all tuners
COMPATIBLE_TUNER_TYPES = (PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.ADALORA, PeftType.OFT)
PREFIXES = [lora.LoraModel.prefix, lokr.LoKrModel.prefix, loha.LoHaModel.prefix, oft.OFTModel.prefix]
Configs = Union[lora.LoraConfig, loha.LoHaConfig, lokr.LoKrConfig, adalora.AdaLoraConfig, oft.OFTConfig]
Layers = (lora.layer.LoraLayer, loha.layer.LoHaLayer, lokr.layer.LoKrLayer, adalora.layer.AdaLoraLayer, oft.OFTLayer)
COMPATIBLE_TUNER_TYPES = (PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.ADALORA, PeftType.OFT, PeftType.SHIRA)
PREFIXES = [
lora.LoraModel.prefix,
lokr.LoKrModel.prefix,
loha.LoHaModel.prefix,
oft.OFTModel.prefix,
shira.ShiraModel.prefix,
]
Configs = Union[
lora.LoraConfig, loha.LoHaConfig, lokr.LoKrConfig, adalora.AdaLoraConfig, oft.OFTConfig, shira.ShiraConfig
]
Layers = (
lora.layer.LoraLayer,
loha.layer.LoHaLayer,
lokr.layer.LoKrLayer,
adalora.layer.AdaLoraLayer,
oft.OFTLayer,
shira.ShiraLayer,
)
class MixedModel(BaseTuner):
@ -96,6 +111,8 @@ class MixedModel(BaseTuner):
lokr.LoKrModel._create_and_replace(self, config, *args, **kwargs)
elif isinstance(config, oft.OFTConfig):
oft.OFTModel._create_and_replace(self, config, *args, **kwargs)
elif isinstance(config, shira.ShiraConfig):
shira.ShiraModel._create_and_replace(self, config, *args, **kwargs)
else:
raise ValueError(f"Unsupported config type {type(config)}, should be one of {COMPATIBLE_TUNER_TYPES}.")
@ -174,6 +191,8 @@ class MixedModel(BaseTuner):
new_module = lokr.LoKrModel._create_new_module(config, adapter_name, target, **kwargs)
elif isinstance(config, oft.OFTConfig):
new_module = oft.OFTModel._create_new_module(config, adapter_name, target, **kwargs)
elif isinstance(config, shira.ShiraConfig):
new_module = shira.ShiraModel._create_new_module(config, adapter_name, target, **kwargs)
else:
raise ValueError(f"Unknown config type {type(config)}, should be one of {COMPATIBLE_TUNER_TYPES}.")
return new_module

View File

@ -0,0 +1,27 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from peft.utils import register_peft_method
from .config import ShiraConfig
from .layer import Linear, ShiraLayer
from .model import ShiraModel
__all__ = ["Linear", "ShiraConfig", "ShiraLayer", "ShiraModel"]
register_peft_method(
name="shira", config_cls=ShiraConfig, model_cls=ShiraModel, prefix="shira_", is_mixed_compatible=True
)

View File

@ -0,0 +1,129 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import warnings
from dataclasses import dataclass, field
from typing import Literal, Optional, Union
from peft.config import PeftConfig
from peft.utils import PeftType
from .mask_functions import random_mask
@dataclass
class ShiraConfig(PeftConfig):
"""
This is the configuration class to store the configuration of a [`ShiraModel`].
Args:
r (`int`, *optional*, defaults to `32`):
For a given target module, the number of SHiRA parameters is computed as r(m+n), where the original tensor
dimensions are m x n. This means the number of SHiRA parameters is the same as that for a LoRA adapter.
SHiRA is a high rank adapter. Setting this r parameter does not restrict the rank to this value.
mask_type (`str`, defaults to `random`):
Type of mask function. Defaults to a random sparse mask. An optional user-defined mask_fn to compute the
mask value can also be supplied by instantiating `config = ShiraConfig(...)` and then setting
`config.mask_fn = <your custom mask function>`. For a pretrained weight with shape m x n, the custom mask
function must return only one mask (shape: m x n) which must be binary 0 or 1 with num_shira_parameters =
r(m + n) for linear layers. Device and dtype of mask must be same as base layer's weight's device and
dtype. Please see mask_functions.py for more details and to see the default random sparse mask
implementation.
random_seed (`int`, *optional*, defaults to `None`):
random seed for the torch generator for random_mask.
target_modules (`Union[List[str], str]`):
List of module names or regex expression of the module names to replace with SHiRA. For example, ['q', 'v']
or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'. Only linear layers are supported.
fan_in_fan_out (`bool`):
Set this to True if the layer to replace stores weight like (fan_in, fan_out). For example, gpt-2 uses
`Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`.
init_weights (`bool`, defaults to `True`):
Initialize SHiRA weight to have zero values. If set to False, SHiRA weights are initialized to randn values
instead of zeros and this is used only for testing.
modules_to_save (`List[str]`):
List of modules apart from SHiRA layers to be set as trainable and saved in the final checkpoint.
"""
r: int = field(
default=32,
metadata={
"help": (
"For a given target module, the number of SHiRA parameters is computed as r(m+n), where the original "
"tensor dimensions are m x n. This means the number of SHiRA parameters is the same as that for a LoRA adapter. "
"SHiRA is a high rank adapter. Setting this r parameter does not restrict the rank to this value."
)
},
)
mask_type: Literal["random"] = field(
default="random",
metadata={
"help": (
"Type of mask function. Defaults to a random sparse mask. "
"An optional user-defined mask_fn to compute the mask value can also be supplied by instantiating `config = ShiraConfig(...)` and then setting "
"`config.mask_fn = <your custom mask function>`. For a pretrained weight with shape m x n, the custom mask function must return only one mask (shape: m x n) "
"which must be binary 0 or 1 with num_shira_parameters = r(m + n) for linear layers. Device and dtype of mask must be same as base layer's weight's device and dtype. "
"Please see mask_functions.py for more details and to see the default random sparse mask implementation."
)
},
)
random_seed: Optional[int] = field(
default=None, metadata={"help": "random seed for the torch generator for random_mask"}
)
target_modules: Optional[Union[list[str], str]] = field(
default=None,
metadata={
"help": (
"List of module names or regex expression of the module names to replace with SHiRA."
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'. "
"Only linear layers are supported."
)
},
)
fan_in_fan_out: bool = field(
default=False,
metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"},
)
init_weights: bool = field(
default=True,
metadata={
"help": "Initialize SHiRA weight to have zero values. If set to False, SHiRA weights are initialized to randn values instead of zeros and this is used only for testing."
},
)
modules_to_save: Optional[list[str]] = field(
default=None,
metadata={
"help": (
"List of modules apart from SHiRA layers to be set as trainable and saved in the final checkpoint. For"
" example, in Sequence Classification or Token Classification tasks, the final layer"
" `classifier/score` are randomly initialized and as such need to be trainable and saved."
)
},
)
def __post_init__(self):
super().__post_init__()
self.peft_type = PeftType.SHIRA
self.target_modules = (
set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules
)
if self.mask_type == "random":
self.mask_fn = random_mask
else:
if not self.inference_mode:
warnings.warn(
f"Argument {self.mask_type=} is not recognized, please supply your own masking function by calling `config.mask_fn = my_mask_fn`."
)
self.mask_fn = None

View File

@ -0,0 +1,215 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import warnings
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
class ShiraLayer(BaseTunerLayer):
# List all names of layers that may contain trainable adapter weights
adapter_layer_names = ("shira_weight",)
# All names of other adapter-related parameters
other_param_names = ("r", "scaling", "shira_indices")
def __init__(self, base_layer: nn.Module, **kwargs):
self.base_layer = base_layer
self.r = {}
self.scaling = {}
self.shira_weight = nn.ParameterDict({})
self.shira_indices = {}
self.weight_shape = base_layer.weight.shape # Assumes SHiRA is on some layer with "weight" parameter
# Mark the weight as unmerged
self._disable_adapters = False
self.merged_adapters = []
base_layer = self.get_base_layer()
if isinstance(base_layer, nn.Linear):
in_features, out_features = base_layer.in_features, base_layer.out_features
else:
raise NotImplementedError("Only nn.Linear layers supported currently")
self.in_features = in_features
self.out_features = out_features
self.kwargs = kwargs
def update_layer(
self,
adapter_name,
mask,
r,
init_weights: bool = True,
):
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
self.r[adapter_name] = r
self.scaling[adapter_name] = (
1.0 # Default scale during training. Can be set to any (non-negative) value during inference.
)
# The number of shira weights in this layer is determined by r such that the total number of weights is the same as a LoRA Layer (for direct comparisons)
num_shira_weight = r * (self.in_features + self.out_features)
if num_shira_weight > self.in_features * self.out_features:
raise ValueError(
f"The set rank {r} results in more shira params than the total number of params in the base layer {self.in_features * self.out_features} and this is not allowed."
)
# Actual trainable parameters
# We have used a vector parameter with fixed indices that we use inside a torch.sparse_coo_tensor in get_delta_weight function.
# Directly using a torch.sparse_coo_tensor as a parameter could have been possible but we ran into some issues similar to:
# https://github.com/pytorch/pytorch/issues/79542.
shira_init_weight = torch.zeros(num_shira_weight) if init_weights else torch.randn(num_shira_weight)
self.shira_weight[adapter_name] = nn.Parameter(
shira_init_weight.to(self.base_layer.weight.dtype).to(self.base_layer.weight.device),
requires_grad=True,
)
if mask is not None:
# Compute the shira_indices from the mask. Make sure the mask is formed using r*(self.in_features + self.out_features) and not some other K.
mask_indices = torch.where(mask == 1.0)
self.shira_indices[adapter_name] = torch.cat(
[mask_indices[0].unsqueeze(0), mask_indices[1].unsqueeze(0)], 0
).to(torch.int)
self.shira_indices[adapter_name] = self.shira_indices[adapter_name].to(self.base_layer.weight.device)
if self.shira_indices[adapter_name].shape[1] != self.shira_weight[adapter_name].shape[0]:
raise ValueError(
f"The SHiRA indices and weights are not the same dimensions for adapter {adapter_name} in layer {self.base_layer}"
)
self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters)
def reset_shira_parameters(self, adapter_name):
nn.init.zeros_(self.shira_weight[adapter_name])
def set_scale(self, adapter, scale):
if adapter not in self.scaling:
# Ignore the case where the adapter is not in the layer
return
self.scaling[adapter] = scale
class Linear(nn.Module, ShiraLayer):
# SHiRA implemented in a dense layer
def __init__(
self,
base_layer,
mask,
adapter_name: str,
r: int = 0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stored weight like (fan_in, fan_out)
init_weights: bool = True,
**kwargs,
) -> None:
super().__init__()
ShiraLayer.__init__(self, base_layer, **kwargs)
self.fan_in_fan_out = fan_in_fan_out
if self.base_layer is not self.get_base_layer():
raise ValueError("SHiRA does not support nested base layers")
self._active_adapter = adapter_name
self.update_layer(adapter_name, mask, r, init_weights=init_weights)
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Args:
safe_merge (`bool`, *optional*):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
"""
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
# no adapter to merge
return
for active_adapter in adapter_names:
if active_adapter in self.shira_weight.keys():
base_layer = self.get_base_layer()
if safe_merge:
# Note that safe_merge will be slower than the normal merge
# because of the copy operation.
orig_weights = base_layer.weight.data.clone()
orig_weights += self.get_delta_weight(active_adapter)
if not torch.isfinite(orig_weights).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
base_layer.weight.data = orig_weights
else:
base_layer.weight.data += self.get_delta_weight(active_adapter)
self.merged_adapters.append(active_adapter)
def unmerge(self) -> None:
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter in self.shira_weight.keys():
self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter)
def get_delta_weight(self, adapter) -> torch.Tensor:
"""
Compute the delta weight for the given adapter.
Args:
adapter (str):
The name of the adapter for which the delta weight should be computed.
"""
# In multi-gpu environment, the indices are at the wrong gpu. This is needed to correct this.
self.shira_indices[adapter] = self.shira_indices[adapter].to(self.shira_weight[adapter].device)
return torch.sparse_coo_tensor(
self.shira_indices[adapter], self.shira_weight[adapter] * self.scaling[adapter], self.weight_shape
)
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
new_weight = copy.deepcopy(self.base_layer.weight.data)
for active_adapter in self.active_adapters:
if active_adapter not in self.shira_weight.keys():
continue
new_weight += self.get_delta_weight(active_adapter)
result = F.linear(x, new_weight, bias=self.base_layer.bias)
return result
def __repr__(self) -> str:
rep = super().__repr__()
return "shira." + rep

View File

@ -0,0 +1,72 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module is intended to store mask functions for use inside SHiRA construction. The mask functions are required to
have a specific signature as shown below.
Required positional arguments:
base_layer - This is the linear layer where the shira adapter will be attached. r - This parameter is used to
determine the number of parameters in the
shira adapter in a way that is consistent with LoRA sizing. SHiRA is a high rank adapter. Setting this
parameter does not restrict the adapter rank.
Keyword arguments can be provided as needed by the particular mask function implementation.
Return:
mask - this is a torch.tensor of the same shape as base_layer.weight that contains 0s and 1s with the same
dtype and device as base_layer.weight
If you would like to attach SHiRA adapters to a model using PEFT methods (such as get_peft_model()), using more
arguments than the provided positional arguments, you can create the mask function reference like the following:
```
def create_mask_function_reference(**my_kwargs):
def mask_fn(base_layer, r):
... your implementation here that might use my_kwargs ...
return mask
return mask_fn
```
Then, you can create your peft model with custom SHiRA mask as follows:
```
model = ...
my_kwargs = ...
mask_fn = create_mask_function_reference(**my_kwargs)
peft_config = ShiraConfig(r=4, mask_type='my_custom_mask')
peft_config.mask_fn = mask_fn
peft_model = get_peft_model(model, peft_config)
```
Complete training examples are provided in the examples/shira/ directory.
"""
from typing import Optional
import torch
import torch.nn as nn
def random_mask(base_layer: nn.Module, r: int, random_seed: Optional[int] = None, **kwargs) -> torch.tensor:
shape = base_layer.weight.shape
num_shira_weights = r * (shape[0] + shape[1])
random_generator = torch.Generator()
if random_seed is not None:
random_generator.manual_seed(random_seed)
idx = (torch.randperm(base_layer.weight.numel(), generator=random_generator)[:num_shira_weights]).to(
base_layer.weight.device
)
val = torch.ones_like(idx.type(base_layer.weight.dtype))
mask = torch.zeros_like(base_layer.weight.view(1, -1))
mask = mask.scatter_(1, idx.unsqueeze(0), val.unsqueeze(0)).view(shape)
return mask

View File

@ -0,0 +1,340 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import warnings
from dataclasses import asdict
from enum import Enum
from typing import Optional
import torch
import torch.nn as nn
from tqdm import tqdm
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists
from peft.utils import (
TRANSFORMERS_MODELS_TO_SHIRA_TARGET_MODULES_MAPPING,
ModulesToSaveWrapper,
_get_submodules,
)
from .config import ShiraConfig
from .layer import Linear, ShiraLayer
class ShiraModel(BaseTuner):
"""
Creates a Sparse High Rank Adapter (SHiRA) Model from a pretrained model.
Args:
model ([`~transformers.PreTrainedModel`]): The model to be adapted.
config ([`ShiraConfig`]): The configuration of the SHiRA model.
adapter_name (`str`): The name of the adapter, defaults to `"default"`.
Returns:
`torch.nn.Module`: The SHiRA model.
Example:
```py
>>> from transformers import AutoModelForCausalLM
>>> from peft import ShiraConfig, get_peft_model
>>> base_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
>>> config = ShiraConfig(r=32)
>>> model = get_peft_model(base_model, config)
```
**Attributes**:
- **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted.
- **peft_config** ([`ShiraConfig`]): The configuration of the SHiRA model.
"""
prefix: str = "shira_"
def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None:
super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
def _check_new_adapter_config(self, config: ShiraConfig) -> None:
"""
A helper method to check the config when a new adapter is being added.
Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters.
"""
for existing_config in self.peft_config.values():
if existing_config is config:
# skip the current config
continue
@staticmethod
def _check_target_module_exists(shira_config, key):
return check_target_module_exists(shira_config, key)
def _create_and_replace(
self,
shira_config,
adapter_name,
target,
target_name,
parent,
current_key,
**optional_kwargs,
):
if current_key is None:
raise ValueError("Current Key shouldn't be `None`")
bias = hasattr(target, "bias") and target.bias is not None
kwargs = {}
kwargs["bias"] = bias
if shira_config.mask_type == "random":
kwargs["random_seed"] = shira_config.random_seed
for k, v in optional_kwargs.items():
kwargs[k] = v
if isinstance(target, Linear):
mask = (
shira_config.mask_fn(target.base_layer, shira_config.r, **kwargs)
if shira_config.mask_fn is not None
else None
)
target.update_layer(
adapter_name,
mask,
shira_config.r,
init_weights=shira_config.init_weights,
)
else:
new_module = self._create_new_module(shira_config, adapter_name, target, **kwargs)
if adapter_name not in self.active_adapter:
# adding an additional adapter: it is not automatically trainable
new_module.requires_grad_(False)
self._replace_module(parent, target_name, new_module, target)
@staticmethod
def _replace_module(parent, child_name, new_module, child):
setattr(parent, child_name, new_module)
# It's not necessary to set requires_grad here, as that is handled by
# _mark_only_adapters_as_trainable
# child layer wraps the original module, unpack it
if hasattr(child, "base_layer"):
child = child.base_layer
if not hasattr(new_module, "base_layer"):
new_module.weight = child.weight
if hasattr(child, "bias"):
new_module.bias = child.bias
if getattr(child, "state", None) is not None:
if hasattr(new_module, "base_layer"):
new_module.base_layer.state = child.state
else:
new_module.state = child.state
new_module.to(child.weight.device)
meta = torch.device("meta")
# dispatch to correct device
for name, module in new_module.named_modules():
if "shira_" in name:
if not any(p.device == meta for p in module.parameters()):
module.to(child.weight.device)
def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None:
for n, p in model.named_parameters():
if self.prefix not in n:
p.requires_grad = False
@staticmethod
def _create_new_module(shira_config, adapter_name, target, **kwargs):
fan_in_fan_out = shira_config.fan_in_fan_out
_ = kwargs.pop("bias", False)
if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target
if isinstance(target_base_layer, torch.nn.Linear):
if fan_in_fan_out:
warnings.warn(
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
"Setting fan_in_fan_out to False."
)
fan_in_fan_out = shira_config.fan_in_fan_out = False
else:
raise ValueError(
f"Target module {target} is not supported. Currently, only the following modules are supported: "
"`torch.nn.Linear`."
)
mask = (
shira_config.mask_fn(target_base_layer, shira_config.r, **kwargs)
if shira_config.mask_fn is not None
else None
)
new_module = Linear(
target,
mask,
adapter_name,
shira_config.r,
fan_in_fan_out,
init_weights=shira_config.init_weights,
**kwargs,
)
return new_module
def __getattr__(self, name: str):
"""Forward missing attributes to the wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
if name == "model": # see #1892: prevent infinite recursion if class is not initialized
raise
return getattr(self.model, name)
def get_peft_config_as_dict(self, inference: bool = False):
config_dict = {}
for key, value in self.peft_config.items():
config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()}
if inference:
config["inference_mode"] = True
config_dict[key] = config
return config
def _set_adapter_layers(self, enabled=True):
for module in self.model.modules():
if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
module.enable_adapters(enabled)
def enable_adapter_layers(self):
self._set_adapter_layers(enabled=True)
def disable_adapter_layers(self):
self._set_adapter_layers(enabled=False)
def set_adapter(self, adapter_name):
for module in self.model.modules():
if isinstance(module, ShiraLayer):
if module.merged:
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
module.unmerge()
module.set_adapter(adapter_name)
self.active_adapter = adapter_name
@staticmethod
def _prepare_adapter_config(peft_config, model_config):
if peft_config.target_modules is None:
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_SHIRA_TARGET_MODULES_MAPPING:
raise ValueError("Please specify `target_modules` in `peft_config`")
peft_config.target_modules = set(
TRANSFORMERS_MODELS_TO_SHIRA_TARGET_MODULES_MAPPING[model_config["model_type"]]
)
return peft_config
def _unload_and_optionally_merge(
self,
merge=True,
progressbar: bool = False,
safe_merge: bool = False,
adapter_names: Optional[list[str]] = None,
):
# we cannot use self.prefix as we want to include non-trainable shira parameters
key_list = [key for key, _ in self.model.named_modules() if "shira" not in key]
desc = "Unloading " + ("and merging " if merge else "") + "model"
for key in tqdm(key_list, disable=not progressbar, desc=desc):
try:
parent, target, target_name = _get_submodules(self.model, key)
except AttributeError:
continue
if hasattr(target, "base_layer"):
if merge:
target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
self._replace_module(parent, target_name, target.get_base_layer(), target)
elif isinstance(target, ModulesToSaveWrapper):
# save any additional trainable modules part of `modules_to_save`
setattr(parent, target_name, target.modules_to_save[target.active_adapter])
return self.model
def delete_adapter(self, adapter_name: str):
"""
Deletes an existing adapter.
Args:
adapter_name (str): Name of the adapter to be deleted.
"""
if adapter_name not in list(self.peft_config.keys()):
raise ValueError(f"Adapter {adapter_name} does not exist")
del self.peft_config[adapter_name]
# we cannot use self.prefix as we want to include non-trainable shira parameters
key_list = [key for key, _ in self.model.named_modules() if "shira" not in key]
new_adapter = None
for key in key_list:
_, target, _ = _get_submodules(self.model, key)
if isinstance(target, ShiraLayer):
target.delete_adapter(adapter_name)
if new_adapter is None:
new_adapter = target.active_adapter[:]
self.active_adapter = new_adapter or []
def merge_and_unload(
self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
):
r"""
This method merges the Shira layers into the base model. This is needed if someone wants to use the base model
as a standalone model.
Args:
progressbar (`bool`):
whether to show a progressbar indicating the unload and merge process
safe_merge (`bool`):
whether to activate the safe merging check to check if there is any potential Nan in the adapter
weights
adapter_names (`list[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
Example:
```py
>>> from transformers import AutoModelForCausalLM
>>> from peft import ShiraConfig, get_peft_model
>>> base_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
>>> config = ShiraConfig(r=32)
>>> model = get_peft_model(base_model, config)
>>> ## [Train the adapter] ##
>>> merged_model = model.merge_and_unload()
```
"""
return self._unload_and_optionally_merge(
progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names
)
def unload(self):
"""
Gets back the base model by removing all the Shira modules without merging. This gives back the original base
model.
"""
return self._unload_and_optionally_merge(merge=False)

View File

@ -27,6 +27,7 @@ from .other import (
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_SHIRA_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING,
WEIGHTS_NAME,
@ -69,6 +70,7 @@ __all__ = [
"TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING",
"TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING",
"TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING",
"TRANSFORMERS_MODELS_TO_SHIRA_TARGET_MODULES_MAPPING",
"TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING",
"TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING",
"WEIGHTS_NAME",

View File

@ -255,6 +255,43 @@ TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING = {
"qwen3": ["q_proj", "v_proj"],
}
TRANSFORMERS_MODELS_TO_SHIRA_TARGET_MODULES_MAPPING = {
"t5": ["q", "v"],
"mt5": ["q", "v"],
"bart": ["q_proj", "v_proj"],
"gpt2": ["c_attn"],
"bloom": ["query_key_value"],
"blip-2": ["q", "v", "q_proj", "v_proj"],
"opt": ["q_proj", "v_proj"],
"gptj": ["q_proj", "v_proj"],
"gpt_neox": ["query_key_value"],
"gpt_neo": ["q_proj", "v_proj"],
"bert": ["query", "value"],
"roberta": ["query", "value"],
"xlm-roberta": ["query", "value"],
"electra": ["query", "value"],
"deberta-v2": ["query_proj", "value_proj"],
"deberta": ["in_proj"],
"layoutlm": ["query", "value"],
"llama": ["q_proj", "v_proj"],
"chatglm": ["query_key_value"],
"gpt_bigcode": ["c_attn"],
"mpt": ["Wqkv"],
"RefinedWebModel": ["query_key_value"],
"RefinedWeb": ["query_key_value"],
"falcon": ["query_key_value"],
"btlm": ["c_proj", "c_attn"],
"codegen": ["qkv_proj"],
"mistral": ["q_proj", "v_proj"],
"mixtral": ["q_proj", "v_proj"],
"stablelm": ["q_proj", "v_proj"],
"phi": ["q_proj", "v_proj"],
"gemma": ["q_proj", "v_proj"],
"gemma2": ["q_proj", "v_proj"],
"gemma3_text": ["q_proj", "v_proj"],
"qwen2": ["q_proj", "v_proj"],
}
TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING = {
"t5": ["q", "v"],
"mt5": ["q", "v"],

View File

@ -50,6 +50,7 @@ from .constants import (
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_SHIRA_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING,
WEIGHTS_NAME,
@ -79,6 +80,7 @@ __all__ = [
"TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING",
"TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING",
"TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING",
"TRANSFORMERS_MODELS_TO_SHIRA_TARGET_MODULES_MAPPING",
"TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING",
"TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING",
"WEIGHTS_NAME",

View File

@ -41,6 +41,7 @@ class PeftType(str, enum.Enum):
- HRA
- BONE
- RANDLORA
- SHIRA
- C3A
"""
@ -67,6 +68,7 @@ class PeftType(str, enum.Enum):
BONE = "BONE"
RANDLORA = "RANDLORA"
TRAINABLE_TOKENS = "TRAINABLE_TOKENS"
SHIRA = "SHIRA"
C3A = "C3A"

View File

@ -14,6 +14,7 @@
from __future__ import annotations
import os
import platform
import re
import warnings
from typing import Optional
@ -152,6 +153,25 @@ def get_peft_model_state_dict(
prompt_embeddings = model.get_prompt_embedding_to_save(adapter_name)
to_return["prompt_embeddings"] = prompt_embeddings
elif config.peft_type == PeftType.SHIRA:
shira_prefix = PEFT_TYPE_TO_PREFIX_MAPPING[config.peft_type]
to_return = {k: state_dict[k] for k in state_dict if shira_prefix in k}
if platform.system() == "Windows":
warnings.warn(
"Windows has issues saving integers into safetensors. Hence, we convert shira_indices to float32 "
"before saving on Windows OS. The shira_indices will always be converted to integers when loading."
)
for name, module in model.named_modules():
if hasattr(module, "shira_indices"):
for k, v in module.shira_indices.items():
# Windows has some issues with saving integers into safetensors. Tests fail with some kind of
# PermissionError. This results in failed tests, so we are converting indices to float32 before
# saving and then converting them back to int when loading. This is happening only for Windows,
# not for Linux and Mac-OS.
to_return[f"{name}.shira_indices.{k}"] = (
v.to(torch.float32) if platform.system() == "Windows" else v
)
elif config.peft_type == PeftType.VERA:
vera_prefix = PEFT_TYPE_TO_PREFIX_MAPPING[config.peft_type]
to_return = {k: state_dict[k] for k in state_dict if vera_prefix in k}
@ -406,6 +426,22 @@ def set_peft_model_state_dict(
rank_pattern = config.rank_pattern
if rank_pattern is not None:
model.resize_modules_by_rank_pattern(rank_pattern, adapter_name)
elif config.peft_type == PeftType.SHIRA:
if platform.system() == "Windows":
warnings.warn(
"Windows has issues saving integers into safetensors. Hence, we had converted shira_indices "
"to float32 before saving on Windows OS. The shira_indices will always be converted to integers "
"when loading."
)
for name, module in model.named_modules():
if hasattr(module, "shira_indices"):
# for k, v in module.shira_indices.items():
if f"{name}.shira_indices.{adapter_name}" in peft_model_state_dict:
shira_indices_values = peft_model_state_dict.pop(f"{name}.shira_indices.{adapter_name}")
# Convert shira_indices to int in case they were saved on a Windows OS and are being loaded
# on a Linux or a Mac-OS system. If they were saved in Linux or Mac-OS, they are already
# integers and the following will not affect anything.
module.shira_indices[adapter_name] = shira_indices_values.to(torch.int)
elif config.peft_type == PeftType.VERA:
if config.save_projection and "base_model.vera_A" not in peft_model_state_dict:
raise ValueError(

View File

@ -46,6 +46,7 @@ from peft import (
OFTConfig,
PeftModel,
RandLoraConfig,
ShiraConfig,
TaskType,
TrainableTokensConfig,
VBLoRAConfig,
@ -523,6 +524,24 @@ TEST_CASES = [
{"target_modules": ["conv2d"], "boft_block_size": 2, "boft_block_num": 0, "boft_n_butterfly_factor": 3},
),
########
# SHiRA #
########
("Vanilla MLP 1 SHiRA", "MLP", ShiraConfig, {"r": 1, "target_modules": "lin0", "init_weights": False}),
("Vanilla MLP 2 SHiRA", "MLP", ShiraConfig, {"r": 1, "target_modules": ["lin0"], "init_weights": False}),
("Vanilla MLP 3 SHiRA", "MLP", ShiraConfig, {"r": 1, "target_modules": ["lin1"], "init_weights": False}),
(
"Vanilla MLP 4 SHiRA",
"MLP",
ShiraConfig,
{"r": 1, "target_modules": ["lin0", "lin1"], "random_seed": 56, "init_weights": False},
),
(
"Vanilla MLP 5 SHiRA",
"MLP",
ShiraConfig,
{"r": 1, "target_modules": ["lin0"], "init_weights": False},
),
########
# VeRA #
########
("Vanilla MLP 1 VeRA", "MLP", VeraConfig, {"target_modules": "lin0"}),
@ -752,6 +771,20 @@ MULTIPLE_ACTIVE_ADAPTERS_TEST_CASES = [
{"n_frequency": 10, "target_modules": ["lin0"]},
{"n_frequency": 10, "target_modules": ["lin1"]},
),
(
"SHiRA Same",
"shira",
ShiraConfig,
{"r": 1, "target_modules": ["lin0"], "init_weights": False},
{"r": 1, "target_modules": ["lin0"], "init_weights": False},
),
(
"SHiRA Different",
"shira",
ShiraConfig,
{"r": 1, "target_modules": ["lin0"], "init_weights": False},
{"r": 1, "target_modules": ["lin1"], "init_weights": False},
),
# Note: Currently, we cannot target lin0 and lin1 with different adapters when using VeRA. The reason is that the
# first adapter being created will result in a vera_A or vera_B shape that is too small for the next adapter
# (remember that VeRA shares these parameters across all layers), which results in an error.
@ -841,6 +874,7 @@ PREFIXES = {
FourierFTConfig: "fourierft_",
C3AConfig: "c3a_",
HRAConfig: "hra_",
ShiraConfig: "shira_",
VBLoRAConfig: "vblora_",
BoneConfig: "bone_",
TrainableTokensConfig: "trainable_tokens_",
@ -1663,6 +1697,12 @@ class TestPeftCustomModel(PeftCommonTester):
config_kwargs = config_kwargs.copy()
# override the default value and make PEFT operation a no-op
config_kwargs["init_weights"] = True
if issubclass(config_cls, (ShiraConfig,)):
# for SHiRA, setting this to default value of True will turn the PEFT operation into a no-op
# because SHiRA is always initialized to zeros. Configs declared in the test file had set init_weights
# to False (to make sure all other tests have a randn SHiRA initialization). Setting it back to True here
# as required by this test.
config_kwargs["init_weights"] = True
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
@ -2100,7 +2140,9 @@ class TestPeftCustomModel(PeftCommonTester):
assert "default" in model.base_model.classifier.modules_to_save
assert "other" in model.base_model.classifier.modules_to_save
@pytest.mark.parametrize("config_cls", [IA3Config, LoHaConfig, LoKrConfig, LoraConfig, HRAConfig, BoneConfig])
@pytest.mark.parametrize(
"config_cls", [IA3Config, LoHaConfig, LoKrConfig, LoraConfig, HRAConfig, BoneConfig, ShiraConfig]
)
def test_multiple_adapters_mixed_modules_to_save(self, config_cls):
# See issue 1574
# Check that we can have a model where one adapter has modules_to_save and the other doesn't. It should be
@ -2110,6 +2152,8 @@ class TestPeftCustomModel(PeftCommonTester):
if config_cls == BoneConfig:
config_cls = partial(config_cls, r=2)
if config_cls == ShiraConfig:
config_cls = partial(config_cls, r=1)
config0 = config_cls(target_modules=["lin0"], modules_to_save=["lin1"])
config1 = config_cls(target_modules=["lin0"])
@ -2128,7 +2172,9 @@ class TestPeftCustomModel(PeftCommonTester):
model.set_adapter("other")
model(**inputs)
@pytest.mark.parametrize("config_cls", [IA3Config, LoHaConfig, LoKrConfig, LoraConfig, HRAConfig, BoneConfig])
@pytest.mark.parametrize(
"config_cls", [IA3Config, LoHaConfig, LoKrConfig, LoraConfig, HRAConfig, BoneConfig, ShiraConfig]
)
def test_multiple_adapters_mixed_modules_to_save_order_switched(self, config_cls):
# See issue 1574
# Same test as test_multiple_adapters_mixed_modules_to_save, but this time the 2nd adapter has modules_to_save.
@ -2137,6 +2183,8 @@ class TestPeftCustomModel(PeftCommonTester):
if config_cls == BoneConfig:
config_cls = partial(config_cls, r=2)
if config_cls == ShiraConfig:
config_cls = partial(config_cls, r=1)
config0 = config_cls(target_modules=["lin0"])
config1 = config_cls(target_modules=["lin0"], modules_to_save=["lin1"])

View File

@ -40,6 +40,7 @@ from peft import (
PromptEncoderConfig,
PromptTuningConfig,
PromptTuningInit,
ShiraConfig,
VBLoRAConfig,
VeraConfig,
get_peft_model,
@ -180,6 +181,15 @@ ALL_CONFIGS = [
"num_virtual_tokens": 10,
},
),
(
ShiraConfig,
{
"r": 1,
"task_type": "CAUSAL_LM",
"target_modules": None,
"init_weights": False,
},
),
(
VBLoRAConfig,
{
@ -215,8 +225,15 @@ ALL_CONFIGS = [
def _skip_if_not_conv1d_supported(model_id, config_cls):
if "GPT2LMHeadModel" in model_id and config_cls in [BOFTConfig, BoneConfig, HRAConfig, OFTConfig, C3AConfig]:
pytest.skip("Skipping BOFT/HRA/OFT/Bone/C3A for GPT2LMHeadModel")
if "GPT2LMHeadModel" in model_id and config_cls in [
BOFTConfig,
BoneConfig,
HRAConfig,
OFTConfig,
ShiraConfig,
C3AConfig,
]:
pytest.skip("Skipping BOFT/HRA/OFT/Bone/SHiRA/C3A for GPT2LMHeadModel")
def _skip_adalora_oft_hra_bone_for_gpt2(model_id, config_cls):
@ -457,6 +474,7 @@ class TestDecoderModels(PeftCommonTester):
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
def test_unload_adapter(self, model_id, config_cls, config_kwargs):
_skip_adalora_oft_hra_bone_for_gpt2(model_id, config_cls)
_skip_if_not_conv1d_supported(model_id, config_cls)
config_kwargs = set_init_weights_false(config_cls, config_kwargs)
self._test_unload_adapter(model_id, config_cls, config_kwargs.copy())

View File

@ -30,6 +30,7 @@ from peft import (
PrefixTuningConfig,
PromptEncoderConfig,
PromptTuningConfig,
ShiraConfig,
TaskType,
VBLoRAConfig,
VeraConfig,
@ -145,6 +146,15 @@ ALL_CONFIGS = [
"task_type": "SEQ_2_SEQ_LM",
},
),
(
ShiraConfig,
{
"r": 1,
"task_type": "SEQ_2_SEQ_LM",
"target_modules": None,
"init_weights": False,
},
),
(
VBLoRAConfig,
{

View File

@ -29,6 +29,7 @@ from peft import (
PromptEncoderConfig,
PromptLearningConfig,
PromptTuningConfig,
ShiraConfig,
VBLoRAConfig,
VeraConfig,
)
@ -145,6 +146,15 @@ ALL_CONFIGS = [
"num_virtual_tokens": 10,
},
),
(
ShiraConfig,
{
"r": 1,
"task_type": "FEATURE_EXTRACTION",
"target_modules": None,
"init_weights": False,
},
),
(
VBLoRAConfig,
{

View File

@ -29,6 +29,7 @@ from peft import (
PromptEncoderConfig,
PromptTuningConfig,
PromptTuningInit,
ShiraConfig,
VBLoRAConfig,
VeraConfig,
)
@ -143,6 +144,15 @@ ALL_CONFIGS = [
"num_virtual_tokens": 10,
},
),
(
ShiraConfig,
{
"r": 1,
"task_type": "SEQ_CLS",
"target_modules": None,
"init_weights": False,
},
),
(
VBLoRAConfig,
{

278
tests/test_shira.py Normal file
View File

@ -0,0 +1,278 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This test file is for tests specific to SHiRA.
import os
import pytest
import torch
from accelerate.utils.imports import is_bf16_available
from torch import nn
from peft import PeftModel, ShiraConfig, get_peft_model
def custom_random_mask_function_with_custom_kwargs(custom_arg):
def mask_fn(base_layer, r):
"""
This mask function is similar to the random_mask provided in src/peft/tuners/shira/mask_functions.py except the
seed is derived from custom_kwargs. Please use this as an example to create your own custom sparse masks that
may use custom_kwargs. Remember, for a pretrained weight with shape m, n, mask_fn must return only one mask
(shape: m, n) which must be binary 0 or 1 with num_shira_parameters = r(m+n) for linear layers. Device and
dtype of mask must be same as base layer's weight's device and dtype.
"""
new_seed = custom_arg
shape = base_layer.weight.shape
num_shira_weights = r * (shape[0] + shape[1])
random_generator = torch.Generator()
random_generator.manual_seed(new_seed)
idx = (torch.randperm(base_layer.weight.numel(), generator=random_generator)[:num_shira_weights]).to(
base_layer.weight.device
)
val = torch.ones_like(idx.type(base_layer.weight.dtype))
mask = torch.zeros_like(base_layer.weight.view(1, -1))
mask = mask.scatter_(1, idx.unsqueeze(0), val.unsqueeze(0)).view(shape)
return mask
return mask_fn
class MLP(nn.Module):
def __init__(self, bias=True):
super().__init__()
self.relu = nn.ReLU()
self.lin0 = nn.Linear(10, 20, bias=bias)
self.lin1 = nn.Linear(20, 40, bias=bias) # lin1 and lin2 have same shape
self.lin2 = nn.Linear(40, 30, bias=bias)
self.lin3 = nn.Linear(30, 10, bias=bias)
self.sm = nn.LogSoftmax(dim=-1)
def forward(self, X):
X = self.lin0(X)
X = self.relu(X)
X = self.lin1(X)
X = self.relu(X)
X = self.lin2(X)
X = self.relu(X)
X = self.lin3(X)
X = self.sm(X)
return X
class TestShira:
@pytest.fixture
def mlp(self):
torch.manual_seed(0)
model = MLP()
return model
def test_mlp_single_adapter_shapes(self, mlp):
# torch.manual_seed(0)
r = 2
config = ShiraConfig(r=r, target_modules=["lin1", "lin2"])
# creates a default SHiRA adapter
peft_model = get_peft_model(mlp, config)
shira_weight1_size = peft_model.base_model.model.lin1.shira_weight["default"].shape[0]
shira_weight2_size = peft_model.base_model.model.lin2.shira_weight["default"].shape[0]
shira_indices1_size = peft_model.base_model.model.lin1.shira_indices["default"].shape[1]
shira_indices2_size = peft_model.base_model.model.lin2.shira_indices["default"].shape[1]
base_weight1_size = peft_model.base_model.model.lin1.base_layer.weight.shape
base_weight2_size = peft_model.base_model.model.lin2.base_layer.weight.shape
delta_weight1_shape = peft_model.base_model.model.lin1.get_delta_weight("default").shape
delta_weight2_shape = peft_model.base_model.model.lin2.get_delta_weight("default").shape
assert shira_weight1_size == r * (base_weight1_size[0] + base_weight1_size[1])
assert shira_weight2_size == r * (base_weight2_size[0] + base_weight2_size[1])
assert shira_weight1_size == shira_indices1_size
assert shira_weight2_size == shira_indices2_size
assert delta_weight1_shape == base_weight1_size
assert delta_weight2_shape == base_weight2_size
return peft_model
def test_multiple_adapters_save_load(self, mlp, tmp_path):
# check saving and loading works with multiple adapters
# note, the random seeds in the below two configs are not the default values.
# so it will lead to different random sparse masks between saving and loading.
# our goal is to make sure that loaded indices are exactly the same as the saved indices regardless of what initial random mask gets generated.
# we will also make sure that parameters are saved and loaded correctly, and the output remains the same.
config = ShiraConfig(r=2, target_modules=["lin1", "lin2"], random_seed=56)
# creates a default SHiRA adapter
peft_model = get_peft_model(mlp, config, adapter_name="first")
config2 = ShiraConfig(r=3, target_modules=["lin1", "lin2", "lin3"], random_seed=67)
peft_model.add_adapter("second", config2)
assert torch.all(peft_model.base_model.model.lin1.shira_weight["first"] == 0)
assert torch.all(peft_model.base_model.model.lin2.shira_weight["first"] == 0)
assert torch.all(peft_model.base_model.model.lin1.shira_weight["second"] == 0)
assert torch.all(peft_model.base_model.model.lin2.shira_weight["second"] == 0)
assert torch.all(peft_model.base_model.model.lin3.shira_weight["second"] == 0)
shira_assign_val1_f = torch.randn_like(peft_model.base_model.model.lin1.shira_weight["first"])
peft_model.base_model.model.lin1.shira_weight["first"] = shira_assign_val1_f
shira_indices1_f = peft_model.base_model.model.lin1.shira_indices["first"]
shira_assign_val2_f = torch.randn_like(peft_model.base_model.model.lin2.shira_weight["first"])
peft_model.base_model.model.lin2.shira_weight["first"] = shira_assign_val2_f
shira_indices2_f = peft_model.base_model.model.lin2.shira_indices["first"]
shira_assign_val1_s = torch.randn_like(peft_model.base_model.model.lin1.shira_weight["second"])
peft_model.base_model.model.lin1.shira_weight["second"] = shira_assign_val1_s
shira_indices1_s = peft_model.base_model.model.lin1.shira_indices["second"]
shira_assign_val2_s = torch.randn_like(peft_model.base_model.model.lin2.shira_weight["second"])
peft_model.base_model.model.lin2.shira_weight["second"] = shira_assign_val2_s
shira_indices2_s = peft_model.base_model.model.lin2.shira_indices["second"]
shira_assign_val3_s = torch.randn_like(peft_model.base_model.model.lin3.shira_weight["second"])
peft_model.base_model.model.lin3.shira_weight["second"] = shira_assign_val3_s
shira_indices3_s = peft_model.base_model.model.lin3.shira_indices["second"]
input = torch.randn(5, 10)
peft_model.set_adapter("first")
output_first = peft_model(input)
peft_model.set_adapter("second")
output_second = peft_model(input)
# sanity check
assert not torch.allclose(output_first, output_second, atol=1e-3, rtol=1e-3)
save_path = os.path.join(tmp_path, "shira")
peft_model.save_pretrained(save_path)
assert os.path.exists(os.path.join(save_path, "first", "adapter_config.json"))
assert os.path.exists(os.path.join(save_path, "second", "adapter_config.json"))
del peft_model
torch.manual_seed(0)
mlp = MLP()
peft_model = PeftModel.from_pretrained(mlp, os.path.join(save_path, "first"), adapter_name="first")
peft_model.load_adapter(os.path.join(save_path, "second"), "second")
peft_model.set_adapter("first")
output_first_loaded = peft_model(input)
peft_model.set_adapter("second")
output_second_loaded = peft_model(input)
assert torch.allclose(output_first, output_first_loaded)
assert torch.allclose(output_second, output_second_loaded)
assert torch.all(shira_assign_val1_f == peft_model.base_model.model.lin1.shira_weight["first"])
assert torch.all(shira_assign_val2_f == peft_model.base_model.model.lin2.shira_weight["first"])
assert torch.all(shira_indices1_f == peft_model.base_model.model.lin1.shira_indices["first"])
assert torch.all(shira_indices2_f == peft_model.base_model.model.lin2.shira_indices["first"])
assert torch.all(shira_assign_val1_s == peft_model.base_model.model.lin1.shira_weight["second"])
assert torch.all(shira_assign_val2_s == peft_model.base_model.model.lin2.shira_weight["second"])
assert torch.all(shira_assign_val3_s == peft_model.base_model.model.lin3.shira_weight["second"])
assert torch.all(shira_indices1_s == peft_model.base_model.model.lin1.shira_indices["second"])
assert torch.all(shira_indices2_s == peft_model.base_model.model.lin2.shira_indices["second"])
assert torch.all(shira_indices3_s == peft_model.base_model.model.lin3.shira_indices["second"])
return peft_model
def test_save_load_custom_mask_function(self, mlp, tmp_path):
# we want to see if saving and loading works when a custom mask is involved
config = ShiraConfig(r=2, mask_type="custom", target_modules=["lin1", "lin2"], init_weights=False)
custom_arg = 120
custom_mask_fn = custom_random_mask_function_with_custom_kwargs(custom_arg)
config.mask_fn = custom_mask_fn
# create a custom mask SHiRA adapter
peft_model = get_peft_model(mlp, config, adapter_name="first")
shira_assign_val1_f = peft_model.base_model.model.lin1.shira_weight["first"]
shira_indices1_f = peft_model.base_model.model.lin1.shira_indices["first"]
shira_assign_val2_f = peft_model.base_model.model.lin2.shira_weight["first"]
shira_indices2_f = peft_model.base_model.model.lin2.shira_indices["first"]
input = torch.randn(5, 10)
peft_model.set_adapter("first")
output_first = peft_model(input)
save_path = os.path.join(tmp_path, "shira")
peft_model.save_pretrained(save_path)
assert os.path.exists(os.path.join(save_path, "first", "adapter_config.json"))
del peft_model
torch.manual_seed(0)
mlp = MLP()
peft_model = PeftModel.from_pretrained(mlp, os.path.join(save_path, "first"), adapter_name="first")
peft_model.set_adapter("first")
output_first_loaded = peft_model(input)
assert torch.allclose(output_first, output_first_loaded)
assert torch.all(shira_assign_val1_f == peft_model.base_model.model.lin1.shira_weight["first"])
assert torch.all(shira_assign_val2_f == peft_model.base_model.model.lin2.shira_weight["first"])
assert torch.all(shira_indices1_f == peft_model.base_model.model.lin1.shira_indices["first"])
assert torch.all(shira_indices2_f == peft_model.base_model.model.lin2.shira_indices["first"])
return peft_model
def test_save_load_default_random_mask_with_seed_function(self, mlp, tmp_path):
# we want to see if saving and loading works when a random mask is involved but the random seed is fixed.
config = ShiraConfig(r=2, target_modules=["lin1", "lin2"], random_seed=567, init_weights=False)
# create a custom mask SHiRA adapter
peft_model = get_peft_model(mlp, config, adapter_name="first")
shira_assign_val1_f = peft_model.base_model.model.lin1.shira_weight["first"]
shira_indices1_f = peft_model.base_model.model.lin1.shira_indices["first"]
shira_assign_val2_f = peft_model.base_model.model.lin2.shira_weight["first"]
shira_indices2_f = peft_model.base_model.model.lin2.shira_indices["first"]
input = torch.randn(5, 10)
peft_model.set_adapter("first")
output_first = peft_model(input)
save_path = os.path.join(tmp_path, "shira")
peft_model.save_pretrained(save_path)
assert os.path.exists(os.path.join(save_path, "first", "adapter_config.json"))
del peft_model
torch.manual_seed(0)
mlp = MLP()
peft_model = PeftModel.from_pretrained(mlp, os.path.join(save_path, "first"), adapter_name="first")
peft_model.set_adapter("first")
output_first_loaded = peft_model(input)
assert torch.allclose(output_first, output_first_loaded)
assert torch.all(shira_assign_val1_f == peft_model.base_model.model.lin1.shira_weight["first"])
assert torch.all(shira_assign_val2_f == peft_model.base_model.model.lin2.shira_weight["first"])
assert torch.all(shira_indices1_f == peft_model.base_model.model.lin1.shira_indices["first"])
assert torch.all(shira_indices2_f == peft_model.base_model.model.lin2.shira_indices["first"])
return peft_model
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
def test_shira_dtypes(self, dtype):
if dtype == torch.bfloat16:
# skip if bf16 is not supported on hardware, see #1872
if not is_bf16_available():
pytest.skip("bfloat16 not supported on this system, skipping the test")
model = MLP().to(dtype)
config = ShiraConfig(r=2, target_modules=["lin1", "lin2"])
peft_model = get_peft_model(model, config)
inputs = torch.randn(5, 10).to(dtype)
output = peft_model(inputs) # should not raise
assert output.dtype == dtype

View File

@ -1608,6 +1608,7 @@ class PeftCommonTester:
"HRA",
"VBLORA",
"RANDLORA",
"SHIRA",
"BONE",
"C3A",
):

View File

@ -26,6 +26,7 @@ from peft import (
IA3Config,
LoraConfig,
PromptLearningConfig,
ShiraConfig,
VBLoRAConfig,
)
from peft.import_utils import (
@ -228,6 +229,8 @@ def set_init_weights_false(config_cls, kwargs):
if issubclass(config_cls, PromptLearningConfig):
return kwargs
if issubclass(config_cls, ShiraConfig):
return kwargs
if config_cls == VBLoRAConfig:
return kwargs