FEAT Add MiSS as a replacement for Bone. (#2604)

Add MiSS, an evolution of Bone, from https://arxiv.org/abs/2409.15371.

MiSS will replace Bone, which is now deprecated. A script to convert Bone
checkpoints to MiSS checkpoints is included.
This commit is contained in:
J.L
2025-08-02 00:37:20 +08:00
committed by GitHub
parent a91ec33fc5
commit bb4fb50e2b
21 changed files with 1412 additions and 11 deletions

View File

@ -130,6 +130,8 @@
title: SHiRA
- local: package_reference/c3a
title: C3A
- local: package_reference/miss
title: MiSS
title: Adapters
- sections:

View File

@ -122,12 +122,16 @@ HRA constructs a chain of `r` trainable Householder reflections (HRs). Because t
The higher `r`, the more trainable parameters, resulting in a larger model capacity and better performance. Besides, due to the chain structure, the orthogonality of HR planes impacts the capacity and regularity of HRA. To achieve a trade-off between the model capacity and regularity, an orthogonality regularizer of the HR planes is added to the loss function. The weight \\(\lambda\\) can control the strength of the regularizer.
## Bone
[DiSHA](https://huggingface.co/papers/2409.15371) A novel PEFT technique distinct from LoRA, called Dimension-Sharding Adaptation (DiSHA). By dividing the original weights into multiple subspaces that share a single matrix for weight updates, DiSHA simplifies the process by requiring the trainable matrix to be initialized to zero, eliminating the need for complex initialization as in some LoRA variants. Bone and Bat are derivative structures of DiSHA. Bone significantly improves computational efficiency while saving memory, whereas Bat addresses the limitation of Bone's linear update by employing a non-linear update to break through the upper bound.
[MiSS](https://huggingface.co/papers/2409.15371) New version of paper(MiSS: Balancing LoRA Performance and Efficiency with Simple Shard Sharing)
If you already have a Bone checkpoint, you can use `/scripts/convert-bone-to-miss.py` to convert it into a MiSS checkpoint and proceed with training using MiSS.
<small><a href="https://huggingface.co/papers/2409.15371">DiSHA: Dimension-Sharding Adaptation with Fast Convergence and Fast Computation</a></small>
## MiSS
[MiSS](https://huggingface.co/papers/2409.15371) MiSS (Matrix Shard Sharing) is a novel Parameter-Efficient Fine-Tuning (PEFT) method designed to address the trade-off between adaptability and efficiency in Large Language Models. The core approach of MiSS involves a simple shard-sharing mechanism. It achieves low-rank adaptation by decomposing a weight matrix into multiple fragments and then utilizing a shared, trainable "common fragment." The final low-rank update matrix is constructed by replicating these shared, partitioned shards. (MiSS is a novel PEFT method that adopts a low-rank structure, requires only a single trainable matrix, and introduces a new update mechanism distinct from LoRA, achieving an excellent balance between performance and efficiency.)
Intuitively, the shape of a single trainable matrix in Bone is consistent with `lora_B`, so the `r` parameter in Bone is less than the `r` in LoRA by (`in_feature * r`).
<small><a href="https://huggingface.co/papers/2409.15371">MiSS: Balancing LoRA Performance and Efficiency with Simple Shard Sharing</a></small>
Note: Bat's r (b) is special and requires that weight W satisfies the conditions `in_features % r == 0` and `out_features % r == 0`. Additionally, when `in_features == out_features` and Bone-r equals LoRA-r, Bone's number of trainable parameters is only half that of LoRA.
Intuitively, the shape of a single trainable matrix in MiSS is consistent with `lora_B`, so the `r` parameter in MiSS is less than the `r` in LoRA by (`in_feature * r`).
Although the nonlinear updates of Bat bring some performance improvements, they also increase computational overhead. Its main purpose is to provide researchers with a direction for improvement. Therefore, we recommend fine-tuning the comprehensive Bone model instead.
Note: Bat's r (b) is special and requires that weight W satisfies the conditions `in_features % r == 0` and `out_features % r == 0`. Additionally, when `in_features == out_features` and MiSS-r equals LoRA-r, MiSS's number of trainable parameters is only half that of LoRA.
Although the nonlinear updates of Bat bring some performance improvements, they also increase computational overhead. Its main purpose is to provide researchers with a direction for improvement. Therefore, we recommend fine-tuning the comprehensive MiSS model instead.

View File

@ -0,0 +1,32 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# MiSS
MiSS: Balancing LoRA Performance and Efficiency with Simple Shard Sharing([MiSS](https://huggingface.co/papers/2409.15371)) is a novel PEFT method that adopts a low-rank structure, requires only a single trainable matrix, and introduces a new update mechanism distinct from LoRA, achieving an excellent balance between performance and efficiency.
The abstract from the paper is:
*Parameter-Efficient Fine-Tuning (PEFT) methods, particularly Low-Rank Adaptation (LoRA), effectively reduce the number of trainable parameters in Large Language Models (LLMs). However, as model scales continue to grow, the demand for computational resources remains a significant challenge. Existing LoRA variants often struggle to strike an optimal balance between adaptability (model performance and convergence speed) and efficiency (computational overhead, memory usage, and initialization time). This paper introduces MiSS(Matrix Shard Sharing ), a novel PEFT approach that addresses this trade-off through a simple shard-sharing mechanism. MiSS leverages the insight that a low-rank adaptation can be achieved by decomposing the weight matrix into multiple fragment matrices and utilizing a shared, trainable common fragment. This method constructs the low-rank update matrix through the replication of these shared, partitioned shards. We also propose a hardware-efficient and broadly applicable implementation for MiSS. Extensive experiments conducted on a range of tasks, alongside a systematic analysis of computational performance, demonstrate MiSS's superiority. The results show that MiSS significantly outperforms standard LoRA and its prominent variants in both model performance metrics and computational efficiency, including initialization speed and training throughput. By effectively balancing expressive power and resource utilization, MiSS offers a compelling solution for efficiently adapting large-scale models*.
## MissConfig
[[autodoc]] tuners.miss.config.MissConfig
## MissModel
[[autodoc]] tuners.miss.model.MissModel

View File

@ -0,0 +1,104 @@
# MiSS: Balancing LoRA Performance and Efficiency with Simple Shard Sharing
## Introduction ([Paper](https://huggingface.co/papers/2409.15371), [code](https://github.com/JL-er/MiSS))
MiSS (Matrix Shard Sharing) is a novel PEFT method that adopts a low-rank structure, requires only a single trainable matrix, and introduces a new update mechanism distinct from LoRA, achieving an excellent balance between performance and efficiency.
## Quick Start
```python
import torch
from peft import MissConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.pad_token_id = tokenizer.eos_token_id
miss_config = MissConfig(
r = 64
)
#bat: In this mode, you can enable nonlinear updates across different shards.
# miss_config = MissConfig(
# r = 64,
# init_weights="bat"
# )
# mini: In this mode, you can set a smaller rank to use fewer trainable parameters, but it is recommended to keep `out_features % mini_r == 0`.
# miss_config = MissConfig(
# r = 64,
# init_weights="mini",
# mini_r = 8
# )
peft_model = get_peft_model(model, miss_config)
peft_model.print_trainable_parameters()
dataset = load_dataset("imdb", split="train[:1%]")
training_args = SFTConfig(dataset_text_field="text", max_seq_length=128)
trainer = SFTTrainer(
model=peft_model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
)
trainer.train()
peft_model.save_pretrained("miss-llama-2-7b")
```
To utilize the fine-tuned MiSS modules, simply run the following command:
```python
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", torch_dtype=torch.bfloat16, device_map="auto"
)
peft_model = PeftModel.from_pretrained(model, "miss-llama-2-7b")
```
## Advanced Usage
### Fine-tune
```shell
#Bat performs better than MiSS, but it uses more memory and is twice as slow. If you want to use the Bat method, you only need to add the parameter init_weights="bat".
python miss_finetuning.py \
--base_model_name_or_path meta-llama/Llama-2-7b-hf \
--output_dir output/miss-llama-2-7b-metamath-10k \
--miss_r 64 \
--init_weights True \
--bits bf16 \
--data_path meta-math/MetaMathQA \
--dataset_split train[:100000] \
--dataset_field query response \
--bf16 True \
--num_train_epochs 1 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 8 \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 1 \
--logging_steps 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--tf32 True \
--report_to none
```
# Citation
```bib
@misc{kang2025balancingloraperformanceefficiency,
title={Balancing LoRA Performance and Efficiency with Simple Shard Sharing},
author={Jiale Kang and Qingyu Yin},
year={2025},
eprint={2409.15371},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2409.15371},
}

View File

@ -0,0 +1,107 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass, field
from typing import Literal, Optional
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
from trl import SFTConfig, SFTTrainer
from peft import MissConfig, get_peft_model
@dataclass
class ScriptArguments(SFTConfig):
# model configs
base_model_name_or_path: Optional[str] = field(
default=None, metadata={"help": "The name or path of the fp32/16 base model."}
)
bits: str = field(default="bf16", metadata={"help": "(`['bf16', 'fp16', fp32]`)"})
init_weights: Literal[True, "bat"] = field(
default=True,
metadata={
"help": (
"True -> MiSS efficience and balance; `bat` -> Bat, `mini` -> smaller MiSS efficience and balance"
),
},
)
miss_r: int = field(default=16)
merge_and_save: bool = field(default=False)
# dataset configs
data_path: str = field(default="imdb", metadata={"help": "Path to the training data."})
dataset_split: str = field(default="train[:1%]", metadata={"help": "(`['train', 'test', 'eval']`):"})
dataset_field: list[str] = field(default=None, metadata={"help": "Fields of dataset input and output."})
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
print(script_args)
print(f"Load pre-processed residual model in {script_args.bits} bits.")
if script_args.bits in ["nf4", "fp4", "int8"]:
print("MiSS currently does not support quantization.")
elif script_args.base_model_name_or_path is not None:
print(f"No available pre-processed model, manually initialize a MiSS using {script_args.base_model_name_or_path}.")
model = AutoModelForCausalLM.from_pretrained(
script_args.base_model_name_or_path,
torch_dtype=(
torch.float16
if script_args.bits == "fp16"
else (torch.bfloat16 if script_args.bits == "bf16" else torch.float32)
),
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name_or_path)
tokenizer.pad_token_id = tokenizer.eos_token_id
miss_config = MissConfig(
r=script_args.miss_r,
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
bias="none",
task_type="CAUSAL_LM",
init_weights=script_args.init_weights,
)
peft_model = get_peft_model(model, miss_config)
print(peft_model)
peft_model.print_trainable_parameters()
print(f"Training MiSS with trl on the {script_args.data_path}[{script_args.dataset_split}] dataset.")
dataset = load_dataset(script_args.data_path, split=script_args.dataset_split)
dataset = dataset.map(
lambda example: {
"text": f"### USER: {example[script_args.dataset_field[0]]}\n### ASSISTANT: {example[script_args.dataset_field[1]]}"
}
)
trainer = SFTTrainer(
model=peft_model,
args=script_args,
train_dataset=dataset,
tokenizer=tokenizer,
)
trainer.train()
trainer.save_state()
peft_model.save_pretrained(
os.path.join(script_args.output_dir, "miss_ft"),
)
if script_args.merge_and_save:
model = peft_model.merge_and_unload()
model.save_pretrained(os.path.join(script_args.output_dir, "miss_merged"))
tokenizer.save_pretrained(os.path.join(script_args.output_dir, "miss_merged"))

View File

@ -0,0 +1,70 @@
#!/usr/bin/env python3
# Copyright (c) 2025 Your Organization/Project. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Bone checkpoint to MiSS format."""
import argparse
import json
import os
from pathlib import Path
from safetensors import safe_open
from safetensors.torch import save_file
from peft.utils import CONFIG_NAME, SAFETENSORS_WEIGHTS_NAME
def convert_bone_to_miss(bone_dir: Path, miss_dir: Path) -> None:
"""Convert Bone checkpoint files to MiSS format."""
bone_config_path = bone_dir / CONFIG_NAME
miss_config_path = miss_dir / CONFIG_NAME
if not os.path.exists(miss_dir):
os.makedirs(miss_dir, exist_ok=True)
with open(bone_config_path, encoding="utf-8") as f:
config = json.load(f)
config["peft_type"] = "MISS"
with open(miss_config_path, "w", encoding="utf-8") as f:
json.dump(config, f, indent=2, ensure_ascii=False)
bone_weight_path = bone_dir / SAFETENSORS_WEIGHTS_NAME
miss_weight_path = miss_dir / SAFETENSORS_WEIGHTS_NAME
new_data = {}
with safe_open(bone_weight_path, framework="pt") as f:
for old_key in f.keys():
tensor = f.get_tensor(old_key)
new_key = old_key.replace(".bone_", ".miss_")
new_data[new_key] = tensor
save_file(new_data, miss_weight_path)
print(f"Converted checkpoint saved at {miss_weight_path}")
def main() -> None:
parser = argparse.ArgumentParser(description="Convert Bone checkpoint to MiSS format.")
parser.add_argument("bone_dir", type=Path, help="Directory containing Bone checkpoint files")
parser.add_argument("miss_dir", type=Path, help="Directory to save MiSS checkpoint files")
args = parser.parse_args()
args.miss_dir.mkdir(parents=True, exist_ok=True)
convert_bone_to_miss(args.bone_dir, args.miss_dir)
if __name__ == "__main__":
main()

View File

@ -75,6 +75,8 @@ from .tuners import (
LoraConfig,
LoraModel,
LoraRuntimeConfig,
MissConfig,
MissModel,
MultitaskPromptTuningConfig,
MultitaskPromptTuningInit,
OFTConfig,
@ -161,6 +163,8 @@ __all__ = [
"LoraConfig",
"LoraModel",
"LoraRuntimeConfig",
"MissConfig",
"MissModel",
"MultitaskPromptTuningConfig",
"MultitaskPromptTuningInit",
"OFTConfig",

View File

@ -33,6 +33,7 @@ from .lora import (
get_eva_state_dict,
initialize_lora_eva_weights,
)
from .miss import MissConfig, MissModel
from .mixed import MixedModel
from .multitask_prompt_tuning import MultitaskPromptEmbedding, MultitaskPromptTuningConfig, MultitaskPromptTuningInit
from .oft import OFTConfig, OFTModel
@ -78,6 +79,8 @@ __all__ = [
"LoraConfig",
"LoraModel",
"LoraRuntimeConfig",
"MissConfig",
"MissModel",
"MixedModel",
"MultitaskPromptEmbedding",
"MultitaskPromptTuningConfig",

View File

@ -14,6 +14,7 @@
from __future__ import annotations
import warnings
from dataclasses import dataclass, field
from typing import Literal, Optional, Union
@ -121,3 +122,8 @@ class BoneConfig(PeftConfig):
# if target_modules is a regex expression, then layers_pattern should be None
if isinstance(self.target_modules, str) and self.layers_pattern is not None:
raise ValueError("`layers_pattern` cannot be used when `target_modules` is a str.")
warnings.warn(
"Bone will be removed in v0.19.0 of PEFT, use `MissConfig` instead. "
"If you already have a Bone checkpoint, you can use `/scripts/convert-bone-to-miss.py` to convert it into "
)

View File

@ -0,0 +1,24 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from peft.utils import register_peft_method
from .config import MissConfig
from .layer import MissLayer, MissLinear
from .model import MissModel
__all__ = ["MissConfig", "MissLayer", "MissLinear", "MissModel"]
register_peft_method(name="miss", config_cls=MissConfig, model_cls=MissModel)

View File

@ -0,0 +1,140 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Literal, Optional, Union
from peft.config import PeftConfig
from peft.utils import PeftType
@dataclass
class MissConfig(PeftConfig):
"""
This is the configuration class to store the configuration of a [`MiSSModel`].
Args:
r (`int`):
The rank of MiSS across different layers. It is best to set 'r' to an even number; otherwise, the default
initialization method will not work. The rank of MiSS corresponds to a low-rank decomposition along the
in_features dimension.
miss_dropout (`float`):
The dropout probability for MiSS layers.
mini_r (`int`):
The rank of MiSS corresponds to a low-rank decomposition along the out_features dimension. When you set
`init_weights=mini`, you need to set `mini_r`. Please make sure that `out_features` is divisible by
`mini_r`.
target_modules (`Optional[Union[List[str], str]]`):
The names of the modules to apply the adapter to. If this is specified, only the modules with the specified
names will be replaced. When passing a string, a regex match will be performed. When passing a list of
strings, either an exact match will be performed or it is checked if the name of the module ends with any
of the passed strings. If this is specified as 'all-linear', then all linear modules are chosen, excluding
the output layer. If this is not specified, modules will be chosen according to the model architecture. If
the architecture is not known, an error will be raised -- in this case, you should specify the target
modules manually.
exclude_modules (`Optional[Union[List[str], str]]`):
The names of the modules to not apply the adapter. When passing a string, a regex match will be performed.
When passing a list of strings, either an exact match will be performed or it is checked if the name of the
module ends with any of the passed strings.
init_weights (bool | Literal["bat", "mini"]):
Different initializations correspond to different MiSS variants. By default(balance), the most efficient
and general method in MiSS will be used. 'bat': In this mode, you can enable nonlinear updates across
different shards. 'mini': In this mode, you can set a smaller rank to use fewer trainable parameters, but
it is recommended to keep `out_features % mini_r == 0`.
layers_to_transform (`Union[List[int], int]`):
The layer indices to transform. If a list of ints is passed, it will apply the adapter to the layer indices
that are specified in this list. If a single integer is passed, it will apply the transformations on the
layer at this index.
layers_pattern (`str`):
The layer pattern name, used only if `layers_to_transform` is different from `None`.
modules_to_save (`List[str]`):
List of modules apart from adapter layers to be set as trainable and saved in the final checkpoint.
"""
r: int = field(
default=64,
metadata={
"help": "The rank of MiSS corresponds to a low-rank decomposition along the in_features dimension.",
"note": "It is best to set 'r' to an even number; otherwise, the default initialization method will not work.",
},
)
miss_dropout: float = field(default=0.0, metadata={"help": "MiSS dropout"})
mini_r: int = field(
default=1,
metadata={
"help": "The rank of MiSS corresponds to a low-rank decomposition along the out_features dimension.",
"note": "It is recommended that mini_r be divisible by out_features. When mini_r == out_features, the mini method is equivalent to the default efficient MiSS.",
},
)
target_modules: Optional[Union[list[str], str]] = field(
default=None,
metadata={
"help": "List of module names or regex expression of the module names to replace with MiSS.",
"example": "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' ",
},
)
exclude_modules: Optional[Union[list[str], str]] = field(
default=None,
metadata={"help": "List of module names or regex expression of the module names to exclude from MiSS."},
)
init_weights: bool | Literal["bat", "mini"] = field(
default=True,
metadata={
"help": (
"True -> MiSS balance; `bat` -> Bat; `mini` -> smaller rank and efficiency"
"Whether to initialize the weights of the MiSS layers with their default initialization. Don't change "
"this setting, except if you know exactly what you're doing."
),
},
)
layers_to_transform: Optional[Union[list[int], int]] = field(
default=None,
metadata={
"help": "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that are specified inside this list. If a single integer is passed, PEFT will transform only the layer at this index."
},
)
layers_pattern: Optional[str] = field(
default=None,
metadata={
"help": "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the common layers pattern."
},
)
bias: str = field(default="none", metadata={"help": "Bias type for MiSS. Can be 'none', 'all' or 'MiSS_only'"})
modules_to_save: Optional[list[str]] = field(
default=None,
metadata={
"help": "List of modules apart from MiSS layers to be set as trainable and saved in the final checkpoint. "
"For example, in Sequence Classification or Token Classification tasks, "
"the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved."
},
)
def __post_init__(self):
super().__post_init__()
self.peft_type = PeftType.MISS
self.target_modules = (
set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules
)
self.exclude_modules = (
set(self.exclude_modules) if isinstance(self.exclude_modules, list) else self.exclude_modules
)
# if target_modules is a regex expression, then layers_to_transform should be None
if isinstance(self.target_modules, str) and self.layers_to_transform is not None:
raise ValueError("`layers_to_transform` cannot be used when `target_modules` is a str.")
# if target_modules is a regex expression, then layers_pattern should be None
if isinstance(self.target_modules, str) and self.layers_pattern is not None:
raise ValueError("`layers_pattern` cannot be used when `target_modules` is a str.")

View File

@ -0,0 +1,390 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from typing import Any, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
class MissLayer(BaseTunerLayer):
# All names of layers that may contain (trainable) adapter weights
adapter_layer_names = ("miss_block",)
# All names of other parameters that may contain adapter-related parameters
other_param_names = ("miss_r", "miss_dropout", "miss_mini_r")
def __init__(self, base_layer: nn.Module, **kwargs) -> None:
self.base_layer = base_layer
self.miss_r = {}
self.miss_dropout = nn.ModuleDict({})
self.miss_mini_r = {}
self.miss_block = nn.ParameterDict({})
# Mark the weight as unmerged
self._disable_adapters = False
self.merged_adapters = []
# flag to enable/disable casting of input to weight dtype during forward call
self.cast_input_dtype_enabled = True
self.kwargs = kwargs
base_layer = self.get_base_layer()
if isinstance(base_layer, nn.Linear):
self.in_features, self.out_features = base_layer.in_features, base_layer.out_features
else:
raise ValueError(f"Unsupported layer type {type(base_layer)}")
def update_layer(
self,
adapter_name: str,
r: int,
mini_r: int,
miss_dropout,
init_weights: bool,
**kwargs,
) -> None:
"""Internal function to create miss adapter
Args:
adapter_name (`str`): Name for the adapter to add.
r (`int`): Rank for the added adapter.
init_weights (`bool`): Whether to initialize weights.
"""
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
self.miss_r[adapter_name] = r
self.miss_mini_r[adapter_name] = mini_r
if miss_dropout > 0.0:
miss_dropout_layer = nn.Dropout(p=miss_dropout)
else:
miss_dropout_layer = nn.Identity()
self.miss_dropout[adapter_name] = miss_dropout_layer
# Determine shape of MiSS weights
base_layer = self.get_base_layer()
if isinstance(base_layer, nn.Linear):
self.miss_block[adapter_name] = nn.Parameter(torch.zeros(r, self.out_features), requires_grad=True)
else:
raise TypeError(f"MiSS is not implemented for base layers of type {type(base_layer).__name__}")
# Initialize weights
if init_weights == "bat":
if self.in_features % r != 0 or self.out_features % r != 0:
raise ValueError("The weight matrix must be fully divisible into [r, r] blocks.")
self.reset_bat_parameters(adapter_name, r)
elif init_weights == "mini":
if self.out_features % mini_r != 0:
raise ValueError(
"mini_r is divided along the out_features dimension. For optimal performance and implementation simplicity,"
"it is recommended that out_features be divisible by mini_r."
"Error: {self.out_features} % mini_r != 0"
)
self.reset_mini_parameters(adapter_name, r, mini_r)
elif init_weights:
self.reset_miss_parameters(adapter_name, r)
else:
self.reset_miss_parameters_random(adapter_name)
# Move new weights to device
self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters)
def reset_miss_parameters(self, adapter_name: str, r):
self.miss_block[adapter_name] = nn.Parameter(torch.zeros(r, self.out_features), requires_grad=True)
def reset_bat_parameters(self, adapter_name: str, r):
self.miss_block[adapter_name] = nn.Parameter(torch.zeros(self.out_features // r, r, r), requires_grad=True)
def reset_mini_parameters(self, adapter_name: str, r, mini_r):
self.miss_block[adapter_name] = nn.Parameter(torch.zeros(r, mini_r), requires_grad=True)
def reset_miss_parameters_random(self, adapter_name: str):
nn.init.kaiming_uniform_(self.miss_block[adapter_name], a=math.sqrt(5))
def scale_layer(self, scale: float) -> None:
if scale == 1:
return
for active_adapter in self.active_adapters:
if active_adapter not in self.miss_block.keys():
continue
warnings.warn("Scaling operation for MiSS not supported! Automatically set scale to 1.")
def unscale_layer(self, scale=None) -> None:
for active_adapter in self.active_adapters:
if active_adapter not in self.miss_block.keys():
continue
warnings.warn("Unscaling operation for MiSS not supported! Keeping scale at 1.")
class MissLinear(nn.Module, MissLayer):
"""
MiSS implemented in a dense layer.
"""
def __init__(
self,
base_layer,
adapter_name: str,
r: int = 0,
mini_r: int = 0,
miss_dropout: float = 0.0,
init_weights: Union[bool, str] = True,
**kwargs,
) -> None:
super().__init__()
MissLayer.__init__(self, base_layer, **kwargs)
self._active_adapter = adapter_name
self.update_layer(adapter_name, r, mini_r, miss_dropout, init_weights, **kwargs)
self.miss_fn = init_weights
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Args:
safe_merge (`bool`, *optional*):
If `True`, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If `None`, all active adapters will be merged.
Defaults to `None`.
"""
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
# no adapter to merge
return
for active_adapter in adapter_names:
if active_adapter in self.miss_block.keys():
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.dtype
if safe_merge:
# Note that safe_merge will be slower than the normal merge
# because of the copy operation.
orig_weight = base_layer.weight.data.clone()
if self.miss_fn == "bat":
delta_weight = self.get_delta_weight(active_adapter, orig_weight)
orig_weight += delta_weight
elif self.miss_fn == "mini":
delta_weight = self.get_delta_weight_miss(active_adapter, self.base_layer.weight.data)
orig_weight = delta_weight
else:
delta_weight = self.get_delta_weight_miss(active_adapter, self.base_layer.weight.data)
orig_weight = delta_weight
if not torch.isfinite(orig_weight).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
base_layer.weight.data = orig_weight.to(orig_dtype)
else:
if self.miss_fn == "bat":
delta_weight = self.get_delta_weight(active_adapter, self.base_layer.weight.data)
base_layer.weight.data += delta_weight.to(orig_dtype)
elif self.miss_fn == "mini":
delta_weight = self.get_delta_weight_miss(active_adapter, self.base_layer.weight.data)
base_layer.weight.data = delta_weight.to(orig_dtype)
else:
delta_weight = self.get_delta_weight_miss(active_adapter, self.base_layer.weight.data)
base_layer.weight.data = delta_weight.to(orig_dtype)
self.merged_adapters.append(active_adapter)
def unmerge(self) -> None:
"""
This method unmerges all merged adapter layers from the base weights.
"""
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.dtype
if active_adapter in self.miss_block.keys():
orig_weight = self.get_base_layer().weight.data.clone()
if self.miss_fn == "bat":
delta_weight = self.get_delta_weight(active_adapter, orig_weight, re=True)
elif self.miss_fn == "mini":
delta_weight = self.get_delta_weight_miss(active_adapter, orig_weight, re=True)
else:
delta_weight = self.get_delta_weight_miss(active_adapter, orig_weight, re=True)
base_layer.weight.data = delta_weight.to(orig_dtype)
def get_delta_weight(self, adapter, orig_weight, re: bool = False) -> torch.Tensor:
"""
Compute the delta weight for the given adapter.
Args:
adapter (str):
The name of the adapter for which the delta weight should be computed.
"""
device = self.miss_block[adapter].device
dtype = self.miss_block[adapter].dtype
# In case users wants to merge the adapter weights that are in
# (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)
weight_miss = self.miss_block[adapter]
if cast_to_fp32:
weight_miss = weight_miss.float()
orig_weight = orig_weight.to(weight_miss.dtype)
r = weight_miss.size(-1)
if re:
o = orig_weight.reshape(orig_weight.size(0) // r, r, orig_weight.size(1) // r, r).permute(2, 0, 1, 3)
one = torch.eye(weight_miss.size(-1)).to(weight_miss.device)
# inverse must be in float32, after that the dtype can be adjusted if needed
inv_I_plus_b = torch.inverse(one + weight_miss)
inv_I_plus_b = inv_I_plus_b.to(weight_miss.dtype)
w = (o - weight_miss) @ inv_I_plus_b
output_tensor = w.permute(1, 2, 0, 3).reshape(*orig_weight.shape)
else:
w = (
orig_weight.reshape(orig_weight.size(0) // r, r, orig_weight.size(1) // r, r).permute(2, 0, 1, 3)
@ weight_miss
+ weight_miss
)
output_tensor = w.permute(1, 2, 0, 3).reshape(*orig_weight.shape)
if cast_to_fp32:
output_tensor = output_tensor.to(dtype=dtype)
# cast back the weights
self.miss_block[adapter].data = weight_miss.to(dtype)
return output_tensor
def get_delta_weight_miss(self, adapter, orig_weight, re: bool = False) -> torch.Tensor:
"""
Compute the delta weight for the given adapter.
Args:
adapter (str):
The name of the adapter for which the delta weight should be computed.
"""
device = self.miss_block[adapter].device
dtype = self.miss_block[adapter].dtype
# In case users wants to merge the adapter weights that are in
# (b)float16 while being on CPU, we need to cast the weights to float32, perform the merge and then cast back to
# (b)float16 because some CPUs have slow bf16/fp16 matmuls.
cast_to_fp32 = device.type == "cpu" and (dtype == torch.float16 or dtype == torch.bfloat16)
weight_miss = self.miss_block[adapter]
if cast_to_fp32:
weight_miss = weight_miss.float()
in_features = orig_weight.size(-1)
out_features = orig_weight.size(0)
r = weight_miss.size(0)
if self.miss_fn == "mini":
weight_miss = weight_miss.repeat(1, out_features // self.miss_mini_r[adapter])
if in_features % r != 0:
last_size = in_features % r
n_block = in_features // r
n_block_size = n_block * r
if re:
orig_weight[:, :n_block_size] = (
(orig_weight[:, :n_block_size].reshape(-1, n_block, r).permute(1, 2, 0) - weight_miss)
.permute(2, 0, 1)
.reshape(*orig_weight[:, :n_block_size].shape)
)
orig_weight[:, n_block_size:] = (
orig_weight[:, n_block_size:] - (weight_miss.transpose(0, 1))[:, :last_size]
)
else:
orig_weight[:, :n_block_size] = (
(orig_weight[:, :n_block_size].reshape(-1, n_block, r).permute(1, 2, 0) + weight_miss)
.permute(2, 0, 1)
.reshape(*orig_weight[:, :n_block_size].shape)
)
orig_weight[:, n_block_size:] = (
orig_weight[:, n_block_size:] + (weight_miss.transpose(0, 1))[:, :last_size]
)
output_tensor = orig_weight
else:
if re:
w = orig_weight.reshape(-1, orig_weight.size(1) // r, r).permute(1, 2, 0) - weight_miss
output_tensor = w.permute(2, 0, 1).reshape(*orig_weight.shape)
else:
w = orig_weight.reshape(-1, orig_weight.size(1) // r, r).permute(1, 2, 0) + weight_miss
output_tensor = w.permute(2, 0, 1).reshape(*orig_weight.shape)
if cast_to_fp32:
output_tensor = output_tensor.to(dtype=dtype)
# cast back the weights
self.miss_block[adapter].data = weight_miss.to(dtype)
return output_tensor
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
previous_dtype = x.dtype
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
if self.miss_fn == "bat":
orig_weight = self.base_layer.weight.data.clone()
for active_adapter in self.active_adapters:
if active_adapter not in self.miss_block.keys():
continue
delta_weight = self.get_delta_weight(active_adapter, orig_weight)
orig_weight = orig_weight + delta_weight
x = self._cast_input_dtype(x, orig_weight.dtype)
bias = self._cast_input_dtype(self.base_layer.bias, orig_weight.dtype)
result = F.linear(input=x, weight=orig_weight, bias=bias)
else:
result = self.base_layer(x, *args, **kwargs)
for active_adapter in self.active_adapters:
if active_adapter not in self.miss_block.keys():
continue
miss = self.miss_block[active_adapter]
if self.miss_fn == "mini":
miss = miss.repeat(1, self.base_layer.out_features // self.miss_mini_r[active_adapter])
dropout = self.miss_dropout[active_adapter]
r = miss.size(0)
if x.size(-1) % r != 0:
padding_size = (r - x.size(-1) % r) % r
x = F.pad(x, (0, padding_size))
x = self._cast_input_dtype(x, miss.dtype)
result = result + torch.sum(dropout(x).reshape(*x.shape[:-1], x.size(-1) // r, r), dim=-2) @ miss
result = result.to(previous_dtype)
return result
def __repr__(self) -> str:
rep = super().__repr__()
return "miss." + rep

View File

@ -0,0 +1,341 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from dataclasses import asdict
from enum import Enum
from typing import Optional
import torch
from torch import nn
from tqdm import tqdm
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists
from peft.utils import (
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
ModulesToSaveWrapper,
_get_submodules,
)
from .config import MissConfig
from .layer import MissLayer, MissLinear
class MissModel(BaseTuner):
"""
Creates Householder reflection adaptation (MiSS) model from a pretrained model. The method is described in
https://huggingface.co/papers/2409.15371
Args:
model (`torch.nn.Module`): The model to which the adapter tuner layers will be attached.
config ([`MissConfig`]): The configuration of the MiSS model.
adapter_name (`str`): The name of the adapter, defaults to `"default"`.
low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
Create empty adapter weights on meta device. Useful to speed up the loading process.
Returns:
`torch.nn.Module`: The MiSS model.
Example:
```py
>>> from diffusers import StableDiffusionPipeline
>>> from peft import MissModel, MissConfig
>>> config_te = MissConfig(
... r=8,
... target_modules=["k_proj", "q_proj", "v_proj", "out_proj", "fc1", "fc2"],
... init_weights=True,
... )
>>> config_unet = MissConfig(
... r=8,
... target_modules=[
... "proj_in",
... "proj_out",
... "to_k",
... "to_q",
... "to_v",
... "to_out.0",
... "ff.net.0.proj",
... "ff.net.2",
... ],
... init_weights=True,
... )
>>> model = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> model.text_encoder = MissModel(model.text_encoder, config_te, "default")
>>> model.unet = MissModel(model.unet, config_unet, "default")
```
**Attributes**:
- **model** ([`~torch.nn.Module`]) -- The model to be adapted.
- **peft_config** ([`MissConfig`]): The configuration of the MiSS model.
"""
prefix: str = "miss_"
def _check_new_adapter_config(self, config: MissConfig) -> None:
"""
A helper method to check the config when a new adapter is being added.
Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters.
"""
# TODO: there should be a check if any of the existing adapters actually has bias != "none", or else the check
# does not fully correspond to the error message.
if (len(self.peft_config) > 1) and (config.bias != "none"):
raise ValueError(
f"{self.__class__.__name__} supports only 1 adapter with bias. When using multiple adapters, "
"set bias to 'none' for all adapters."
)
@staticmethod
def _check_target_module_exists(miss_config, key):
return check_target_module_exists(miss_config, key)
def _create_and_replace(
self,
miss_config,
adapter_name,
target,
target_name,
parent,
current_key,
**optional_kwargs,
):
if current_key is None:
raise ValueError("Current Key shouldn't be `None`")
bias = hasattr(target, "bias") and target.bias is not None
kwargs = {
"r": miss_config.r,
"mini_r": miss_config.mini_r,
"miss_dropout": miss_config.miss_dropout,
"init_weights": miss_config.init_weights,
}
kwargs["bias"] = bias
# If it is not a MissLayer, create a new module, else update it with new adapters
if not isinstance(target, MissLayer):
new_module = self._create_new_module(miss_config, adapter_name, target, **kwargs)
if adapter_name not in self.active_adapters:
# adding an additional adapter: it is not automatically trainable
new_module.requires_grad_(False)
self._replace_module(parent, target_name, new_module, target)
else:
target.update_layer(
adapter_name,
r=miss_config.r,
init_weights=miss_config.init_weights,
miss_dropout=miss_config.miss_dropout,
mini_r=miss_config.mini_r,
)
def _replace_module(self, parent, child_name, new_module, child):
setattr(parent, child_name, new_module)
# It's not necessary to set requires_grad here, as that is handled by
# _mark_only_adapters_as_trainable
# child layer wraps the original module, unpack it
if hasattr(child, "base_layer"):
child = child.base_layer
if not hasattr(new_module, "base_layer"):
new_module.weight = child.weight
if hasattr(child, "bias"):
new_module.bias = child.bias
if getattr(child, "state", None) is not None:
if hasattr(new_module, "base_layer"):
new_module.base_layer.state = child.state
else:
new_module.state = child.state
new_module.to(child.weight.device)
meta = torch.device("meta")
# dispatch to correct device
for name, module in new_module.named_modules():
if self.prefix in name:
if not any(p.device == meta for p in module.parameters()):
module.to(child.weight.device)
def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None:
for n, p in model.named_parameters():
if self.prefix not in n:
p.requires_grad = False
for active_adapter in self.active_adapters:
bias = self.peft_config[active_adapter].bias
if bias == "none":
continue
if bias == "all":
for n, p in model.named_parameters():
if "bias" in n:
p.requires_grad = True
elif bias == "miss_only":
for name, m in model.named_modules():
if isinstance(m, MissLayer) and hasattr(m, "bias") and m.bias is not None:
m.bias.requires_grad = True
else:
raise NotImplementedError(f"Requested bias: {bias}, is not implemented.")
@staticmethod
def _create_new_module(miss_config, adapter_name, target, **kwargs):
if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target
if isinstance(target_base_layer, torch.nn.Linear):
new_module = MissLinear(target, adapter_name, **kwargs)
else:
raise ValueError(
f"Target module {target} is not supported. Currently, only `torch.nn.Linear` is supported."
)
return new_module
def __getattr__(self, name: str):
"""Forward missing attributes to the wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
if name == "base_model":
raise
return getattr(self.model, name)
def get_peft_config_as_dict(self, inference: bool = False):
config_dict = {}
for key, value in self.peft_config.items():
config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()}
if inference:
config["inference_mode"] = True
config_dict[key] = config
return config
def _set_adapter_layers(self, enabled=True):
for module in self.model.modules():
if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
module.enable_adapters(enabled)
def enable_adapter_layers(self):
self._set_adapter_layers(enabled=True)
def disable_adapter_layers(self):
for active_adapter in self.active_adapters:
val = self.peft_config[active_adapter].bias
if val != "none":
msg = (
f"Careful, disabling adapter layers with bias configured to be '{val}' does not produce the same "
"output as the base model would without adaption."
)
warnings.warn(msg)
self._set_adapter_layers(enabled=False)
def set_adapter(self, adapter_name):
for module in self.model.modules():
if isinstance(module, MissLayer):
if module.merged:
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
module.unmerge()
module.set_adapter(adapter_name)
self.active_adapter = adapter_name
@staticmethod
def _prepare_adapter_config(peft_config, model_config):
if peft_config.target_modules is None:
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING:
raise ValueError("Please specify `target_modules` in `peft_config`")
peft_config.target_modules = set(
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]]
)
return peft_config
def _unload_and_optionally_merge(
self,
merge=True,
progressbar: bool = False,
safe_merge: bool = False,
adapter_names: Optional[list[str]] = None,
):
self._unloading_checks(adapter_names)
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
desc = "Unloading " + ("and merging " if merge else "") + "model"
for key in tqdm(key_list, disable=not progressbar, desc=desc):
try:
parent, target, target_name = _get_submodules(self.model, key)
except AttributeError:
continue
if hasattr(target, "base_layer"):
if merge:
target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
self._replace_module(parent, target_name, target.get_base_layer(), target)
elif isinstance(target, ModulesToSaveWrapper):
# save any additional trainable modules part of `modules_to_save`
setattr(parent, target_name, target.modules_to_save[target.active_adapter])
return self.model
def delete_adapter(self, adapter_name: str) -> None:
"""
Deletes an existing adapter.
Args:
adapter_name (str): Name of the adapter to be deleted.
"""
if adapter_name not in list(self.peft_config.keys()):
raise ValueError(f"Adapter {adapter_name} does not exist")
del self.peft_config[adapter_name]
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
new_adapter = None
for key in key_list:
_, target, _ = _get_submodules(self.model, key)
if isinstance(target, MissLayer):
target.delete_adapter(adapter_name)
if new_adapter is None:
new_adapter = target.active_adapters[:]
self.active_adapter = new_adapter or []
self._delete_auxiliary_adapter(adapter_name, new_active_adapters=new_adapter)
def merge_and_unload(
self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
) -> torch.nn.Module:
r"""
This method merges the MiSS layers into the base model. This is needed if someone wants to use the base model
as a standalone model.
Args:
progressbar (`bool`):
whether to show a progressbar indicating the unload and merge process
safe_merge (`bool`):
whether to activate the safe merging check to check if there is any potential Nan in the adapter
weights
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
"""
return self._unload_and_optionally_merge(
progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names
)
def unload(self) -> torch.nn.Module:
"""
Gets back the base model by removing all the miss modules without merging. This gives back the original base
model.
"""
return self._unload_and_optionally_merge(merge=False)

View File

@ -40,6 +40,7 @@ class PeftType(str, enum.Enum):
- FOURIERFT
- HRA
- BONE
- MISS
- RANDLORA
- SHIRA
- C3A
@ -66,6 +67,7 @@ class PeftType(str, enum.Enum):
VBLORA = "VBLORA"
CPT = "CPT"
BONE = "BONE"
MISS = "MISS"
RANDLORA = "RANDLORA"
TRAINABLE_TOKENS = "TRAINABLE_TOKENS"
SHIRA = "SHIRA"

View File

@ -43,6 +43,7 @@ from peft import (
LoHaConfig,
LoKrConfig,
LoraConfig,
MissConfig,
OFTConfig,
PeftModel,
RandLoraConfig,
@ -444,6 +445,22 @@ TEST_CASES = [
BoneConfig,
{"target_modules": ["lin0"], "modules_to_save": ["lin1"], "r": 2, "init_weights": "bat"},
),
########
# MiSS #
########
("Vanilla MLP 1 MiSS", "MLP", MissConfig, {"target_modules": "lin0", "r": 2}),
("Vanilla MLP 2 MiSS", "MLP", MissConfig, {"target_modules": ["lin0"], "r": 2}),
("Vanilla MLP 3 MiSS", "MLP", MissConfig, {"target_modules": ["lin0", "lin1"], "r": 2}),
("Vanilla MLP 5 MiSS", "MLP", MissConfig, {"target_modules": ["lin0"], "modules_to_save": ["lin1"], "r": 2}),
("Vanilla MLP 1 MiSS", "MLP", MissConfig, {"target_modules": "lin0", "r": 2, "init_weights": "bat"}),
("Vanilla MLP 2 MiSS", "MLP", MissConfig, {"target_modules": ["lin0"], "r": 2, "init_weights": "bat"}),
("Vanilla MLP 3 MiSS", "MLP", MissConfig, {"target_modules": ["lin0", "lin1"], "r": 2, "init_weights": "bat"}),
(
"Vanilla MLP 5 MiSS",
"MLP",
MissConfig,
{"target_modules": ["lin0"], "modules_to_save": ["lin1"], "r": 2, "init_weights": "bat"},
),
#############
# LN Tuning #
#############
@ -853,6 +870,21 @@ MULTIPLE_ACTIVE_ADAPTERS_TEST_CASES = [
{"target_modules": ["lin0"], "init_weights": False, "r": 2},
{"target_modules": ["lin1"], "init_weights": False, "r": 2},
),
(
"MiSS Same",
"miss",
MissConfig,
{"target_modules": ["lin0"], "init_weights": False, "r": 2},
{"target_modules": ["lin0"], "init_weights": False, "r": 2},
),
(
"MiSS Different",
"miss",
MissConfig,
{"target_modules": ["lin0"], "init_weights": False, "r": 2},
{"target_modules": ["lin1"], "init_weights": False, "r": 2},
),
# Not testing "mini" initialization targeting the same layer, because The matrix is initialized to all zeros in MiSS-mini mode.
(
"VBLoRA Same",
"vblora",
@ -899,6 +931,7 @@ PREFIXES = {
ShiraConfig: "shira_",
VBLoRAConfig: "vblora_",
BoneConfig: "bone_",
MissConfig: "miss_",
TrainableTokensConfig: "trainable_tokens_",
}
@ -2202,7 +2235,7 @@ class TestPeftCustomModel(PeftCommonTester):
assert "other" in model.base_model.classifier.modules_to_save
@pytest.mark.parametrize(
"config_cls", [IA3Config, LoHaConfig, LoKrConfig, LoraConfig, HRAConfig, BoneConfig, ShiraConfig]
"config_cls", [IA3Config, LoHaConfig, LoKrConfig, LoraConfig, HRAConfig, BoneConfig, ShiraConfig, MissConfig]
)
def test_multiple_adapters_mixed_modules_to_save(self, config_cls):
# See issue 1574
@ -2211,7 +2244,7 @@ class TestPeftCustomModel(PeftCommonTester):
if hasattr(config_cls, "feedforward_modules"): # IA³
config_cls = partial(config_cls, feedforward_modules=["lin0"])
if config_cls == BoneConfig:
if config_cls == BoneConfig or config_cls == MissConfig:
config_cls = partial(config_cls, r=2)
if config_cls == ShiraConfig:
config_cls = partial(config_cls, r=1)
@ -2242,7 +2275,7 @@ class TestPeftCustomModel(PeftCommonTester):
if hasattr(config_cls, "feedforward_modules"): # IA³
config_cls = partial(config_cls, feedforward_modules=["lin0"])
if config_cls == BoneConfig:
if config_cls == BoneConfig or config_cls == MissConfig:
config_cls = partial(config_cls, r=2)
if config_cls == ShiraConfig:
config_cls = partial(config_cls, r=1)
@ -2324,7 +2357,9 @@ class TestPeftCustomModel(PeftCommonTester):
with pytest.raises(ValueError, match=msg):
model.add_weighted_adapter(["default", "other"], weights=[1.0, 1.0], adapter_name="merged")
@pytest.mark.parametrize("config_cls", [IA3Config, LoHaConfig, LoKrConfig, LoraConfig, HRAConfig, BoneConfig])
@pytest.mark.parametrize(
"config_cls", [IA3Config, LoHaConfig, LoKrConfig, LoraConfig, HRAConfig, BoneConfig, MissConfig]
)
def test_add_weighted_adapter_cat_with_rank_pattern(self, config_cls):
# Fixes a bug described in #2512, which resulted from the rank_pattern not being taken into account
config0 = LoraConfig(target_modules=["lin0", "lin1"], r=8, rank_pattern={"lin0": 2})
@ -2490,6 +2525,7 @@ class TestPeftCustomModel(PeftCommonTester):
BOFTConfig(target_modules=["lin0"], init_weights=False, boft_block_size=2),
HRAConfig(target_modules=["lin0"], init_weights=False),
BoneConfig(target_modules=["lin0"], init_weights=False, r=2),
MissConfig(target_modules=["lin0"], init_weights=False, r=2),
],
)
def test_adapter_name_makes_no_difference(self, config0):
@ -3886,6 +3922,83 @@ class TestRequiresGrad:
"base_model.model.lin0.bone_block.adapter1",
)
def test_requires_grad_miss_different_targets(self):
# test two different HRA adapters that target different modules
config0 = MissConfig(target_modules=["lin0"], r=2)
peft_model = get_peft_model(MLP(), config0)
config1 = MissConfig(target_modules=["lin1"], r=2, inference_mode=True)
peft_model.add_adapter("adapter1", config1)
# active adapter is still "default"
self.check_requires_grad(
peft_model,
"base_model.model.lin0.miss_block.default",
)
# set config0 as active, should not change anything
peft_model.set_adapter("default")
self.check_requires_grad(
peft_model,
"base_model.model.lin0.miss_block.default",
)
# change activate pter to pter1
peft_model.set_adapter("adapter1")
self.check_requires_grad(
peft_model,
"base_model.model.lin1.miss_block.adapter1",
)
# disable all pters
with peft_model.disable_adapter():
self.check_requires_grad(peft_model)
# after context is exited, return to the previous state
self.check_requires_grad(
peft_model,
"base_model.model.lin1.miss_block.adapter1",
)
def test_requires_grad_miss_same_targets(self):
# same as previous test, except that HRA adapters target the same layer
config0 = MissConfig(target_modules=["lin0"], r=2)
peft_model = get_peft_model(MLP(), config0)
config1 = MissConfig(target_modules=["lin0"], r=2, inference_mode=True)
peft_model.add_adapter("adapter1", config1)
# active adapter is still "default"
self.check_requires_grad(
peft_model,
"base_model.model.lin0.miss_block.default",
)
# set config0 as active, should not change anything
peft_model.set_adapter("default")
self.check_requires_grad(
peft_model,
"base_model.model.lin0.miss_block.default",
)
# change activate adapter to adapter1
peft_model.set_adapter("adapter1")
self.check_requires_grad(
peft_model,
"base_model.model.lin0.miss_block.adapter1",
)
# disable all adapters
with peft_model.disable_adapter():
self.check_requires_grad(peft_model)
# after context is exited, return to the previous state
peft_model.set_adapter("adapter1")
self.check_requires_grad(
peft_model,
"base_model.model.lin0.miss_block.adapter1",
)
def test_requires_grad_boft_different_targets(self):
# test two different OFT adapters that target different modules
config0 = BOFTConfig(target_modules=["lin0"], boft_block_size=2)

View File

@ -36,6 +36,7 @@ from peft import (
HRAConfig,
IA3Config,
LoraConfig,
MissConfig,
OFTConfig,
PrefixTuningConfig,
PromptEncoderConfig,
@ -97,6 +98,14 @@ ALL_CONFIGS = [
"r": 2,
},
),
(
MissConfig,
{
"task_type": "CAUSAL_LM",
"target_modules": None,
"r": 2,
},
),
(
CPTConfig,
{
@ -233,8 +242,9 @@ def _skip_if_not_conv1d_supported(model_id, config_cls):
OFTConfig,
ShiraConfig,
C3AConfig,
MissConfig,
]:
pytest.skip("Skipping BOFT/HRA/OFT/Bone/SHiRA/C3A for GPT2LMHeadModel")
pytest.skip("Skipping BOFT/HRA/OFT/Bone/SHiRA/C3A/MiSS for GPT2LMHeadModel")
def _skip_adalora_oft_hra_bone_for_gpt2(model_id, config_cls):
@ -245,8 +255,9 @@ def _skip_adalora_oft_hra_bone_for_gpt2(model_id, config_cls):
OFTConfig,
BoneConfig,
C3AConfig,
MissConfig,
]:
pytest.skip("Skipping AdaLora/BOFT/HRA/OFT/Bone for GPT2LMHeadModel")
pytest.skip("Skipping AdaLora/BOFT/HRA/OFT/Bone/MiSS for GPT2LMHeadModel")
class TestDecoderModels(PeftCommonTester):

View File

@ -26,6 +26,7 @@ from peft import (
HRAConfig,
IA3Config,
LoraConfig,
MissConfig,
OFTConfig,
PrefixTuningConfig,
PromptEncoderConfig,
@ -71,6 +72,14 @@ ALL_CONFIGS = [
"task_type": "SEQ_2_SEQ_LM",
},
),
(
MissConfig,
{
"target_modules": None,
"r": 2,
"task_type": "SEQ_2_SEQ_LM",
},
),
(
FourierFTConfig,
{

View File

@ -24,6 +24,7 @@ from peft import (
HRAConfig,
IA3Config,
LoraConfig,
MissConfig,
OFTConfig,
PrefixTuningConfig,
PromptEncoderConfig,
@ -70,6 +71,14 @@ ALL_CONFIGS = [
"r": 2,
},
),
(
MissConfig,
{
"task_type": "FEATURE_EXTRACTION",
"target_modules": None,
"r": 2,
},
),
(
FourierFTConfig,
{

View File

@ -24,6 +24,7 @@ from peft import (
HRAConfig,
IA3Config,
LoraConfig,
MissConfig,
OFTConfig,
PrefixTuningConfig,
PromptEncoderConfig,
@ -70,6 +71,14 @@ ALL_CONFIGS = [
"r": 2,
},
),
(
MissConfig,
{
"task_type": "SEQ_CLS",
"target_modules": None,
"r": 2,
},
),
(
FourierFTConfig,
{

View File

@ -45,6 +45,7 @@ from peft import (
LoHaConfig,
LoKrConfig,
LoraConfig,
MissConfig,
OFTConfig,
PeftModel,
TaskType,
@ -86,6 +87,15 @@ SETTINGS = {
BoneConfig(task_type=TaskType.CAUSAL_LM, target_modules=["q_proj", "v_proj"], r=2, init_weights="bat"),
{},
),
"miss": (MissConfig(task_type=TaskType.CAUSAL_LM, target_modules=["q_proj", "v_proj"], r=2), {}),
"miss-bat": (
MissConfig(task_type=TaskType.CAUSAL_LM, target_modules=["q_proj", "v_proj"], r=2, init_weights="bat"),
{},
),
"miss-mini": (
MissConfig(task_type=TaskType.CAUSAL_LM, target_modules=["q_proj", "v_proj"], r=2, init_weights="mini"),
{},
),
}

View File

@ -43,6 +43,7 @@ from peft import (
LoHaConfig,
LoKrConfig,
LoraConfig,
MissConfig,
OFTConfig,
PeftModel,
PeftType,
@ -134,6 +135,11 @@ CONFIG_TESTING_KWARGS = (
"target_modules": None,
"r": 2,
},
# MiSS
{
"target_modules": None,
"r": 2,
},
# LoRA + trainable_tokens
{
"r": 8,
@ -175,6 +181,7 @@ CLASSES_MAPPING = {
"vblora": (VBLoRAConfig, CONFIG_TESTING_KWARGS[10]),
"oft": (OFTConfig, CONFIG_TESTING_KWARGS[11]),
"bone": (BoneConfig, CONFIG_TESTING_KWARGS[12]),
"miss": (MissConfig, CONFIG_TESTING_KWARGS[12]),
"lora+trainable_tokens": (LoraConfig, CONFIG_TESTING_KWARGS[13]),
"randlora": (RandLoraConfig, CONFIG_TESTING_KWARGS[14]),
}
@ -832,6 +839,7 @@ class PeftCommonTester:
PeftType.BOFT,
PeftType.HRA,
PeftType.BONE,
PeftType.MISS,
]
if ("gpt2" in model_id.lower()) and (config_cls == IA3Config):
@ -1444,6 +1452,7 @@ class PeftCommonTester:
PeftType.HRA,
PeftType.VBLORA,
PeftType.BONE,
PeftType.MISS,
]
# IA3 does not support deleting adapters yet, but it just needs to be added
# AdaLora does not support multiple adapters
@ -1517,6 +1526,7 @@ class PeftCommonTester:
PeftType.HRA,
PeftType.VBLORA,
PeftType.BONE,
PeftType.MISS,
]
# IA3 does not support deleting adapters yet, but it just needs to be added
# AdaLora does not support multiple adapters
@ -1615,6 +1625,7 @@ class PeftCommonTester:
"SHIRA",
"BONE",
"C3A",
"MISS",
):
with pytest.raises(AttributeError):
model = model.unload()