mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 15:33:48 +08:00
FEAT Add SHiRA Adapters (#2584)
Implements: Sparse High Rank Adapters Paper: https://arxiv.org/abs/2406.13175
This commit is contained in:
@ -126,6 +126,8 @@
|
||||
title: Trainable Tokens
|
||||
- local: package_reference/randlora
|
||||
title: RandLora
|
||||
- local: package_reference/shira
|
||||
title: SHiRA
|
||||
- local: package_reference/c3a
|
||||
title: C3A
|
||||
|
||||
|
35
docs/source/package_reference/shira.md
Normal file
35
docs/source/package_reference/shira.md
Normal file
@ -0,0 +1,35 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
|
||||
-->
|
||||
|
||||
# Sparse High Rank Adapters
|
||||
|
||||
Sparse High Rank Adapters or [SHiRA](https://arxiv.org/abs/2406.13175) is an alternate type of adapter and has been found to have significant advantages over the low rank adapters. Specifically, SHiRA achieves better accuracy than LoRA for a variety of vision and language tasks. It also offers simpler and higher quality multi-adapter fusion by significantly reducing concept loss, a common problem faced by low rank adapters. SHiRA directly finetunes a small number of the base model's parameters to finetune the model on any adaptation task.
|
||||
|
||||
SHiRA currently has the following constraint:
|
||||
|
||||
- Only `nn.Linear` layers are supported.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
> Low Rank Adaptation (LoRA) has gained massive attention in the recent generative AI research. One of the main advantages of LoRA is its ability to be fused with pretrained models, adding no overhead during inference. However, from a mobile deployment standpoint, we can either avoid inference overhead in the fused mode but lose the ability to switch adapters rapidly, or suffer significant (up to 30% higher) inference latency while enabling rapid switching in the unfused mode. LoRA also exhibits concept-loss when multiple adapters are used concurrently. In this paper, we propose Sparse High Rank Adapters (SHiRA), a new paradigm which incurs no inference overhead, enables rapid switching, and significantly reduces concept-loss. Specifically, SHiRA can be trained by directly tuning only 1-2% of the base model weights while leaving others unchanged. This results in a highly sparse adapter which can be switched directly in the fused mode. We further provide theoretical and empirical insights on how high sparsity in SHiRA can aid multi-adapter fusion by reducing concept loss. Our extensive experiments on LVMs and LLMs demonstrate that finetuning only a small fraction of the parameters in the base model significantly outperforms LoRA while enabling both rapid switching and multi-adapter fusion. Finally, we provide a latency- and memory-efficient SHiRA implementation based on Parameter-Efficient Finetuning (PEFT) Library which trains at nearly the same speed as LoRA while consuming up to 16% lower peak GPU memory, thus making SHiRA easy to adopt for practical use cases. To demonstrate rapid switching benefits during inference, we show that loading SHiRA on a base model can be 5x-16x faster than LoRA fusion on a CPU.
|
||||
|
||||
## ShiraConfig
|
||||
|
||||
[[autodoc]] tuners.shira.config.ShiraConfig
|
||||
|
||||
## ShiraModel
|
||||
|
||||
[[autodoc]] tuners.shira.model.ShiraModel
|
73
examples/shira_finetuning/README.md
Normal file
73
examples/shira_finetuning/README.md
Normal file
@ -0,0 +1,73 @@
|
||||
# Sparse High Rank Adapters
|
||||
|
||||
## Introduction
|
||||
Sparse High Rank Adapters or [SHiRA](https://arxiv.org/abs/2406.13175) is an alternate type of adapter and has been found to have significant advantages over the low rank adapters. Specifically, SHiRA achieves better accuracy than LoRA for a variety of vision and language tasks. It also offers simpler and higher quality multi-adapter fusion by significantly reducing concept loss, a common problem faced by low rank adapters. SHiRA directly finetunes a small number of the base model's parameters to finetune the model on any adaptation task.
|
||||
|
||||
## Quick start
|
||||
```python
|
||||
import torch
|
||||
from peft import ShiraConfig, get_peft_model
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
from datasets import load_dataset
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.bfloat16, device_map="auto")
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
dataset = load_dataset("imdb", split="train[:1%]")
|
||||
shira_config = ShiraConfig(
|
||||
r=32,
|
||||
)
|
||||
peft_model = get_peft_model(model, shira_config)
|
||||
training_args = SFTConfig(dataset_text_field="text", max_seq_length=128)
|
||||
trainer = SFTTrainer(
|
||||
model=peft_model,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
trainer.train()
|
||||
peft_model.save_pretrained("shira-opt-350m")
|
||||
```
|
||||
|
||||
For more options and a more detailed example code, you can refer to shira finetuning script.
|
||||
Run the script simply by running:
|
||||
```bash
|
||||
python3 examples/shira_finetuning/shira_finetuning.py --base_model facebook/opt-350m
|
||||
```
|
||||
|
||||
If you want to run DDP by [accelerate](https://huggingface.co/docs/accelerate/en/index), please run `accelerate config` to set your ddp config, and run:
|
||||
```bash
|
||||
accelerate launch examples/shira_finetuning/shira_finetuning.py --base_model facebook/opt-350m
|
||||
```
|
||||
please add `--device_map cpu` if you want to run finetune on CPU.
|
||||
|
||||
If you want to train SHiRA with a custom sparse mask function which requires custom keyword arguments, please see the definition of `custom_random_mask_function_with_custom_kwargs` function provided in the `shira_fintuning.py` script. You can run this code using the `--use_custom_random_mask_function_with_custom_kwargs` argument. Without this argument, SHiRA defaults to a random sparse mask. Please run the code as follows. :
|
||||
```bash
|
||||
python3 examples/shira_finetuning/shira_finetuning.py --base_model facebook/opt-350m --use_custom_random_mask_function_with_custom_kwargs
|
||||
|
||||
```
|
||||
|
||||
|
||||
## Use the model
|
||||
You can load and use the model as any other 🤗 PEFT model
|
||||
```python
|
||||
from peft import PeftModel
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
shira_model = PeftModel.from_pretrained(model, "shira-opt-350m")
|
||||
```
|
||||
|
||||
## Citation
|
||||
```
|
||||
@inproceedings{NEURIPS2024_18c0102c,
|
||||
author = {Bhardwaj, Kartikeya and Pandey, Nilesh Prasad and Priyadarshi, Sweta and Ganapathy, Viswanath and Kadambi, Shreya and Esteves, Rafael and Borse, Shubhankar and Whatmough, Paul and Garrepalli, Risheek and Van Baalen, Mart and Teague, Harris and Nagel, Markus},
|
||||
booktitle = {Advances in Neural Information Processing Systems},
|
||||
editor = {A. Globerson and L. Mackey and D. Belgrave and A. Fan and U. Paquet and J. Tomczak and C. Zhang},
|
||||
pages = {13685--13715},
|
||||
publisher = {Curran Associates, Inc.},
|
||||
title = {Sparse High Rank Adapters},
|
||||
url = {https://proceedings.neurips.cc/paper_files/paper/2024/file/18c0102cb7f1a02c14f0929089b2e576-Paper-Conference.pdf},
|
||||
volume = {37},
|
||||
year = {2024}
|
||||
}
|
||||
```
|
217
examples/shira_finetuning/shira_finetuning.py
Normal file
217
examples/shira_finetuning/shira_finetuning.py
Normal file
@ -0,0 +1,217 @@
|
||||
# Copyright 2025-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
|
||||
|
||||
from peft import (
|
||||
PeftModel,
|
||||
ShiraConfig,
|
||||
get_peft_model,
|
||||
)
|
||||
|
||||
|
||||
def train(
|
||||
base_model: str = "path/to/model",
|
||||
data_path: str = "yahma/alpaca-cleaned",
|
||||
output_dir: str = "shira",
|
||||
batch_size: int = 16,
|
||||
num_epochs: int = 1,
|
||||
learning_rate: float = 3e-4,
|
||||
cutoff_len: int = 256,
|
||||
val_set_size: int = 16,
|
||||
eval_step: int = 100,
|
||||
save_step: int = 100,
|
||||
device_map: str = "auto",
|
||||
shira_r: int = 32,
|
||||
shira_target_modules: list[str] = None,
|
||||
torch_dtype: str = "float16",
|
||||
seed: Optional[int] = None,
|
||||
use_custom_random_mask_function_with_custom_kwargs: Optional[bool] = False,
|
||||
):
|
||||
# Set device_map to the right place when enabling DDP.
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 0)) or int(os.environ.get("PMI_SIZE", 0))
|
||||
if world_size > 1 and device_map != "cpu":
|
||||
from accelerate import Accelerator
|
||||
|
||||
device_map = {"": Accelerator().process_index}
|
||||
# Set seed
|
||||
if seed is not None:
|
||||
set_seed(seed)
|
||||
model_kwargs = {"torch_dtype": getattr(torch, torch_dtype), "device_map": device_map}
|
||||
model = AutoModelForCausalLM.from_pretrained(base_model, **model_kwargs)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
|
||||
# For some tokenizer with no pad token like llama
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
def tokenize(prompt, add_eos_token=True):
|
||||
result = tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
max_length=cutoff_len,
|
||||
padding=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
if (
|
||||
result["input_ids"][-1] != tokenizer.eos_token_id
|
||||
and len(result["input_ids"]) < cutoff_len
|
||||
and add_eos_token
|
||||
):
|
||||
result["input_ids"].append(tokenizer.eos_token_id)
|
||||
result["attention_mask"].append(1)
|
||||
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
|
||||
return result
|
||||
|
||||
def generate_and_tokenize_prompt(example):
|
||||
full_prompt = generate_prompt(example)
|
||||
tokenized_full_prompt = tokenize(full_prompt)
|
||||
return tokenized_full_prompt
|
||||
|
||||
def custom_random_mask_function_with_custom_kwargs(custom_arg):
|
||||
def mask_fn(base_layer, r):
|
||||
"""
|
||||
This mask function is similar to the random_mask provided in src/peft/tuners/shira/mask_functions.py except the seed is derived from custom_kwargs.
|
||||
Please use this as an example to create your own custom sparse masks that may use custom_kwargs. Remember, for a pretrained weight with shape m, n,
|
||||
mask_fn must return only one mask (shape: m, n) which must be binary 0 or 1 with num_shira_parameters = r(m+n) for linear layers. Device and dtype
|
||||
of mask must be same as base layer's weight's device and dtype.
|
||||
"""
|
||||
new_seed = custom_arg
|
||||
shape = base_layer.weight.shape
|
||||
num_shira_weights = r * (shape[0] + shape[1])
|
||||
random_generator = torch.Generator()
|
||||
random_generator.manual_seed(new_seed)
|
||||
|
||||
idx = (torch.randperm(base_layer.weight.numel(), generator=random_generator)[:num_shira_weights]).to(
|
||||
base_layer.weight.device
|
||||
)
|
||||
val = torch.ones_like(idx.type(base_layer.weight.dtype))
|
||||
mask = torch.zeros_like(base_layer.weight.view(1, -1))
|
||||
mask = mask.scatter_(1, idx.unsqueeze(0), val.unsqueeze(0)).view(shape)
|
||||
|
||||
return mask
|
||||
|
||||
return mask_fn
|
||||
|
||||
mask_type = "random" if not use_custom_random_mask_function_with_custom_kwargs else "custom"
|
||||
config = ShiraConfig(
|
||||
r=shira_r,
|
||||
mask_type=mask_type,
|
||||
target_modules=shira_target_modules,
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
if use_custom_random_mask_function_with_custom_kwargs:
|
||||
custom_arg = 120
|
||||
custom_mask_fn = custom_random_mask_function_with_custom_kwargs(custom_arg)
|
||||
config.mask_fn = custom_mask_fn
|
||||
|
||||
model = get_peft_model(model, config)
|
||||
|
||||
data = load_dataset(data_path)
|
||||
|
||||
train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=42)
|
||||
train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
|
||||
val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)
|
||||
|
||||
trainer = transformers.Trainer(
|
||||
model=model,
|
||||
train_dataset=train_data,
|
||||
eval_dataset=val_data,
|
||||
args=transformers.TrainingArguments(
|
||||
per_device_train_batch_size=batch_size,
|
||||
warmup_steps=100,
|
||||
num_train_epochs=num_epochs,
|
||||
learning_rate=learning_rate,
|
||||
logging_steps=100,
|
||||
optim="adamw_torch",
|
||||
eval_strategy="steps",
|
||||
save_strategy="steps",
|
||||
eval_steps=eval_step,
|
||||
save_steps=save_step,
|
||||
output_dir=output_dir,
|
||||
save_total_limit=3,
|
||||
load_best_model_at_end=True,
|
||||
ddp_find_unused_parameters=False if world_size > 1 else None,
|
||||
),
|
||||
data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
|
||||
),
|
||||
)
|
||||
trainer.train()
|
||||
model.save_pretrained(output_dir)
|
||||
|
||||
# Delete the model and load it again from the checkpoint.
|
||||
del model
|
||||
model = AutoModelForCausalLM.from_pretrained(base_model, **model_kwargs)
|
||||
model = PeftModel.from_pretrained(model, output_dir)
|
||||
|
||||
|
||||
def generate_prompt(example):
|
||||
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
||||
### Instruction:
|
||||
{example["instruction"]}
|
||||
### Response:
|
||||
{example["output"]}"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--base_model", type=str, default="path/to/model")
|
||||
parser.add_argument("--data_path", type=str, default="yahma/alpaca-cleaned")
|
||||
parser.add_argument("--output_dir", type=str, default="shira")
|
||||
parser.add_argument("--batch_size", type=int, default=16)
|
||||
parser.add_argument("--num_epochs", type=int, default=1)
|
||||
parser.add_argument("--learning_rate", type=float, default=3e-4)
|
||||
parser.add_argument("--cutoff_len", type=int, default=256)
|
||||
parser.add_argument("--val_set_size", type=int, default=16)
|
||||
parser.add_argument("--eval_step", type=int, default=100)
|
||||
parser.add_argument("--save_step", type=int, default=100)
|
||||
parser.add_argument("--device_map", type=str, default="auto")
|
||||
parser.add_argument("--shira_r", type=int, default=32)
|
||||
parser.add_argument("--shira_target_modules", type=str, default=None)
|
||||
parser.add_argument("--torch_dtype", type=str, default="float16")
|
||||
parser.add_argument("--seed", type=int, default=None)
|
||||
parser.add_argument("--use_custom_random_mask_function_with_custom_kwargs", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
train(
|
||||
base_model=args.base_model,
|
||||
data_path=args.data_path,
|
||||
output_dir=args.output_dir,
|
||||
batch_size=args.batch_size,
|
||||
num_epochs=args.num_epochs,
|
||||
learning_rate=args.learning_rate,
|
||||
cutoff_len=args.cutoff_len,
|
||||
val_set_size=args.val_set_size,
|
||||
eval_step=args.eval_step,
|
||||
save_step=args.save_step,
|
||||
device_map=args.device_map,
|
||||
shira_r=args.shira_r,
|
||||
shira_target_modules=args.shira_target_modules,
|
||||
torch_dtype=args.torch_dtype,
|
||||
seed=args.seed,
|
||||
use_custom_random_mask_function_with_custom_kwargs=args.use_custom_random_mask_function_with_custom_kwargs,
|
||||
)
|
@ -0,0 +1,15 @@
|
||||
{
|
||||
"auto_mapping": null,
|
||||
"base_model_name_or_path": null,
|
||||
"fan_in_fan_out": false,
|
||||
"inference_mode": false,
|
||||
"init_weights": true,
|
||||
"mask_type": "random",
|
||||
"modules_to_save": null,
|
||||
"peft_type": "SHIRA",
|
||||
"r": 32,
|
||||
"random_seed": 42,
|
||||
"revision": null,
|
||||
"target_modules": null,
|
||||
"task_type": null
|
||||
}
|
@ -0,0 +1,6 @@
|
||||
{
|
||||
"optimizer_kwargs": {
|
||||
"lr": 3e-4
|
||||
}
|
||||
}
|
||||
|
@ -91,6 +91,8 @@ from .tuners import (
|
||||
PromptTuningInit,
|
||||
RandLoraConfig,
|
||||
RandLoraModel,
|
||||
ShiraConfig,
|
||||
ShiraModel,
|
||||
TrainableTokensConfig,
|
||||
TrainableTokensModel,
|
||||
VBLoRAConfig,
|
||||
@ -186,6 +188,8 @@ __all__ = [
|
||||
"PromptTuningInit",
|
||||
"RandLoraConfig",
|
||||
"RandLoraModel",
|
||||
"ShiraConfig",
|
||||
"ShiraModel",
|
||||
"TaskType",
|
||||
"TrainableTokensConfig",
|
||||
"TrainableTokensModel",
|
||||
|
@ -41,6 +41,7 @@ from .poly import PolyConfig, PolyModel
|
||||
from .prefix_tuning import PrefixEncoder, PrefixTuningConfig
|
||||
from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit
|
||||
from .randlora import RandLoraConfig, RandLoraModel
|
||||
from .shira import ShiraConfig, ShiraModel
|
||||
from .trainable_tokens import TrainableTokensConfig, TrainableTokensModel
|
||||
from .vblora import VBLoRAConfig, VBLoRAModel
|
||||
from .vera import VeraConfig, VeraModel
|
||||
@ -95,6 +96,8 @@ __all__ = [
|
||||
"PromptTuningInit",
|
||||
"RandLoraConfig",
|
||||
"RandLoraModel",
|
||||
"ShiraConfig",
|
||||
"ShiraModel",
|
||||
"TrainableTokensConfig",
|
||||
"TrainableTokensModel",
|
||||
"VBLoRAConfig",
|
||||
|
@ -19,7 +19,7 @@ from typing import Any, Optional, Union
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
|
||||
from peft.tuners import adalora, loha, lokr, lora, oft
|
||||
from peft.tuners import adalora, loha, lokr, lora, oft, shira
|
||||
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists
|
||||
from peft.utils import (
|
||||
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
|
||||
@ -31,10 +31,25 @@ from peft.utils import (
|
||||
|
||||
|
||||
# Collection of constants used for all tuners
|
||||
COMPATIBLE_TUNER_TYPES = (PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.ADALORA, PeftType.OFT)
|
||||
PREFIXES = [lora.LoraModel.prefix, lokr.LoKrModel.prefix, loha.LoHaModel.prefix, oft.OFTModel.prefix]
|
||||
Configs = Union[lora.LoraConfig, loha.LoHaConfig, lokr.LoKrConfig, adalora.AdaLoraConfig, oft.OFTConfig]
|
||||
Layers = (lora.layer.LoraLayer, loha.layer.LoHaLayer, lokr.layer.LoKrLayer, adalora.layer.AdaLoraLayer, oft.OFTLayer)
|
||||
COMPATIBLE_TUNER_TYPES = (PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.ADALORA, PeftType.OFT, PeftType.SHIRA)
|
||||
PREFIXES = [
|
||||
lora.LoraModel.prefix,
|
||||
lokr.LoKrModel.prefix,
|
||||
loha.LoHaModel.prefix,
|
||||
oft.OFTModel.prefix,
|
||||
shira.ShiraModel.prefix,
|
||||
]
|
||||
Configs = Union[
|
||||
lora.LoraConfig, loha.LoHaConfig, lokr.LoKrConfig, adalora.AdaLoraConfig, oft.OFTConfig, shira.ShiraConfig
|
||||
]
|
||||
Layers = (
|
||||
lora.layer.LoraLayer,
|
||||
loha.layer.LoHaLayer,
|
||||
lokr.layer.LoKrLayer,
|
||||
adalora.layer.AdaLoraLayer,
|
||||
oft.OFTLayer,
|
||||
shira.ShiraLayer,
|
||||
)
|
||||
|
||||
|
||||
class MixedModel(BaseTuner):
|
||||
@ -96,6 +111,8 @@ class MixedModel(BaseTuner):
|
||||
lokr.LoKrModel._create_and_replace(self, config, *args, **kwargs)
|
||||
elif isinstance(config, oft.OFTConfig):
|
||||
oft.OFTModel._create_and_replace(self, config, *args, **kwargs)
|
||||
elif isinstance(config, shira.ShiraConfig):
|
||||
shira.ShiraModel._create_and_replace(self, config, *args, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported config type {type(config)}, should be one of {COMPATIBLE_TUNER_TYPES}.")
|
||||
|
||||
@ -174,6 +191,8 @@ class MixedModel(BaseTuner):
|
||||
new_module = lokr.LoKrModel._create_new_module(config, adapter_name, target, **kwargs)
|
||||
elif isinstance(config, oft.OFTConfig):
|
||||
new_module = oft.OFTModel._create_new_module(config, adapter_name, target, **kwargs)
|
||||
elif isinstance(config, shira.ShiraConfig):
|
||||
new_module = shira.ShiraModel._create_new_module(config, adapter_name, target, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown config type {type(config)}, should be one of {COMPATIBLE_TUNER_TYPES}.")
|
||||
return new_module
|
||||
|
27
src/peft/tuners/shira/__init__.py
Normal file
27
src/peft/tuners/shira/__init__.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright 2025-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from peft.utils import register_peft_method
|
||||
|
||||
from .config import ShiraConfig
|
||||
from .layer import Linear, ShiraLayer
|
||||
from .model import ShiraModel
|
||||
|
||||
|
||||
__all__ = ["Linear", "ShiraConfig", "ShiraLayer", "ShiraModel"]
|
||||
|
||||
|
||||
register_peft_method(
|
||||
name="shira", config_cls=ShiraConfig, model_cls=ShiraModel, prefix="shira_", is_mixed_compatible=True
|
||||
)
|
129
src/peft/tuners/shira/config.py
Normal file
129
src/peft/tuners/shira/config.py
Normal file
@ -0,0 +1,129 @@
|
||||
# Copyright 2025-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from peft.config import PeftConfig
|
||||
from peft.utils import PeftType
|
||||
|
||||
from .mask_functions import random_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShiraConfig(PeftConfig):
|
||||
"""
|
||||
This is the configuration class to store the configuration of a [`ShiraModel`].
|
||||
|
||||
Args:
|
||||
r (`int`, *optional*, defaults to `32`):
|
||||
For a given target module, the number of SHiRA parameters is computed as r(m+n), where the original tensor
|
||||
dimensions are m x n. This means the number of SHiRA parameters is the same as that for a LoRA adapter.
|
||||
SHiRA is a high rank adapter. Setting this r parameter does not restrict the rank to this value.
|
||||
mask_type (`str`, defaults to `random`):
|
||||
Type of mask function. Defaults to a random sparse mask. An optional user-defined mask_fn to compute the
|
||||
mask value can also be supplied by instantiating `config = ShiraConfig(...)` and then setting
|
||||
`config.mask_fn = <your custom mask function>`. For a pretrained weight with shape m x n, the custom mask
|
||||
function must return only one mask (shape: m x n) which must be binary 0 or 1 with num_shira_parameters =
|
||||
r(m + n) for linear layers. Device and dtype of mask must be same as base layer's weight's device and
|
||||
dtype. Please see mask_functions.py for more details and to see the default random sparse mask
|
||||
implementation.
|
||||
random_seed (`int`, *optional*, defaults to `None`):
|
||||
random seed for the torch generator for random_mask.
|
||||
target_modules (`Union[List[str], str]`):
|
||||
List of module names or regex expression of the module names to replace with SHiRA. For example, ['q', 'v']
|
||||
or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'. Only linear layers are supported.
|
||||
fan_in_fan_out (`bool`):
|
||||
Set this to True if the layer to replace stores weight like (fan_in, fan_out). For example, gpt-2 uses
|
||||
`Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`.
|
||||
init_weights (`bool`, defaults to `True`):
|
||||
Initialize SHiRA weight to have zero values. If set to False, SHiRA weights are initialized to randn values
|
||||
instead of zeros and this is used only for testing.
|
||||
modules_to_save (`List[str]`):
|
||||
List of modules apart from SHiRA layers to be set as trainable and saved in the final checkpoint.
|
||||
"""
|
||||
|
||||
r: int = field(
|
||||
default=32,
|
||||
metadata={
|
||||
"help": (
|
||||
"For a given target module, the number of SHiRA parameters is computed as r(m+n), where the original "
|
||||
"tensor dimensions are m x n. This means the number of SHiRA parameters is the same as that for a LoRA adapter. "
|
||||
"SHiRA is a high rank adapter. Setting this r parameter does not restrict the rank to this value."
|
||||
)
|
||||
},
|
||||
)
|
||||
mask_type: Literal["random"] = field(
|
||||
default="random",
|
||||
metadata={
|
||||
"help": (
|
||||
"Type of mask function. Defaults to a random sparse mask. "
|
||||
"An optional user-defined mask_fn to compute the mask value can also be supplied by instantiating `config = ShiraConfig(...)` and then setting "
|
||||
"`config.mask_fn = <your custom mask function>`. For a pretrained weight with shape m x n, the custom mask function must return only one mask (shape: m x n) "
|
||||
"which must be binary 0 or 1 with num_shira_parameters = r(m + n) for linear layers. Device and dtype of mask must be same as base layer's weight's device and dtype. "
|
||||
"Please see mask_functions.py for more details and to see the default random sparse mask implementation."
|
||||
)
|
||||
},
|
||||
)
|
||||
random_seed: Optional[int] = field(
|
||||
default=None, metadata={"help": "random seed for the torch generator for random_mask"}
|
||||
)
|
||||
target_modules: Optional[Union[list[str], str]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"List of module names or regex expression of the module names to replace with SHiRA."
|
||||
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'. "
|
||||
"Only linear layers are supported."
|
||||
)
|
||||
},
|
||||
)
|
||||
fan_in_fan_out: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Set this to True if the layer to replace stores weight like (fan_in, fan_out)"},
|
||||
)
|
||||
init_weights: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Initialize SHiRA weight to have zero values. If set to False, SHiRA weights are initialized to randn values instead of zeros and this is used only for testing."
|
||||
},
|
||||
)
|
||||
modules_to_save: Optional[list[str]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"List of modules apart from SHiRA layers to be set as trainable and saved in the final checkpoint. For"
|
||||
" example, in Sequence Classification or Token Classification tasks, the final layer"
|
||||
" `classifier/score` are randomly initialized and as such need to be trainable and saved."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.peft_type = PeftType.SHIRA
|
||||
self.target_modules = (
|
||||
set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules
|
||||
)
|
||||
if self.mask_type == "random":
|
||||
self.mask_fn = random_mask
|
||||
else:
|
||||
if not self.inference_mode:
|
||||
warnings.warn(
|
||||
f"Argument {self.mask_type=} is not recognized, please supply your own masking function by calling `config.mask_fn = my_mask_fn`."
|
||||
)
|
||||
self.mask_fn = None
|
215
src/peft/tuners/shira/layer.py
Normal file
215
src/peft/tuners/shira/layer.py
Normal file
@ -0,0 +1,215 @@
|
||||
# Copyright 2025-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
|
||||
|
||||
|
||||
class ShiraLayer(BaseTunerLayer):
|
||||
# List all names of layers that may contain trainable adapter weights
|
||||
adapter_layer_names = ("shira_weight",)
|
||||
# All names of other adapter-related parameters
|
||||
other_param_names = ("r", "scaling", "shira_indices")
|
||||
|
||||
def __init__(self, base_layer: nn.Module, **kwargs):
|
||||
self.base_layer = base_layer
|
||||
self.r = {}
|
||||
self.scaling = {}
|
||||
self.shira_weight = nn.ParameterDict({})
|
||||
self.shira_indices = {}
|
||||
self.weight_shape = base_layer.weight.shape # Assumes SHiRA is on some layer with "weight" parameter
|
||||
|
||||
# Mark the weight as unmerged
|
||||
self._disable_adapters = False
|
||||
self.merged_adapters = []
|
||||
|
||||
base_layer = self.get_base_layer()
|
||||
if isinstance(base_layer, nn.Linear):
|
||||
in_features, out_features = base_layer.in_features, base_layer.out_features
|
||||
else:
|
||||
raise NotImplementedError("Only nn.Linear layers supported currently")
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.kwargs = kwargs
|
||||
|
||||
def update_layer(
|
||||
self,
|
||||
adapter_name,
|
||||
mask,
|
||||
r,
|
||||
init_weights: bool = True,
|
||||
):
|
||||
if r <= 0:
|
||||
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
|
||||
self.r[adapter_name] = r
|
||||
self.scaling[adapter_name] = (
|
||||
1.0 # Default scale during training. Can be set to any (non-negative) value during inference.
|
||||
)
|
||||
# The number of shira weights in this layer is determined by r such that the total number of weights is the same as a LoRA Layer (for direct comparisons)
|
||||
num_shira_weight = r * (self.in_features + self.out_features)
|
||||
if num_shira_weight > self.in_features * self.out_features:
|
||||
raise ValueError(
|
||||
f"The set rank {r} results in more shira params than the total number of params in the base layer {self.in_features * self.out_features} and this is not allowed."
|
||||
)
|
||||
|
||||
# Actual trainable parameters
|
||||
# We have used a vector parameter with fixed indices that we use inside a torch.sparse_coo_tensor in get_delta_weight function.
|
||||
# Directly using a torch.sparse_coo_tensor as a parameter could have been possible but we ran into some issues similar to:
|
||||
# https://github.com/pytorch/pytorch/issues/79542.
|
||||
shira_init_weight = torch.zeros(num_shira_weight) if init_weights else torch.randn(num_shira_weight)
|
||||
self.shira_weight[adapter_name] = nn.Parameter(
|
||||
shira_init_weight.to(self.base_layer.weight.dtype).to(self.base_layer.weight.device),
|
||||
requires_grad=True,
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
# Compute the shira_indices from the mask. Make sure the mask is formed using r*(self.in_features + self.out_features) and not some other K.
|
||||
mask_indices = torch.where(mask == 1.0)
|
||||
self.shira_indices[adapter_name] = torch.cat(
|
||||
[mask_indices[0].unsqueeze(0), mask_indices[1].unsqueeze(0)], 0
|
||||
).to(torch.int)
|
||||
self.shira_indices[adapter_name] = self.shira_indices[adapter_name].to(self.base_layer.weight.device)
|
||||
|
||||
if self.shira_indices[adapter_name].shape[1] != self.shira_weight[adapter_name].shape[0]:
|
||||
raise ValueError(
|
||||
f"The SHiRA indices and weights are not the same dimensions for adapter {adapter_name} in layer {self.base_layer}"
|
||||
)
|
||||
|
||||
self._move_adapter_to_device_of_base_layer(adapter_name)
|
||||
self.set_adapter(self.active_adapters)
|
||||
|
||||
def reset_shira_parameters(self, adapter_name):
|
||||
nn.init.zeros_(self.shira_weight[adapter_name])
|
||||
|
||||
def set_scale(self, adapter, scale):
|
||||
if adapter not in self.scaling:
|
||||
# Ignore the case where the adapter is not in the layer
|
||||
return
|
||||
self.scaling[adapter] = scale
|
||||
|
||||
|
||||
class Linear(nn.Module, ShiraLayer):
|
||||
# SHiRA implemented in a dense layer
|
||||
def __init__(
|
||||
self,
|
||||
base_layer,
|
||||
mask,
|
||||
adapter_name: str,
|
||||
r: int = 0,
|
||||
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stored weight like (fan_in, fan_out)
|
||||
init_weights: bool = True,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
ShiraLayer.__init__(self, base_layer, **kwargs)
|
||||
self.fan_in_fan_out = fan_in_fan_out
|
||||
if self.base_layer is not self.get_base_layer():
|
||||
raise ValueError("SHiRA does not support nested base layers")
|
||||
|
||||
self._active_adapter = adapter_name
|
||||
self.update_layer(adapter_name, mask, r, init_weights=init_weights)
|
||||
|
||||
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
|
||||
"""
|
||||
Merge the active adapter weights into the base weights
|
||||
|
||||
Args:
|
||||
safe_merge (`bool`, *optional*):
|
||||
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
|
||||
before merging the weights. This is useful if you want to check if the merge operation will produce
|
||||
NaNs. Defaults to `False`.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
|
||||
to `None`.
|
||||
"""
|
||||
|
||||
adapter_names = check_adapters_to_merge(self, adapter_names)
|
||||
if not adapter_names:
|
||||
# no adapter to merge
|
||||
return
|
||||
|
||||
for active_adapter in adapter_names:
|
||||
if active_adapter in self.shira_weight.keys():
|
||||
base_layer = self.get_base_layer()
|
||||
if safe_merge:
|
||||
# Note that safe_merge will be slower than the normal merge
|
||||
# because of the copy operation.
|
||||
orig_weights = base_layer.weight.data.clone()
|
||||
|
||||
orig_weights += self.get_delta_weight(active_adapter)
|
||||
|
||||
if not torch.isfinite(orig_weights).all():
|
||||
raise ValueError(
|
||||
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
||||
)
|
||||
|
||||
base_layer.weight.data = orig_weights
|
||||
else:
|
||||
base_layer.weight.data += self.get_delta_weight(active_adapter)
|
||||
self.merged_adapters.append(active_adapter)
|
||||
|
||||
def unmerge(self) -> None:
|
||||
if not self.merged:
|
||||
warnings.warn("Already unmerged. Nothing to do.")
|
||||
return
|
||||
|
||||
while len(self.merged_adapters) > 0:
|
||||
active_adapter = self.merged_adapters.pop()
|
||||
if active_adapter in self.shira_weight.keys():
|
||||
self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter)
|
||||
|
||||
def get_delta_weight(self, adapter) -> torch.Tensor:
|
||||
"""
|
||||
Compute the delta weight for the given adapter.
|
||||
|
||||
Args:
|
||||
adapter (str):
|
||||
The name of the adapter for which the delta weight should be computed.
|
||||
"""
|
||||
|
||||
# In multi-gpu environment, the indices are at the wrong gpu. This is needed to correct this.
|
||||
self.shira_indices[adapter] = self.shira_indices[adapter].to(self.shira_weight[adapter].device)
|
||||
return torch.sparse_coo_tensor(
|
||||
self.shira_indices[adapter], self.shira_weight[adapter] * self.scaling[adapter], self.weight_shape
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
if self.disable_adapters:
|
||||
if self.merged:
|
||||
self.unmerge()
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
elif self.merged:
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
else:
|
||||
new_weight = copy.deepcopy(self.base_layer.weight.data)
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self.shira_weight.keys():
|
||||
continue
|
||||
new_weight += self.get_delta_weight(active_adapter)
|
||||
|
||||
result = F.linear(x, new_weight, bias=self.base_layer.bias)
|
||||
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "shira." + rep
|
72
src/peft/tuners/shira/mask_functions.py
Normal file
72
src/peft/tuners/shira/mask_functions.py
Normal file
@ -0,0 +1,72 @@
|
||||
# Copyright 2025-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This module is intended to store mask functions for use inside SHiRA construction. The mask functions are required to
|
||||
have a specific signature as shown below.
|
||||
|
||||
Required positional arguments:
|
||||
base_layer - This is the linear layer where the shira adapter will be attached. r - This parameter is used to
|
||||
determine the number of parameters in the
|
||||
shira adapter in a way that is consistent with LoRA sizing. SHiRA is a high rank adapter. Setting this
|
||||
parameter does not restrict the adapter rank.
|
||||
Keyword arguments can be provided as needed by the particular mask function implementation.
|
||||
|
||||
Return:
|
||||
mask - this is a torch.tensor of the same shape as base_layer.weight that contains 0s and 1s with the same
|
||||
dtype and device as base_layer.weight
|
||||
|
||||
If you would like to attach SHiRA adapters to a model using PEFT methods (such as get_peft_model()), using more
|
||||
arguments than the provided positional arguments, you can create the mask function reference like the following:
|
||||
|
||||
```
|
||||
def create_mask_function_reference(**my_kwargs):
|
||||
def mask_fn(base_layer, r):
|
||||
... your implementation here that might use my_kwargs ...
|
||||
return mask
|
||||
return mask_fn
|
||||
```
|
||||
Then, you can create your peft model with custom SHiRA mask as follows:
|
||||
```
|
||||
model = ...
|
||||
my_kwargs = ...
|
||||
mask_fn = create_mask_function_reference(**my_kwargs)
|
||||
peft_config = ShiraConfig(r=4, mask_type='my_custom_mask')
|
||||
peft_config.mask_fn = mask_fn
|
||||
peft_model = get_peft_model(model, peft_config)
|
||||
```
|
||||
|
||||
Complete training examples are provided in the examples/shira/ directory.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def random_mask(base_layer: nn.Module, r: int, random_seed: Optional[int] = None, **kwargs) -> torch.tensor:
|
||||
shape = base_layer.weight.shape
|
||||
num_shira_weights = r * (shape[0] + shape[1])
|
||||
random_generator = torch.Generator()
|
||||
if random_seed is not None:
|
||||
random_generator.manual_seed(random_seed)
|
||||
idx = (torch.randperm(base_layer.weight.numel(), generator=random_generator)[:num_shira_weights]).to(
|
||||
base_layer.weight.device
|
||||
)
|
||||
val = torch.ones_like(idx.type(base_layer.weight.dtype))
|
||||
mask = torch.zeros_like(base_layer.weight.view(1, -1))
|
||||
mask = mask.scatter_(1, idx.unsqueeze(0), val.unsqueeze(0)).view(shape)
|
||||
|
||||
return mask
|
340
src/peft/tuners/shira/model.py
Normal file
340
src/peft/tuners/shira/model.py
Normal file
@ -0,0 +1,340 @@
|
||||
# Copyright 2025-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from dataclasses import asdict
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
|
||||
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists
|
||||
from peft.utils import (
|
||||
TRANSFORMERS_MODELS_TO_SHIRA_TARGET_MODULES_MAPPING,
|
||||
ModulesToSaveWrapper,
|
||||
_get_submodules,
|
||||
)
|
||||
|
||||
from .config import ShiraConfig
|
||||
from .layer import Linear, ShiraLayer
|
||||
|
||||
|
||||
class ShiraModel(BaseTuner):
|
||||
"""
|
||||
Creates a Sparse High Rank Adapter (SHiRA) Model from a pretrained model.
|
||||
|
||||
Args:
|
||||
model ([`~transformers.PreTrainedModel`]): The model to be adapted.
|
||||
config ([`ShiraConfig`]): The configuration of the SHiRA model.
|
||||
adapter_name (`str`): The name of the adapter, defaults to `"default"`.
|
||||
|
||||
Returns:
|
||||
`torch.nn.Module`: The SHiRA model.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
>>> from transformers import AutoModelForCausalLM
|
||||
>>> from peft import ShiraConfig, get_peft_model
|
||||
|
||||
>>> base_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
|
||||
>>> config = ShiraConfig(r=32)
|
||||
>>> model = get_peft_model(base_model, config)
|
||||
```
|
||||
|
||||
**Attributes**:
|
||||
- **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted.
|
||||
- **peft_config** ([`ShiraConfig`]): The configuration of the SHiRA model.
|
||||
"""
|
||||
|
||||
prefix: str = "shira_"
|
||||
|
||||
def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None:
|
||||
super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
|
||||
def _check_new_adapter_config(self, config: ShiraConfig) -> None:
|
||||
"""
|
||||
A helper method to check the config when a new adapter is being added.
|
||||
|
||||
Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters.
|
||||
|
||||
"""
|
||||
for existing_config in self.peft_config.values():
|
||||
if existing_config is config:
|
||||
# skip the current config
|
||||
continue
|
||||
|
||||
@staticmethod
|
||||
def _check_target_module_exists(shira_config, key):
|
||||
return check_target_module_exists(shira_config, key)
|
||||
|
||||
def _create_and_replace(
|
||||
self,
|
||||
shira_config,
|
||||
adapter_name,
|
||||
target,
|
||||
target_name,
|
||||
parent,
|
||||
current_key,
|
||||
**optional_kwargs,
|
||||
):
|
||||
if current_key is None:
|
||||
raise ValueError("Current Key shouldn't be `None`")
|
||||
|
||||
bias = hasattr(target, "bias") and target.bias is not None
|
||||
kwargs = {}
|
||||
kwargs["bias"] = bias
|
||||
if shira_config.mask_type == "random":
|
||||
kwargs["random_seed"] = shira_config.random_seed
|
||||
|
||||
for k, v in optional_kwargs.items():
|
||||
kwargs[k] = v
|
||||
|
||||
if isinstance(target, Linear):
|
||||
mask = (
|
||||
shira_config.mask_fn(target.base_layer, shira_config.r, **kwargs)
|
||||
if shira_config.mask_fn is not None
|
||||
else None
|
||||
)
|
||||
target.update_layer(
|
||||
adapter_name,
|
||||
mask,
|
||||
shira_config.r,
|
||||
init_weights=shira_config.init_weights,
|
||||
)
|
||||
else:
|
||||
new_module = self._create_new_module(shira_config, adapter_name, target, **kwargs)
|
||||
if adapter_name not in self.active_adapter:
|
||||
# adding an additional adapter: it is not automatically trainable
|
||||
new_module.requires_grad_(False)
|
||||
self._replace_module(parent, target_name, new_module, target)
|
||||
|
||||
@staticmethod
|
||||
def _replace_module(parent, child_name, new_module, child):
|
||||
setattr(parent, child_name, new_module)
|
||||
# It's not necessary to set requires_grad here, as that is handled by
|
||||
# _mark_only_adapters_as_trainable
|
||||
|
||||
# child layer wraps the original module, unpack it
|
||||
if hasattr(child, "base_layer"):
|
||||
child = child.base_layer
|
||||
|
||||
if not hasattr(new_module, "base_layer"):
|
||||
new_module.weight = child.weight
|
||||
if hasattr(child, "bias"):
|
||||
new_module.bias = child.bias
|
||||
|
||||
if getattr(child, "state", None) is not None:
|
||||
if hasattr(new_module, "base_layer"):
|
||||
new_module.base_layer.state = child.state
|
||||
else:
|
||||
new_module.state = child.state
|
||||
new_module.to(child.weight.device)
|
||||
|
||||
meta = torch.device("meta")
|
||||
# dispatch to correct device
|
||||
for name, module in new_module.named_modules():
|
||||
if "shira_" in name:
|
||||
if not any(p.device == meta for p in module.parameters()):
|
||||
module.to(child.weight.device)
|
||||
|
||||
def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None:
|
||||
for n, p in model.named_parameters():
|
||||
if self.prefix not in n:
|
||||
p.requires_grad = False
|
||||
|
||||
@staticmethod
|
||||
def _create_new_module(shira_config, adapter_name, target, **kwargs):
|
||||
fan_in_fan_out = shira_config.fan_in_fan_out
|
||||
|
||||
_ = kwargs.pop("bias", False)
|
||||
|
||||
if isinstance(target, BaseTunerLayer):
|
||||
target_base_layer = target.get_base_layer()
|
||||
else:
|
||||
target_base_layer = target
|
||||
|
||||
if isinstance(target_base_layer, torch.nn.Linear):
|
||||
if fan_in_fan_out:
|
||||
warnings.warn(
|
||||
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
|
||||
"Setting fan_in_fan_out to False."
|
||||
)
|
||||
fan_in_fan_out = shira_config.fan_in_fan_out = False
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Target module {target} is not supported. Currently, only the following modules are supported: "
|
||||
"`torch.nn.Linear`."
|
||||
)
|
||||
|
||||
mask = (
|
||||
shira_config.mask_fn(target_base_layer, shira_config.r, **kwargs)
|
||||
if shira_config.mask_fn is not None
|
||||
else None
|
||||
)
|
||||
|
||||
new_module = Linear(
|
||||
target,
|
||||
mask,
|
||||
adapter_name,
|
||||
shira_config.r,
|
||||
fan_in_fan_out,
|
||||
init_weights=shira_config.init_weights,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return new_module
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
"""Forward missing attributes to the wrapped module."""
|
||||
try:
|
||||
return super().__getattr__(name) # defer to nn.Module's logic
|
||||
except AttributeError:
|
||||
if name == "model": # see #1892: prevent infinite recursion if class is not initialized
|
||||
raise
|
||||
return getattr(self.model, name)
|
||||
|
||||
def get_peft_config_as_dict(self, inference: bool = False):
|
||||
config_dict = {}
|
||||
for key, value in self.peft_config.items():
|
||||
config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()}
|
||||
if inference:
|
||||
config["inference_mode"] = True
|
||||
config_dict[key] = config
|
||||
return config
|
||||
|
||||
def _set_adapter_layers(self, enabled=True):
|
||||
for module in self.model.modules():
|
||||
if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
|
||||
module.enable_adapters(enabled)
|
||||
|
||||
def enable_adapter_layers(self):
|
||||
self._set_adapter_layers(enabled=True)
|
||||
|
||||
def disable_adapter_layers(self):
|
||||
self._set_adapter_layers(enabled=False)
|
||||
|
||||
def set_adapter(self, adapter_name):
|
||||
for module in self.model.modules():
|
||||
if isinstance(module, ShiraLayer):
|
||||
if module.merged:
|
||||
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
|
||||
module.unmerge()
|
||||
module.set_adapter(adapter_name)
|
||||
self.active_adapter = adapter_name
|
||||
|
||||
@staticmethod
|
||||
def _prepare_adapter_config(peft_config, model_config):
|
||||
if peft_config.target_modules is None:
|
||||
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_SHIRA_TARGET_MODULES_MAPPING:
|
||||
raise ValueError("Please specify `target_modules` in `peft_config`")
|
||||
peft_config.target_modules = set(
|
||||
TRANSFORMERS_MODELS_TO_SHIRA_TARGET_MODULES_MAPPING[model_config["model_type"]]
|
||||
)
|
||||
return peft_config
|
||||
|
||||
def _unload_and_optionally_merge(
|
||||
self,
|
||||
merge=True,
|
||||
progressbar: bool = False,
|
||||
safe_merge: bool = False,
|
||||
adapter_names: Optional[list[str]] = None,
|
||||
):
|
||||
# we cannot use self.prefix as we want to include non-trainable shira parameters
|
||||
key_list = [key for key, _ in self.model.named_modules() if "shira" not in key]
|
||||
desc = "Unloading " + ("and merging " if merge else "") + "model"
|
||||
for key in tqdm(key_list, disable=not progressbar, desc=desc):
|
||||
try:
|
||||
parent, target, target_name = _get_submodules(self.model, key)
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
if hasattr(target, "base_layer"):
|
||||
if merge:
|
||||
target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
|
||||
|
||||
self._replace_module(parent, target_name, target.get_base_layer(), target)
|
||||
elif isinstance(target, ModulesToSaveWrapper):
|
||||
# save any additional trainable modules part of `modules_to_save`
|
||||
setattr(parent, target_name, target.modules_to_save[target.active_adapter])
|
||||
|
||||
return self.model
|
||||
|
||||
def delete_adapter(self, adapter_name: str):
|
||||
"""
|
||||
Deletes an existing adapter.
|
||||
|
||||
Args:
|
||||
adapter_name (str): Name of the adapter to be deleted.
|
||||
"""
|
||||
if adapter_name not in list(self.peft_config.keys()):
|
||||
raise ValueError(f"Adapter {adapter_name} does not exist")
|
||||
del self.peft_config[adapter_name]
|
||||
|
||||
# we cannot use self.prefix as we want to include non-trainable shira parameters
|
||||
key_list = [key for key, _ in self.model.named_modules() if "shira" not in key]
|
||||
new_adapter = None
|
||||
for key in key_list:
|
||||
_, target, _ = _get_submodules(self.model, key)
|
||||
if isinstance(target, ShiraLayer):
|
||||
target.delete_adapter(adapter_name)
|
||||
if new_adapter is None:
|
||||
new_adapter = target.active_adapter[:]
|
||||
|
||||
self.active_adapter = new_adapter or []
|
||||
|
||||
def merge_and_unload(
|
||||
self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
|
||||
):
|
||||
r"""
|
||||
This method merges the Shira layers into the base model. This is needed if someone wants to use the base model
|
||||
as a standalone model.
|
||||
|
||||
Args:
|
||||
progressbar (`bool`):
|
||||
whether to show a progressbar indicating the unload and merge process
|
||||
safe_merge (`bool`):
|
||||
whether to activate the safe merging check to check if there is any potential Nan in the adapter
|
||||
weights
|
||||
adapter_names (`list[str]`, *optional*):
|
||||
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
|
||||
to `None`.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
>>> from transformers import AutoModelForCausalLM
|
||||
>>> from peft import ShiraConfig, get_peft_model
|
||||
|
||||
>>> base_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
|
||||
>>> config = ShiraConfig(r=32)
|
||||
>>> model = get_peft_model(base_model, config)
|
||||
>>> ## [Train the adapter] ##
|
||||
>>> merged_model = model.merge_and_unload()
|
||||
```
|
||||
"""
|
||||
return self._unload_and_optionally_merge(
|
||||
progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
def unload(self):
|
||||
"""
|
||||
Gets back the base model by removing all the Shira modules without merging. This gives back the original base
|
||||
model.
|
||||
"""
|
||||
return self._unload_and_optionally_merge(merge=False)
|
@ -27,6 +27,7 @@ from .other import (
|
||||
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
|
||||
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
|
||||
TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING,
|
||||
TRANSFORMERS_MODELS_TO_SHIRA_TARGET_MODULES_MAPPING,
|
||||
TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING,
|
||||
TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING,
|
||||
WEIGHTS_NAME,
|
||||
@ -69,6 +70,7 @@ __all__ = [
|
||||
"TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING",
|
||||
"TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING",
|
||||
"TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING",
|
||||
"TRANSFORMERS_MODELS_TO_SHIRA_TARGET_MODULES_MAPPING",
|
||||
"TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING",
|
||||
"TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING",
|
||||
"WEIGHTS_NAME",
|
||||
|
@ -255,6 +255,43 @@ TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING = {
|
||||
"qwen3": ["q_proj", "v_proj"],
|
||||
}
|
||||
|
||||
TRANSFORMERS_MODELS_TO_SHIRA_TARGET_MODULES_MAPPING = {
|
||||
"t5": ["q", "v"],
|
||||
"mt5": ["q", "v"],
|
||||
"bart": ["q_proj", "v_proj"],
|
||||
"gpt2": ["c_attn"],
|
||||
"bloom": ["query_key_value"],
|
||||
"blip-2": ["q", "v", "q_proj", "v_proj"],
|
||||
"opt": ["q_proj", "v_proj"],
|
||||
"gptj": ["q_proj", "v_proj"],
|
||||
"gpt_neox": ["query_key_value"],
|
||||
"gpt_neo": ["q_proj", "v_proj"],
|
||||
"bert": ["query", "value"],
|
||||
"roberta": ["query", "value"],
|
||||
"xlm-roberta": ["query", "value"],
|
||||
"electra": ["query", "value"],
|
||||
"deberta-v2": ["query_proj", "value_proj"],
|
||||
"deberta": ["in_proj"],
|
||||
"layoutlm": ["query", "value"],
|
||||
"llama": ["q_proj", "v_proj"],
|
||||
"chatglm": ["query_key_value"],
|
||||
"gpt_bigcode": ["c_attn"],
|
||||
"mpt": ["Wqkv"],
|
||||
"RefinedWebModel": ["query_key_value"],
|
||||
"RefinedWeb": ["query_key_value"],
|
||||
"falcon": ["query_key_value"],
|
||||
"btlm": ["c_proj", "c_attn"],
|
||||
"codegen": ["qkv_proj"],
|
||||
"mistral": ["q_proj", "v_proj"],
|
||||
"mixtral": ["q_proj", "v_proj"],
|
||||
"stablelm": ["q_proj", "v_proj"],
|
||||
"phi": ["q_proj", "v_proj"],
|
||||
"gemma": ["q_proj", "v_proj"],
|
||||
"gemma2": ["q_proj", "v_proj"],
|
||||
"gemma3_text": ["q_proj", "v_proj"],
|
||||
"qwen2": ["q_proj", "v_proj"],
|
||||
}
|
||||
|
||||
TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING = {
|
||||
"t5": ["q", "v"],
|
||||
"mt5": ["q", "v"],
|
||||
|
@ -50,6 +50,7 @@ from .constants import (
|
||||
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
|
||||
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
|
||||
TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING,
|
||||
TRANSFORMERS_MODELS_TO_SHIRA_TARGET_MODULES_MAPPING,
|
||||
TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING,
|
||||
TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING,
|
||||
WEIGHTS_NAME,
|
||||
@ -79,6 +80,7 @@ __all__ = [
|
||||
"TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING",
|
||||
"TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING",
|
||||
"TRANSFORMERS_MODELS_TO_RANDLORA_TARGET_MODULES_MAPPING",
|
||||
"TRANSFORMERS_MODELS_TO_SHIRA_TARGET_MODULES_MAPPING",
|
||||
"TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING",
|
||||
"TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING",
|
||||
"WEIGHTS_NAME",
|
||||
|
@ -41,6 +41,7 @@ class PeftType(str, enum.Enum):
|
||||
- HRA
|
||||
- BONE
|
||||
- RANDLORA
|
||||
- SHIRA
|
||||
- C3A
|
||||
"""
|
||||
|
||||
@ -67,6 +68,7 @@ class PeftType(str, enum.Enum):
|
||||
BONE = "BONE"
|
||||
RANDLORA = "RANDLORA"
|
||||
TRAINABLE_TOKENS = "TRAINABLE_TOKENS"
|
||||
SHIRA = "SHIRA"
|
||||
C3A = "C3A"
|
||||
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import warnings
|
||||
from typing import Optional
|
||||
@ -152,6 +153,25 @@ def get_peft_model_state_dict(
|
||||
prompt_embeddings = model.get_prompt_embedding_to_save(adapter_name)
|
||||
to_return["prompt_embeddings"] = prompt_embeddings
|
||||
|
||||
elif config.peft_type == PeftType.SHIRA:
|
||||
shira_prefix = PEFT_TYPE_TO_PREFIX_MAPPING[config.peft_type]
|
||||
to_return = {k: state_dict[k] for k in state_dict if shira_prefix in k}
|
||||
if platform.system() == "Windows":
|
||||
warnings.warn(
|
||||
"Windows has issues saving integers into safetensors. Hence, we convert shira_indices to float32 "
|
||||
"before saving on Windows OS. The shira_indices will always be converted to integers when loading."
|
||||
)
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, "shira_indices"):
|
||||
for k, v in module.shira_indices.items():
|
||||
# Windows has some issues with saving integers into safetensors. Tests fail with some kind of
|
||||
# PermissionError. This results in failed tests, so we are converting indices to float32 before
|
||||
# saving and then converting them back to int when loading. This is happening only for Windows,
|
||||
# not for Linux and Mac-OS.
|
||||
to_return[f"{name}.shira_indices.{k}"] = (
|
||||
v.to(torch.float32) if platform.system() == "Windows" else v
|
||||
)
|
||||
|
||||
elif config.peft_type == PeftType.VERA:
|
||||
vera_prefix = PEFT_TYPE_TO_PREFIX_MAPPING[config.peft_type]
|
||||
to_return = {k: state_dict[k] for k in state_dict if vera_prefix in k}
|
||||
@ -406,6 +426,22 @@ def set_peft_model_state_dict(
|
||||
rank_pattern = config.rank_pattern
|
||||
if rank_pattern is not None:
|
||||
model.resize_modules_by_rank_pattern(rank_pattern, adapter_name)
|
||||
elif config.peft_type == PeftType.SHIRA:
|
||||
if platform.system() == "Windows":
|
||||
warnings.warn(
|
||||
"Windows has issues saving integers into safetensors. Hence, we had converted shira_indices "
|
||||
"to float32 before saving on Windows OS. The shira_indices will always be converted to integers "
|
||||
"when loading."
|
||||
)
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, "shira_indices"):
|
||||
# for k, v in module.shira_indices.items():
|
||||
if f"{name}.shira_indices.{adapter_name}" in peft_model_state_dict:
|
||||
shira_indices_values = peft_model_state_dict.pop(f"{name}.shira_indices.{adapter_name}")
|
||||
# Convert shira_indices to int in case they were saved on a Windows OS and are being loaded
|
||||
# on a Linux or a Mac-OS system. If they were saved in Linux or Mac-OS, they are already
|
||||
# integers and the following will not affect anything.
|
||||
module.shira_indices[adapter_name] = shira_indices_values.to(torch.int)
|
||||
elif config.peft_type == PeftType.VERA:
|
||||
if config.save_projection and "base_model.vera_A" not in peft_model_state_dict:
|
||||
raise ValueError(
|
||||
|
@ -46,6 +46,7 @@ from peft import (
|
||||
OFTConfig,
|
||||
PeftModel,
|
||||
RandLoraConfig,
|
||||
ShiraConfig,
|
||||
TaskType,
|
||||
TrainableTokensConfig,
|
||||
VBLoRAConfig,
|
||||
@ -523,6 +524,24 @@ TEST_CASES = [
|
||||
{"target_modules": ["conv2d"], "boft_block_size": 2, "boft_block_num": 0, "boft_n_butterfly_factor": 3},
|
||||
),
|
||||
########
|
||||
# SHiRA #
|
||||
########
|
||||
("Vanilla MLP 1 SHiRA", "MLP", ShiraConfig, {"r": 1, "target_modules": "lin0", "init_weights": False}),
|
||||
("Vanilla MLP 2 SHiRA", "MLP", ShiraConfig, {"r": 1, "target_modules": ["lin0"], "init_weights": False}),
|
||||
("Vanilla MLP 3 SHiRA", "MLP", ShiraConfig, {"r": 1, "target_modules": ["lin1"], "init_weights": False}),
|
||||
(
|
||||
"Vanilla MLP 4 SHiRA",
|
||||
"MLP",
|
||||
ShiraConfig,
|
||||
{"r": 1, "target_modules": ["lin0", "lin1"], "random_seed": 56, "init_weights": False},
|
||||
),
|
||||
(
|
||||
"Vanilla MLP 5 SHiRA",
|
||||
"MLP",
|
||||
ShiraConfig,
|
||||
{"r": 1, "target_modules": ["lin0"], "init_weights": False},
|
||||
),
|
||||
########
|
||||
# VeRA #
|
||||
########
|
||||
("Vanilla MLP 1 VeRA", "MLP", VeraConfig, {"target_modules": "lin0"}),
|
||||
@ -752,6 +771,20 @@ MULTIPLE_ACTIVE_ADAPTERS_TEST_CASES = [
|
||||
{"n_frequency": 10, "target_modules": ["lin0"]},
|
||||
{"n_frequency": 10, "target_modules": ["lin1"]},
|
||||
),
|
||||
(
|
||||
"SHiRA Same",
|
||||
"shira",
|
||||
ShiraConfig,
|
||||
{"r": 1, "target_modules": ["lin0"], "init_weights": False},
|
||||
{"r": 1, "target_modules": ["lin0"], "init_weights": False},
|
||||
),
|
||||
(
|
||||
"SHiRA Different",
|
||||
"shira",
|
||||
ShiraConfig,
|
||||
{"r": 1, "target_modules": ["lin0"], "init_weights": False},
|
||||
{"r": 1, "target_modules": ["lin1"], "init_weights": False},
|
||||
),
|
||||
# Note: Currently, we cannot target lin0 and lin1 with different adapters when using VeRA. The reason is that the
|
||||
# first adapter being created will result in a vera_A or vera_B shape that is too small for the next adapter
|
||||
# (remember that VeRA shares these parameters across all layers), which results in an error.
|
||||
@ -841,6 +874,7 @@ PREFIXES = {
|
||||
FourierFTConfig: "fourierft_",
|
||||
C3AConfig: "c3a_",
|
||||
HRAConfig: "hra_",
|
||||
ShiraConfig: "shira_",
|
||||
VBLoRAConfig: "vblora_",
|
||||
BoneConfig: "bone_",
|
||||
TrainableTokensConfig: "trainable_tokens_",
|
||||
@ -1663,6 +1697,12 @@ class TestPeftCustomModel(PeftCommonTester):
|
||||
config_kwargs = config_kwargs.copy()
|
||||
# override the default value and make PEFT operation a no-op
|
||||
config_kwargs["init_weights"] = True
|
||||
if issubclass(config_cls, (ShiraConfig,)):
|
||||
# for SHiRA, setting this to default value of True will turn the PEFT operation into a no-op
|
||||
# because SHiRA is always initialized to zeros. Configs declared in the test file had set init_weights
|
||||
# to False (to make sure all other tests have a randn SHiRA initialization). Setting it back to True here
|
||||
# as required by this test.
|
||||
config_kwargs["init_weights"] = True
|
||||
config = config_cls(
|
||||
base_model_name_or_path=model_id,
|
||||
**config_kwargs,
|
||||
@ -2100,7 +2140,9 @@ class TestPeftCustomModel(PeftCommonTester):
|
||||
assert "default" in model.base_model.classifier.modules_to_save
|
||||
assert "other" in model.base_model.classifier.modules_to_save
|
||||
|
||||
@pytest.mark.parametrize("config_cls", [IA3Config, LoHaConfig, LoKrConfig, LoraConfig, HRAConfig, BoneConfig])
|
||||
@pytest.mark.parametrize(
|
||||
"config_cls", [IA3Config, LoHaConfig, LoKrConfig, LoraConfig, HRAConfig, BoneConfig, ShiraConfig]
|
||||
)
|
||||
def test_multiple_adapters_mixed_modules_to_save(self, config_cls):
|
||||
# See issue 1574
|
||||
# Check that we can have a model where one adapter has modules_to_save and the other doesn't. It should be
|
||||
@ -2110,6 +2152,8 @@ class TestPeftCustomModel(PeftCommonTester):
|
||||
|
||||
if config_cls == BoneConfig:
|
||||
config_cls = partial(config_cls, r=2)
|
||||
if config_cls == ShiraConfig:
|
||||
config_cls = partial(config_cls, r=1)
|
||||
|
||||
config0 = config_cls(target_modules=["lin0"], modules_to_save=["lin1"])
|
||||
config1 = config_cls(target_modules=["lin0"])
|
||||
@ -2128,7 +2172,9 @@ class TestPeftCustomModel(PeftCommonTester):
|
||||
model.set_adapter("other")
|
||||
model(**inputs)
|
||||
|
||||
@pytest.mark.parametrize("config_cls", [IA3Config, LoHaConfig, LoKrConfig, LoraConfig, HRAConfig, BoneConfig])
|
||||
@pytest.mark.parametrize(
|
||||
"config_cls", [IA3Config, LoHaConfig, LoKrConfig, LoraConfig, HRAConfig, BoneConfig, ShiraConfig]
|
||||
)
|
||||
def test_multiple_adapters_mixed_modules_to_save_order_switched(self, config_cls):
|
||||
# See issue 1574
|
||||
# Same test as test_multiple_adapters_mixed_modules_to_save, but this time the 2nd adapter has modules_to_save.
|
||||
@ -2137,6 +2183,8 @@ class TestPeftCustomModel(PeftCommonTester):
|
||||
|
||||
if config_cls == BoneConfig:
|
||||
config_cls = partial(config_cls, r=2)
|
||||
if config_cls == ShiraConfig:
|
||||
config_cls = partial(config_cls, r=1)
|
||||
|
||||
config0 = config_cls(target_modules=["lin0"])
|
||||
config1 = config_cls(target_modules=["lin0"], modules_to_save=["lin1"])
|
||||
|
@ -40,6 +40,7 @@ from peft import (
|
||||
PromptEncoderConfig,
|
||||
PromptTuningConfig,
|
||||
PromptTuningInit,
|
||||
ShiraConfig,
|
||||
VBLoRAConfig,
|
||||
VeraConfig,
|
||||
get_peft_model,
|
||||
@ -180,6 +181,15 @@ ALL_CONFIGS = [
|
||||
"num_virtual_tokens": 10,
|
||||
},
|
||||
),
|
||||
(
|
||||
ShiraConfig,
|
||||
{
|
||||
"r": 1,
|
||||
"task_type": "CAUSAL_LM",
|
||||
"target_modules": None,
|
||||
"init_weights": False,
|
||||
},
|
||||
),
|
||||
(
|
||||
VBLoRAConfig,
|
||||
{
|
||||
@ -215,8 +225,15 @@ ALL_CONFIGS = [
|
||||
|
||||
|
||||
def _skip_if_not_conv1d_supported(model_id, config_cls):
|
||||
if "GPT2LMHeadModel" in model_id and config_cls in [BOFTConfig, BoneConfig, HRAConfig, OFTConfig, C3AConfig]:
|
||||
pytest.skip("Skipping BOFT/HRA/OFT/Bone/C3A for GPT2LMHeadModel")
|
||||
if "GPT2LMHeadModel" in model_id and config_cls in [
|
||||
BOFTConfig,
|
||||
BoneConfig,
|
||||
HRAConfig,
|
||||
OFTConfig,
|
||||
ShiraConfig,
|
||||
C3AConfig,
|
||||
]:
|
||||
pytest.skip("Skipping BOFT/HRA/OFT/Bone/SHiRA/C3A for GPT2LMHeadModel")
|
||||
|
||||
|
||||
def _skip_adalora_oft_hra_bone_for_gpt2(model_id, config_cls):
|
||||
@ -457,6 +474,7 @@ class TestDecoderModels(PeftCommonTester):
|
||||
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
|
||||
def test_unload_adapter(self, model_id, config_cls, config_kwargs):
|
||||
_skip_adalora_oft_hra_bone_for_gpt2(model_id, config_cls)
|
||||
_skip_if_not_conv1d_supported(model_id, config_cls)
|
||||
config_kwargs = set_init_weights_false(config_cls, config_kwargs)
|
||||
self._test_unload_adapter(model_id, config_cls, config_kwargs.copy())
|
||||
|
||||
|
@ -30,6 +30,7 @@ from peft import (
|
||||
PrefixTuningConfig,
|
||||
PromptEncoderConfig,
|
||||
PromptTuningConfig,
|
||||
ShiraConfig,
|
||||
TaskType,
|
||||
VBLoRAConfig,
|
||||
VeraConfig,
|
||||
@ -145,6 +146,15 @@ ALL_CONFIGS = [
|
||||
"task_type": "SEQ_2_SEQ_LM",
|
||||
},
|
||||
),
|
||||
(
|
||||
ShiraConfig,
|
||||
{
|
||||
"r": 1,
|
||||
"task_type": "SEQ_2_SEQ_LM",
|
||||
"target_modules": None,
|
||||
"init_weights": False,
|
||||
},
|
||||
),
|
||||
(
|
||||
VBLoRAConfig,
|
||||
{
|
||||
|
@ -29,6 +29,7 @@ from peft import (
|
||||
PromptEncoderConfig,
|
||||
PromptLearningConfig,
|
||||
PromptTuningConfig,
|
||||
ShiraConfig,
|
||||
VBLoRAConfig,
|
||||
VeraConfig,
|
||||
)
|
||||
@ -145,6 +146,15 @@ ALL_CONFIGS = [
|
||||
"num_virtual_tokens": 10,
|
||||
},
|
||||
),
|
||||
(
|
||||
ShiraConfig,
|
||||
{
|
||||
"r": 1,
|
||||
"task_type": "FEATURE_EXTRACTION",
|
||||
"target_modules": None,
|
||||
"init_weights": False,
|
||||
},
|
||||
),
|
||||
(
|
||||
VBLoRAConfig,
|
||||
{
|
||||
|
@ -29,6 +29,7 @@ from peft import (
|
||||
PromptEncoderConfig,
|
||||
PromptTuningConfig,
|
||||
PromptTuningInit,
|
||||
ShiraConfig,
|
||||
VBLoRAConfig,
|
||||
VeraConfig,
|
||||
)
|
||||
@ -143,6 +144,15 @@ ALL_CONFIGS = [
|
||||
"num_virtual_tokens": 10,
|
||||
},
|
||||
),
|
||||
(
|
||||
ShiraConfig,
|
||||
{
|
||||
"r": 1,
|
||||
"task_type": "SEQ_CLS",
|
||||
"target_modules": None,
|
||||
"init_weights": False,
|
||||
},
|
||||
),
|
||||
(
|
||||
VBLoRAConfig,
|
||||
{
|
||||
|
278
tests/test_shira.py
Normal file
278
tests/test_shira.py
Normal file
@ -0,0 +1,278 @@
|
||||
# Copyright 2025-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This test file is for tests specific to SHiRA.
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from accelerate.utils.imports import is_bf16_available
|
||||
from torch import nn
|
||||
|
||||
from peft import PeftModel, ShiraConfig, get_peft_model
|
||||
|
||||
|
||||
def custom_random_mask_function_with_custom_kwargs(custom_arg):
|
||||
def mask_fn(base_layer, r):
|
||||
"""
|
||||
This mask function is similar to the random_mask provided in src/peft/tuners/shira/mask_functions.py except the
|
||||
seed is derived from custom_kwargs. Please use this as an example to create your own custom sparse masks that
|
||||
may use custom_kwargs. Remember, for a pretrained weight with shape m, n, mask_fn must return only one mask
|
||||
(shape: m, n) which must be binary 0 or 1 with num_shira_parameters = r(m+n) for linear layers. Device and
|
||||
dtype of mask must be same as base layer's weight's device and dtype.
|
||||
"""
|
||||
new_seed = custom_arg
|
||||
shape = base_layer.weight.shape
|
||||
num_shira_weights = r * (shape[0] + shape[1])
|
||||
random_generator = torch.Generator()
|
||||
random_generator.manual_seed(new_seed)
|
||||
|
||||
idx = (torch.randperm(base_layer.weight.numel(), generator=random_generator)[:num_shira_weights]).to(
|
||||
base_layer.weight.device
|
||||
)
|
||||
val = torch.ones_like(idx.type(base_layer.weight.dtype))
|
||||
mask = torch.zeros_like(base_layer.weight.view(1, -1))
|
||||
mask = mask.scatter_(1, idx.unsqueeze(0), val.unsqueeze(0)).view(shape)
|
||||
|
||||
return mask
|
||||
|
||||
return mask_fn
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, bias=True):
|
||||
super().__init__()
|
||||
self.relu = nn.ReLU()
|
||||
self.lin0 = nn.Linear(10, 20, bias=bias)
|
||||
self.lin1 = nn.Linear(20, 40, bias=bias) # lin1 and lin2 have same shape
|
||||
self.lin2 = nn.Linear(40, 30, bias=bias)
|
||||
self.lin3 = nn.Linear(30, 10, bias=bias)
|
||||
self.sm = nn.LogSoftmax(dim=-1)
|
||||
|
||||
def forward(self, X):
|
||||
X = self.lin0(X)
|
||||
X = self.relu(X)
|
||||
X = self.lin1(X)
|
||||
X = self.relu(X)
|
||||
X = self.lin2(X)
|
||||
X = self.relu(X)
|
||||
X = self.lin3(X)
|
||||
X = self.sm(X)
|
||||
return X
|
||||
|
||||
|
||||
class TestShira:
|
||||
@pytest.fixture
|
||||
def mlp(self):
|
||||
torch.manual_seed(0)
|
||||
model = MLP()
|
||||
return model
|
||||
|
||||
def test_mlp_single_adapter_shapes(self, mlp):
|
||||
# torch.manual_seed(0)
|
||||
|
||||
r = 2
|
||||
config = ShiraConfig(r=r, target_modules=["lin1", "lin2"])
|
||||
# creates a default SHiRA adapter
|
||||
peft_model = get_peft_model(mlp, config)
|
||||
|
||||
shira_weight1_size = peft_model.base_model.model.lin1.shira_weight["default"].shape[0]
|
||||
shira_weight2_size = peft_model.base_model.model.lin2.shira_weight["default"].shape[0]
|
||||
shira_indices1_size = peft_model.base_model.model.lin1.shira_indices["default"].shape[1]
|
||||
shira_indices2_size = peft_model.base_model.model.lin2.shira_indices["default"].shape[1]
|
||||
|
||||
base_weight1_size = peft_model.base_model.model.lin1.base_layer.weight.shape
|
||||
base_weight2_size = peft_model.base_model.model.lin2.base_layer.weight.shape
|
||||
|
||||
delta_weight1_shape = peft_model.base_model.model.lin1.get_delta_weight("default").shape
|
||||
delta_weight2_shape = peft_model.base_model.model.lin2.get_delta_weight("default").shape
|
||||
|
||||
assert shira_weight1_size == r * (base_weight1_size[0] + base_weight1_size[1])
|
||||
assert shira_weight2_size == r * (base_weight2_size[0] + base_weight2_size[1])
|
||||
|
||||
assert shira_weight1_size == shira_indices1_size
|
||||
assert shira_weight2_size == shira_indices2_size
|
||||
|
||||
assert delta_weight1_shape == base_weight1_size
|
||||
assert delta_weight2_shape == base_weight2_size
|
||||
|
||||
return peft_model
|
||||
|
||||
def test_multiple_adapters_save_load(self, mlp, tmp_path):
|
||||
# check saving and loading works with multiple adapters
|
||||
# note, the random seeds in the below two configs are not the default values.
|
||||
# so it will lead to different random sparse masks between saving and loading.
|
||||
# our goal is to make sure that loaded indices are exactly the same as the saved indices regardless of what initial random mask gets generated.
|
||||
# we will also make sure that parameters are saved and loaded correctly, and the output remains the same.
|
||||
config = ShiraConfig(r=2, target_modules=["lin1", "lin2"], random_seed=56)
|
||||
# creates a default SHiRA adapter
|
||||
peft_model = get_peft_model(mlp, config, adapter_name="first")
|
||||
config2 = ShiraConfig(r=3, target_modules=["lin1", "lin2", "lin3"], random_seed=67)
|
||||
peft_model.add_adapter("second", config2)
|
||||
|
||||
assert torch.all(peft_model.base_model.model.lin1.shira_weight["first"] == 0)
|
||||
assert torch.all(peft_model.base_model.model.lin2.shira_weight["first"] == 0)
|
||||
assert torch.all(peft_model.base_model.model.lin1.shira_weight["second"] == 0)
|
||||
assert torch.all(peft_model.base_model.model.lin2.shira_weight["second"] == 0)
|
||||
assert torch.all(peft_model.base_model.model.lin3.shira_weight["second"] == 0)
|
||||
|
||||
shira_assign_val1_f = torch.randn_like(peft_model.base_model.model.lin1.shira_weight["first"])
|
||||
peft_model.base_model.model.lin1.shira_weight["first"] = shira_assign_val1_f
|
||||
shira_indices1_f = peft_model.base_model.model.lin1.shira_indices["first"]
|
||||
shira_assign_val2_f = torch.randn_like(peft_model.base_model.model.lin2.shira_weight["first"])
|
||||
peft_model.base_model.model.lin2.shira_weight["first"] = shira_assign_val2_f
|
||||
shira_indices2_f = peft_model.base_model.model.lin2.shira_indices["first"]
|
||||
|
||||
shira_assign_val1_s = torch.randn_like(peft_model.base_model.model.lin1.shira_weight["second"])
|
||||
peft_model.base_model.model.lin1.shira_weight["second"] = shira_assign_val1_s
|
||||
shira_indices1_s = peft_model.base_model.model.lin1.shira_indices["second"]
|
||||
shira_assign_val2_s = torch.randn_like(peft_model.base_model.model.lin2.shira_weight["second"])
|
||||
peft_model.base_model.model.lin2.shira_weight["second"] = shira_assign_val2_s
|
||||
shira_indices2_s = peft_model.base_model.model.lin2.shira_indices["second"]
|
||||
shira_assign_val3_s = torch.randn_like(peft_model.base_model.model.lin3.shira_weight["second"])
|
||||
peft_model.base_model.model.lin3.shira_weight["second"] = shira_assign_val3_s
|
||||
shira_indices3_s = peft_model.base_model.model.lin3.shira_indices["second"]
|
||||
|
||||
input = torch.randn(5, 10)
|
||||
peft_model.set_adapter("first")
|
||||
output_first = peft_model(input)
|
||||
peft_model.set_adapter("second")
|
||||
output_second = peft_model(input)
|
||||
|
||||
# sanity check
|
||||
assert not torch.allclose(output_first, output_second, atol=1e-3, rtol=1e-3)
|
||||
|
||||
save_path = os.path.join(tmp_path, "shira")
|
||||
peft_model.save_pretrained(save_path)
|
||||
assert os.path.exists(os.path.join(save_path, "first", "adapter_config.json"))
|
||||
assert os.path.exists(os.path.join(save_path, "second", "adapter_config.json"))
|
||||
del peft_model
|
||||
|
||||
torch.manual_seed(0)
|
||||
mlp = MLP()
|
||||
peft_model = PeftModel.from_pretrained(mlp, os.path.join(save_path, "first"), adapter_name="first")
|
||||
peft_model.load_adapter(os.path.join(save_path, "second"), "second")
|
||||
|
||||
peft_model.set_adapter("first")
|
||||
output_first_loaded = peft_model(input)
|
||||
peft_model.set_adapter("second")
|
||||
output_second_loaded = peft_model(input)
|
||||
|
||||
assert torch.allclose(output_first, output_first_loaded)
|
||||
assert torch.allclose(output_second, output_second_loaded)
|
||||
|
||||
assert torch.all(shira_assign_val1_f == peft_model.base_model.model.lin1.shira_weight["first"])
|
||||
assert torch.all(shira_assign_val2_f == peft_model.base_model.model.lin2.shira_weight["first"])
|
||||
assert torch.all(shira_indices1_f == peft_model.base_model.model.lin1.shira_indices["first"])
|
||||
assert torch.all(shira_indices2_f == peft_model.base_model.model.lin2.shira_indices["first"])
|
||||
assert torch.all(shira_assign_val1_s == peft_model.base_model.model.lin1.shira_weight["second"])
|
||||
assert torch.all(shira_assign_val2_s == peft_model.base_model.model.lin2.shira_weight["second"])
|
||||
assert torch.all(shira_assign_val3_s == peft_model.base_model.model.lin3.shira_weight["second"])
|
||||
assert torch.all(shira_indices1_s == peft_model.base_model.model.lin1.shira_indices["second"])
|
||||
assert torch.all(shira_indices2_s == peft_model.base_model.model.lin2.shira_indices["second"])
|
||||
assert torch.all(shira_indices3_s == peft_model.base_model.model.lin3.shira_indices["second"])
|
||||
|
||||
return peft_model
|
||||
|
||||
def test_save_load_custom_mask_function(self, mlp, tmp_path):
|
||||
# we want to see if saving and loading works when a custom mask is involved
|
||||
config = ShiraConfig(r=2, mask_type="custom", target_modules=["lin1", "lin2"], init_weights=False)
|
||||
custom_arg = 120
|
||||
custom_mask_fn = custom_random_mask_function_with_custom_kwargs(custom_arg)
|
||||
config.mask_fn = custom_mask_fn
|
||||
|
||||
# create a custom mask SHiRA adapter
|
||||
peft_model = get_peft_model(mlp, config, adapter_name="first")
|
||||
|
||||
shira_assign_val1_f = peft_model.base_model.model.lin1.shira_weight["first"]
|
||||
shira_indices1_f = peft_model.base_model.model.lin1.shira_indices["first"]
|
||||
shira_assign_val2_f = peft_model.base_model.model.lin2.shira_weight["first"]
|
||||
shira_indices2_f = peft_model.base_model.model.lin2.shira_indices["first"]
|
||||
|
||||
input = torch.randn(5, 10)
|
||||
peft_model.set_adapter("first")
|
||||
output_first = peft_model(input)
|
||||
|
||||
save_path = os.path.join(tmp_path, "shira")
|
||||
peft_model.save_pretrained(save_path)
|
||||
assert os.path.exists(os.path.join(save_path, "first", "adapter_config.json"))
|
||||
del peft_model
|
||||
|
||||
torch.manual_seed(0)
|
||||
mlp = MLP()
|
||||
peft_model = PeftModel.from_pretrained(mlp, os.path.join(save_path, "first"), adapter_name="first")
|
||||
|
||||
peft_model.set_adapter("first")
|
||||
output_first_loaded = peft_model(input)
|
||||
|
||||
assert torch.allclose(output_first, output_first_loaded)
|
||||
|
||||
assert torch.all(shira_assign_val1_f == peft_model.base_model.model.lin1.shira_weight["first"])
|
||||
assert torch.all(shira_assign_val2_f == peft_model.base_model.model.lin2.shira_weight["first"])
|
||||
assert torch.all(shira_indices1_f == peft_model.base_model.model.lin1.shira_indices["first"])
|
||||
assert torch.all(shira_indices2_f == peft_model.base_model.model.lin2.shira_indices["first"])
|
||||
|
||||
return peft_model
|
||||
|
||||
def test_save_load_default_random_mask_with_seed_function(self, mlp, tmp_path):
|
||||
# we want to see if saving and loading works when a random mask is involved but the random seed is fixed.
|
||||
config = ShiraConfig(r=2, target_modules=["lin1", "lin2"], random_seed=567, init_weights=False)
|
||||
|
||||
# create a custom mask SHiRA adapter
|
||||
peft_model = get_peft_model(mlp, config, adapter_name="first")
|
||||
|
||||
shira_assign_val1_f = peft_model.base_model.model.lin1.shira_weight["first"]
|
||||
shira_indices1_f = peft_model.base_model.model.lin1.shira_indices["first"]
|
||||
shira_assign_val2_f = peft_model.base_model.model.lin2.shira_weight["first"]
|
||||
shira_indices2_f = peft_model.base_model.model.lin2.shira_indices["first"]
|
||||
|
||||
input = torch.randn(5, 10)
|
||||
peft_model.set_adapter("first")
|
||||
output_first = peft_model(input)
|
||||
|
||||
save_path = os.path.join(tmp_path, "shira")
|
||||
peft_model.save_pretrained(save_path)
|
||||
assert os.path.exists(os.path.join(save_path, "first", "adapter_config.json"))
|
||||
del peft_model
|
||||
|
||||
torch.manual_seed(0)
|
||||
mlp = MLP()
|
||||
peft_model = PeftModel.from_pretrained(mlp, os.path.join(save_path, "first"), adapter_name="first")
|
||||
|
||||
peft_model.set_adapter("first")
|
||||
output_first_loaded = peft_model(input)
|
||||
|
||||
assert torch.allclose(output_first, output_first_loaded)
|
||||
|
||||
assert torch.all(shira_assign_val1_f == peft_model.base_model.model.lin1.shira_weight["first"])
|
||||
assert torch.all(shira_assign_val2_f == peft_model.base_model.model.lin2.shira_weight["first"])
|
||||
assert torch.all(shira_indices1_f == peft_model.base_model.model.lin1.shira_indices["first"])
|
||||
assert torch.all(shira_indices2_f == peft_model.base_model.model.lin2.shira_indices["first"])
|
||||
|
||||
return peft_model
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16])
|
||||
def test_shira_dtypes(self, dtype):
|
||||
if dtype == torch.bfloat16:
|
||||
# skip if bf16 is not supported on hardware, see #1872
|
||||
if not is_bf16_available():
|
||||
pytest.skip("bfloat16 not supported on this system, skipping the test")
|
||||
|
||||
model = MLP().to(dtype)
|
||||
config = ShiraConfig(r=2, target_modules=["lin1", "lin2"])
|
||||
peft_model = get_peft_model(model, config)
|
||||
inputs = torch.randn(5, 10).to(dtype)
|
||||
output = peft_model(inputs) # should not raise
|
||||
assert output.dtype == dtype
|
@ -1608,6 +1608,7 @@ class PeftCommonTester:
|
||||
"HRA",
|
||||
"VBLORA",
|
||||
"RANDLORA",
|
||||
"SHIRA",
|
||||
"BONE",
|
||||
"C3A",
|
||||
):
|
||||
|
@ -26,6 +26,7 @@ from peft import (
|
||||
IA3Config,
|
||||
LoraConfig,
|
||||
PromptLearningConfig,
|
||||
ShiraConfig,
|
||||
VBLoRAConfig,
|
||||
)
|
||||
from peft.import_utils import (
|
||||
@ -228,6 +229,8 @@ def set_init_weights_false(config_cls, kwargs):
|
||||
|
||||
if issubclass(config_cls, PromptLearningConfig):
|
||||
return kwargs
|
||||
if issubclass(config_cls, ShiraConfig):
|
||||
return kwargs
|
||||
if config_cls == VBLoRAConfig:
|
||||
return kwargs
|
||||
|
||||
|
Reference in New Issue
Block a user