mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 15:33:48 +08:00
ENH Make OFT faster and more memory efficient (#2575)
Make OFT faster and more memory efficient. This new version of OFT is not backwards compatible with older checkpoints and vice versa. To load older checkpoints, downgrade PEFT to 0.15.2 or lower.
This commit is contained in:
@ -16,9 +16,9 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
# Orthogonal Finetuning (OFT and BOFT)
|
||||
|
||||
This conceptual guide gives a brief overview of [OFT](https://huggingface.co/papers/2306.07280) and [BOFT](https://huggingface.co/papers/2311.06243), a parameter-efficient fine-tuning technique that utilizes orthogonal matrix to multiplicatively transform the pretrained weight matrices.
|
||||
This conceptual guide gives a brief overview of [OFT](https://huggingface.co/papers/2306.07280), [OFTv2](https://www.arxiv.org/abs/2506.19847) and [BOFT](https://huggingface.co/papers/2311.06243), a parameter-efficient fine-tuning technique that utilizes orthogonal matrix to multiplicatively transform the pretrained weight matrices.
|
||||
|
||||
To achieve efficient fine-tuning, OFT represents the weight updates with an orthogonal transformation. The orthogonal transformation is parameterized by an orthogonal matrix multiplied to the pretrained weight matrix. These new matrices can be trained to adapt to the new data while keeping the overall number of changes low. The original weight matrix remains frozen and doesn’t receive any further adjustments. To produce the final results, both the original and the adapted weights are multiplied togethor.
|
||||
To achieve efficient fine-tuning, OFT represents the weight updates with an orthogonal transformation. The orthogonal transformation is parameterized by an orthogonal matrix multiplied to the pretrained weight matrix. These new matrices can be trained to adapt to the new data while keeping the overall number of changes low. The original weight matrix remains frozen and doesn't receive any further adjustments. To produce the final results, both the original and the adapted weights are multiplied togethor.
|
||||
|
||||
Orthogonal Butterfly (BOFT) generalizes OFT with Butterfly factorization and further improves its parameter efficiency and finetuning flexibility. In short, OFT can be viewed as a special case of BOFT. Different from LoRA that uses additive low-rank weight updates, BOFT uses multiplicative orthogonal weight updates. The comparison is shown below.
|
||||
|
||||
@ -58,13 +58,25 @@ As with other methods supported by PEFT, to fine-tune a model using OFT or BOFT,
|
||||
4. Train the `PeftModel` as you normally would train the base model.
|
||||
|
||||
|
||||
### OFT-specific parameters
|
||||
|
||||
`OFTConfig` allows you to control how OFT is applied to the base model through the following parameters:
|
||||
|
||||
- `r`: OFT rank, number of OFT blocks per injected layer. **Bigger** `r` results in more sparse update matrices with **fewer** trainable paramters. **Note**: You can only specify either `r` or `oft_block_size`, but not both simultaneously, because `r` × `oft_block_size` = layer dimension. For simplicity, we let the user speficy either `r` or `oft_block_size` and infer the other one. Default set to `r = 0`, the user is advised to set the `oft_block_size` instead for better clarity.
|
||||
- `oft_block_size`: OFT block size across different layers. **Bigger** `oft_block_size` results in more dense update matrices with **more** trainable parameters. **Note**: Please choose `oft_block_size` to be divisible by layer's input dimension (`in_features`), e.g., 4, 8, 16. You can only specify either `r` or `oft_block_size`, but not both simultaneously, because `r` × `oft_block_size` = layer dimension. For simplicity, we let the user speficy either `r` or `oft_block_size` and infer the other one. Default set to `oft_block_size = 32`.
|
||||
- `use_cayley_neumann`: Specifies whether to use the Cayley-Neumann parameterization (efficient but approximate) or the vanilla Cayley parameterization (exact but computationally expensive because of matrix inverse). We recommend to set it to `True` for better efficiency, but performance may be slightly worse because of the approximation error. Please test both settings (`True` and `False`) depending on your needs. Default is `False`.
|
||||
- `module_dropout`: The multiplicative dropout probability, by setting OFT blocks to identity during training, similar to the dropout layer in LoRA.
|
||||
- `bias`: specify if the `bias` parameters should be trained. Can be `"none"`, `"all"` or `"oft_only"`.
|
||||
- `target_modules`: The modules (for example, attention blocks) to inject the OFT matrices.
|
||||
- `modules_to_save`: List of modules apart from OFT matrices to be set as trainable and saved in the final checkpoint. These typically include model's custom head that is randomly initialized for the fine-tuning task.
|
||||
|
||||
### BOFT-specific parameters
|
||||
|
||||
`BOFTConfig` allows you to control how OFT/BOFT is applied to the base model through the following parameters:
|
||||
`BOFTConfig` allows you to control how BOFT is applied to the base model through the following parameters:
|
||||
|
||||
- `boft_block_size`: the BOFT matrix block size across different layers, expressed in `int`. Smaller block size results in sparser update matrices with fewer trainable parameters. **Note**, please choose `boft_block_size` to be divisible by most layer's input dimension (`in_features`), e.g., 4, 8, 16. Also, please only
|
||||
- `boft_block_size`: the BOFT matrix block size across different layers, expressed in `int`. **Bigger** `boft_block_size` results in more dense update matrices with **more** trainable parameters. **Note**, please choose `boft_block_size` to be divisible by most layer's input dimension (`in_features`), e.g., 4, 8, 16. Also, please only
|
||||
specify either `boft_block_size` or `boft_block_num`, but not both simultaneously or leaving both to 0, because `boft_block_size` x `boft_block_num` must equal the layer's input dimension.
|
||||
- `boft_block_num`: the number of BOFT matrix blocks across different layers, expressed in `int`. Fewer blocks result in sparser update matrices with fewer trainable parameters. **Note**, please choose `boft_block_num` to be divisible by most layer's input dimension (`in_features`), e.g., 4, 8, 16. Also, please only
|
||||
- `boft_block_num`: the number of BOFT matrix blocks across different layers, expressed in `int`. **Bigger** `boft_block_num` result in sparser update matrices with **fewer** trainable parameters. **Note**, please choose `boft_block_num` to be divisible by most layer's input dimension (`in_features`), e.g., 4, 8, 16. Also, please only
|
||||
specify either `boft_block_size` or `boft_block_num`, but not both simultaneously or leaving both to 0, because `boft_block_size` x `boft_block_num` must equal the layer's input dimension.
|
||||
- `boft_n_butterfly_factor`: the number of butterfly factors. **Note**, for `boft_n_butterfly_factor=1`, BOFT is the same as vanilla OFT, for `boft_n_butterfly_factor=2`, the effective block size of OFT becomes twice as big and the number of blocks become half.
|
||||
- `bias`: specify if the `bias` parameters should be trained. Can be `"none"`, `"all"` or `"boft_only"`.
|
||||
@ -74,6 +86,52 @@ specify either `boft_block_size` or `boft_block_num`, but not both simultaneousl
|
||||
|
||||
|
||||
|
||||
## OFT Example Usage
|
||||
|
||||
For using OFT for quantized finetuning with [TRL](https://github.com/huggingface/trl) for `SFT`, `PPO`, or `DPO` fine-tuning, follow the following outline:
|
||||
|
||||
```py
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
||||
from trl import SFTTrainer
|
||||
from peft import OFTConfig
|
||||
|
||||
if use_quantization:
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_storage=torch.bfloat16,
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"model_name",
|
||||
quantization_config=bnb_config
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("model_name")
|
||||
|
||||
# Configure OFT
|
||||
peft_config = OFTConfig(
|
||||
oft_block_size=32,
|
||||
use_cayley_neumann=True,
|
||||
target_modules="all-linear",
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
train_dataset=ds['train'],
|
||||
peft_config=peft_config,
|
||||
tokenizer=tokenizer,
|
||||
args=training_arguments,
|
||||
data_collator=collator,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
|
||||
## BOFT Example Usage
|
||||
|
||||
For an example of the BOFT method application to various downstream tasks, please refer to the following guides:
|
||||
|
@ -12,13 +12,41 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from peft.import_utils import is_bnb_4bit_available, is_bnb_available, is_eetq_available
|
||||
from peft.utils import register_peft_method
|
||||
|
||||
from .config import OFTConfig
|
||||
from .gptq import GPTQOFTLinear
|
||||
from .layer import Conv2d, Linear, OFTLayer
|
||||
from .model import OFTModel
|
||||
|
||||
|
||||
__all__ = ["Conv2d", "Linear", "OFTConfig", "OFTLayer", "OFTModel"]
|
||||
__all__ = [
|
||||
"Conv2d",
|
||||
"GPTQOFTLinear",
|
||||
"Linear",
|
||||
"OFTConfig",
|
||||
"OFTLayer",
|
||||
"OFTModel",
|
||||
]
|
||||
|
||||
register_peft_method(name="oft", config_cls=OFTConfig, model_cls=OFTModel)
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
if (name == "Linear8bitLt") and is_bnb_available():
|
||||
from .bnb import Linear8bitLt
|
||||
|
||||
return Linear8bitLt
|
||||
|
||||
if (name == "Linear4bit") and is_bnb_4bit_available():
|
||||
from .bnb import Linear4bit
|
||||
|
||||
return Linear4bit
|
||||
|
||||
if (name == "EetqOFTLinear") and is_eetq_available():
|
||||
from .eetq import EetqOFTLinear
|
||||
|
||||
return EetqOFTLinear
|
||||
|
||||
raise AttributeError(f"module {__name__} has no attribute {name}")
|
||||
|
105
src/peft/tuners/oft/aqlm.py
Normal file
105
src/peft/tuners/oft/aqlm.py
Normal file
@ -0,0 +1,105 @@
|
||||
# Copyright 2025-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from peft.import_utils import is_aqlm_available
|
||||
from peft.tuners.oft.layer import OFTLayer
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
|
||||
if is_aqlm_available():
|
||||
from aqlm import QuantizedLinear
|
||||
|
||||
|
||||
class AqlmOFTLinear(torch.nn.Module, OFTLayer):
|
||||
def __init__(
|
||||
self,
|
||||
base_layer,
|
||||
adapter_name: str,
|
||||
r: int = 0,
|
||||
oft_block_size: int = 32,
|
||||
module_dropout: float = 0.0,
|
||||
init_weights: bool = True,
|
||||
coft: bool = False,
|
||||
eps: float = 6e-5,
|
||||
block_share: bool = False,
|
||||
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
||||
use_cayley_neumann: bool = False,
|
||||
num_cayley_neumann_terms: int = 5,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
OFTLayer.__init__(self, base_layer)
|
||||
|
||||
self._active_adapter = adapter_name
|
||||
self.update_layer(
|
||||
adapter_name,
|
||||
r,
|
||||
oft_block_size=oft_block_size,
|
||||
module_dropout=module_dropout,
|
||||
init_weights=init_weights,
|
||||
coft=coft,
|
||||
eps=eps,
|
||||
block_share=block_share,
|
||||
use_cayley_neumann=use_cayley_neumann,
|
||||
num_cayley_neumann_terms=num_cayley_neumann_terms,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# note: logic differs from default Linear because merging is not supported
|
||||
if self.disable_adapters:
|
||||
return self.base_layer(x)
|
||||
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self.oft_R.keys():
|
||||
continue
|
||||
oft_R = self.oft_R[active_adapter]
|
||||
|
||||
requires_conversion = not torch.is_autocast_enabled()
|
||||
if requires_conversion:
|
||||
expected_dtype = x.dtype
|
||||
x = self._cast_input_dtype(x, oft_R.weight.dtype)
|
||||
|
||||
x = oft_R(x)
|
||||
|
||||
result = self.base_layer(x)
|
||||
if requires_conversion:
|
||||
result = result.to(expected_dtype)
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "oft." + rep
|
||||
|
||||
|
||||
def dispatch_aqlm(
|
||||
target: torch.nn.Module,
|
||||
adapter_name: str,
|
||||
**kwargs: Any,
|
||||
) -> Optional[torch.nn.Module]:
|
||||
new_module = None
|
||||
|
||||
if isinstance(target, BaseTunerLayer):
|
||||
target_base_layer = target.get_base_layer()
|
||||
else:
|
||||
target_base_layer = target
|
||||
|
||||
if is_aqlm_available() and isinstance(target_base_layer, QuantizedLinear):
|
||||
new_module = AqlmOFTLinear(target, adapter_name, **kwargs)
|
||||
target.qweight = target_base_layer.codes
|
||||
|
||||
return new_module
|
119
src/peft/tuners/oft/awq.py
Normal file
119
src/peft/tuners/oft/awq.py
Normal file
@ -0,0 +1,119 @@
|
||||
# Copyright 2025-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib.metadata as importlib_metadata
|
||||
from typing import Any, Optional
|
||||
|
||||
import packaging.version
|
||||
import torch
|
||||
|
||||
from peft.import_utils import is_auto_awq_available
|
||||
from peft.tuners.oft.layer import OFTLayer
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
|
||||
class AwqOFTLinear(torch.nn.Module, OFTLayer):
|
||||
def __init__(
|
||||
self,
|
||||
base_layer,
|
||||
adapter_name,
|
||||
r: int = 0,
|
||||
oft_block_size: int = 32,
|
||||
module_dropout: float = 0.0,
|
||||
coft: bool = False,
|
||||
eps: float = 6e-5,
|
||||
block_share: bool = False,
|
||||
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
||||
init_weights: bool = True,
|
||||
use_cayley_neumann: bool = False,
|
||||
num_cayley_neumann_terms: int = 5,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
OFTLayer.__init__(self, base_layer)
|
||||
|
||||
# self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter
|
||||
# for backwards compatibility
|
||||
self.quant_linear_module = base_layer
|
||||
|
||||
self._active_adapter = adapter_name
|
||||
self.update_layer(
|
||||
adapter_name,
|
||||
r,
|
||||
oft_block_size=oft_block_size,
|
||||
module_dropout=module_dropout,
|
||||
coft=coft,
|
||||
eps=eps,
|
||||
block_share=block_share,
|
||||
init_weights=init_weights,
|
||||
use_cayley_neumann=use_cayley_neumann,
|
||||
num_cayley_neumann_terms=num_cayley_neumann_terms,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.disable_adapters:
|
||||
result = self.quant_linear_module(x)
|
||||
return result
|
||||
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self.oft_R.keys():
|
||||
continue
|
||||
oft_R = self.oft_R[active_adapter]
|
||||
|
||||
requires_conversion = not torch.is_autocast_enabled()
|
||||
if requires_conversion:
|
||||
expected_dtype = x.dtype
|
||||
x = self._cast_input_dtype(x, oft_R.weight.dtype)
|
||||
|
||||
x = oft_R(x)
|
||||
if requires_conversion:
|
||||
x = x.to(expected_dtype)
|
||||
|
||||
result = self.quant_linear_module(x)
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "oft." + rep
|
||||
|
||||
|
||||
def dispatch_awq(
|
||||
target: torch.nn.Module,
|
||||
adapter_name: str,
|
||||
**kwargs: Any,
|
||||
) -> Optional[torch.nn.Module]:
|
||||
new_module = None
|
||||
|
||||
if isinstance(target, BaseTunerLayer):
|
||||
target_base_layer = target.get_base_layer()
|
||||
else:
|
||||
target_base_layer = target
|
||||
|
||||
if is_auto_awq_available():
|
||||
from awq.modules.linear import WQLinear_GEMM
|
||||
|
||||
if isinstance(target_base_layer, WQLinear_GEMM):
|
||||
# Raise the error only at the dispatch level
|
||||
AUTOAWQ_MINIMUM_VERSION = packaging.version.parse("0.2.0")
|
||||
version_autoawq = packaging.version.parse(importlib_metadata.version("autoawq"))
|
||||
|
||||
if AUTOAWQ_MINIMUM_VERSION > version_autoawq:
|
||||
raise ImportError(
|
||||
f"Found an incompatible version of auto-awq. Found version {version_autoawq}, "
|
||||
f"but only versions above {AUTOAWQ_MINIMUM_VERSION} are supported for PEFT."
|
||||
)
|
||||
|
||||
new_module = AwqOFTLinear(target, adapter_name, **kwargs)
|
||||
target.qweight = target_base_layer.qweight
|
||||
|
||||
return new_module
|
388
src/peft/tuners/oft/bnb.py
Normal file
388
src/peft/tuners/oft/bnb.py
Normal file
@ -0,0 +1,388 @@
|
||||
# Copyright 2025-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
|
||||
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
|
||||
from peft.utils.integrations import dequantize_bnb_weight
|
||||
|
||||
from .layer import OFTLayer
|
||||
|
||||
|
||||
if is_bnb_available():
|
||||
|
||||
class Linear8bitLt(torch.nn.Module, OFTLayer):
|
||||
# OFT implemented in a dense layer
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: torch.nn.Module,
|
||||
adapter_name: str,
|
||||
r: int = 8,
|
||||
oft_block_size: int = 0,
|
||||
module_dropout: float = 0.0,
|
||||
init_weights: bool = True,
|
||||
coft: bool = False,
|
||||
eps: float = 6e-5,
|
||||
block_share: bool = False,
|
||||
use_cayley_neumann: bool = False,
|
||||
num_cayley_neumann_terms: int = 5,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
OFTLayer.__init__(self, base_layer)
|
||||
self.fan_in_fan_out = False
|
||||
|
||||
self._active_adapter = adapter_name
|
||||
self.update_layer(
|
||||
adapter_name,
|
||||
r,
|
||||
oft_block_size=oft_block_size,
|
||||
module_dropout=module_dropout,
|
||||
coft=coft,
|
||||
eps=eps,
|
||||
block_share=block_share,
|
||||
init_weights=init_weights,
|
||||
use_cayley_neumann=use_cayley_neumann,
|
||||
num_cayley_neumann_terms=num_cayley_neumann_terms,
|
||||
)
|
||||
|
||||
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
|
||||
"""
|
||||
Merge the active adapter weights into the base weights
|
||||
|
||||
Args:
|
||||
safe_merge (`bool`, *optional*):
|
||||
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
|
||||
before merging the weights. This is useful if you want to check if the merge operation will produce
|
||||
NaNs. Defaults to `False`.
|
||||
adapter_names (`list[str]`, *optional*):
|
||||
The list of adapter names that should be merged. If None, all active adapters will be merged.
|
||||
Defaults to `None`.
|
||||
"""
|
||||
adapter_names = check_adapters_to_merge(self, adapter_names)
|
||||
if not adapter_names:
|
||||
# no adapter to merge
|
||||
return
|
||||
|
||||
for active_adapter in adapter_names:
|
||||
if active_adapter not in self.oft_R.keys():
|
||||
continue
|
||||
|
||||
warnings.warn("Merge oft module to 8-bit linear may get different generations due to rounding errors.")
|
||||
|
||||
weight = self.get_base_layer().weight
|
||||
state = self.get_base_layer().state
|
||||
if state.SCB is None:
|
||||
state.SCB = weight.SCB
|
||||
|
||||
# Dequantize the result of identity matrix and int8 weight because bitsandbytes does not support int8
|
||||
# dequantization directly
|
||||
output = dequantize_bnb_weight(weight, state=state)
|
||||
oft_data = self.get_delta_weight(active_adapter)
|
||||
|
||||
output = torch.transpose(output, 0, 1)
|
||||
w_data = torch.mm(oft_data, output.to(oft_data.dtype))
|
||||
w_data = torch.transpose(w_data, 0, 1)
|
||||
w_data = output.to(oft_data.dtype).to(oft_data.device)
|
||||
|
||||
if safe_merge and not torch.isfinite(w_data).all():
|
||||
raise ValueError(
|
||||
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
||||
)
|
||||
|
||||
self.get_base_layer().weight = bnb.nn.Int8Params(
|
||||
w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights
|
||||
).to(weight.device)
|
||||
|
||||
state.reset_grads()
|
||||
self.merged_adapters.append(active_adapter)
|
||||
|
||||
def unmerge(self) -> None:
|
||||
"""
|
||||
This method unmerges all merged adapter layers from the base weights.
|
||||
"""
|
||||
if not self.merged:
|
||||
warnings.warn("Already unmerged. Nothing to do.")
|
||||
return
|
||||
|
||||
while len(self.merged_adapters) > 0:
|
||||
active_adapter = self.merged_adapters.pop()
|
||||
if active_adapter not in self.oft_R.keys():
|
||||
continue
|
||||
warnings.warn(
|
||||
"Unmerge oft module to 8-bit linear may get different generations due to rounding errors."
|
||||
)
|
||||
|
||||
weight = self.get_base_layer().weight
|
||||
state = self.get_base_layer().state
|
||||
if state.SCB is None:
|
||||
state.SCB = weight.SCB
|
||||
output = dequantize_bnb_weight(weight, state=state)
|
||||
|
||||
oft_data = self.get_delta_weight(active_adapter)
|
||||
|
||||
output = torch.transpose(output, 0, 1)
|
||||
w_data = torch.mm(oft_data.t(), output.to(oft_data.dtype))
|
||||
w_data = torch.transpose(w_data, 0, 1)
|
||||
w_data = w_data.to(oft_data.dtype).to(oft_data.device)
|
||||
|
||||
self.get_base_layer().weight = bnb.nn.Int8Params(
|
||||
w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights
|
||||
).to(weight.device)
|
||||
|
||||
state.reset_grads()
|
||||
|
||||
def get_delta_weight(self, adapter):
|
||||
return self.oft_R[adapter].get_weight()
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
if self.disable_adapters:
|
||||
if self.merged:
|
||||
self.unmerge()
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
elif self.merged:
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
else:
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self.oft_R.keys():
|
||||
continue
|
||||
oft_R = self.oft_R[active_adapter]
|
||||
|
||||
requires_conversion = not torch.is_autocast_enabled()
|
||||
if requires_conversion:
|
||||
expected_dtype = x.dtype
|
||||
x = self._cast_input_dtype(x, oft_R.weight.dtype)
|
||||
|
||||
x = oft_R(x)
|
||||
if requires_conversion:
|
||||
x = x.to(expected_dtype)
|
||||
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "oft." + rep
|
||||
|
||||
def dispatch_bnb_8bit(target: torch.nn.Module, adapter_name: str, **kwargs):
|
||||
new_module = None
|
||||
|
||||
if isinstance(target, BaseTunerLayer):
|
||||
target_base_layer = target.get_base_layer()
|
||||
else:
|
||||
target_base_layer = target
|
||||
|
||||
loaded_in_8bit = kwargs.get("loaded_in_8bit", False)
|
||||
if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt):
|
||||
eightbit_kwargs = kwargs.copy()
|
||||
eightbit_kwargs.update(
|
||||
{
|
||||
"has_fp16_weights": target.state.has_fp16_weights,
|
||||
"threshold": target.state.threshold,
|
||||
"index": target.index,
|
||||
}
|
||||
)
|
||||
new_module = Linear8bitLt(target, adapter_name, **eightbit_kwargs)
|
||||
|
||||
return new_module
|
||||
|
||||
|
||||
if is_bnb_4bit_available():
|
||||
|
||||
class Linear4bit(torch.nn.Module, OFTLayer):
|
||||
# OFT implemented in a dense layer
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: torch.nn.Module,
|
||||
adapter_name: str,
|
||||
r: int = 8,
|
||||
oft_block_size: int = 0,
|
||||
module_dropout: float = 0.0,
|
||||
coft: bool = False,
|
||||
eps: float = 6e-5,
|
||||
block_share: bool = False,
|
||||
init_weights: bool = True,
|
||||
use_cayley_neumann: bool = False,
|
||||
num_cayley_neumann_terms: int = 5,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
OFTLayer.__init__(self, base_layer)
|
||||
self.fan_in_fan_out = False
|
||||
|
||||
self._active_adapter = adapter_name
|
||||
self.update_layer(
|
||||
adapter_name,
|
||||
r,
|
||||
oft_block_size=oft_block_size,
|
||||
module_dropout=module_dropout,
|
||||
coft=coft,
|
||||
eps=eps,
|
||||
block_share=block_share,
|
||||
init_weights=init_weights,
|
||||
use_cayley_neumann=use_cayley_neumann,
|
||||
num_cayley_neumann_terms=num_cayley_neumann_terms,
|
||||
)
|
||||
|
||||
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
|
||||
"""
|
||||
Merge the active adapter weights into the base weights
|
||||
|
||||
Args:
|
||||
safe_merge (`bool`, *optional*):
|
||||
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
|
||||
before merging the weights. This is useful if you want to check if the merge operation will produce
|
||||
NaNs. Defaults to `False`.
|
||||
adapter_names (`list[str]`, *optional*):
|
||||
The list of adapter names that should be merged. If None, all active adapters will be merged.
|
||||
Defaults to `None`.
|
||||
"""
|
||||
adapter_names = check_adapters_to_merge(self, adapter_names)
|
||||
if not adapter_names:
|
||||
# no adapter to merge
|
||||
return
|
||||
|
||||
for active_adapter in adapter_names:
|
||||
if active_adapter not in self.oft_R.keys():
|
||||
continue
|
||||
|
||||
warnings.warn("Merge oft module to 4-bit linear may get different generations due to rounding errors.")
|
||||
# Refer to https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930
|
||||
weight = self.get_base_layer().weight
|
||||
kwargs = weight.__dict__
|
||||
|
||||
output = dequantize_bnb_weight(weight, state=weight.quant_state)
|
||||
|
||||
oft_data = self.get_delta_weight(active_adapter)
|
||||
output = torch.transpose(output, 0, 1)
|
||||
w_data = torch.mm(oft_data, output.to(oft_data.dtype))
|
||||
w_data = torch.transpose(w_data, 0, 1)
|
||||
w_data = output.to(oft_data.dtype).to(oft_data.device)
|
||||
|
||||
if safe_merge and not torch.isfinite(w_data).all():
|
||||
raise ValueError(
|
||||
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
||||
)
|
||||
|
||||
if "bnb_quantized" in kwargs:
|
||||
kwargs["bnb_quantized"] = False
|
||||
kwargs["requires_grad"] = False
|
||||
kwargs.pop("data", None)
|
||||
# torch.compile can introduce attributes preceded by '_', remove them
|
||||
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
|
||||
self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), **kwargs).to(weight.device)
|
||||
|
||||
self.merged_adapters.append(active_adapter)
|
||||
|
||||
def unmerge(self) -> None:
|
||||
"""
|
||||
This method unmerges all merged adapter layers from the base weights.
|
||||
"""
|
||||
if not self.merged:
|
||||
warnings.warn("Already unmerged. Nothing to do.")
|
||||
return
|
||||
|
||||
while len(self.merged_adapters) > 0:
|
||||
active_adapter = self.merged_adapters.pop()
|
||||
if active_adapter not in self.oft_R.keys():
|
||||
continue
|
||||
warnings.warn(
|
||||
"Unmerge oft module to 4-bit linear may get different generations due to rounding errors."
|
||||
)
|
||||
|
||||
weight = self.get_base_layer().weight
|
||||
kwargs = weight.__dict__
|
||||
output = dequantize_bnb_weight(weight, state=weight.quant_state)
|
||||
|
||||
oft_data = self.get_delta_weight(active_adapter)
|
||||
|
||||
output = torch.transpose(output, 0, 1)
|
||||
w_data = torch.mm(oft_data.t(), output.to(oft_data.dtype))
|
||||
w_data = torch.transpose(w_data, 0, 1)
|
||||
w_data = output.to(oft_data.dtype).to(oft_data.device)
|
||||
|
||||
if "bnb_quantized" in kwargs:
|
||||
kwargs["bnb_quantized"] = False
|
||||
kwargs["requires_grad"] = False
|
||||
kwargs.pop("data", None)
|
||||
self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), **kwargs).to(weight.device)
|
||||
|
||||
def get_delta_weight(self, adapter):
|
||||
return self.oft_R[adapter].get_weight()
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
if self.disable_adapters:
|
||||
if self.merged:
|
||||
self.unmerge()
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
elif self.merged:
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
else:
|
||||
# As per Tim Dettmers, for 4bit, we need to defensively clone here.
|
||||
# The reason is that in some cases, an error can occur that backprop
|
||||
# does not work on a manipulated view. This issue may be solved with
|
||||
# newer PyTorch versions but this would need extensive testing to be
|
||||
# sure.
|
||||
# result = result.clone()
|
||||
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self.oft_R.keys():
|
||||
continue
|
||||
oft_R = self.oft_R[active_adapter]
|
||||
|
||||
requires_conversion = not torch.is_autocast_enabled()
|
||||
if requires_conversion:
|
||||
expected_dtype = x.dtype
|
||||
x = self._cast_input_dtype(x, oft_R.weight.dtype)
|
||||
|
||||
x = oft_R(x)
|
||||
if requires_conversion:
|
||||
x = x.to(expected_dtype)
|
||||
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "oft." + rep
|
||||
|
||||
def dispatch_bnb_4bit(target: torch.nn.Module, adapter_name: str, **kwargs):
|
||||
new_module = None
|
||||
|
||||
if isinstance(target, BaseTunerLayer):
|
||||
target_base_layer = target.get_base_layer()
|
||||
else:
|
||||
target_base_layer = target
|
||||
|
||||
loaded_in_4bit = kwargs.get("loaded_in_4bit", False)
|
||||
if loaded_in_4bit and is_bnb_4bit_available() and isinstance(target_base_layer, bnb.nn.Linear4bit):
|
||||
fourbit_kwargs = kwargs.copy()
|
||||
fourbit_kwargs.update(
|
||||
{
|
||||
"compute_dtype": target_base_layer.compute_dtype,
|
||||
"compress_statistics": target_base_layer.weight.compress_statistics,
|
||||
"quant_type": target_base_layer.weight.quant_type,
|
||||
}
|
||||
)
|
||||
new_module = Linear4bit(target, adapter_name, **fourbit_kwargs)
|
||||
|
||||
return new_module
|
@ -67,9 +67,9 @@ class OFTConfig(PeftConfig):
|
||||
Whether to share the OFT parameters between blocks or not. This is `False` by default.
|
||||
"""
|
||||
|
||||
r: int = field(default=8, metadata={"help": "OFT rank, number of OFT blocks per injected layer."})
|
||||
r: int = field(default=0, metadata={"help": "OFT rank, number of OFT blocks per injected layer."})
|
||||
oft_block_size: int = field(
|
||||
default=0,
|
||||
default=32,
|
||||
metadata={
|
||||
"help": "OFT block size across different layers.",
|
||||
"note": "You can only specify either r or oft_block_size, but not both simultaneously, because r x oft_block_size = layer dimension.",
|
||||
@ -144,6 +144,18 @@ class OFTConfig(PeftConfig):
|
||||
default=False,
|
||||
metadata={"help": "Whether to share the OFT parameters between blocks or not."},
|
||||
)
|
||||
use_cayley_neumann: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether to use the Cayley-Neumann Formulation of OFT or not. Set to True to improve computational efficiency but comes at costs of bigger approximation error for orthogonality."
|
||||
},
|
||||
)
|
||||
num_cayley_neumann_terms: int = field(
|
||||
default=5,
|
||||
metadata={
|
||||
"help": "Number of Cayley-Neumann terms to use. Higher number results in less approximation error for orthogonality."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
116
src/peft/tuners/oft/eetq.py
Normal file
116
src/peft/tuners/oft/eetq.py
Normal file
@ -0,0 +1,116 @@
|
||||
# Copyright 2024-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from peft.import_utils import is_eetq_available
|
||||
from peft.tuners.oft.layer import OFTLayer
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
|
||||
if is_eetq_available():
|
||||
from eetq import EetqLinear
|
||||
|
||||
class EetqOFTLinear(torch.nn.Module, OFTLayer):
|
||||
def __init__(
|
||||
self,
|
||||
base_layer,
|
||||
adapter_name,
|
||||
r: int = 0,
|
||||
oft_block_size: int = 0,
|
||||
module_dropout: float = 0.0,
|
||||
init_weights: bool = True,
|
||||
coft: bool = False,
|
||||
eps: float = 6e-5,
|
||||
block_share: bool = False,
|
||||
use_cayley_neumann: bool = False,
|
||||
num_cayley_neumann_terms: int = 5,
|
||||
fan_in_fan_out: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
OFTLayer.__init__(self, base_layer)
|
||||
|
||||
# self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter
|
||||
# for backwards compatibility
|
||||
self.quant_linear_module = base_layer
|
||||
|
||||
self._active_adapter = adapter_name
|
||||
self.update_layer(
|
||||
adapter_name,
|
||||
r,
|
||||
oft_block_size=oft_block_size,
|
||||
module_dropout=module_dropout,
|
||||
init_weights=init_weights,
|
||||
coft=coft,
|
||||
eps=eps,
|
||||
block_share=block_share,
|
||||
fan_in_fan_out=fan_in_fan_out,
|
||||
use_cayley_neumann=use_cayley_neumann,
|
||||
num_cayley_neumann_terms=num_cayley_neumann_terms,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.disable_adapters:
|
||||
return self.quant_linear_module(x)
|
||||
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self.oft_R.keys():
|
||||
continue
|
||||
oft_R = self.oft_R[active_adapter]
|
||||
|
||||
requires_conversion = not torch.is_autocast_enabled()
|
||||
if requires_conversion:
|
||||
expected_dtype = x.dtype
|
||||
x = self._cast_input_dtype(x, oft_R.weight.dtype)
|
||||
|
||||
x = oft_R(x)
|
||||
|
||||
result = self.quant_linear_module(x)
|
||||
if requires_conversion:
|
||||
result = result.to(expected_dtype)
|
||||
return result
|
||||
|
||||
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
|
||||
raise AttributeError("Merging LoRA layers is not supported for Eetq layers.")
|
||||
|
||||
def unmerge(self) -> None:
|
||||
raise AttributeError("Unmerging LoRA layers is not supported for Eetq layers.")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "oft." + rep
|
||||
|
||||
|
||||
def dispatch_eetq(
|
||||
target: torch.nn.Module,
|
||||
adapter_name: str,
|
||||
**kwargs: Any,
|
||||
) -> Optional[torch.nn.Module]:
|
||||
new_module = None
|
||||
|
||||
if isinstance(target, BaseTunerLayer):
|
||||
target_base_layer = target.get_base_layer()
|
||||
else:
|
||||
target_base_layer = target
|
||||
|
||||
if is_eetq_available() and isinstance(target_base_layer, EetqLinear):
|
||||
new_module = EetqOFTLinear(target, adapter_name, **kwargs)
|
||||
target.weight = target_base_layer.weight
|
||||
|
||||
if hasattr(target, "bias"):
|
||||
target.bias = target_base_layer.bias
|
||||
|
||||
return new_module
|
118
src/peft/tuners/oft/gptq.py
Normal file
118
src/peft/tuners/oft/gptq.py
Normal file
@ -0,0 +1,118 @@
|
||||
# Copyright 2023-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from peft.import_utils import is_gptqmodel_available
|
||||
from peft.tuners.oft.layer import OFTLayer
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
from peft.utils import get_auto_gptq_quant_linear
|
||||
|
||||
|
||||
class GPTQOFTLinear(torch.nn.Module, OFTLayer):
|
||||
def __init__(
|
||||
self,
|
||||
base_layer,
|
||||
adapter_name: str,
|
||||
r: int = 8,
|
||||
oft_block_size: int = 0,
|
||||
module_dropout: float = 0.0,
|
||||
coft: bool = False,
|
||||
eps: float = 6e-5,
|
||||
block_share: bool = False,
|
||||
use_cayley_neumann: bool = False,
|
||||
num_cayley_neumann_terms: int = 5,
|
||||
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
||||
init_weights: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
OFTLayer.__init__(self, base_layer)
|
||||
|
||||
# self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter
|
||||
# for backwards compatibility
|
||||
self.quant_linear_module = base_layer
|
||||
self._active_adapter = adapter_name
|
||||
self.update_layer(
|
||||
adapter_name,
|
||||
r,
|
||||
oft_block_size=oft_block_size,
|
||||
module_dropout=module_dropout,
|
||||
coft=coft,
|
||||
eps=eps,
|
||||
block_share=block_share,
|
||||
init_weights=init_weights,
|
||||
use_cayley_neumann=use_cayley_neumann,
|
||||
num_cayley_neumann_terms=num_cayley_neumann_terms,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
# note: logic differs from default Linear because merging is not supported
|
||||
result = self.quant_linear_module(x)
|
||||
|
||||
if self.disable_adapters:
|
||||
return self.quant_linear_module(x)
|
||||
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self.oft_R.keys():
|
||||
continue
|
||||
|
||||
oft_R = self.oft_R[active_adapter]
|
||||
|
||||
requires_conversion = not torch.is_autocast_enabled()
|
||||
if requires_conversion:
|
||||
expected_dtype = x.dtype
|
||||
x = self._cast_input_dtype(x, oft_R.weight.dtype)
|
||||
|
||||
x = oft_R(x)
|
||||
|
||||
result = self.quant_linear_module(x)
|
||||
if requires_conversion:
|
||||
result = result.to(expected_dtype)
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "oft." + rep
|
||||
|
||||
|
||||
def dispatch_gptq(
|
||||
target: torch.nn.Module,
|
||||
adapter_name: str,
|
||||
**kwargs: Any,
|
||||
) -> Optional[torch.nn.Module]:
|
||||
new_module = None
|
||||
|
||||
if isinstance(target, BaseTunerLayer):
|
||||
target_base_layer = target.get_base_layer()
|
||||
else:
|
||||
target_base_layer = target
|
||||
|
||||
cfg = kwargs.get("gptq_quantization_config", None)
|
||||
|
||||
if is_gptqmodel_available():
|
||||
from gptqmodel.nn_modules.qlinear import BaseQuantLinear
|
||||
|
||||
if isinstance(target_base_layer, BaseQuantLinear):
|
||||
new_module = GPTQOFTLinear(target, adapter_name, **kwargs)
|
||||
target.qweight = target_base_layer.qweight
|
||||
else:
|
||||
quant_linear = get_auto_gptq_quant_linear(cfg)
|
||||
|
||||
if quant_linear is not None and isinstance(target_base_layer, quant_linear):
|
||||
new_module = GPTQOFTLinear(target, adapter_name, **kwargs)
|
||||
target.qweight = target_base_layer.qweight
|
||||
|
||||
return new_module
|
186
src/peft/tuners/oft/hqq.py
Normal file
186
src/peft/tuners/oft/hqq.py
Normal file
@ -0,0 +1,186 @@
|
||||
# Copyright 2024-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from peft.import_utils import is_hqq_available
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
|
||||
|
||||
from .layer import OFTLayer
|
||||
|
||||
|
||||
if is_hqq_available():
|
||||
from hqq.core.quantize import HQQLinear
|
||||
|
||||
class HqqOFTLinear(torch.nn.Module, OFTLayer):
|
||||
# Lora implemented in a dense layer
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: torch.nn.Module,
|
||||
adapter_name: str,
|
||||
r: int = 8,
|
||||
oft_block_size: int = 0,
|
||||
module_dropout: float = 0.0,
|
||||
init_weights: bool = True,
|
||||
coft: bool = False,
|
||||
eps: float = 6e-5,
|
||||
block_share: bool = False,
|
||||
use_cayley_neumann: bool = False,
|
||||
num_cayley_neumann_terms: int = 5,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
OFTLayer.__init__(self, base_layer)
|
||||
self.fan_in_fan_out = False
|
||||
|
||||
self._active_adapter = adapter_name
|
||||
self.update_layer(
|
||||
adapter_name,
|
||||
r,
|
||||
oft_block_size=oft_block_size,
|
||||
module_dropout=module_dropout,
|
||||
init_weights=init_weights,
|
||||
coft=coft,
|
||||
eps=eps,
|
||||
block_share=block_share,
|
||||
use_cayley_neumann=use_cayley_neumann,
|
||||
num_cayley_neumann_terms=num_cayley_neumann_terms,
|
||||
)
|
||||
|
||||
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
|
||||
"""
|
||||
Merge the active adapter weights into the base weights
|
||||
|
||||
Args:
|
||||
safe_merge (`bool`, *optional*):
|
||||
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
|
||||
before merging the weights. This is useful if you want to check if the merge operation will produce
|
||||
NaNs. Defaults to `False`.
|
||||
adapter_names (`list[str]`, *optional*):
|
||||
The list of adapter names that should be merged. If None, all active adapters will be merged.
|
||||
Defaults to `None`.
|
||||
"""
|
||||
adapter_names = check_adapters_to_merge(self, adapter_names)
|
||||
if not adapter_names:
|
||||
# no adapter to merge
|
||||
return
|
||||
|
||||
for active_adapter in adapter_names:
|
||||
if active_adapter not in self.lora_A.keys():
|
||||
continue
|
||||
|
||||
layer = self.get_base_layer()
|
||||
quant_config = {**copy.deepcopy(layer.quant_config), "offload_meta": layer.offload_meta}
|
||||
|
||||
output = layer.dequantize()
|
||||
oft_data = self.get_delta_weight(active_adapter)
|
||||
|
||||
output = torch.transpose(output, 0, 1)
|
||||
w_data = torch.mm(oft_data, output.to(oft_data.dtype))
|
||||
w_data = torch.transpose(w_data, 0, 1)
|
||||
w_data = output.to(oft_data.dtype).to(oft_data.device)
|
||||
|
||||
if safe_merge and not torch.isfinite(w_data).all():
|
||||
raise ValueError(
|
||||
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
||||
)
|
||||
|
||||
new_hqq_layer = HQQLinear(None, quant_config, compute_dtype=layer.compute_dtype, device=layer.device)
|
||||
quant_config.pop("offload_meta", None)
|
||||
new_hqq_layer.quantize(w_data, **quant_config)
|
||||
self.base_layer = new_hqq_layer
|
||||
self.merged_adapters.append(active_adapter)
|
||||
|
||||
def unmerge(self) -> None:
|
||||
"""
|
||||
This method unmerges all merged adapter layers from the base weights.
|
||||
"""
|
||||
if not self.merged:
|
||||
warnings.warn("Already unmerged. Nothing to do.")
|
||||
return
|
||||
|
||||
while len(self.merged_adapters) > 0:
|
||||
active_adapter = self.merged_adapters.pop()
|
||||
if active_adapter not in self.oft_R.keys():
|
||||
continue
|
||||
|
||||
layer = self.get_base_layer()
|
||||
quant_config = {**copy.deepcopy(layer.quant_config), "offload_meta": layer.offload_meta}
|
||||
output = layer.dequantize()
|
||||
|
||||
oft_data = self.get_delta_weight(active_adapter)
|
||||
|
||||
output = torch.transpose(output, 0, 1)
|
||||
w_data = torch.mm(oft_data.t(), output.to(oft_data.dtype))
|
||||
w_data = torch.transpose(w_data, 0, 1)
|
||||
w_data = w_data.to(oft_data.dtype).to(oft_data.device)
|
||||
|
||||
new_hqq_layer = HQQLinear(None, quant_config, compute_dtype=layer.compute_dtype, device=layer.device)
|
||||
quant_config.pop("offload_meta", None)
|
||||
new_hqq_layer.quantize(w_data, **quant_config)
|
||||
self.base_layer = new_hqq_layer
|
||||
|
||||
def get_delta_weight(self, adapter):
|
||||
return self.oft_R[adapter].get_weight()
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
self._check_forward_args(x, *args, **kwargs)
|
||||
adapter_names = kwargs.pop("adapter_names", None)
|
||||
|
||||
if self.disable_adapters:
|
||||
if self.merged:
|
||||
self.unmerge()
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
elif self.merged:
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
else:
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self.oft_R.keys():
|
||||
continue
|
||||
oft_R = self.oft_R[active_adapter]
|
||||
|
||||
requires_conversion = not torch.is_autocast_enabled()
|
||||
if requires_conversion:
|
||||
expected_dtype = x.dtype
|
||||
x = self._cast_input_dtype(x, oft_R.weight.dtype)
|
||||
|
||||
x = oft_R(x)
|
||||
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
if requires_conversion:
|
||||
result = result.to(expected_dtype)
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "oft." + rep
|
||||
|
||||
|
||||
def dispatch_hqq(target: torch.nn.Module, adapter_name: str, **kwargs):
|
||||
new_module = None
|
||||
|
||||
if isinstance(target, BaseTunerLayer):
|
||||
target_base_layer = target.get_base_layer()
|
||||
else:
|
||||
target_base_layer = target
|
||||
|
||||
if is_hqq_available() and isinstance(target_base_layer, HQQLinear):
|
||||
new_module = HqqOFTLinear(target_base_layer, adapter_name, **kwargs)
|
||||
|
||||
return new_module
|
78
src/peft/tuners/oft/inc.py
Normal file
78
src/peft/tuners/oft/inc.py
Normal file
@ -0,0 +1,78 @@
|
||||
# Copyright 2025-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# NOTE: PEFT tests related to INC are handled under Optimum-Habana repository:
|
||||
# - LLMs: https://github.com/huggingface/optimum-habana/blob/main/tests/test_peft_inference.py
|
||||
# - Diffusers: https://github.com/huggingface/optimum-habana/blob/main/tests/test_diffusers.py
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from peft.import_utils import is_inc_available
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
from .layer import Linear
|
||||
|
||||
|
||||
if is_inc_available():
|
||||
|
||||
class IncOFTLinear(Linear):
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: torch.nn.Module,
|
||||
adapter_name: str,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(base_layer, adapter_name, **kwargs)
|
||||
|
||||
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
|
||||
"""
|
||||
Merge the active adapter weights into the base weights
|
||||
|
||||
Args:
|
||||
safe_merge (`bool`, *optional*):
|
||||
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
|
||||
before merging the weights. This is useful if you want to check if the merge operation will produce
|
||||
NaNs. Defaults to `False`.
|
||||
adapter_names (`list[str]`, *optional*):
|
||||
The list of adapter names that should be merged. If None, all active adapters will be merged.
|
||||
Defaults to `None`.
|
||||
"""
|
||||
raise NotImplementedError("Merging OFT with INC layers is not yet implemented")
|
||||
|
||||
def unmerge(self) -> None:
|
||||
"""
|
||||
This method unmerges all merged adapter layers from the base weights.
|
||||
"""
|
||||
raise NotImplementedError("Unmerging OFT from INC layers is not yet implemented")
|
||||
|
||||
|
||||
def dispatch_inc(target: torch.nn.Module, adapter_name: str, **kwargs):
|
||||
new_module = None
|
||||
|
||||
if isinstance(target, BaseTunerLayer):
|
||||
target_base_layer = target.get_base_layer()
|
||||
else:
|
||||
target_base_layer = target
|
||||
|
||||
if is_inc_available():
|
||||
from neural_compressor.torch.algorithms.fp8_quant._quant_common.helper_modules import (
|
||||
PatchedLinear,
|
||||
)
|
||||
|
||||
if isinstance(target_base_layer, PatchedLinear):
|
||||
new_module = IncOFTLinear(target, adapter_name, **kwargs)
|
||||
|
||||
return new_module
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
@ -23,6 +22,8 @@ import torch.nn.functional as F
|
||||
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
|
||||
|
||||
from .config import OFTConfig
|
||||
|
||||
|
||||
class MultiplicativeDropoutLayer(nn.Module):
|
||||
"""
|
||||
@ -48,7 +49,7 @@ class MultiplicativeDropoutLayer(nn.Module):
|
||||
the number of OFT blocks, and `H` is the size of the square blocks along the last two dimensions,
|
||||
the block size in OFT.
|
||||
"""
|
||||
if self.training:
|
||||
if self.training and self.p > 0:
|
||||
# Ensure the last two dimensions are the same
|
||||
if x.shape[-1] != x.shape[-2]:
|
||||
raise ValueError("The last two dimensions of input should be the same!")
|
||||
@ -68,15 +69,249 @@ class MultiplicativeDropoutLayer(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class OFTRotationModule(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
r,
|
||||
n_elements,
|
||||
block_size,
|
||||
in_features,
|
||||
coft=False,
|
||||
eps=6e-5,
|
||||
block_share=False,
|
||||
kernel_size=(0, 0),
|
||||
use_cayley_neumann=True,
|
||||
num_cayley_neumann_terms=5,
|
||||
):
|
||||
super().__init__()
|
||||
self.r = r
|
||||
self.n_elements = n_elements
|
||||
self.block_size = block_size
|
||||
self.in_features = in_features
|
||||
self.weight = nn.Parameter(torch.empty(r, n_elements))
|
||||
self.coft = coft
|
||||
self.eps = eps
|
||||
self.block_share = block_share
|
||||
# Conv2d specific parameters
|
||||
self.kernel_size = kernel_size
|
||||
self.use_cayley_neumann = use_cayley_neumann
|
||||
self.num_cayley_neumann_terms = num_cayley_neumann_terms
|
||||
# Create indices for upper triangle (excluding diagonal)
|
||||
self.rows, self.cols = torch.triu_indices(block_size, block_size, 1)
|
||||
|
||||
def _pytorch_skew_symmetric(self, vec, block_size):
|
||||
batch_size = vec.shape[0]
|
||||
matrix = torch.zeros(batch_size, block_size, block_size, device=vec.device, dtype=vec.dtype)
|
||||
|
||||
matrix[:, self.rows, self.cols] = vec
|
||||
matrix = matrix - matrix.transpose(-2, -1)
|
||||
return matrix
|
||||
|
||||
def _pytorch_skew_symmetric_inv(self, matrix, block_size):
|
||||
batch_size = matrix.shape[0]
|
||||
|
||||
# Extract the upper triangular elements
|
||||
vec = matrix[:, self.rows, self.cols]
|
||||
return vec
|
||||
|
||||
def _cayley_batch(
|
||||
self, Q: torch.Tensor, block_size: int, use_cayley_neumann: bool = True, num_neumann_terms: int = 5
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform the Cayley parametrization on a batch of skew-symmetric matrices.
|
||||
|
||||
Args:
|
||||
data: A batch of skew-symmetric matrices of shape (b, r, c).
|
||||
"""
|
||||
|
||||
b, _ = Q.shape
|
||||
previous_dtype = Q.dtype
|
||||
|
||||
# Q_skew = SkewSymmetric.apply(Q, block_size)
|
||||
Q_skew = self._pytorch_skew_symmetric(Q, block_size)
|
||||
|
||||
if use_cayley_neumann:
|
||||
R = torch.eye(block_size, device=Q.device, dtype=Q.dtype).repeat(b, 1, 1)
|
||||
if num_neumann_terms > 1:
|
||||
R.add_(Q_skew, alpha=2.0)
|
||||
if num_neumann_terms > 2:
|
||||
Q_squared = torch.bmm(Q_skew, Q_skew)
|
||||
R.add_(Q_squared, alpha=2.0)
|
||||
|
||||
Q_power = Q_squared
|
||||
for i in range(3, num_neumann_terms):
|
||||
Q_power = torch.bmm(Q_power, Q_skew)
|
||||
R.add_(Q_power, alpha=2.0)
|
||||
else:
|
||||
id_mat = (
|
||||
torch.eye(Q_skew.shape[-1], device=Q_skew.device)
|
||||
.unsqueeze(0)
|
||||
.expand(b, Q_skew.shape[-1], Q_skew.shape[-1])
|
||||
)
|
||||
R = torch.linalg.solve(id_mat + Q_skew, id_mat - Q_skew, left=False)
|
||||
|
||||
return R.to(previous_dtype)
|
||||
|
||||
# Copied from https://github.com/Zeju1997/oft/blob/84cebb965df69781e3d9c3c875f5980b421eaf24/oft-control/oft.py#L52
|
||||
def _project_batch(self, Q, eps=1e-5):
|
||||
oft_R = self._pytorch_skew_symmetric(Q, self.block_size)
|
||||
# scaling factor for each of the smaller block matrix
|
||||
eps = eps * 1 / torch.sqrt(torch.tensor(oft_R.shape[0]))
|
||||
I = ( # noqa: E741
|
||||
torch.zeros((oft_R.size(1), oft_R.size(1)), device=oft_R.device, dtype=oft_R.dtype)
|
||||
.unsqueeze(0)
|
||||
.expand_as(oft_R)
|
||||
)
|
||||
diff = oft_R - I
|
||||
norm_diff = torch.norm(oft_R - I, dim=(1, 2), keepdim=True)
|
||||
mask = (norm_diff <= eps).bool()
|
||||
out = torch.where(mask, oft_R, I + eps * (diff / norm_diff))
|
||||
|
||||
return self._pytorch_skew_symmetric_inv(out, self.block_size)
|
||||
|
||||
# Copied from https://github.com/Zeju1997/oft/blob/84cebb965df69781e3d9c3c875f5980b421eaf24/oft-control/oft.py#L155
|
||||
def _block_diagonal(self, oft_R: torch.Tensor, rank: int) -> torch.Tensor:
|
||||
if oft_R.shape[0] == 1:
|
||||
# block share
|
||||
blocks = [oft_R[0, ...] for i in range(rank)]
|
||||
else:
|
||||
blocks = [oft_R[i, ...] for i in range(rank)]
|
||||
|
||||
# Use torch.block_diag to create the block diagonal matrix
|
||||
A = torch.block_diag(*blocks)
|
||||
|
||||
return A
|
||||
|
||||
def _unfold(self, x):
|
||||
"""
|
||||
Unfold with stride=1, padding=0 to preserve spatial dimensions. Only use kernel_size from base layer to define
|
||||
patch size.
|
||||
"""
|
||||
batch_size, in_channels, in_height, in_width = x.shape
|
||||
|
||||
if isinstance(self.kernel_size, int):
|
||||
kernel_height, kernel_width = self.kernel_size, self.kernel_size
|
||||
else:
|
||||
kernel_height, kernel_width = self.kernel_size
|
||||
|
||||
stride_h = stride_w = 1
|
||||
pad_h = pad_w = 0
|
||||
|
||||
# output dimensions
|
||||
out_height = (in_height + 2 * pad_h - kernel_height) // stride_h + 1
|
||||
out_width = (in_width + 2 * pad_w - kernel_width) // stride_w + 1
|
||||
|
||||
# Reshape input from [B, C, H, W] to [B, C, H_out, W_out, K_H, K_W]
|
||||
x_unfolded = x.unfold(2, kernel_height, stride_h).unfold(3, kernel_width, stride_w)
|
||||
x_unfolded = x_unfolded.permute(0, 2, 3, 1, 4, 5).contiguous()
|
||||
x_unfolded = x_unfolded.view(batch_size * out_height * out_width, -1)
|
||||
|
||||
return x_unfolded
|
||||
|
||||
def _fold(self, x_unfolded, orig_shape):
|
||||
"""
|
||||
Fold back to preserve spatial dimensions.
|
||||
"""
|
||||
batch_size, in_channels, in_height, in_width = orig_shape
|
||||
|
||||
if isinstance(self.kernel_size, int):
|
||||
kernel_height, kernel_width = self.kernel_size, self.kernel_size
|
||||
else:
|
||||
kernel_height, kernel_width = self.kernel_size
|
||||
|
||||
# With stride=1, padding=0:
|
||||
out_height = in_height - kernel_height + 1
|
||||
out_width = in_width - kernel_width + 1
|
||||
|
||||
# Reshape: [B*H_out*W_out, C*K_H*K_W] -> [B, H_out, W_out, C, K_H, K_W]
|
||||
x_reshaped = x_unfolded.view(batch_size, out_height, out_width, in_channels, kernel_height, kernel_width)
|
||||
|
||||
# Permute to: [B, C, H_out, W_out, K_H, K_W]
|
||||
x_reshaped = x_reshaped.permute(0, 3, 1, 2, 4, 5).contiguous()
|
||||
|
||||
# Use F.fold to reconstruct 4D tensor
|
||||
x_folded = F.fold(
|
||||
x_reshaped.view(batch_size, in_channels * kernel_height * kernel_width, out_height * out_width),
|
||||
output_size=(in_height, in_width),
|
||||
kernel_size=(kernel_height, kernel_width),
|
||||
stride=(1, 1),
|
||||
)
|
||||
|
||||
return x_folded
|
||||
|
||||
def forward(self, x):
|
||||
# This module doesn't need to implement the orthogonal transform
|
||||
# It's primarily a container for the parameter
|
||||
# The actual transformation logic stays in your OFTLayer
|
||||
|
||||
required_dtype = x.dtype
|
||||
if required_dtype != self.weight.dtype:
|
||||
x = x.to(self.weight.dtype)
|
||||
|
||||
orig_shape = x.shape
|
||||
|
||||
if self.coft:
|
||||
with torch.no_grad():
|
||||
self.weight.copy_(self._project_batch(self.weight, eps=self.eps))
|
||||
|
||||
orth_rotate = self._cayley_batch(
|
||||
self.weight, self.block_size, self.use_cayley_neumann, self.num_cayley_neumann_terms
|
||||
)
|
||||
|
||||
# Unfold the input for Conv2d layer
|
||||
if len(orig_shape) == 4:
|
||||
x = self._unfold(x)
|
||||
|
||||
folded_shape = x.shape
|
||||
rank = self.in_features // self.block_size if self.block_share else self.r
|
||||
batch_dims = x.shape[:-1]
|
||||
x_reshaped = x.reshape(*batch_dims, rank, self.block_size)
|
||||
|
||||
if self.block_share:
|
||||
orth_rotate = orth_rotate.repeat(rank, 1, 1)
|
||||
x_rotated_reshaped = torch.einsum("...rk,rkc->...rc", x_reshaped, orth_rotate)
|
||||
else:
|
||||
x_rotated_reshaped = torch.einsum("...rk,rkc->...rc", x_reshaped, orth_rotate)
|
||||
|
||||
x_rotated = x_rotated_reshaped.reshape(*folded_shape)
|
||||
|
||||
if len(orig_shape) == 4:
|
||||
x_rotated = self._fold(x_rotated, orig_shape)
|
||||
|
||||
return x_rotated.to(required_dtype)
|
||||
|
||||
def get_weight(self):
|
||||
"""
|
||||
Compute the delta weight for the given adapter.
|
||||
|
||||
Args:
|
||||
adapter (str):
|
||||
The name of the adapter for which the delta weight should be computed.
|
||||
"""
|
||||
weight = self.weight
|
||||
|
||||
if self.coft:
|
||||
with torch.no_grad():
|
||||
weight = self._project_batch(weight, eps=self.eps)
|
||||
self.weight.copy_(weight)
|
||||
|
||||
orth_rotate = self._cayley_batch(
|
||||
weight, self.block_size, self.use_cayley_neumann, self.num_cayley_neumann_terms
|
||||
)
|
||||
|
||||
rank = self.r if not self.block_share else self.in_features // self.block_size
|
||||
return self._block_diagonal(orth_rotate, rank)
|
||||
|
||||
|
||||
class OFTLayer(BaseTunerLayer):
|
||||
"""
|
||||
Implements the OFT layer.
|
||||
"""
|
||||
|
||||
# All names of layers that may contain adapter weights
|
||||
adapter_layer_names = ("oft_r", "oft_s")
|
||||
# other_param_names is defined on parent class
|
||||
other_param_names = ("r", "oft_block_size", "oft_dropout")
|
||||
# All names of layers that may contain (trainable) adapter weights
|
||||
adapter_layer_names: tuple[str, ...] = ("oft_R",)
|
||||
# All names of other parameters that may contain adapter-related parameters
|
||||
other_param_names: tuple[str, ...] = ("r", "oft_block_size", "oft_dropout")
|
||||
|
||||
def __init__(self, base_layer: nn.Module, **kwargs) -> None:
|
||||
"""
|
||||
@ -89,15 +324,11 @@ class OFTLayer(BaseTunerLayer):
|
||||
base_layer: the pretrained model layer
|
||||
"""
|
||||
self.base_layer = base_layer
|
||||
# OFT info
|
||||
self.oft_r = nn.ParameterDict({})
|
||||
self.oft_s = nn.ParameterDict({})
|
||||
self.oft_R = nn.ModuleDict({})
|
||||
self.oft_block_size = {}
|
||||
self.r = {}
|
||||
self.oft_block_size = {}
|
||||
self.oft_dropout = nn.ModuleDict({})
|
||||
self.coft = {}
|
||||
self.eps = {}
|
||||
self.block_share = {}
|
||||
# Mark the weight as unmerged
|
||||
self._disable_adapters = False
|
||||
self.merged_adapters = []
|
||||
@ -106,20 +337,44 @@ class OFTLayer(BaseTunerLayer):
|
||||
self.kwargs = kwargs
|
||||
|
||||
base_layer = self.get_base_layer()
|
||||
|
||||
if isinstance(base_layer, nn.Linear):
|
||||
in_features, out_features = base_layer.in_features, base_layer.out_features
|
||||
elif isinstance(base_layer, nn.Conv2d):
|
||||
in_features, out_features = base_layer.in_channels, base_layer.out_channels
|
||||
elif hasattr(base_layer, "infeatures") and hasattr(base_layer, "outfeatures"):
|
||||
# QuantLinear
|
||||
in_features, out_features = base_layer.infeatures, base_layer.outfeatures
|
||||
elif hasattr(base_layer, "input_size") and hasattr(base_layer, "output_size"):
|
||||
# Megatron ColumnParallelLinear,RowParallelLinear
|
||||
in_features, out_features = base_layer.input_size, base_layer.output_size
|
||||
elif hasattr(base_layer, "codebooks") and base_layer.__class__.__name__ == "QuantizedLinear":
|
||||
# AQLM QuantLinear
|
||||
in_features, out_features = base_layer.in_features, base_layer.out_features
|
||||
elif hasattr(base_layer, "w_bit") and base_layer.__class__.__name__ == "WQLinear_GEMM":
|
||||
# Awq layers
|
||||
in_features, out_features = base_layer.in_features, base_layer.out_features
|
||||
elif base_layer.__class__.__name__ == "EetqLinear":
|
||||
# Eetq layers
|
||||
in_features, out_features = base_layer.in_features, base_layer.out_features
|
||||
elif hasattr(base_layer, "W_q") and base_layer.__class__.__name__ == "HQQLinear":
|
||||
# HQQ layers
|
||||
in_features, out_features = base_layer.in_features, base_layer.out_features
|
||||
else:
|
||||
raise ValueError(f"Unsupported layer type {type(base_layer)}")
|
||||
# possibly support user provided custom layer types using dynamic dispatch
|
||||
if hasattr(base_layer, "in_features") and hasattr(base_layer, "out_features"):
|
||||
in_features, out_features = base_layer.in_features, base_layer.out_features
|
||||
else:
|
||||
in_features, out_features = None, None
|
||||
warnings.warn(
|
||||
f"Unsupported layer type '{type(base_layer)}' encountered, proceed at your own risk.", UserWarning
|
||||
)
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
|
||||
@property
|
||||
def _available_adapters(self) -> set[str]:
|
||||
return {*self.oft_r}
|
||||
return {*self.oft_R}
|
||||
|
||||
def set_scale(self, adapter, scale):
|
||||
if adapter not in self.scaling:
|
||||
@ -133,19 +388,31 @@ class OFTLayer(BaseTunerLayer):
|
||||
return
|
||||
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self.oft_r.keys():
|
||||
if active_adapter not in self.oft_R.keys():
|
||||
continue
|
||||
|
||||
warnings.warn("Scaling operation for OFT not supported! Automatically set scale to 1.")
|
||||
|
||||
def unscale_layer(self, scale=None) -> None:
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self.oft_r.keys():
|
||||
if active_adapter not in self.oft_R.keys():
|
||||
continue
|
||||
|
||||
warnings.warn("Unscaling operation for OFT not supported! Keeping scale to 1.")
|
||||
|
||||
def update_layer(self, adapter_name, r, oft_block_size, module_dropout, coft, eps, block_share, init_weights):
|
||||
def update_layer(
|
||||
self,
|
||||
adapter_name,
|
||||
r,
|
||||
oft_block_size,
|
||||
module_dropout,
|
||||
coft,
|
||||
eps,
|
||||
block_share,
|
||||
init_weights,
|
||||
use_cayley_neumann,
|
||||
num_cayley_neumann_terms,
|
||||
):
|
||||
"""
|
||||
Update the linear layer with trainable OFT weights. Override for other layer types.
|
||||
"""
|
||||
@ -189,20 +456,19 @@ class OFTLayer(BaseTunerLayer):
|
||||
"Something went wrong, please report this error: https://github.com/huggingface/peft/issues"
|
||||
)
|
||||
|
||||
self.coft[adapter_name] = coft
|
||||
self.block_share[adapter_name] = block_share
|
||||
self.eps[adapter_name] = eps * math.ceil(self.out_features / r) * math.ceil(self.out_features / r)
|
||||
|
||||
# Create weights with provided shape
|
||||
if block_share:
|
||||
self.oft_r[adapter_name] = nn.Parameter(
|
||||
torch.empty(1, math.ceil(self.in_features / r), math.ceil(self.in_features / r))
|
||||
)
|
||||
else:
|
||||
self.oft_r[adapter_name] = nn.Parameter(
|
||||
torch.empty(r, math.ceil(self.in_features / r), math.ceil(self.in_features / r))
|
||||
)
|
||||
self.oft_s[adapter_name] = nn.Parameter(torch.empty(int(self.out_features), 1))
|
||||
n_elements = oft_block_size * (oft_block_size - 1) // 2
|
||||
self.oft_R[adapter_name] = OFTRotationModule(
|
||||
r if not block_share else 1,
|
||||
n_elements,
|
||||
oft_block_size,
|
||||
self.in_features,
|
||||
coft=coft,
|
||||
eps=eps,
|
||||
block_share=block_share,
|
||||
use_cayley_neumann=use_cayley_neumann,
|
||||
num_cayley_neumann_terms=num_cayley_neumann_terms,
|
||||
)
|
||||
|
||||
# Initialize weights
|
||||
self.reset_oft_parameters(adapter_name, init_weights)
|
||||
@ -220,63 +486,16 @@ class OFTLayer(BaseTunerLayer):
|
||||
Reset the OFT parameters.
|
||||
"""
|
||||
if init_weights is False:
|
||||
nn.init.normal_(self.oft_r[adapter_name], mean=0.0, std=0.1)
|
||||
nn.init.normal_(self.oft_s[adapter_name], mean=1.0, std=0.1)
|
||||
nn.init.normal_(self.oft_R[adapter_name].weight, mean=0.0, std=0.1)
|
||||
return
|
||||
|
||||
if adapter_name in self.oft_r.keys():
|
||||
if adapter_name in self.oft_R.keys():
|
||||
if init_weights is True:
|
||||
# initialize oft_r to zero
|
||||
nn.init.zeros_(self.oft_r[adapter_name])
|
||||
nn.init.ones_(self.oft_s[adapter_name])
|
||||
# initialize oft_R to zero
|
||||
nn.init.zeros_(self.oft_R[adapter_name].weight)
|
||||
else:
|
||||
raise ValueError(f"Unknown initialization {init_weights=}")
|
||||
|
||||
def _cayley_batch(self, data: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Perform the Cayley parametrization on a batch of skew-symmetric matrices.
|
||||
|
||||
Args:
|
||||
data: A batch of skew-symmetric matrices of shape (b, r, c).
|
||||
"""
|
||||
b, r, c = data.shape
|
||||
# Ensure the input matrix is skew-symmetric
|
||||
skew_mat = 0.5 * (data - data.transpose(1, 2))
|
||||
id_mat = torch.eye(r, device=data.device).unsqueeze(0).expand(b, r, c) # noqa: E741
|
||||
|
||||
# Perform the Cayley parametrization
|
||||
Q = torch.linalg.solve(id_mat + skew_mat, id_mat - skew_mat, left=False)
|
||||
|
||||
return Q
|
||||
|
||||
# Copied from https://github.com/Zeju1997/oft/blob/84cebb965df69781e3d9c3c875f5980b421eaf24/oft-control/oft.py#L155
|
||||
def _block_diagonal(self, oft_r: torch.Tensor, rank: int) -> torch.Tensor:
|
||||
if oft_r.shape[0] == 1:
|
||||
# block share
|
||||
blocks = [oft_r[0, ...] for i in range(rank)]
|
||||
else:
|
||||
blocks = [oft_r[i, ...] for i in range(rank)]
|
||||
|
||||
# Use torch.block_diag to create the block diagonal matrix
|
||||
A = torch.block_diag(*blocks)
|
||||
|
||||
return A
|
||||
|
||||
# Copied from https://github.com/Zeju1997/oft/blob/84cebb965df69781e3d9c3c875f5980b421eaf24/oft-control/oft.py#L52
|
||||
def _project_batch(self, oft_r, eps=1e-5):
|
||||
# scaling factor for each of the smaller block matrix
|
||||
eps = eps * 1 / torch.sqrt(torch.tensor(oft_r.shape[0]))
|
||||
I = ( # noqa: E741
|
||||
torch.zeros((oft_r.size(1), oft_r.size(1)), device=oft_r.device, dtype=oft_r.dtype)
|
||||
.unsqueeze(0)
|
||||
.expand_as(oft_r)
|
||||
)
|
||||
diff = oft_r - I
|
||||
norm_diff = torch.norm(oft_r - I, dim=(1, 2), keepdim=True)
|
||||
mask = (norm_diff <= eps).bool()
|
||||
out = torch.where(mask, oft_r, I + eps * (diff / norm_diff))
|
||||
return out
|
||||
|
||||
def adjust_oft_parameters(self, in_features, params):
|
||||
"""
|
||||
Adjust the OFT parameters to be divisible by the in_features dimension.
|
||||
@ -311,6 +530,8 @@ class Linear(nn.Module, OFTLayer):
|
||||
coft: bool = False,
|
||||
eps: float = 6e-5,
|
||||
block_share: bool = False,
|
||||
use_cayley_neumann: bool = False,
|
||||
num_cayley_neumann_terms: int = 5,
|
||||
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
||||
init_weights: Union[bool, str] = True,
|
||||
is_target_conv_1d_layer: bool = False,
|
||||
@ -322,7 +543,18 @@ class Linear(nn.Module, OFTLayer):
|
||||
|
||||
self._active_adapter = adapter_name
|
||||
|
||||
self.update_layer(adapter_name, r, oft_block_size, module_dropout, coft, eps, block_share, init_weights)
|
||||
self.update_layer(
|
||||
adapter_name,
|
||||
r,
|
||||
oft_block_size=oft_block_size,
|
||||
module_dropout=module_dropout,
|
||||
coft=coft,
|
||||
eps=eps,
|
||||
block_share=block_share,
|
||||
init_weights=init_weights,
|
||||
use_cayley_neumann=use_cayley_neumann,
|
||||
num_cayley_neumann_terms=num_cayley_neumann_terms,
|
||||
)
|
||||
self.is_target_conv_1d_layer = is_target_conv_1d_layer
|
||||
|
||||
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
|
||||
@ -349,13 +581,11 @@ class Linear(nn.Module, OFTLayer):
|
||||
orig_dtype = base_layer.weight.dtype
|
||||
if safe_merge:
|
||||
# Note that safe_merge will be slower than the normal merge
|
||||
# because of the copy operation.
|
||||
orig_weights = base_layer.weight.data
|
||||
oft_mat, oft_s = self.get_delta_weight(active_adapter)
|
||||
oft_mat = self.get_delta_weight(active_adapter)
|
||||
orig_weights = torch.transpose(orig_weights, 0, 1)
|
||||
orig_weights = torch.mm(oft_mat, orig_weights.to(oft_mat.dtype))
|
||||
orig_weights = torch.transpose(orig_weights, 0, 1)
|
||||
orig_weights = orig_weights * oft_s
|
||||
|
||||
if not torch.isfinite(orig_weights).all():
|
||||
raise ValueError(
|
||||
@ -364,12 +594,11 @@ class Linear(nn.Module, OFTLayer):
|
||||
|
||||
base_layer.weight.data = orig_weights.contiguous().to(orig_dtype)
|
||||
else:
|
||||
oft_mat, oft_s = self.get_delta_weight(active_adapter)
|
||||
orig_weights = base_layer.weight.data
|
||||
oft_mat = self.get_delta_weight(active_adapter)
|
||||
orig_weights = torch.transpose(orig_weights, 0, 1)
|
||||
orig_weights = torch.mm(oft_mat, orig_weights.to(oft_mat.dtype))
|
||||
orig_weights = torch.transpose(orig_weights, 0, 1)
|
||||
orig_weights = orig_weights * oft_s
|
||||
|
||||
base_layer.weight.data = orig_weights.contiguous().to(orig_dtype)
|
||||
|
||||
@ -387,15 +616,15 @@ class Linear(nn.Module, OFTLayer):
|
||||
orig_dtype = base_layer.weight.dtype
|
||||
while len(self.merged_adapters) > 0:
|
||||
active_adapter = self.merged_adapters.pop()
|
||||
if active_adapter in self.oft_r.keys():
|
||||
oft_mat, oft_s = self.get_delta_weight(active_adapter)
|
||||
if active_adapter in self.oft_R.keys():
|
||||
oft_mat = self.get_delta_weight(active_adapter)
|
||||
|
||||
orig_weights = self.get_base_layer().weight.data
|
||||
orig_weights = torch.transpose(orig_weights, 0, 1)
|
||||
orig_weights = torch.mm(oft_mat.t(), orig_weights.to(oft_mat.dtype))
|
||||
orig_weights = torch.transpose(orig_weights, 0, 1)
|
||||
|
||||
base_layer.weight.data = (orig_weights * (1 / oft_s)).to(orig_dtype)
|
||||
base_layer.weight.data = orig_weights.to(orig_dtype)
|
||||
|
||||
def get_delta_weight(self, adapter_name) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@ -405,21 +634,8 @@ class Linear(nn.Module, OFTLayer):
|
||||
adapter (str):
|
||||
The name of the adapter for which the delta weight should be computed.
|
||||
"""
|
||||
oft_r = self.oft_r[adapter_name]
|
||||
oft_s = self.oft_s[adapter_name]
|
||||
|
||||
rank = self.r[adapter_name]
|
||||
coft = self.coft[adapter_name]
|
||||
eps = self.eps[adapter_name]
|
||||
|
||||
if coft:
|
||||
with torch.no_grad():
|
||||
oft_r.copy_(self._project_batch(oft_r, eps=eps))
|
||||
|
||||
orth_rotate = self._cayley_batch(oft_r)
|
||||
weight = self._block_diagonal(orth_rotate, rank)
|
||||
|
||||
return weight, oft_s
|
||||
return self.oft_R[adapter_name].get_weight()
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
previous_dtype = x.dtype
|
||||
@ -431,42 +647,15 @@ class Linear(nn.Module, OFTLayer):
|
||||
elif self.merged:
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
else:
|
||||
oft_rotation = torch.eye(self.in_features, device=x.device)
|
||||
oft_scale = torch.ones((int(self.out_features), 1), device=x.device)
|
||||
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self.oft_r.keys():
|
||||
if active_adapter not in self.oft_R.keys():
|
||||
continue
|
||||
oft_r = self.oft_r[active_adapter]
|
||||
oft_s = self.oft_s[active_adapter]
|
||||
dropout = self.oft_dropout[active_adapter]
|
||||
oft_R = self.oft_R[active_adapter]
|
||||
|
||||
rank = self.r[active_adapter]
|
||||
coft = self.coft[active_adapter]
|
||||
eps = self.eps[active_adapter]
|
||||
x = self._cast_input_dtype(x, oft_R.weight.dtype)
|
||||
x = oft_R(x)
|
||||
|
||||
if coft:
|
||||
with torch.no_grad():
|
||||
oft_r.copy_(self._project_batch(oft_r, eps=eps))
|
||||
|
||||
orth_rotate = self._cayley_batch(oft_r)
|
||||
orth_rotate = dropout(orth_rotate)
|
||||
oft_mat = self._block_diagonal(orth_rotate, rank)
|
||||
|
||||
oft_rotation = oft_mat @ oft_rotation
|
||||
oft_scale = oft_s * oft_scale
|
||||
|
||||
x = x.to(self.get_base_layer().weight.data.dtype)
|
||||
|
||||
orig_weight = self.get_base_layer().weight.data
|
||||
orig_weight = torch.transpose(orig_weight, 0, 1)
|
||||
rotated_weight = torch.mm(oft_rotation, orig_weight.to(oft_rotation.dtype))
|
||||
rotated_weight = torch.transpose(rotated_weight, 0, 1)
|
||||
scaled_rotated_weight = rotated_weight * oft_scale
|
||||
|
||||
x = self._cast_input_dtype(x, scaled_rotated_weight.dtype)
|
||||
bias = self._cast_input_dtype(self.get_base_layer().bias, scaled_rotated_weight.dtype)
|
||||
result = F.linear(input=x, weight=scaled_rotated_weight, bias=bias)
|
||||
result = self.base_layer(x.to(previous_dtype), *args, **kwargs)
|
||||
|
||||
result = result.to(previous_dtype)
|
||||
return result
|
||||
@ -491,6 +680,8 @@ class Conv2d(nn.Module, OFTLayer):
|
||||
eps: float = 6e-5,
|
||||
block_share: bool = False,
|
||||
init_weights: Union[bool, str] = True,
|
||||
use_cayley_neumann: bool = False,
|
||||
num_cayley_neumann_terms: int = 5,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -500,9 +691,32 @@ class Conv2d(nn.Module, OFTLayer):
|
||||
self._active_adapter = adapter_name
|
||||
|
||||
# Create adapter and set it active
|
||||
self.update_layer(adapter_name, r, oft_block_size, module_dropout, coft, eps, block_share, init_weights)
|
||||
self.update_layer(
|
||||
adapter_name,
|
||||
r,
|
||||
oft_block_size=oft_block_size,
|
||||
module_dropout=module_dropout,
|
||||
coft=coft,
|
||||
eps=eps,
|
||||
block_share=block_share,
|
||||
init_weights=init_weights,
|
||||
use_cayley_neumann=use_cayley_neumann,
|
||||
num_cayley_neumann_terms=num_cayley_neumann_terms,
|
||||
)
|
||||
|
||||
def update_layer(self, adapter_name, r, oft_block_size, module_dropout, coft, eps, block_share, init_weights):
|
||||
def update_layer(
|
||||
self,
|
||||
adapter_name,
|
||||
r,
|
||||
oft_block_size,
|
||||
module_dropout,
|
||||
coft,
|
||||
eps,
|
||||
block_share,
|
||||
init_weights,
|
||||
use_cayley_neumann,
|
||||
num_cayley_neumann_terms,
|
||||
):
|
||||
"""
|
||||
Update the conv2d layer with trainable OFT weights.
|
||||
"""
|
||||
@ -515,6 +729,9 @@ class Conv2d(nn.Module, OFTLayer):
|
||||
|
||||
# layer information from the base layer
|
||||
base_layer = self.get_base_layer()
|
||||
if base_layer.dilation[0] > 1:
|
||||
raise ValueError("Conv2d with dilation > 1 is not supported by OFT.")
|
||||
|
||||
conv_filter_dim = self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0]
|
||||
|
||||
if r == 0 and oft_block_size != 0:
|
||||
@ -536,20 +753,20 @@ class Conv2d(nn.Module, OFTLayer):
|
||||
"Something went wrong, please report this error: https://github.com/huggingface/peft/issues"
|
||||
)
|
||||
|
||||
self.coft[adapter_name] = coft
|
||||
self.block_share[adapter_name] = block_share
|
||||
self.eps[adapter_name] = eps * math.ceil(self.out_features / r) * math.ceil(self.out_features / r)
|
||||
|
||||
# Create weights with provided shape
|
||||
if block_share:
|
||||
self.oft_r[adapter_name] = nn.Parameter(
|
||||
torch.empty(1, math.ceil(conv_filter_dim / r), math.ceil(conv_filter_dim / r))
|
||||
)
|
||||
else:
|
||||
self.oft_r[adapter_name] = nn.Parameter(
|
||||
torch.empty(r, math.ceil(conv_filter_dim / r), math.ceil(conv_filter_dim / r))
|
||||
)
|
||||
self.oft_s[adapter_name] = nn.Parameter(torch.empty(int(self.out_features), 1))
|
||||
n_elements = oft_block_size * (oft_block_size - 1) // 2
|
||||
self.oft_R[adapter_name] = OFTRotationModule(
|
||||
r if not block_share else 1,
|
||||
n_elements,
|
||||
oft_block_size,
|
||||
conv_filter_dim,
|
||||
coft=coft,
|
||||
eps=eps,
|
||||
block_share=block_share,
|
||||
kernel_size=base_layer.kernel_size,
|
||||
use_cayley_neumann=use_cayley_neumann,
|
||||
num_cayley_neumann_terms=num_cayley_neumann_terms,
|
||||
)
|
||||
|
||||
# Initialize weights
|
||||
self.reset_oft_parameters(adapter_name, init_weights)
|
||||
@ -581,14 +798,14 @@ class Conv2d(nn.Module, OFTLayer):
|
||||
return
|
||||
|
||||
for active_adapter in adapter_names:
|
||||
if active_adapter in self.oft_r.keys():
|
||||
if active_adapter in self.oft_R.keys():
|
||||
base_layer = self.get_base_layer()
|
||||
orig_dtype = base_layer.weight.dtype
|
||||
if safe_merge:
|
||||
# Note that safe_merge will be slower than the normal merge
|
||||
# because of the copy operation.
|
||||
orig_weights = base_layer.weight.data.clone()
|
||||
oft_mat, oft_s = self.get_delta_weight(active_adapter)
|
||||
oft_mat = self.get_delta_weight(active_adapter)
|
||||
|
||||
orig_weights = orig_weights.view(
|
||||
self.out_features, self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0]
|
||||
@ -596,14 +813,13 @@ class Conv2d(nn.Module, OFTLayer):
|
||||
orig_weights = torch.transpose(orig_weights, 0, 1)
|
||||
orig_weights = torch.mm(oft_mat, orig_weights.to(oft_mat.dtype))
|
||||
orig_weights = torch.transpose(orig_weights, 0, 1)
|
||||
orig_weights = orig_weights * oft_s
|
||||
orig_weights = orig_weights.view(
|
||||
self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0]
|
||||
)
|
||||
|
||||
base_layer.weight.data = orig_weights.contiguous().to(orig_dtype)
|
||||
else:
|
||||
oft_mat, oft_s = self.get_delta_weight(active_adapter)
|
||||
oft_mat = self.get_delta_weight(active_adapter)
|
||||
|
||||
orig_weights = base_layer.weight.data.clone()
|
||||
orig_weights = orig_weights.view(
|
||||
@ -612,7 +828,6 @@ class Conv2d(nn.Module, OFTLayer):
|
||||
orig_weights = torch.transpose(orig_weights, 0, 1)
|
||||
orig_weights = torch.mm(oft_mat, orig_weights.to(oft_mat.dtype))
|
||||
orig_weights = torch.transpose(orig_weights, 0, 1)
|
||||
orig_weights = orig_weights * oft_s
|
||||
orig_weights = orig_weights.view(
|
||||
self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0]
|
||||
)
|
||||
@ -633,8 +848,8 @@ class Conv2d(nn.Module, OFTLayer):
|
||||
orig_dtype = base_layer.weight.dtype
|
||||
while len(self.merged_adapters) > 0:
|
||||
active_adapter = self.merged_adapters.pop()
|
||||
if active_adapter in self.oft_r.keys():
|
||||
oft_mat, oft_s = self.get_delta_weight(active_adapter)
|
||||
if active_adapter in self.oft_R.keys():
|
||||
oft_mat = self.get_delta_weight(active_adapter)
|
||||
|
||||
orig_weights = self.get_base_layer().weight.data.clone()
|
||||
orig_weights = orig_weights.view(
|
||||
@ -644,7 +859,6 @@ class Conv2d(nn.Module, OFTLayer):
|
||||
orig_weights = torch.transpose(orig_weights, 0, 1)
|
||||
orig_weights = torch.mm(oft_mat.t(), orig_weights.to(oft_mat.dtype))
|
||||
orig_weights = torch.transpose(orig_weights, 0, 1)
|
||||
orig_weights = orig_weights * (1 / oft_s)
|
||||
orig_weights = orig_weights.view(
|
||||
self.out_features,
|
||||
self.in_features,
|
||||
@ -662,21 +876,8 @@ class Conv2d(nn.Module, OFTLayer):
|
||||
adapter (str):
|
||||
The name of the adapter for which the delta weight should be computed.
|
||||
"""
|
||||
oft_r = self.oft_r[adapter_name]
|
||||
oft_s = self.oft_s[adapter_name]
|
||||
|
||||
rank = self.r[adapter_name]
|
||||
coft = self.coft[adapter_name]
|
||||
eps = self.eps[adapter_name]
|
||||
|
||||
if coft:
|
||||
with torch.no_grad():
|
||||
oft_r.copy_(self._project_batch(oft_r, eps=eps))
|
||||
|
||||
orth_rotate = self._cayley_batch(oft_r)
|
||||
weight = self._block_diagonal(orth_rotate, rank)
|
||||
|
||||
return weight, oft_s
|
||||
return self.oft_R[adapter_name].get_weight()
|
||||
|
||||
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
previous_dtype = x.dtype
|
||||
@ -688,62 +889,15 @@ class Conv2d(nn.Module, OFTLayer):
|
||||
elif self.merged:
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
else:
|
||||
oft_rotation = torch.eye(
|
||||
self.in_features * self.get_base_layer().kernel_size[0] * self.get_base_layer().kernel_size[0],
|
||||
device=x.device,
|
||||
)
|
||||
oft_scale = torch.ones((int(self.out_features), 1), device=x.device)
|
||||
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self.oft_r.keys():
|
||||
if active_adapter not in self.oft_R.keys():
|
||||
continue
|
||||
oft_r = self.oft_r[active_adapter]
|
||||
oft_s = self.oft_s[active_adapter]
|
||||
dropout = self.oft_dropout[active_adapter]
|
||||
|
||||
rank = self.r[active_adapter]
|
||||
coft = self.coft[active_adapter]
|
||||
eps = self.eps[active_adapter]
|
||||
oft_R = self.oft_R[active_adapter]
|
||||
x = self._cast_input_dtype(x, oft_R.weight.dtype)
|
||||
x = oft_R(x)
|
||||
|
||||
if coft:
|
||||
with torch.no_grad():
|
||||
oft_r.copy_(self._project_batch(oft_r, eps=eps))
|
||||
|
||||
orth_rotate = self._cayley_batch(oft_r)
|
||||
orth_rotate = dropout(orth_rotate)
|
||||
oft_mat = self._block_diagonal(orth_rotate, rank)
|
||||
|
||||
oft_rotation = oft_mat @ oft_rotation
|
||||
oft_scale = oft_s * oft_scale
|
||||
|
||||
orig_weights = self.base_layer.weight.data
|
||||
orig_weights = orig_weights.view(
|
||||
self.out_features,
|
||||
self.in_features * self.get_base_layer().kernel_size[0] * self.get_base_layer().kernel_size[0],
|
||||
)
|
||||
orig_weights = torch.transpose(orig_weights, 0, 1)
|
||||
oft_rotation = oft_rotation.to(previous_dtype)
|
||||
orig_weights = orig_weights.to(previous_dtype)
|
||||
rotated_weight = torch.mm(oft_rotation, orig_weights)
|
||||
rotated_weight = torch.transpose(rotated_weight, 0, 1)
|
||||
|
||||
scaled_rotated_weight = rotated_weight * oft_scale
|
||||
|
||||
scaled_rotated_weight = scaled_rotated_weight.view(
|
||||
self.out_features,
|
||||
self.in_features,
|
||||
self.get_base_layer().kernel_size[0],
|
||||
self.get_base_layer().kernel_size[0],
|
||||
)
|
||||
x = self._cast_input_dtype(x, scaled_rotated_weight.dtype)
|
||||
bias = self._cast_input_dtype(self.get_base_layer().bias, scaled_rotated_weight.dtype)
|
||||
result = F.conv2d(
|
||||
input=x,
|
||||
weight=scaled_rotated_weight,
|
||||
bias=bias,
|
||||
padding=self.get_base_layer().padding[0],
|
||||
stride=self.get_base_layer().stride[0],
|
||||
)
|
||||
result = self.base_layer(x.to(previous_dtype), *args, **kwargs)
|
||||
|
||||
result = result.to(previous_dtype)
|
||||
return result
|
||||
@ -751,3 +905,30 @@ class Conv2d(nn.Module, OFTLayer):
|
||||
def __repr__(self) -> str:
|
||||
rep = super().__repr__()
|
||||
return "oft." + rep
|
||||
|
||||
|
||||
def dispatch_default(
|
||||
target: torch.nn.Module,
|
||||
adapter_name: str,
|
||||
oft_config: OFTConfig,
|
||||
**kwargs,
|
||||
) -> Optional[torch.nn.Module]:
|
||||
new_module = None
|
||||
|
||||
if isinstance(target, BaseTunerLayer):
|
||||
target_base_layer = target.get_base_layer()
|
||||
else:
|
||||
target_base_layer = target
|
||||
|
||||
if isinstance(target_base_layer, torch.nn.Conv2d):
|
||||
new_module = Conv2d(target, adapter_name, **kwargs)
|
||||
elif isinstance(target_base_layer, torch.nn.Linear):
|
||||
if kwargs["fan_in_fan_out"]:
|
||||
warnings.warn(
|
||||
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
|
||||
"Setting fan_in_fan_out to False."
|
||||
)
|
||||
kwargs["fan_in_fan_out"] = oft_config.fan_in_fan_out = False
|
||||
new_module = Linear(target, adapter_name, **kwargs)
|
||||
|
||||
return new_module
|
||||
|
@ -21,6 +21,7 @@ import torch
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
|
||||
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
|
||||
from peft.tuners.tuners_utils import (
|
||||
BaseTuner,
|
||||
BaseTunerLayer,
|
||||
@ -31,10 +32,17 @@ from peft.utils import (
|
||||
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
|
||||
ModulesToSaveWrapper,
|
||||
_get_submodules,
|
||||
get_quantization_config,
|
||||
)
|
||||
|
||||
from .aqlm import dispatch_aqlm
|
||||
from .awq import dispatch_awq
|
||||
from .config import OFTConfig
|
||||
from .layer import Conv2d, Linear, OFTLayer
|
||||
from .eetq import dispatch_eetq
|
||||
from .gptq import dispatch_gptq
|
||||
from .hqq import dispatch_hqq
|
||||
from .inc import dispatch_inc
|
||||
from .layer import OFTLayer, dispatch_default
|
||||
|
||||
|
||||
class OFTModel(BaseTuner):
|
||||
@ -126,7 +134,6 @@ class OFTModel(BaseTuner):
|
||||
if current_key is None:
|
||||
raise ValueError("Current Key shouldn't be `None`")
|
||||
|
||||
bias = hasattr(target, "bias") and target.bias is not None
|
||||
kwargs = {
|
||||
"r": oft_config.r,
|
||||
"oft_block_size": oft_config.oft_block_size,
|
||||
@ -134,14 +141,24 @@ class OFTModel(BaseTuner):
|
||||
"coft": oft_config.coft,
|
||||
"eps": oft_config.eps,
|
||||
"block_share": oft_config.block_share,
|
||||
"use_cayley_neumann": oft_config.use_cayley_neumann,
|
||||
"num_cayley_neumann_terms": oft_config.num_cayley_neumann_terms,
|
||||
"fan_in_fan_out": oft_config.fan_in_fan_out,
|
||||
"init_weights": oft_config.init_weights,
|
||||
"loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False),
|
||||
"loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False),
|
||||
}
|
||||
kwargs["bias"] = bias
|
||||
|
||||
quant_methods = ["gptq", "aqlm", "awq"]
|
||||
for quant_method in quant_methods:
|
||||
quantization_config = get_quantization_config(self.model, method=quant_method)
|
||||
if quantization_config is not None:
|
||||
kwargs[f"{quant_method}_quantization_config"] = quantization_config
|
||||
|
||||
# If it is not a OFTLayer, create a new module, else update it with new adapters
|
||||
if not isinstance(target, OFTLayer):
|
||||
new_module = self._create_new_module(oft_config, adapter_name, target, **kwargs)
|
||||
device_map = self.model.hf_device_map if hasattr(self.model, "hf_device_map") else None
|
||||
new_module = self._create_new_module(oft_config, adapter_name, target, device_map=device_map, **kwargs)
|
||||
if adapter_name not in self.active_adapters:
|
||||
# adding an additional adapter: it is not automatically trainable
|
||||
new_module.requires_grad_(False)
|
||||
@ -155,6 +172,8 @@ class OFTModel(BaseTuner):
|
||||
coft=oft_config.coft,
|
||||
eps=oft_config.eps,
|
||||
block_share=oft_config.block_share,
|
||||
use_cayley_neumann=oft_config.use_cayley_neumann,
|
||||
num_cayley_neumann_terms=oft_config.num_cayley_neumann_terms,
|
||||
init_weights=oft_config.init_weights,
|
||||
)
|
||||
|
||||
@ -167,24 +186,22 @@ class OFTModel(BaseTuner):
|
||||
if hasattr(child, "base_layer"):
|
||||
child = child.base_layer
|
||||
|
||||
if not hasattr(new_module, "base_layer"):
|
||||
new_module.weight = child.weight
|
||||
if hasattr(child, "bias"):
|
||||
new_module.bias = child.bias
|
||||
|
||||
if getattr(child, "state", None) is not None:
|
||||
if hasattr(new_module, "base_layer"):
|
||||
new_module.base_layer.state = child.state
|
||||
else:
|
||||
new_module.state = child.state
|
||||
new_module.to(child.weight.device)
|
||||
|
||||
meta = torch.device("meta")
|
||||
# dispatch to correct device
|
||||
for name, module in new_module.named_modules():
|
||||
if self.prefix in name:
|
||||
if (self.prefix in name) or ("ranknum" in name):
|
||||
if hasattr(child, "qweight"):
|
||||
weight = child.qweight
|
||||
elif hasattr(child, "W_q"):
|
||||
weight = child.W_q
|
||||
elif hasattr(child, "weight"):
|
||||
weight = child.weight
|
||||
elif getattr(child, "in_proj_weight", None) is not None: # MHA
|
||||
weight = child.in_proj_weight
|
||||
else:
|
||||
weight = next(child.parameters())
|
||||
if not any(p.device == meta for p in module.parameters()):
|
||||
module.to(child.weight.device)
|
||||
module.to(weight.device)
|
||||
|
||||
def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None:
|
||||
for n, p in model.named_parameters():
|
||||
@ -209,25 +226,44 @@ class OFTModel(BaseTuner):
|
||||
|
||||
@staticmethod
|
||||
def _create_new_module(oft_config, adapter_name, target, **kwargs):
|
||||
if isinstance(target, BaseTunerLayer):
|
||||
target_base_layer = target.get_base_layer()
|
||||
else:
|
||||
target_base_layer = target
|
||||
# Collect dispatcher functions to decide what backend to use for the replaced OFT layer. The order matters,
|
||||
# because the first match is always used. Therefore, the default layers should be checked last.
|
||||
dispatchers = []
|
||||
|
||||
if isinstance(target_base_layer, torch.nn.Linear):
|
||||
if kwargs["fan_in_fan_out"]:
|
||||
warnings.warn(
|
||||
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
|
||||
"Setting fan_in_fan_out to False."
|
||||
)
|
||||
kwargs["fan_in_fan_out"] = oft_config.fan_in_fan_out = False
|
||||
new_module = Linear(target, adapter_name, **kwargs)
|
||||
elif isinstance(target_base_layer, torch.nn.Conv2d):
|
||||
new_module = Conv2d(target, adapter_name, **kwargs)
|
||||
else:
|
||||
# avoid eager bnb import
|
||||
if is_bnb_available():
|
||||
from .bnb import dispatch_bnb_8bit
|
||||
|
||||
dispatchers.append(dispatch_bnb_8bit)
|
||||
|
||||
if is_bnb_4bit_available():
|
||||
from .bnb import dispatch_bnb_4bit
|
||||
|
||||
dispatchers.append(dispatch_bnb_4bit)
|
||||
|
||||
dispatchers.extend(
|
||||
[
|
||||
dispatch_eetq,
|
||||
dispatch_aqlm,
|
||||
dispatch_awq,
|
||||
dispatch_gptq,
|
||||
dispatch_hqq,
|
||||
dispatch_inc,
|
||||
dispatch_default,
|
||||
]
|
||||
)
|
||||
|
||||
new_module = None
|
||||
for dispatcher in dispatchers:
|
||||
new_module = dispatcher(target, adapter_name, oft_config=oft_config, **kwargs)
|
||||
if new_module is not None: # first match wins
|
||||
break
|
||||
|
||||
if new_module is None:
|
||||
# no module could be matched
|
||||
raise ValueError(
|
||||
f"Target module {target} is not supported. "
|
||||
"Currently, only `torch.nn.Linear` and `torch.nn.Conv2d` are supported."
|
||||
f"Target module {target} is not supported. Currently, only the following modules are supported: "
|
||||
"`torch.nn.Linear`, `torch.nn.Conv2d`."
|
||||
)
|
||||
|
||||
return new_module
|
||||
@ -255,7 +291,11 @@ class OFTModel(BaseTuner):
|
||||
if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
|
||||
module.enable_adapters(enabled)
|
||||
|
||||
def enable_adapter_layers(self):
|
||||
def enable_adapter_layers(self) -> None:
|
||||
"""Enable all adapters.
|
||||
|
||||
Call this if you have previously disabled all adapters and want to re-enable them.
|
||||
"""
|
||||
self._set_adapter_layers(enabled=True)
|
||||
|
||||
def disable_adapter_layers(self):
|
||||
@ -270,6 +310,20 @@ class OFTModel(BaseTuner):
|
||||
self._set_adapter_layers(enabled=False)
|
||||
|
||||
def set_adapter(self, adapter_name):
|
||||
"""Set the active adapter(s).
|
||||
|
||||
Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is
|
||||
not desired, use the following code.
|
||||
|
||||
```py
|
||||
>>> for name, param in model_peft.named_parameters():
|
||||
... if ...: # some check on name (ex. if 'lora' in name)
|
||||
... param.requires_grad = False
|
||||
```
|
||||
|
||||
Args:
|
||||
adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated.
|
||||
"""
|
||||
for module in self.model.modules():
|
||||
if isinstance(module, OFTLayer):
|
||||
if module.merged:
|
||||
@ -278,6 +332,17 @@ class OFTModel(BaseTuner):
|
||||
module.set_adapter(adapter_name)
|
||||
self.active_adapter = adapter_name
|
||||
|
||||
def _check_merge_allowed(self):
|
||||
"""Verify that the configuration supports merging.
|
||||
|
||||
Currently gptq quantization and replicated layers do not support merging.
|
||||
"""
|
||||
super()._check_merge_allowed()
|
||||
if getattr(self.model, "quantization_method", None) == "gptq":
|
||||
raise ValueError("Cannot merge OFT layers when the model is gptq quantized")
|
||||
if self.peft_config.get("layer_replication"):
|
||||
raise ValueError("Cannot merge OFT layers when base model layers are replicated")
|
||||
|
||||
@staticmethod
|
||||
def _prepare_adapter_config(peft_config, model_config):
|
||||
if peft_config.target_modules is None:
|
||||
@ -306,19 +371,16 @@ class OFTModel(BaseTuner):
|
||||
except AttributeError:
|
||||
continue
|
||||
with onload_layer(target):
|
||||
if hasattr(target, "base_layer"):
|
||||
if hasattr(target, "unload_and_optionally_merge_module"):
|
||||
# if layers have special unloading method, like MultiheadAttention, use that
|
||||
unloaded_module = target.unload_and_optionally_merge_module(
|
||||
merge=merge, safe_merge=safe_merge, adapter_names=adapter_names
|
||||
)
|
||||
self._replace_module(parent, target_name, unloaded_module, target)
|
||||
elif hasattr(target, "base_layer"):
|
||||
if merge:
|
||||
target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
|
||||
self._replace_module(parent, target_name, target.get_base_layer(), target)
|
||||
elif isinstance(target, ModulesToSaveWrapper):
|
||||
# save any additional trainable modules part of `modules_to_save`
|
||||
new_module = target.modules_to_save[target.active_adapter]
|
||||
if hasattr(new_module, "base_layer"):
|
||||
# check if the module is itself a tuner layer
|
||||
if merge:
|
||||
new_module.merge(safe_merge=safe_merge, adapter_names=adapter_names)
|
||||
new_module = new_module.get_base_layer()
|
||||
setattr(parent, target_name, new_module)
|
||||
|
||||
return self.model
|
||||
|
||||
|
@ -434,6 +434,11 @@ def set_peft_model_state_dict(
|
||||
return k
|
||||
|
||||
peft_model_state_dict = {renamed_dora_weights(k): v for k, v in peft_model_state_dict.items()}
|
||||
elif config.peft_type == PeftType.OFT:
|
||||
if any(".oft_r." in key for key in peft_model_state_dict):
|
||||
raise ValueError(
|
||||
"Trying to load old OFT checkpoint, which is no longer supported. Please install PEFT <= v0.15.2 to load it or train a new OFT adapter."
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -1884,18 +1884,18 @@ class TestSameAdapterDifferentDevices:
|
||||
config = OFTConfig(target_modules=["lin0"])
|
||||
model = get_peft_model(mlp, config)
|
||||
model = model.to(self.device)
|
||||
model.lin0.oft_r.cpu()
|
||||
model.lin0.oft_R.default.weight.cpu()
|
||||
|
||||
# check that the adapter is indeed on CPU and the base model on GPU
|
||||
assert model.lin0.oft_r.default.device.type == "cpu"
|
||||
assert model.lin0.oft_R.default.weight.device.type == "cpu"
|
||||
assert model.lin0.base_layer.weight.device.type == self.device
|
||||
|
||||
model.add_adapter("other", config)
|
||||
# check that after adding a new adapter, the old adapter is still on CPU
|
||||
assert model.lin0.oft_r.default.device.type == "cpu"
|
||||
assert model.lin0.oft_R.default.weight.device.type == "cpu"
|
||||
# the rest should be on GPU
|
||||
assert model.lin0.base_layer.weight.device.type == self.device
|
||||
assert model.lin0.oft_r.other.device.type == self.device
|
||||
assert model.lin0.oft_R.other.weight.device.type == self.device
|
||||
|
||||
def test_vera_add_new_adapter_does_not_change_device(self, mlp):
|
||||
# same as first test, but using VERA
|
||||
|
@ -300,26 +300,114 @@ TEST_CASES = [
|
||||
########
|
||||
# OFT #
|
||||
########
|
||||
("Vanilla MLP 1 OFT", "MLP", OFTConfig, {"r": 2, "target_modules": "lin0"}),
|
||||
("Vanilla MLP 2 OFT", "MLP", OFTConfig, {"r": 2, "target_modules": ["lin0"]}),
|
||||
("Vanilla MLP 5 OFT", "MLP", OFTConfig, {"r": 2, "target_modules": ["lin0"], "modules_to_save": ["lin1"]}),
|
||||
(
|
||||
"Vanilla MLP 1 OFT",
|
||||
"MLP",
|
||||
OFTConfig,
|
||||
{"r": 2, "oft_block_size": 0, "target_modules": "lin0", "use_cayley_neumann": False},
|
||||
),
|
||||
(
|
||||
"Vanilla MLP 2 OFT",
|
||||
"MLP",
|
||||
OFTConfig,
|
||||
{"r": 2, "oft_block_size": 0, "target_modules": ["lin0"], "use_cayley_neumann": False},
|
||||
),
|
||||
(
|
||||
"Vanilla MLP 5 OFT",
|
||||
"MLP",
|
||||
OFTConfig,
|
||||
{
|
||||
"r": 2,
|
||||
"oft_block_size": 0,
|
||||
"target_modules": ["lin0"],
|
||||
"modules_to_save": ["lin1"],
|
||||
"use_cayley_neumann": False,
|
||||
},
|
||||
),
|
||||
(
|
||||
"Vanilla MLP 6 OFT",
|
||||
"MLP",
|
||||
OFTConfig,
|
||||
{
|
||||
"r": 2,
|
||||
"oft_block_size": 0,
|
||||
"target_modules": ["lin0"],
|
||||
"module_dropout": 0.1,
|
||||
"use_cayley_neumann": False,
|
||||
},
|
||||
),
|
||||
("Vanilla MLP 7 OFT", "MLP", OFTConfig, {"r": 2, "target_modules": ["lin0"], "coft": True}),
|
||||
("Vanilla MLP 8 OFT", "MLP", OFTConfig, {"r": 2, "target_modules": ["lin0"], "block_share": True}),
|
||||
("Vanilla MLP 9 OFT", "MLP", OFTConfig, {"r": 2, "target_modules": ["lin0"], "coft": True, "block_share": True}),
|
||||
("Conv2d 1 OFT", "Conv2d", OFTConfig, {"r": 5, "target_modules": ["conv2d"]}),
|
||||
("Conv2d 3 OFT", "Conv2d", OFTConfig, {"r": 5, "target_modules": ["conv2d"], "coft": True}),
|
||||
("Conv2d 4 OFT", "Conv2d", OFTConfig, {"r": 5, "target_modules": ["conv2d"], "block_share": True}),
|
||||
("Conv2d 5 OFT", "Conv2d", OFTConfig, {"r": 5, "target_modules": ["conv2d"], "coft": True, "block_share": True}),
|
||||
(
|
||||
"Vanilla MLP 7 OFT",
|
||||
"MLP",
|
||||
OFTConfig,
|
||||
{"r": 2, "oft_block_size": 0, "target_modules": ["lin0"], "coft": True, "eps": 1e-2},
|
||||
),
|
||||
(
|
||||
"Vanilla MLP 8 OFT",
|
||||
"MLP",
|
||||
OFTConfig,
|
||||
{"r": 2, "oft_block_size": 0, "target_modules": ["lin0"], "block_share": True, "use_cayley_neumann": False},
|
||||
),
|
||||
(
|
||||
"Vanilla MLP 9 OFT",
|
||||
"MLP",
|
||||
OFTConfig,
|
||||
{"r": 2, "oft_block_size": 0, "target_modules": ["lin0"], "coft": True, "eps": 1e-2, "block_share": True},
|
||||
),
|
||||
(
|
||||
"Vanilla MLP 10 OFT",
|
||||
"MLP",
|
||||
OFTConfig,
|
||||
{"r": 0, "oft_block_size": 2, "target_modules": ["lin0"], "use_cayley_neumann": True},
|
||||
),
|
||||
(
|
||||
"Vanilla MLP 11 OFT",
|
||||
"MLP",
|
||||
OFTConfig,
|
||||
{"r": 0, "oft_block_size": 2, "target_modules": ["lin0"], "use_cayley_neumann": False},
|
||||
),
|
||||
(
|
||||
"Vanilla MLP 12 OFT",
|
||||
"MLP",
|
||||
OFTConfig,
|
||||
{
|
||||
"r": 0,
|
||||
"oft_block_size": 2,
|
||||
"target_modules": ["lin0"],
|
||||
"coft": True,
|
||||
"eps": 1e-2,
|
||||
"block_share": True,
|
||||
"use_cayley_neumann": True,
|
||||
},
|
||||
),
|
||||
(
|
||||
"Vanilla MLP 13 OFT",
|
||||
"MLP",
|
||||
OFTConfig,
|
||||
{
|
||||
"r": 0,
|
||||
"oft_block_size": 2,
|
||||
"target_modules": ["lin0"],
|
||||
"coft": True,
|
||||
"eps": 1e-2,
|
||||
"block_share": True,
|
||||
"use_cayley_neumann": False,
|
||||
},
|
||||
),
|
||||
("Conv2d 1 OFT", "Conv2d", OFTConfig, {"r": 5, "oft_block_size": 0, "target_modules": ["conv2d"]}),
|
||||
("Conv2d 3 OFT", "Conv2d", OFTConfig, {"r": 5, "oft_block_size": 0, "target_modules": ["conv2d"], "coft": True}),
|
||||
(
|
||||
"Conv2d 4 OFT",
|
||||
"Conv2d",
|
||||
OFTConfig,
|
||||
{"r": 5, "oft_block_size": 0, "target_modules": ["conv2d"], "block_share": True},
|
||||
),
|
||||
(
|
||||
"Conv2d 5 OFT",
|
||||
"Conv2d",
|
||||
OFTConfig,
|
||||
{"r": 5, "oft_block_size": 0, "target_modules": ["conv2d"], "coft": True, "block_share": True},
|
||||
),
|
||||
########
|
||||
# HRA #
|
||||
########
|
||||
@ -1629,6 +1717,9 @@ class TestPeftCustomModel(PeftCommonTester):
|
||||
if issubclass(config_cls, (IA3Config, LoraConfig)) and model_id in conv_ids: # more instability with Conv
|
||||
atol, rtol = 1e-3, 1e-3
|
||||
|
||||
if issubclass(config_cls, OFTConfig):
|
||||
atol, rtol = 1e-4, 1e-4
|
||||
|
||||
if config_kwargs.get("use_dora") and model_id == "EmbConv1D":
|
||||
atol, rtol = 1e-4, 1e-4
|
||||
|
||||
@ -2214,7 +2305,7 @@ class TestPeftCustomModel(PeftCommonTester):
|
||||
LoHaConfig(target_modules=["lin0"], init_weights=False),
|
||||
AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False, total_step=1),
|
||||
IA3Config(target_modules=["lin0"], feedforward_modules=["lin0"], init_ia3_weights=False),
|
||||
OFTConfig(target_modules=["lin0"], init_weights=False, r=2),
|
||||
OFTConfig(target_modules=["lin0"], init_weights=False, r=2, oft_block_size=0),
|
||||
BOFTConfig(target_modules=["lin0"], init_weights=False, boft_block_size=2),
|
||||
HRAConfig(target_modules=["lin0"], init_weights=False),
|
||||
BoneConfig(target_modules=["lin0"], init_weights=False, r=2),
|
||||
@ -3385,33 +3476,30 @@ class TestRequiresGrad:
|
||||
|
||||
def test_requires_grad_oft_different_targets(self):
|
||||
# test two different OFT adapters that target different modules
|
||||
config0 = OFTConfig(target_modules=["lin0"], r=2)
|
||||
config0 = OFTConfig(target_modules=["lin0"], r=2, oft_block_size=0)
|
||||
peft_model = get_peft_model(MLP(), config0)
|
||||
|
||||
config1 = OFTConfig(target_modules=["lin1"], r=2, inference_mode=True)
|
||||
config1 = OFTConfig(target_modules=["lin1"], r=2, oft_block_size=0, inference_mode=True)
|
||||
peft_model.add_adapter("adapter1", config1)
|
||||
|
||||
# active adapter is still "default"
|
||||
self.check_requires_grad(
|
||||
peft_model,
|
||||
"base_model.model.lin0.oft_r.default",
|
||||
"base_model.model.lin0.oft_s.default",
|
||||
"base_model.model.lin0.oft_R.default.weight",
|
||||
)
|
||||
|
||||
# set config0 as active, should not change anything
|
||||
peft_model.set_adapter("default")
|
||||
self.check_requires_grad(
|
||||
peft_model,
|
||||
"base_model.model.lin0.oft_r.default",
|
||||
"base_model.model.lin0.oft_s.default",
|
||||
"base_model.model.lin0.oft_R.default.weight",
|
||||
)
|
||||
|
||||
# change activate pter to pter1
|
||||
peft_model.set_adapter("adapter1")
|
||||
self.check_requires_grad(
|
||||
peft_model,
|
||||
"base_model.model.lin1.oft_r.adapter1",
|
||||
"base_model.model.lin1.oft_s.adapter1",
|
||||
"base_model.model.lin1.oft_R.adapter1.weight",
|
||||
)
|
||||
|
||||
# disable all pters
|
||||
@ -3421,39 +3509,35 @@ class TestRequiresGrad:
|
||||
# after context is exited, return to the previous state
|
||||
self.check_requires_grad(
|
||||
peft_model,
|
||||
"base_model.model.lin1.oft_r.adapter1",
|
||||
"base_model.model.lin1.oft_s.adapter1",
|
||||
"base_model.model.lin1.oft_R.adapter1.weight",
|
||||
)
|
||||
|
||||
def test_requires_grad_oft_same_targets(self):
|
||||
# same as previous test, except that OFT adapters target the same layer
|
||||
config0 = OFTConfig(target_modules=["lin0"], r=2)
|
||||
config0 = OFTConfig(target_modules=["lin0"], r=2, oft_block_size=0)
|
||||
peft_model = get_peft_model(MLP(), config0)
|
||||
|
||||
config1 = OFTConfig(target_modules=["lin0"], r=2, inference_mode=True)
|
||||
config1 = OFTConfig(target_modules=["lin0"], r=2, oft_block_size=0, inference_mode=True)
|
||||
peft_model.add_adapter("adapter1", config1)
|
||||
|
||||
# active adapter is still "default"
|
||||
self.check_requires_grad(
|
||||
peft_model,
|
||||
"base_model.model.lin0.oft_r.default",
|
||||
"base_model.model.lin0.oft_s.default",
|
||||
"base_model.model.lin0.oft_R.default.weight",
|
||||
)
|
||||
|
||||
# set config0 as active, should not change anything
|
||||
peft_model.set_adapter("default")
|
||||
self.check_requires_grad(
|
||||
peft_model,
|
||||
"base_model.model.lin0.oft_r.default",
|
||||
"base_model.model.lin0.oft_s.default",
|
||||
"base_model.model.lin0.oft_R.default.weight",
|
||||
)
|
||||
|
||||
# change activate adapter to adapter1
|
||||
peft_model.set_adapter("adapter1")
|
||||
self.check_requires_grad(
|
||||
peft_model,
|
||||
"base_model.model.lin0.oft_r.adapter1",
|
||||
"base_model.model.lin0.oft_s.adapter1",
|
||||
"base_model.model.lin0.oft_R.adapter1.weight",
|
||||
)
|
||||
|
||||
# disable all adapters
|
||||
@ -3464,8 +3548,7 @@ class TestRequiresGrad:
|
||||
peft_model.set_adapter("adapter1")
|
||||
self.check_requires_grad(
|
||||
peft_model,
|
||||
"base_model.model.lin0.oft_r.adapter1",
|
||||
"base_model.model.lin0.oft_s.adapter1",
|
||||
"base_model.model.lin0.oft_R.adapter1.weight",
|
||||
)
|
||||
|
||||
def test_requires_grad_hra_different_targets(self):
|
||||
|
@ -32,6 +32,7 @@ from transformers import (
|
||||
from peft import (
|
||||
AdaLoraConfig,
|
||||
LoraConfig,
|
||||
OFTConfig,
|
||||
PeftModel,
|
||||
get_peft_model,
|
||||
prepare_model_for_kbit_training,
|
||||
@ -103,6 +104,43 @@ class PeftGPTQModelCommonTests(unittest.TestCase):
|
||||
assert "default" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.lora_A
|
||||
assert "adapter2" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.lora_A
|
||||
|
||||
def test_oft_gptq_quantization_from_pretrained_safetensors(self):
|
||||
r"""
|
||||
Tests that the gptqmodel quantization using OFT works as expected with safetensors weights.
|
||||
"""
|
||||
from transformers import GPTQConfig
|
||||
|
||||
model_id = "marcsun13/opt-350m-gptq-4bit"
|
||||
quantization_config = GPTQConfig(bits=4, use_exllama=False)
|
||||
kwargs = {
|
||||
"pretrained_model_name_or_path": model_id,
|
||||
"torch_dtype": torch.float16,
|
||||
"device_map": "auto",
|
||||
"quantization_config": quantization_config,
|
||||
}
|
||||
model = AutoModelForCausalLM.from_pretrained(**kwargs)
|
||||
model = prepare_model_for_kbit_training(model)
|
||||
|
||||
config = OFTConfig(task_type="CAUSAL_LM")
|
||||
peft_model = get_peft_model(model, config)
|
||||
peft_model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(peft_model.device))
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
peft_model.save_pretrained(tmp_dir)
|
||||
model = AutoModelForCausalLM.from_pretrained(**kwargs)
|
||||
model = PeftModel.from_pretrained(model, tmp_dir)
|
||||
model = prepare_model_for_kbit_training(model)
|
||||
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(peft_model.device))
|
||||
|
||||
# loading a 2nd adapter works, #1239
|
||||
model.load_adapter(tmp_dir, "adapter2")
|
||||
model.set_adapter("adapter2")
|
||||
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(peft_model.device))
|
||||
|
||||
# check that both adapters are in the same layer
|
||||
assert "default" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.oft_R
|
||||
assert "adapter2" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.oft_R
|
||||
|
||||
|
||||
@require_gptqmodel
|
||||
@require_optimum
|
||||
@ -186,6 +224,58 @@ class PeftGPTQModelTests(unittest.TestCase):
|
||||
# assert loss is not None
|
||||
assert trainer.state.log_history[-1]["train_loss"] is not None
|
||||
|
||||
def test_oft_causal_lm_training(self):
|
||||
r"""
|
||||
Test the CausalLM training on a single GPU device. The test would simply fail if the adapters are not set
|
||||
correctly.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.causal_lm_model_id,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
quantization_config=self.quantization_config,
|
||||
)
|
||||
|
||||
model = prepare_model_for_kbit_training(model)
|
||||
config = OFTConfig(
|
||||
r=0,
|
||||
oft_block_size=8,
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
model = get_peft_model(model, config)
|
||||
|
||||
data = load_dataset_english_quotes()
|
||||
data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
train_dataset=data["train"],
|
||||
args=TrainingArguments(
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
warmup_steps=2,
|
||||
max_steps=3,
|
||||
learning_rate=2e-4,
|
||||
fp16=True,
|
||||
logging_steps=1,
|
||||
output_dir=tmp_dir,
|
||||
),
|
||||
data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
|
||||
)
|
||||
model.config.use_cache = False
|
||||
trainer.train()
|
||||
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
assert "adapter_config.json" in os.listdir(tmp_dir)
|
||||
assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
|
||||
|
||||
# assert loss is not None
|
||||
assert trainer.state.log_history[-1]["train_loss"] is not None
|
||||
|
||||
@pytest.mark.single_gpu_tests
|
||||
def test_adalora_causalLM(self):
|
||||
r"""
|
||||
@ -315,6 +405,68 @@ class PeftGPTQModelTests(unittest.TestCase):
|
||||
# assert loss is not None
|
||||
assert trainer.state.log_history[-1]["train_loss"] is not None
|
||||
|
||||
@pytest.mark.multi_gpu_tests
|
||||
@require_torch_multi_accelerator
|
||||
def test_oft_causal_lm_training_multi_accelerator(self):
|
||||
r"""
|
||||
Test the CausalLM training on a multi-accelerator device. The test would simply fail if the adapters are not
|
||||
set correctly.
|
||||
"""
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.causal_lm_model_id,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
quantization_config=self.quantization_config,
|
||||
)
|
||||
|
||||
assert set(model.hf_device_map.values()) == set(range(device_count))
|
||||
|
||||
model = prepare_model_for_kbit_training(model)
|
||||
|
||||
setattr(model, "model_parallel", True)
|
||||
setattr(model, "is_parallelizable", True)
|
||||
|
||||
config = OFTConfig(
|
||||
r=0,
|
||||
oft_block_size=8,
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
model = get_peft_model(model, config)
|
||||
|
||||
data = load_dataset_english_quotes()
|
||||
data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
train_dataset=data["train"],
|
||||
args=TrainingArguments(
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
warmup_steps=2,
|
||||
max_steps=3,
|
||||
learning_rate=2e-4,
|
||||
fp16=True,
|
||||
logging_steps=1,
|
||||
output_dir=tmp_dir,
|
||||
),
|
||||
data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
|
||||
)
|
||||
model.config.use_cache = False
|
||||
trainer.train()
|
||||
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
assert "adapter_config.json" in os.listdir(tmp_dir)
|
||||
assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
|
||||
|
||||
# assert loss is not None
|
||||
assert trainer.state.log_history[-1]["train_loss"] is not None
|
||||
|
||||
def test_non_default_adapter_name(self):
|
||||
# See issue 1346
|
||||
config = LoraConfig(
|
||||
@ -350,6 +502,42 @@ class PeftGPTQModelTests(unittest.TestCase):
|
||||
assert n_trainable_default == n_trainable_other
|
||||
assert n_total_default == n_total_other
|
||||
|
||||
def test_oft_non_default_adapter_name(self):
|
||||
# See issue 1346
|
||||
config = OFTConfig(
|
||||
r=0,
|
||||
oft_block_size=8,
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
# default adapter name
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.causal_lm_model_id,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
quantization_config=self.quantization_config,
|
||||
)
|
||||
model = prepare_model_for_kbit_training(model)
|
||||
model = get_peft_model(model, config)
|
||||
n_trainable_default, n_total_default = model.get_nb_trainable_parameters()
|
||||
|
||||
# other adapter name
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.causal_lm_model_id,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
quantization_config=self.quantization_config,
|
||||
)
|
||||
model = prepare_model_for_kbit_training(model)
|
||||
model = get_peft_model(model, config, adapter_name="other")
|
||||
n_trainable_other, n_total_other = model.get_nb_trainable_parameters()
|
||||
|
||||
assert n_trainable_other > 0
|
||||
# sanity check
|
||||
assert n_trainable_default == n_trainable_other
|
||||
assert n_total_default == n_total_other
|
||||
|
||||
def test_load_lora(self):
|
||||
model_id = "ModelCloud/Llama-3.2-1B-gptqmodel-ci-4bit"
|
||||
adapter_id = "ModelCloud/Llama-3.2-1B-gptqmodel-ci-4bit-lora"
|
||||
|
@ -134,12 +134,15 @@ DIFFUSERS_CONFIGS = [
|
||||
{
|
||||
"text_encoder": {
|
||||
"r": 1,
|
||||
"oft_block_size": 0,
|
||||
"target_modules": ["k_proj", "q_proj", "v_proj", "out_proj", "fc1", "fc2"],
|
||||
"module_dropout": 0.0,
|
||||
"init_weights": False,
|
||||
"use_cayley_neumann": False,
|
||||
},
|
||||
"unet": {
|
||||
"r": 1,
|
||||
"oft_block_size": 0,
|
||||
"target_modules": [
|
||||
"proj_in",
|
||||
"proj_out",
|
||||
@ -152,6 +155,7 @@ DIFFUSERS_CONFIGS = [
|
||||
],
|
||||
"module_dropout": 0.0,
|
||||
"init_weights": False,
|
||||
"use_cayley_neumann": False,
|
||||
},
|
||||
},
|
||||
),
|
||||
|
@ -46,7 +46,9 @@ CONFIGS = {
|
||||
"lora": LoraConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"]),
|
||||
"loha": LoHaConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"]),
|
||||
"lokr": LoKrConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"]),
|
||||
"oft": OFTConfig(r=1, target_modules=["convolution"], modules_to_save=["classifier", "normalization"]),
|
||||
"oft": OFTConfig(
|
||||
r=1, oft_block_size=0, target_modules=["convolution"], modules_to_save=["classifier", "normalization"]
|
||||
),
|
||||
"hra": HRAConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"]),
|
||||
# TODO: cannot use BOFT because some convolutional kernel dimensions are even (64) and others odd (147). There is no
|
||||
# common denominator for the boft_block_size except 1, but using 1 results in an error in the fbd_cuda kernel:
|
||||
|
Reference in New Issue
Block a user