Merge branch 'main' into smangrul/fix-fsdp-auto-wrap

This commit is contained in:
Sourab Mangrulkar
2024-04-30 19:46:21 +05:30
21 changed files with 2032 additions and 7 deletions

View File

@ -14,6 +14,7 @@ jobs:
commit_sha: ${{ github.sha }}
package: peft
notebook_folder: peft_docs
custom_container: huggingface/transformers-doc-builder
secrets:
token: ${{ secrets.HUGGINGFACE_PUSH }}
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}

View File

@ -14,3 +14,4 @@ jobs:
commit_sha: ${{ github.event.pull_request.head.sha }}
pr_number: ${{ github.event.number }}
package: peft
custom_container: huggingface/transformers-doc-builder

View File

@ -102,6 +102,8 @@
title: Prefix tuning
- local: package_reference/prompt_tuning
title: Prompt tuning
- local: package_reference/layernorm_tuning
title: Layernorm tuning
- local: package_reference/vera
title: VeRA
title: Adapters

View File

@ -0,0 +1,34 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# LayerNorm Tuning
LayerNorm Tuning ([LN Tuning](https://huggingface.co/papers/2312.11420)) is a PEFT method that only fine-tunes the parameters of the LayerNorm layers in a model.
The paper has tested the performance of this method on large language models and has shown that it can achieve strong performance with a significant reduction in the number of trainable parameters and GPU memory usage.
However, the method is not limited to language models and can be applied to any model that uses LayerNorm layers.
In this implementation, the default is that all layernorm layers inside a model is finetuned, but it could be used to target other layer types such as `MLP` or `Attention` layers, this can be done by specifying the `target_modules` in the `LNTuningConfig`.
The abstract from the paper is:
*This paper introduces an efficient strategy to transform Large Language Models (LLMs) into Multi-Modal Large Language Models (MLLMs). By conceptualizing this transformation as a domain adaptation process, i.e., transitioning from text understanding to embracing multiple modalities, we intriguingly note that, within each attention block, tuning LayerNorm suffices to yield strong performance. Moreover, when benchmarked against other tuning approaches like full parameter finetuning or LoRA, its benefits on efficiency are substantial. For example, when compared to LoRA on a 13B model scale, performance can be enhanced by an average of over 20% across five multi-modal tasks, and meanwhile, results in a significant reduction of trainable parameters by 41.9% and a decrease in GPU memory usage by 17.6%. On top of this LayerNorm strategy, we showcase that selectively tuning only with conversational data can improve efficiency further. Beyond these empirical outcomes, we provide a comprehensive analysis to explore the role of LayerNorm in adapting LLMs to the multi-modal domain and improving the expressive power of the model.*
## LNTuningConfig
[[autodoc]] tuners.ln_tuning.config.LNTuningConfig
## LNTuningModel
[[autodoc]] tuners.ln_tuning.model.LNTuningModel

File diff suppressed because it is too large Load Diff

View File

@ -77,6 +77,8 @@ from .tuners import (
OFTModel,
PolyConfig,
PolyModel,
LNTuningConfig,
LNTuningModel,
VeraConfig,
VeraModel,
)

View File

@ -37,6 +37,8 @@ from .tuners import (
BOFTModel,
IA3Config,
IA3Model,
LNTuningConfig,
LNTuningModel,
LoHaConfig,
LoHaModel,
LoKrConfig,
@ -85,6 +87,7 @@ PEFT_TYPE_TO_CONFIG_MAPPING: dict[str, type[PeftConfig]] = {
"MULTITASK_PROMPT_TUNING": MultitaskPromptTuningConfig,
"OFT": OFTConfig,
"POLY": PolyConfig,
"LN_TUNING": LNTuningConfig,
"VERA": VeraConfig,
}
@ -97,6 +100,7 @@ PEFT_TYPE_TO_TUNER_MAPPING: dict[str, type[_BaseTuner]] = {
"IA3": IA3Model,
"OFT": OFTModel,
"POLY": PolyModel,
"LN_TUNING": LNTuningModel,
"VERA": VeraModel,
}

View File

@ -44,6 +44,7 @@ from .tuners import (
AdaptionPromptModel,
BOFTModel,
IA3Model,
LNTuningModel,
LoHaModel,
LoKrModel,
LoraModel,
@ -88,6 +89,7 @@ PEFT_TYPE_TO_MODEL_MAPPING = {
PeftType.IA3: IA3Model,
PeftType.OFT: OFTModel,
PeftType.POLY: PolyModel,
PeftType.LN_TUNING: LNTuningModel,
PeftType.VERA: VeraModel,
}

View File

@ -31,4 +31,5 @@ from .multitask_prompt_tuning import MultitaskPromptEmbedding, MultitaskPromptTu
from .oft import OFTConfig, OFTModel
from .mixed import MixedModel
from .poly import PolyConfig, PolyModel
from .ln_tuning import LNTuningConfig, LNTuningModel
from .vera import VeraConfig, VeraModel

View File

@ -0,0 +1,19 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config import LNTuningConfig
from .model import LNTuningModel
__all__ = ["LNTuningConfig", "LNTuningModel"]

View File

@ -0,0 +1,61 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Optional, Union
from peft.config import PeftConfig
from peft.utils import PeftType
@dataclass
class LNTuningConfig(PeftConfig):
"""
This is the configuration class to store the configuration of a :class:`~peft.tuners.LNTuningModel`.
Args:
target_modules (`Optional[Union[List[str], str]]`):
List of module names or regex expression of the module names to replace with LNTuning. For example,
'.*decoder.*' or '.*encoder.*'. If this is not specified, modules will be chosen according to the model
architecture. If the architecture is not known, an error will be raised -- in this case, you should specify
the target modules manually.
modules_to_save (`Optional[Union[List[str], str]]`):
List of modules to be set as trainable and saved in the final checkpoint. For example, in Sequence
Classification or Token Classification tasks, the final layer `classifier/score` are randomly initialized
and as such need to be trainable and saved.
"""
target_modules: Optional[Union[list[str], str]] = field(
default=None,
metadata={
"help": (
"List of module names or regex expression of the module names to replace with LNTuning."
"For example, '.*decoder.*' or '.*encoder.*'. "
"If not specified, modules will be chosen according to the model architecture, If the architecture is "
"not known, an error will be raised -- in this case, you shoud specify the target modules manually."
),
},
)
modules_to_save: Optional[Union[list[str], str]] = field(
default=None,
metadata={
"help": "List of modules to be set as trainable and saved in the final checkpoint. "
"For example, in Sequence Classification or Token Classification tasks, "
"the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved."
},
)
def __post_init__(self):
self.peft_type = PeftType.LN_TUNING

View File

@ -0,0 +1,117 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from copy import deepcopy
from typing import List, Optional
import torch
import torch.nn as nn
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
class LNTuningLayer(nn.Module, BaseTunerLayer):
"""
Selects a layer from the model.
"""
adapter_layer_names = ("ln_tuning_layers",)
def __init__(self, base_layer: nn.Module, adapter_name: str):
super().__init__()
self.base_layer = base_layer
self.ln_tuning_layers = nn.ModuleDict({})
self.update_layer(self.base_layer, adapter_name)
self._active_adapter = adapter_name
self.merged_adapters = []
def update_layer(self, layer: nn.Module, adapter_name: str):
self.ln_tuning_layers[adapter_name] = deepcopy(layer)
def enable_adapters(self, enabled: bool) -> None:
"""Toggle the enabling and disabling of adapters
Takes care of setting the requires_grad flag for the adapter weights.
Args:
enabled (bool): True to enable adapters, False to disable adapters
"""
if enabled:
self.set_adapter(self.active_adapters)
self._disable_adapters = False
else:
if self.merged:
self.unmerge()
# disable grads on all adapter layers
for layer_name in self.adapter_layer_names:
layer = getattr(self, layer_name)
layer.requires_grad_(False)
self._disable_adapters = True
def merge(self, adapter_names: Optional[List[str]] = None):
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
# no adapter to merge
return
if len(adapter_names) > 1:
raise ValueError(
f"Trying to merge {len(adapter_names)} adapters, but LN "
f"tuning does not allow merging more than one adapter at a time"
)
merged_adapters = set(self.merged_adapters)
if merged_adapters:
warnings.warn(f"Already merged with {merged_adapters}. Unmerging first.")
self.unmerge()
self.base_layer, self.ln_tuning_layers[adapter_names[0]] = (
self.ln_tuning_layers[adapter_names[0]],
self.base_layer,
)
self.merged_adapters.append(adapter_names[0])
def unmerge(self):
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
# popping one element is sufficient because LN
# tuning does not allow merging more than one adapter at a time.
merged_name = self.merged_adapters.pop()
self.base_layer, self.ln_tuning_layers[merged_name] = (
self.ln_tuning_layers[merged_name],
self.base_layer,
)
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
if len(self.active_adapters) != 1:
raise ValueError(
f"Trying to run forward with {len(self.active_adapters)} active "
f"adapters, but LN tuning does not allow inference with more than one adapter at a time"
)
active_adapter = self.active_adapters[0]
result = self.ln_tuning_layers[active_adapter](x, *args, **kwargs)
return result
def __repr__(self) -> str:
rep = super().__repr__()
return "ln_tuning." + rep

View File

@ -0,0 +1,201 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import warnings
from typing import Optional
from torch import nn
from torch.nn.modules import Module
from tqdm import tqdm
from peft.config import PeftConfig
from peft.tuners.tuners_utils import BaseTuner, _get_submodules, check_target_module_exists
from peft.utils import TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING, ModulesToSaveWrapper
from .layer import LNTuningLayer
class LNTuningModel(BaseTuner):
"""
Creates LayerNorm tuning from a pretrained transformer model.
The method is described in detail in https://arxiv.org/abs/2312.11420.
Args:
model ([`torch.nn.Module`]): The model to be adapted.
config ([`LNTuningConfig`]): The configuration of the Lora model.
adapter_name (`str`): The name of the adapter, defaults to `"default"`.
Returns:
'torch.nn.Module': The adapted model with LayerNorm tuned on.
Example:
```py
>>> from transformers import AutoModelForCausalLM
>>> from peft import get_peft_model, TaskType, LNTuningConfig
>>> peft_config = LNTuningConfig(
... task_type=TaskType.CAUSAL_LM,
... )
>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
>>> model = get_peft_model(model, peft_config)
>>> model.print_trainable_parameters()
```
**Attributes**:
- **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted.
- **peft_config** ([`LNTuningConfig`]): The configuration of the Lora model.
"""
prefix: str = "ln_tuning_"
def __init__(self, model, config, adapter_name) -> None:
# self.adapter_name = adapter_name
super().__init__(model, config, adapter_name)
def __getattr__(self, name: str):
"""Forward missing attributes to the wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
return getattr(self.model, name)
# TODO: here need to handle the modules_to_save rather than the target_modules
@staticmethod
def _prepare_adapter_config(peft_config: PeftConfig, model_config: dict) -> PeftConfig:
if peft_config.target_modules is None:
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING:
raise ValueError("Please specify `target_modules` in `peft_config`")
peft_config.target_modules = set(
TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING[model_config["model_type"]]
)
return peft_config
def _create_and_replace(
self,
peft_config: PeftConfig,
adapter_name: str,
target: Module,
target_name: str,
parent: Module,
current_key: str,
) -> None:
# replace the original module with a same new module
new_module = self._create_new_module(peft_config, target, adapter_name)
if adapter_name != self.active_adapter:
new_module.requires_grad_(False)
self._replace_module(parent, target_name, new_module, target)
def _create_new_module(
self,
peft_config: PeftConfig,
target: Module,
adapter_name: str,
) -> Module:
if not isinstance(target, LNTuningLayer):
new_module = LNTuningLayer(target, adapter_name)
else:
new_module = target
new_module.update_layer(target.base_layer, adapter_name)
return new_module
def _replace_module(self, parent: Module, child_name: str, new_module: Module, child: Module) -> None:
setattr(parent, child_name, new_module)
if hasattr(child, "base_layer"):
child = child.base_layer
if getattr(child, "state", None) is not None:
if hasattr(new_module, "base_layer"):
new_module.base_layer.state = child.state
else:
new_module.state = child.state
new_module.to(child.weight.device)
for name, module in new_module.named_modules():
weight = child.qweight if hasattr(child, "qweight") else child.weight
module.to(weight.device)
def _mark_only_adapters_as_trainable(self, model: Module):
for n, p in model.named_parameters():
if self.prefix not in n:
p.requires_grad = False
else:
p.requires_grad = True
def _check_target_module_exists(self, peft_config: PeftConfig, key: str) -> bool:
return check_target_module_exists(peft_config, key)
def _set_adapter_layers(self, enabled: bool) -> None:
for module in self.model.modules():
if isinstance(module, (LNTuningLayer, ModulesToSaveWrapper)):
module.enable_adapters(enabled)
def enable_adapter_layers(self) -> None:
"""Enable all adapters.
Call this if you have previously disabled all adapters and want to re-enable them.
"""
self._set_adapter_layers(enabled=True)
def disable_adapter_layers(self) -> None:
"""Disable all adapters.
When disabling all adapters, the model output corresponds to the output of the base model.
"""
self._set_adapter_layers(enabled=False)
def set_adapter(self, adapter_name: str) -> None:
for module in self.model.modules():
if isinstance(module, LNTuningLayer):
if module.merged:
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
module.unmerge()
module.set_adapter(adapter_name)
self.active_adapter = adapter_name
def _unload_and_optionally_merge(
self,
merge=True,
progressbar: bool = False,
safe_merge: bool = False,
adapter_names: Optional[list[str]] = None,
):
self._unloading_checks(adapter_names)
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
desc = "Unloading adapters " + ("and merging " if merge else "") + "model"
for key in tqdm(key_list, disable=not progressbar, desc=desc):
try:
parent, target, target_name = _get_submodules(self.model, key)
except AttributeError:
continue
if hasattr(target, "base_layer"):
if merge:
target.merge(adapter_names)
self._replace_module(parent, target_name, target.get_base_layer(), target)
return self.model
def unload(self):
return self._unload_and_optionally_merge(merge=False)
def merge_and_unload(
self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
) -> nn.Module:
return self._unload_and_optionally_merge(merge=True)

View File

@ -26,6 +26,7 @@ from .other import (
TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING,
CONFIG_NAME,
WEIGHTS_NAME,

View File

@ -44,6 +44,28 @@ TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING = {
"gpt_bigcode": starcoder_model_postprocess_past_key_value,
}
TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING = {
"llama": ["input_layernorm", "post_attention_layernorm", "norm"],
"bloom": ["input_layernorm", "post_attention_layernorm", "ln_f"],
"llava": [
"multi_modal_projector",
"input_layernorm",
"post_attention_layernorm",
"norm",
"embed_tokens",
"lm_head",
],
"t5": ["layer_norm", "final_layer_norm"],
"mt5": ["layer_norm", "final_layer_norm"],
"bart": ["self_attn_layer_norm", "encoder_attn_layer_norm", "final_layer_norm"],
"gpt2": ["ln_1", "ln_2", "ln_f"],
"blip-2": ["layernorm", "LayerNorm", "final_layer_norm", "self_attn_layer_norm"],
"gptj": ["ln_1", "ln_f"],
"falcon": ["input_layernorm", "post_attention_layernorm", "ln_f"],
"mistral": ["input_layernorm", "post_attention_layernorm", "norm"],
"phi": ["input_layernorm", "final_layernorm"],
"gemma": ["input_layernorm", "post_attention_layernorm", "norm"],
}
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING = {
"t5": ["q", "v"],

View File

@ -35,6 +35,7 @@ from .constants import (
TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING,
@ -53,6 +54,7 @@ __all__ = [
"TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING",
"TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING",
"TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING",
"TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING",
"TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING",
"WEIGHTS_NAME",
"INCLUDE_LINEAR_LAYERS_SHORTHAND",

View File

@ -30,11 +30,14 @@ class PeftType(str, enum.Enum):
- PREFIX_TUNING
- LORA
- ADALORA
- BOFT
- ADAPTION_PROMPT
- IA3
- LOHA
- LOKR
- OFT
- POLY
- LN_TUNING
"""
PROMPT_TUNING = "PROMPT_TUNING"
@ -50,6 +53,7 @@ class PeftType(str, enum.Enum):
LOKR = "LOKR"
OFT = "OFT"
POLY = "POLY"
LN_TUNING = "LN_TUNING"
VERA = "VERA"

View File

@ -146,6 +146,9 @@ def get_peft_model_state_dict(
elif config.peft_type == PeftType.POLY:
to_return = {k: state_dict[k] for k in state_dict if "poly_" in k}
elif config.peft_type == PeftType.LN_TUNING:
to_return = {k: state_dict[k] for k in state_dict if "ln_tuning_" in k}
elif config.peft_type == PeftType.VERA:
to_return = {k: state_dict[k] for k in state_dict if "vera_lambda_" in k}
if config.save_projection:
@ -289,6 +292,7 @@ def set_peft_model_state_dict(
PeftType.IA3,
PeftType.OFT,
PeftType.POLY,
PeftType.LN_TUNING,
PeftType.BOFT,
PeftType.VERA,
):
@ -302,6 +306,7 @@ def set_peft_model_state_dict(
PeftType.OFT: "oft_",
PeftType.POLY: "poly_",
PeftType.BOFT: "boft_",
PeftType.LN_TUNING: "ln_tuning_",
PeftType.VERA: "vera_lambda_",
}[config.peft_type]
for k, v in state_dict.items():

View File

@ -10,6 +10,11 @@ if os.environ.get("PEFT_DEBUG_WITH_TORCH_COMPILE") == "1":
import peft
from peft.mapping import get_peft_model as get_peft_model_original
# TODO: Experimental dynamo feature that should allow correct compilation of more PEFT modules. This should be
# removed once PyTorch has found a better solution, as this incurs a performance penalty.
# https://github.com/pytorch/pytorch/issues/124717#issuecomment-2083235776
torch._dynamo.config.guard_nn_modules = True
def get_peft_model_new(*args, **kwargs):
"""Make get_peft_model() return a compiled model."""
peft_model = get_peft_model_original(*args, **kwargs)

View File

@ -35,6 +35,7 @@ from peft import (
AdaLoraConfig,
BOFTConfig,
IA3Config,
LNTuningConfig,
LoHaConfig,
LoKrConfig,
LoraConfig,
@ -243,6 +244,19 @@ TEST_CASES = [
("Conv2d 3 OFT", "Conv2d", OFTConfig, {"target_modules": ["conv2d"], "coft": True}),
("Conv2d 4 OFT", "Conv2d", OFTConfig, {"target_modules": ["conv2d"], "block_share": True}),
("Conv2d 5 OFT", "Conv2d", OFTConfig, {"target_modules": ["conv2d"], "coft": True, "block_share": True}),
#############
# LN Tuning #
#############
("LayerNorm 1 LNTuning", "MLP_LayerNorm", LNTuningConfig, {"target_modules": "layernorm0"}),
("LayerNorm 2 LNTuning", "MLP_LayerNorm", LNTuningConfig, {"target_modules": ["layernorm0"]}),
(
"LayerNorm 3 LNTuning",
"MLP_LayerNorm",
LNTuningConfig,
{"target_modules": ["layernorm0"], "modules_to_save": ["layernorm1"]},
),
("Linear 4 LNTuning", "MLP_LayerNorm", LNTuningConfig, {"target_modules": "lin0"}),
("Linear 5 LNTuning", "MLP_LayerNorm", LNTuningConfig, {"target_modules": ["lin0"]}),
########
# BOFT #
########
@ -404,6 +418,7 @@ PREFIXES = {
LoKrConfig: "lokr_",
OFTConfig: "oft_",
BOFTConfig: "boft_",
LNTuningConfig: "ln_tuning_",
VeraConfig: "vera_lambda_",
}
@ -427,6 +442,29 @@ class MLP(nn.Module):
return X
class MLP_LayerNorm(nn.Module):
def __init__(self, bias=True):
super().__init__()
self.layernorm0 = nn.LayerNorm(10, 10)
self.lin0 = nn.Linear(10, 20, bias=bias)
self.relu = nn.ReLU()
self.drop = nn.Dropout(0.5)
self.layernorm1 = nn.LayerNorm(20, 20)
self.lin1 = nn.Linear(20, 2, bias=bias)
self.sm = nn.LogSoftmax(dim=-1)
def forward(self, X):
X = X.float()
X = self.layernorm0(X)
X = self.lin0(X)
X = self.relu(X)
X = self.drop(X)
X = self.layernorm1(X)
X = self.lin1(X)
X = self.sm(X)
return X
class MLP2(nn.Module):
def __init__(self, bias=True):
super().__init__()
@ -592,6 +630,9 @@ class MockTransformerWrapper:
if model_id == "Conv2d":
return ModelConv2D().to(torch_dtype)
if model_id == "MLP_LayerNorm":
return MLP_LayerNorm().to(torch_dtype)
if model_id == "MLP2":
return MLP2().to(torch_dtype)
@ -645,6 +686,8 @@ class PeftCustomModelTester(unittest.TestCase, PeftCommonTester):
config_kwargs["init_lora_weights"] = False
elif issubclass(config_cls, IA3Config):
config_kwargs["init_ia3_weights"] = False
elif issubclass(config_cls, LNTuningConfig):
pass
else:
config_kwargs["init_weights"] = False
self._test_merge_layers(model_id, config_cls, config_kwargs)
@ -676,6 +719,9 @@ class PeftCustomModelTester(unittest.TestCase, PeftCommonTester):
config_kwargs["init_lora_weights"] = False
elif issubclass(config_cls, IA3Config):
config_kwargs["init_ia3_weights"] = False
elif issubclass(config_cls, LNTuningConfig):
# LNTuning do not take init_weights
pass
else:
config_kwargs["init_weights"] = False
self._test_safe_merge(model_id, config_cls, config_kwargs)
@ -825,6 +871,9 @@ class PeftCustomModelTester(unittest.TestCase, PeftCommonTester):
model.train()
# EmbConv1D is slow to learn for some reason
lr = 0.01 if model_id != "EmbConv1D" else 1.0
if isinstance(config_cls, LNTuningConfig):
# LayerNorm tuning is slow to learn
lr = 1.0
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
# train at least 3 steps for all parameters to be updated (probably this is required because of symmetry
@ -846,7 +895,12 @@ class PeftCustomModelTester(unittest.TestCase, PeftCommonTester):
# check that after leaving the disable_adapter context, everything is enabled again
outputs_enabled_after_disable = model(**X)
assert not torch.allclose(outputs_before, outputs_after)
if self.torch_device == "cpu":
# LayerNorm is running float32 on cpu, so difference in outputs are smaller
rtol, atol = 1e-8, 1e-8
else:
rtol, atol = 1e-5, 1e-8
assert not torch.allclose(outputs_before, outputs_after, rtol=rtol, atol=atol)
assert torch.allclose(outputs_before, outputs_disabled)
assert torch.allclose(outputs_after, outputs_enabled_after_disable)
@ -864,9 +918,14 @@ class PeftCustomModelTester(unittest.TestCase, PeftCommonTester):
outputs_before = model(**X)
model.train()
lr = 0.01
# Adam optimizer since SGD isn't great for small models with IA3 + Conv1D
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
if isinstance(config_cls, LNTuningConfig):
# LayerNorm tuning is slow to learn
lr = 1.0
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
else:
# Adam optimizer since SGD isn't great for small models with IA3 + Conv1D
lr = 0.01
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# train at least 3 steps for all parameters to be updated (probably this is required because of symmetry
# breaking of some LoRA layers that are initialized with constants)
@ -2544,6 +2603,97 @@ class RequiresGradTester(unittest.TestCase):
"base_model.model.lin1.boft_s.adapter1",
)
def test_requires_grad_lntuning_different_targets(self):
config0 = LNTuningConfig(
target_modules=["layernorm0"],
)
peft_model = get_peft_model(MLP_LayerNorm(), config0)
config1 = LNTuningConfig(
target_modules=["layernorm1"],
inference_mode=True,
)
peft_model.add_adapter("adapter1", config1)
# active adapter is still "default"
self.check_requires_grad(
peft_model,
"base_model.model.layernorm0.ln_tuning_layers.default.weight",
"base_model.model.layernorm0.ln_tuning_layers.default.bias",
)
# set config0 as active, should not change anything
peft_model.set_adapter("default")
self.check_requires_grad(
peft_model,
"base_model.model.layernorm0.ln_tuning_layers.default.weight",
"base_model.model.layernorm0.ln_tuning_layers.default.bias",
)
# change activate adapter to adapter1
peft_model.set_adapter("adapter1")
self.check_requires_grad(
peft_model,
"base_model.model.layernorm1.ln_tuning_layers.adapter1.weight",
"base_model.model.layernorm1.ln_tuning_layers.adapter1.bias",
)
# disable all adapters
with peft_model.disable_adapter():
self.check_requires_grad(peft_model)
# after context is exited, return to the previous state
peft_model.set_adapter("adapter1")
self.check_requires_grad(
peft_model,
"base_model.model.layernorm1.ln_tuning_layers.adapter1.weight",
"base_model.model.layernorm1.ln_tuning_layers.adapter1.bias",
)
def test_requires_grad_lntuning_same_targets(self):
config0 = LNTuningConfig(
target_modules=["layernorm0"],
)
peft_model = get_peft_model(MLP_LayerNorm(), config0)
config1 = LNTuningConfig(target_modules=["layernorm0"], inference_mode=True)
peft_model.add_adapter("adapter1", config1)
# active adapter is still "default"
self.check_requires_grad(
peft_model,
"base_model.model.layernorm0.ln_tuning_layers.default.weight",
"base_model.model.layernorm0.ln_tuning_layers.default.bias",
)
# set config0 as active, should not change anything
peft_model.set_adapter("default")
self.check_requires_grad(
peft_model,
"base_model.model.layernorm0.ln_tuning_layers.default.weight",
"base_model.model.layernorm0.ln_tuning_layers.default.bias",
)
# change activate adapter to adapter1
peft_model.set_adapter("adapter1")
self.check_requires_grad(
peft_model,
"base_model.model.layernorm0.ln_tuning_layers.adapter1.weight",
"base_model.model.layernorm0.ln_tuning_layers.adapter1.bias",
)
# disable all adapters
with peft_model.disable_adapter():
self.check_requires_grad(peft_model)
# after context is exited, return to the previous state
peft_model.set_adapter("adapter1")
self.check_requires_grad(
peft_model,
"base_model.model.layernorm0.ln_tuning_layers.adapter1.weight",
"base_model.model.layernorm0.ln_tuning_layers.adapter1.bias",
)
def test_requires_grad_vera_different_targets(self):
# Test two different VeRA adapters that target different modules. Most notably, ensure that vera_A and vera_B
# don't require grads.

View File

@ -29,6 +29,7 @@ from peft import (
AdaLoraConfig,
BOFTConfig,
IA3Config,
LNTuningConfig,
LoHaConfig,
LoKrConfig,
LoraConfig,
@ -688,8 +689,10 @@ class PeftCommonTester:
model = get_peft_model(model, config).eval()
logits_peft = model(**inputs)[0]
# sanity check that the logits are different
assert not torch.allclose(logits_base, logits_peft, atol=1e-6, rtol=1e-6)
# Initializing with LN tuning cannot be configured to change the outputs (unlike init_lora_weights=False)
if not issubclass(config_cls, LNTuningConfig):
# sanity check that the logits are different
assert not torch.allclose(logits_base, logits_peft, atol=1e-6, rtol=1e-6)
model_unloaded = model.merge_and_unload(safe_merge=True)
logits_unloaded = model_unloaded(**inputs)[0]