ENH Make OFT faster and more memory efficient (#2575)

Make OFT faster and more memory efficient. This new version of OFT is
not backwards compatible with older checkpoints and vice versa. To load
older checkpoints, downgrade PEFT to 0.15.2 or lower.
This commit is contained in:
Zeju Qiu
2025-06-26 14:27:03 +02:00
committed by GitHub
parent e34852f7b6
commit d936478f07
18 changed files with 2049 additions and 316 deletions

View File

@ -16,9 +16,9 @@ rendered properly in your Markdown viewer.
# Orthogonal Finetuning (OFT and BOFT)
This conceptual guide gives a brief overview of [OFT](https://huggingface.co/papers/2306.07280) and [BOFT](https://huggingface.co/papers/2311.06243), a parameter-efficient fine-tuning technique that utilizes orthogonal matrix to multiplicatively transform the pretrained weight matrices.
This conceptual guide gives a brief overview of [OFT](https://huggingface.co/papers/2306.07280), [OFTv2](https://www.arxiv.org/abs/2506.19847) and [BOFT](https://huggingface.co/papers/2311.06243), a parameter-efficient fine-tuning technique that utilizes orthogonal matrix to multiplicatively transform the pretrained weight matrices.
To achieve efficient fine-tuning, OFT represents the weight updates with an orthogonal transformation. The orthogonal transformation is parameterized by an orthogonal matrix multiplied to the pretrained weight matrix. These new matrices can be trained to adapt to the new data while keeping the overall number of changes low. The original weight matrix remains frozen and doesnt receive any further adjustments. To produce the final results, both the original and the adapted weights are multiplied togethor.
To achieve efficient fine-tuning, OFT represents the weight updates with an orthogonal transformation. The orthogonal transformation is parameterized by an orthogonal matrix multiplied to the pretrained weight matrix. These new matrices can be trained to adapt to the new data while keeping the overall number of changes low. The original weight matrix remains frozen and doesn't receive any further adjustments. To produce the final results, both the original and the adapted weights are multiplied togethor.
Orthogonal Butterfly (BOFT) generalizes OFT with Butterfly factorization and further improves its parameter efficiency and finetuning flexibility. In short, OFT can be viewed as a special case of BOFT. Different from LoRA that uses additive low-rank weight updates, BOFT uses multiplicative orthogonal weight updates. The comparison is shown below.
@ -58,13 +58,25 @@ As with other methods supported by PEFT, to fine-tune a model using OFT or BOFT,
4. Train the `PeftModel` as you normally would train the base model.
### OFT-specific parameters
`OFTConfig` allows you to control how OFT is applied to the base model through the following parameters:
- `r`: OFT rank, number of OFT blocks per injected layer. **Bigger** `r` results in more sparse update matrices with **fewer** trainable paramters. **Note**: You can only specify either `r` or `oft_block_size`, but not both simultaneously, because `r` × `oft_block_size` = layer dimension. For simplicity, we let the user speficy either `r` or `oft_block_size` and infer the other one. Default set to `r = 0`, the user is advised to set the `oft_block_size` instead for better clarity.
- `oft_block_size`: OFT block size across different layers. **Bigger** `oft_block_size` results in more dense update matrices with **more** trainable parameters. **Note**: Please choose `oft_block_size` to be divisible by layer's input dimension (`in_features`), e.g., 4, 8, 16. You can only specify either `r` or `oft_block_size`, but not both simultaneously, because `r` × `oft_block_size` = layer dimension. For simplicity, we let the user speficy either `r` or `oft_block_size` and infer the other one. Default set to `oft_block_size = 32`.
- `use_cayley_neumann`: Specifies whether to use the Cayley-Neumann parameterization (efficient but approximate) or the vanilla Cayley parameterization (exact but computationally expensive because of matrix inverse). We recommend to set it to `True` for better efficiency, but performance may be slightly worse because of the approximation error. Please test both settings (`True` and `False`) depending on your needs. Default is `False`.
- `module_dropout`: The multiplicative dropout probability, by setting OFT blocks to identity during training, similar to the dropout layer in LoRA.
- `bias`: specify if the `bias` parameters should be trained. Can be `"none"`, `"all"` or `"oft_only"`.
- `target_modules`: The modules (for example, attention blocks) to inject the OFT matrices.
- `modules_to_save`: List of modules apart from OFT matrices to be set as trainable and saved in the final checkpoint. These typically include model's custom head that is randomly initialized for the fine-tuning task.
### BOFT-specific parameters
`BOFTConfig` allows you to control how OFT/BOFT is applied to the base model through the following parameters:
`BOFTConfig` allows you to control how BOFT is applied to the base model through the following parameters:
- `boft_block_size`: the BOFT matrix block size across different layers, expressed in `int`. Smaller block size results in sparser update matrices with fewer trainable parameters. **Note**, please choose `boft_block_size` to be divisible by most layer's input dimension (`in_features`), e.g., 4, 8, 16. Also, please only
- `boft_block_size`: the BOFT matrix block size across different layers, expressed in `int`. **Bigger** `boft_block_size` results in more dense update matrices with **more** trainable parameters. **Note**, please choose `boft_block_size` to be divisible by most layer's input dimension (`in_features`), e.g., 4, 8, 16. Also, please only
specify either `boft_block_size` or `boft_block_num`, but not both simultaneously or leaving both to 0, because `boft_block_size` x `boft_block_num` must equal the layer's input dimension.
- `boft_block_num`: the number of BOFT matrix blocks across different layers, expressed in `int`. Fewer blocks result in sparser update matrices with fewer trainable parameters. **Note**, please choose `boft_block_num` to be divisible by most layer's input dimension (`in_features`), e.g., 4, 8, 16. Also, please only
- `boft_block_num`: the number of BOFT matrix blocks across different layers, expressed in `int`. **Bigger** `boft_block_num` result in sparser update matrices with **fewer** trainable parameters. **Note**, please choose `boft_block_num` to be divisible by most layer's input dimension (`in_features`), e.g., 4, 8, 16. Also, please only
specify either `boft_block_size` or `boft_block_num`, but not both simultaneously or leaving both to 0, because `boft_block_size` x `boft_block_num` must equal the layer's input dimension.
- `boft_n_butterfly_factor`: the number of butterfly factors. **Note**, for `boft_n_butterfly_factor=1`, BOFT is the same as vanilla OFT, for `boft_n_butterfly_factor=2`, the effective block size of OFT becomes twice as big and the number of blocks become half.
- `bias`: specify if the `bias` parameters should be trained. Can be `"none"`, `"all"` or `"boft_only"`.
@ -74,6 +86,52 @@ specify either `boft_block_size` or `boft_block_num`, but not both simultaneousl
## OFT Example Usage
For using OFT for quantized finetuning with [TRL](https://github.com/huggingface/trl) for `SFT`, `PPO`, or `DPO` fine-tuning, follow the following outline:
```py
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from trl import SFTTrainer
from peft import OFTConfig
if use_quantization:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_storage=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
"model_name",
quantization_config=bnb_config
)
tokenizer = AutoTokenizer.from_pretrained("model_name")
# Configure OFT
peft_config = OFTConfig(
oft_block_size=32,
use_cayley_neumann=True,
target_modules="all-linear",
bias="none",
task_type="CAUSAL_LM"
)
trainer = SFTTrainer(
model=model,
train_dataset=ds['train'],
peft_config=peft_config,
tokenizer=tokenizer,
args=training_arguments,
data_collator=collator,
)
trainer.train()
```
## BOFT Example Usage
For an example of the BOFT method application to various downstream tasks, please refer to the following guides:

View File

@ -12,13 +12,41 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from peft.import_utils import is_bnb_4bit_available, is_bnb_available, is_eetq_available
from peft.utils import register_peft_method
from .config import OFTConfig
from .gptq import GPTQOFTLinear
from .layer import Conv2d, Linear, OFTLayer
from .model import OFTModel
__all__ = ["Conv2d", "Linear", "OFTConfig", "OFTLayer", "OFTModel"]
__all__ = [
"Conv2d",
"GPTQOFTLinear",
"Linear",
"OFTConfig",
"OFTLayer",
"OFTModel",
]
register_peft_method(name="oft", config_cls=OFTConfig, model_cls=OFTModel)
def __getattr__(name):
if (name == "Linear8bitLt") and is_bnb_available():
from .bnb import Linear8bitLt
return Linear8bitLt
if (name == "Linear4bit") and is_bnb_4bit_available():
from .bnb import Linear4bit
return Linear4bit
if (name == "EetqOFTLinear") and is_eetq_available():
from .eetq import EetqOFTLinear
return EetqOFTLinear
raise AttributeError(f"module {__name__} has no attribute {name}")

105
src/peft/tuners/oft/aqlm.py Normal file
View File

@ -0,0 +1,105 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
import torch
from peft.import_utils import is_aqlm_available
from peft.tuners.oft.layer import OFTLayer
from peft.tuners.tuners_utils import BaseTunerLayer
if is_aqlm_available():
from aqlm import QuantizedLinear
class AqlmOFTLinear(torch.nn.Module, OFTLayer):
def __init__(
self,
base_layer,
adapter_name: str,
r: int = 0,
oft_block_size: int = 32,
module_dropout: float = 0.0,
init_weights: bool = True,
coft: bool = False,
eps: float = 6e-5,
block_share: bool = False,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
use_cayley_neumann: bool = False,
num_cayley_neumann_terms: int = 5,
**kwargs,
):
super().__init__()
OFTLayer.__init__(self, base_layer)
self._active_adapter = adapter_name
self.update_layer(
adapter_name,
r,
oft_block_size=oft_block_size,
module_dropout=module_dropout,
init_weights=init_weights,
coft=coft,
eps=eps,
block_share=block_share,
use_cayley_neumann=use_cayley_neumann,
num_cayley_neumann_terms=num_cayley_neumann_terms,
)
def forward(self, x: torch.Tensor):
# note: logic differs from default Linear because merging is not supported
if self.disable_adapters:
return self.base_layer(x)
for active_adapter in self.active_adapters:
if active_adapter not in self.oft_R.keys():
continue
oft_R = self.oft_R[active_adapter]
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = x.dtype
x = self._cast_input_dtype(x, oft_R.weight.dtype)
x = oft_R(x)
result = self.base_layer(x)
if requires_conversion:
result = result.to(expected_dtype)
return result
def __repr__(self) -> str:
rep = super().__repr__()
return "oft." + rep
def dispatch_aqlm(
target: torch.nn.Module,
adapter_name: str,
**kwargs: Any,
) -> Optional[torch.nn.Module]:
new_module = None
if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target
if is_aqlm_available() and isinstance(target_base_layer, QuantizedLinear):
new_module = AqlmOFTLinear(target, adapter_name, **kwargs)
target.qweight = target_base_layer.codes
return new_module

119
src/peft/tuners/oft/awq.py Normal file
View File

@ -0,0 +1,119 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.metadata as importlib_metadata
from typing import Any, Optional
import packaging.version
import torch
from peft.import_utils import is_auto_awq_available
from peft.tuners.oft.layer import OFTLayer
from peft.tuners.tuners_utils import BaseTunerLayer
class AwqOFTLinear(torch.nn.Module, OFTLayer):
def __init__(
self,
base_layer,
adapter_name,
r: int = 0,
oft_block_size: int = 32,
module_dropout: float = 0.0,
coft: bool = False,
eps: float = 6e-5,
block_share: bool = False,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
init_weights: bool = True,
use_cayley_neumann: bool = False,
num_cayley_neumann_terms: int = 5,
**kwargs,
):
super().__init__()
OFTLayer.__init__(self, base_layer)
# self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter
# for backwards compatibility
self.quant_linear_module = base_layer
self._active_adapter = adapter_name
self.update_layer(
adapter_name,
r,
oft_block_size=oft_block_size,
module_dropout=module_dropout,
coft=coft,
eps=eps,
block_share=block_share,
init_weights=init_weights,
use_cayley_neumann=use_cayley_neumann,
num_cayley_neumann_terms=num_cayley_neumann_terms,
)
def forward(self, x: torch.Tensor):
if self.disable_adapters:
result = self.quant_linear_module(x)
return result
for active_adapter in self.active_adapters:
if active_adapter not in self.oft_R.keys():
continue
oft_R = self.oft_R[active_adapter]
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = x.dtype
x = self._cast_input_dtype(x, oft_R.weight.dtype)
x = oft_R(x)
if requires_conversion:
x = x.to(expected_dtype)
result = self.quant_linear_module(x)
return result
def __repr__(self) -> str:
rep = super().__repr__()
return "oft." + rep
def dispatch_awq(
target: torch.nn.Module,
adapter_name: str,
**kwargs: Any,
) -> Optional[torch.nn.Module]:
new_module = None
if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target
if is_auto_awq_available():
from awq.modules.linear import WQLinear_GEMM
if isinstance(target_base_layer, WQLinear_GEMM):
# Raise the error only at the dispatch level
AUTOAWQ_MINIMUM_VERSION = packaging.version.parse("0.2.0")
version_autoawq = packaging.version.parse(importlib_metadata.version("autoawq"))
if AUTOAWQ_MINIMUM_VERSION > version_autoawq:
raise ImportError(
f"Found an incompatible version of auto-awq. Found version {version_autoawq}, "
f"but only versions above {AUTOAWQ_MINIMUM_VERSION} are supported for PEFT."
)
new_module = AwqOFTLinear(target, adapter_name, **kwargs)
target.qweight = target_base_layer.qweight
return new_module

388
src/peft/tuners/oft/bnb.py Normal file
View File

@ -0,0 +1,388 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import warnings
from typing import Optional
import bitsandbytes as bnb
import torch
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
from peft.utils.integrations import dequantize_bnb_weight
from .layer import OFTLayer
if is_bnb_available():
class Linear8bitLt(torch.nn.Module, OFTLayer):
# OFT implemented in a dense layer
def __init__(
self,
base_layer: torch.nn.Module,
adapter_name: str,
r: int = 8,
oft_block_size: int = 0,
module_dropout: float = 0.0,
init_weights: bool = True,
coft: bool = False,
eps: float = 6e-5,
block_share: bool = False,
use_cayley_neumann: bool = False,
num_cayley_neumann_terms: int = 5,
**kwargs,
) -> None:
super().__init__()
OFTLayer.__init__(self, base_layer)
self.fan_in_fan_out = False
self._active_adapter = adapter_name
self.update_layer(
adapter_name,
r,
oft_block_size=oft_block_size,
module_dropout=module_dropout,
coft=coft,
eps=eps,
block_share=block_share,
init_weights=init_weights,
use_cayley_neumann=use_cayley_neumann,
num_cayley_neumann_terms=num_cayley_neumann_terms,
)
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Args:
safe_merge (`bool`, *optional*):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`list[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged.
Defaults to `None`.
"""
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
# no adapter to merge
return
for active_adapter in adapter_names:
if active_adapter not in self.oft_R.keys():
continue
warnings.warn("Merge oft module to 8-bit linear may get different generations due to rounding errors.")
weight = self.get_base_layer().weight
state = self.get_base_layer().state
if state.SCB is None:
state.SCB = weight.SCB
# Dequantize the result of identity matrix and int8 weight because bitsandbytes does not support int8
# dequantization directly
output = dequantize_bnb_weight(weight, state=state)
oft_data = self.get_delta_weight(active_adapter)
output = torch.transpose(output, 0, 1)
w_data = torch.mm(oft_data, output.to(oft_data.dtype))
w_data = torch.transpose(w_data, 0, 1)
w_data = output.to(oft_data.dtype).to(oft_data.device)
if safe_merge and not torch.isfinite(w_data).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
self.get_base_layer().weight = bnb.nn.Int8Params(
w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights
).to(weight.device)
state.reset_grads()
self.merged_adapters.append(active_adapter)
def unmerge(self) -> None:
"""
This method unmerges all merged adapter layers from the base weights.
"""
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter not in self.oft_R.keys():
continue
warnings.warn(
"Unmerge oft module to 8-bit linear may get different generations due to rounding errors."
)
weight = self.get_base_layer().weight
state = self.get_base_layer().state
if state.SCB is None:
state.SCB = weight.SCB
output = dequantize_bnb_weight(weight, state=state)
oft_data = self.get_delta_weight(active_adapter)
output = torch.transpose(output, 0, 1)
w_data = torch.mm(oft_data.t(), output.to(oft_data.dtype))
w_data = torch.transpose(w_data, 0, 1)
w_data = w_data.to(oft_data.dtype).to(oft_data.device)
self.get_base_layer().weight = bnb.nn.Int8Params(
w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights
).to(weight.device)
state.reset_grads()
def get_delta_weight(self, adapter):
return self.oft_R[adapter].get_weight()
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
for active_adapter in self.active_adapters:
if active_adapter not in self.oft_R.keys():
continue
oft_R = self.oft_R[active_adapter]
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = x.dtype
x = self._cast_input_dtype(x, oft_R.weight.dtype)
x = oft_R(x)
if requires_conversion:
x = x.to(expected_dtype)
result = self.base_layer(x, *args, **kwargs)
return result
def __repr__(self) -> str:
rep = super().__repr__()
return "oft." + rep
def dispatch_bnb_8bit(target: torch.nn.Module, adapter_name: str, **kwargs):
new_module = None
if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target
loaded_in_8bit = kwargs.get("loaded_in_8bit", False)
if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt):
eightbit_kwargs = kwargs.copy()
eightbit_kwargs.update(
{
"has_fp16_weights": target.state.has_fp16_weights,
"threshold": target.state.threshold,
"index": target.index,
}
)
new_module = Linear8bitLt(target, adapter_name, **eightbit_kwargs)
return new_module
if is_bnb_4bit_available():
class Linear4bit(torch.nn.Module, OFTLayer):
# OFT implemented in a dense layer
def __init__(
self,
base_layer: torch.nn.Module,
adapter_name: str,
r: int = 8,
oft_block_size: int = 0,
module_dropout: float = 0.0,
coft: bool = False,
eps: float = 6e-5,
block_share: bool = False,
init_weights: bool = True,
use_cayley_neumann: bool = False,
num_cayley_neumann_terms: int = 5,
**kwargs,
) -> None:
super().__init__()
OFTLayer.__init__(self, base_layer)
self.fan_in_fan_out = False
self._active_adapter = adapter_name
self.update_layer(
adapter_name,
r,
oft_block_size=oft_block_size,
module_dropout=module_dropout,
coft=coft,
eps=eps,
block_share=block_share,
init_weights=init_weights,
use_cayley_neumann=use_cayley_neumann,
num_cayley_neumann_terms=num_cayley_neumann_terms,
)
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Args:
safe_merge (`bool`, *optional*):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`list[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged.
Defaults to `None`.
"""
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
# no adapter to merge
return
for active_adapter in adapter_names:
if active_adapter not in self.oft_R.keys():
continue
warnings.warn("Merge oft module to 4-bit linear may get different generations due to rounding errors.")
# Refer to https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930
weight = self.get_base_layer().weight
kwargs = weight.__dict__
output = dequantize_bnb_weight(weight, state=weight.quant_state)
oft_data = self.get_delta_weight(active_adapter)
output = torch.transpose(output, 0, 1)
w_data = torch.mm(oft_data, output.to(oft_data.dtype))
w_data = torch.transpose(w_data, 0, 1)
w_data = output.to(oft_data.dtype).to(oft_data.device)
if safe_merge and not torch.isfinite(w_data).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
if "bnb_quantized" in kwargs:
kwargs["bnb_quantized"] = False
kwargs["requires_grad"] = False
kwargs.pop("data", None)
# torch.compile can introduce attributes preceded by '_', remove them
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), **kwargs).to(weight.device)
self.merged_adapters.append(active_adapter)
def unmerge(self) -> None:
"""
This method unmerges all merged adapter layers from the base weights.
"""
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter not in self.oft_R.keys():
continue
warnings.warn(
"Unmerge oft module to 4-bit linear may get different generations due to rounding errors."
)
weight = self.get_base_layer().weight
kwargs = weight.__dict__
output = dequantize_bnb_weight(weight, state=weight.quant_state)
oft_data = self.get_delta_weight(active_adapter)
output = torch.transpose(output, 0, 1)
w_data = torch.mm(oft_data.t(), output.to(oft_data.dtype))
w_data = torch.transpose(w_data, 0, 1)
w_data = output.to(oft_data.dtype).to(oft_data.device)
if "bnb_quantized" in kwargs:
kwargs["bnb_quantized"] = False
kwargs["requires_grad"] = False
kwargs.pop("data", None)
self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), **kwargs).to(weight.device)
def get_delta_weight(self, adapter):
return self.oft_R[adapter].get_weight()
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
# As per Tim Dettmers, for 4bit, we need to defensively clone here.
# The reason is that in some cases, an error can occur that backprop
# does not work on a manipulated view. This issue may be solved with
# newer PyTorch versions but this would need extensive testing to be
# sure.
# result = result.clone()
for active_adapter in self.active_adapters:
if active_adapter not in self.oft_R.keys():
continue
oft_R = self.oft_R[active_adapter]
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = x.dtype
x = self._cast_input_dtype(x, oft_R.weight.dtype)
x = oft_R(x)
if requires_conversion:
x = x.to(expected_dtype)
result = self.base_layer(x, *args, **kwargs)
return result
def __repr__(self) -> str:
rep = super().__repr__()
return "oft." + rep
def dispatch_bnb_4bit(target: torch.nn.Module, adapter_name: str, **kwargs):
new_module = None
if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target
loaded_in_4bit = kwargs.get("loaded_in_4bit", False)
if loaded_in_4bit and is_bnb_4bit_available() and isinstance(target_base_layer, bnb.nn.Linear4bit):
fourbit_kwargs = kwargs.copy()
fourbit_kwargs.update(
{
"compute_dtype": target_base_layer.compute_dtype,
"compress_statistics": target_base_layer.weight.compress_statistics,
"quant_type": target_base_layer.weight.quant_type,
}
)
new_module = Linear4bit(target, adapter_name, **fourbit_kwargs)
return new_module

View File

@ -67,9 +67,9 @@ class OFTConfig(PeftConfig):
Whether to share the OFT parameters between blocks or not. This is `False` by default.
"""
r: int = field(default=8, metadata={"help": "OFT rank, number of OFT blocks per injected layer."})
r: int = field(default=0, metadata={"help": "OFT rank, number of OFT blocks per injected layer."})
oft_block_size: int = field(
default=0,
default=32,
metadata={
"help": "OFT block size across different layers.",
"note": "You can only specify either r or oft_block_size, but not both simultaneously, because r x oft_block_size = layer dimension.",
@ -144,6 +144,18 @@ class OFTConfig(PeftConfig):
default=False,
metadata={"help": "Whether to share the OFT parameters between blocks or not."},
)
use_cayley_neumann: bool = field(
default=True,
metadata={
"help": "Whether to use the Cayley-Neumann Formulation of OFT or not. Set to True to improve computational efficiency but comes at costs of bigger approximation error for orthogonality."
},
)
num_cayley_neumann_terms: int = field(
default=5,
metadata={
"help": "Number of Cayley-Neumann terms to use. Higher number results in less approximation error for orthogonality."
},
)
def __post_init__(self):
super().__post_init__()

116
src/peft/tuners/oft/eetq.py Normal file
View File

@ -0,0 +1,116 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
import torch
from peft.import_utils import is_eetq_available
from peft.tuners.oft.layer import OFTLayer
from peft.tuners.tuners_utils import BaseTunerLayer
if is_eetq_available():
from eetq import EetqLinear
class EetqOFTLinear(torch.nn.Module, OFTLayer):
def __init__(
self,
base_layer,
adapter_name,
r: int = 0,
oft_block_size: int = 0,
module_dropout: float = 0.0,
init_weights: bool = True,
coft: bool = False,
eps: float = 6e-5,
block_share: bool = False,
use_cayley_neumann: bool = False,
num_cayley_neumann_terms: int = 5,
fan_in_fan_out: bool = False,
**kwargs,
):
super().__init__()
OFTLayer.__init__(self, base_layer)
# self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter
# for backwards compatibility
self.quant_linear_module = base_layer
self._active_adapter = adapter_name
self.update_layer(
adapter_name,
r,
oft_block_size=oft_block_size,
module_dropout=module_dropout,
init_weights=init_weights,
coft=coft,
eps=eps,
block_share=block_share,
fan_in_fan_out=fan_in_fan_out,
use_cayley_neumann=use_cayley_neumann,
num_cayley_neumann_terms=num_cayley_neumann_terms,
)
def forward(self, x: torch.Tensor):
if self.disable_adapters:
return self.quant_linear_module(x)
for active_adapter in self.active_adapters:
if active_adapter not in self.oft_R.keys():
continue
oft_R = self.oft_R[active_adapter]
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = x.dtype
x = self._cast_input_dtype(x, oft_R.weight.dtype)
x = oft_R(x)
result = self.quant_linear_module(x)
if requires_conversion:
result = result.to(expected_dtype)
return result
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
raise AttributeError("Merging LoRA layers is not supported for Eetq layers.")
def unmerge(self) -> None:
raise AttributeError("Unmerging LoRA layers is not supported for Eetq layers.")
def __repr__(self) -> str:
rep = super().__repr__()
return "oft." + rep
def dispatch_eetq(
target: torch.nn.Module,
adapter_name: str,
**kwargs: Any,
) -> Optional[torch.nn.Module]:
new_module = None
if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target
if is_eetq_available() and isinstance(target_base_layer, EetqLinear):
new_module = EetqOFTLinear(target, adapter_name, **kwargs)
target.weight = target_base_layer.weight
if hasattr(target, "bias"):
target.bias = target_base_layer.bias
return new_module

118
src/peft/tuners/oft/gptq.py Normal file
View File

@ -0,0 +1,118 @@
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
import torch
from peft.import_utils import is_gptqmodel_available
from peft.tuners.oft.layer import OFTLayer
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import get_auto_gptq_quant_linear
class GPTQOFTLinear(torch.nn.Module, OFTLayer):
def __init__(
self,
base_layer,
adapter_name: str,
r: int = 8,
oft_block_size: int = 0,
module_dropout: float = 0.0,
coft: bool = False,
eps: float = 6e-5,
block_share: bool = False,
use_cayley_neumann: bool = False,
num_cayley_neumann_terms: int = 5,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
init_weights: bool = True,
**kwargs,
):
super().__init__()
OFTLayer.__init__(self, base_layer)
# self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter
# for backwards compatibility
self.quant_linear_module = base_layer
self._active_adapter = adapter_name
self.update_layer(
adapter_name,
r,
oft_block_size=oft_block_size,
module_dropout=module_dropout,
coft=coft,
eps=eps,
block_share=block_share,
init_weights=init_weights,
use_cayley_neumann=use_cayley_neumann,
num_cayley_neumann_terms=num_cayley_neumann_terms,
)
def forward(self, x: torch.Tensor):
# note: logic differs from default Linear because merging is not supported
result = self.quant_linear_module(x)
if self.disable_adapters:
return self.quant_linear_module(x)
for active_adapter in self.active_adapters:
if active_adapter not in self.oft_R.keys():
continue
oft_R = self.oft_R[active_adapter]
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = x.dtype
x = self._cast_input_dtype(x, oft_R.weight.dtype)
x = oft_R(x)
result = self.quant_linear_module(x)
if requires_conversion:
result = result.to(expected_dtype)
return result
def __repr__(self) -> str:
rep = super().__repr__()
return "oft." + rep
def dispatch_gptq(
target: torch.nn.Module,
adapter_name: str,
**kwargs: Any,
) -> Optional[torch.nn.Module]:
new_module = None
if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target
cfg = kwargs.get("gptq_quantization_config", None)
if is_gptqmodel_available():
from gptqmodel.nn_modules.qlinear import BaseQuantLinear
if isinstance(target_base_layer, BaseQuantLinear):
new_module = GPTQOFTLinear(target, adapter_name, **kwargs)
target.qweight = target_base_layer.qweight
else:
quant_linear = get_auto_gptq_quant_linear(cfg)
if quant_linear is not None and isinstance(target_base_layer, quant_linear):
new_module = GPTQOFTLinear(target, adapter_name, **kwargs)
target.qweight = target_base_layer.qweight
return new_module

186
src/peft/tuners/oft/hqq.py Normal file
View File

@ -0,0 +1,186 @@
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import copy
import warnings
from typing import Optional
import torch
from peft.import_utils import is_hqq_available
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
from .layer import OFTLayer
if is_hqq_available():
from hqq.core.quantize import HQQLinear
class HqqOFTLinear(torch.nn.Module, OFTLayer):
# Lora implemented in a dense layer
def __init__(
self,
base_layer: torch.nn.Module,
adapter_name: str,
r: int = 8,
oft_block_size: int = 0,
module_dropout: float = 0.0,
init_weights: bool = True,
coft: bool = False,
eps: float = 6e-5,
block_share: bool = False,
use_cayley_neumann: bool = False,
num_cayley_neumann_terms: int = 5,
**kwargs,
) -> None:
super().__init__()
OFTLayer.__init__(self, base_layer)
self.fan_in_fan_out = False
self._active_adapter = adapter_name
self.update_layer(
adapter_name,
r,
oft_block_size=oft_block_size,
module_dropout=module_dropout,
init_weights=init_weights,
coft=coft,
eps=eps,
block_share=block_share,
use_cayley_neumann=use_cayley_neumann,
num_cayley_neumann_terms=num_cayley_neumann_terms,
)
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Args:
safe_merge (`bool`, *optional*):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`list[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged.
Defaults to `None`.
"""
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
# no adapter to merge
return
for active_adapter in adapter_names:
if active_adapter not in self.lora_A.keys():
continue
layer = self.get_base_layer()
quant_config = {**copy.deepcopy(layer.quant_config), "offload_meta": layer.offload_meta}
output = layer.dequantize()
oft_data = self.get_delta_weight(active_adapter)
output = torch.transpose(output, 0, 1)
w_data = torch.mm(oft_data, output.to(oft_data.dtype))
w_data = torch.transpose(w_data, 0, 1)
w_data = output.to(oft_data.dtype).to(oft_data.device)
if safe_merge and not torch.isfinite(w_data).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
new_hqq_layer = HQQLinear(None, quant_config, compute_dtype=layer.compute_dtype, device=layer.device)
quant_config.pop("offload_meta", None)
new_hqq_layer.quantize(w_data, **quant_config)
self.base_layer = new_hqq_layer
self.merged_adapters.append(active_adapter)
def unmerge(self) -> None:
"""
This method unmerges all merged adapter layers from the base weights.
"""
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter not in self.oft_R.keys():
continue
layer = self.get_base_layer()
quant_config = {**copy.deepcopy(layer.quant_config), "offload_meta": layer.offload_meta}
output = layer.dequantize()
oft_data = self.get_delta_weight(active_adapter)
output = torch.transpose(output, 0, 1)
w_data = torch.mm(oft_data.t(), output.to(oft_data.dtype))
w_data = torch.transpose(w_data, 0, 1)
w_data = w_data.to(oft_data.dtype).to(oft_data.device)
new_hqq_layer = HQQLinear(None, quant_config, compute_dtype=layer.compute_dtype, device=layer.device)
quant_config.pop("offload_meta", None)
new_hqq_layer.quantize(w_data, **quant_config)
self.base_layer = new_hqq_layer
def get_delta_weight(self, adapter):
return self.oft_R[adapter].get_weight()
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
self._check_forward_args(x, *args, **kwargs)
adapter_names = kwargs.pop("adapter_names", None)
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
for active_adapter in self.active_adapters:
if active_adapter not in self.oft_R.keys():
continue
oft_R = self.oft_R[active_adapter]
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = x.dtype
x = self._cast_input_dtype(x, oft_R.weight.dtype)
x = oft_R(x)
result = self.base_layer(x, *args, **kwargs)
if requires_conversion:
result = result.to(expected_dtype)
return result
def __repr__(self) -> str:
rep = super().__repr__()
return "oft." + rep
def dispatch_hqq(target: torch.nn.Module, adapter_name: str, **kwargs):
new_module = None
if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target
if is_hqq_available() and isinstance(target_base_layer, HQQLinear):
new_module = HqqOFTLinear(target_base_layer, adapter_name, **kwargs)
return new_module

View File

@ -0,0 +1,78 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# NOTE: PEFT tests related to INC are handled under Optimum-Habana repository:
# - LLMs: https://github.com/huggingface/optimum-habana/blob/main/tests/test_peft_inference.py
# - Diffusers: https://github.com/huggingface/optimum-habana/blob/main/tests/test_diffusers.py
from typing import Optional
import torch
from peft.import_utils import is_inc_available
from peft.tuners.tuners_utils import BaseTunerLayer
from .layer import Linear
if is_inc_available():
class IncOFTLinear(Linear):
def __init__(
self,
base_layer: torch.nn.Module,
adapter_name: str,
**kwargs,
):
super().__init__(base_layer, adapter_name, **kwargs)
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Args:
safe_merge (`bool`, *optional*):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`list[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged.
Defaults to `None`.
"""
raise NotImplementedError("Merging OFT with INC layers is not yet implemented")
def unmerge(self) -> None:
"""
This method unmerges all merged adapter layers from the base weights.
"""
raise NotImplementedError("Unmerging OFT from INC layers is not yet implemented")
def dispatch_inc(target: torch.nn.Module, adapter_name: str, **kwargs):
new_module = None
if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target
if is_inc_available():
from neural_compressor.torch.algorithms.fp8_quant._quant_common.helper_modules import (
PatchedLinear,
)
if isinstance(target_base_layer, PatchedLinear):
new_module = IncOFTLinear(target, adapter_name, **kwargs)
return new_module

View File

@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import annotations
import math
import warnings
from typing import Any, Optional, Union
@ -23,6 +22,8 @@ import torch.nn.functional as F
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
from .config import OFTConfig
class MultiplicativeDropoutLayer(nn.Module):
"""
@ -48,7 +49,7 @@ class MultiplicativeDropoutLayer(nn.Module):
the number of OFT blocks, and `H` is the size of the square blocks along the last two dimensions,
the block size in OFT.
"""
if self.training:
if self.training and self.p > 0:
# Ensure the last two dimensions are the same
if x.shape[-1] != x.shape[-2]:
raise ValueError("The last two dimensions of input should be the same!")
@ -68,15 +69,249 @@ class MultiplicativeDropoutLayer(nn.Module):
return x
class OFTRotationModule(nn.Module):
def __init__(
self,
r,
n_elements,
block_size,
in_features,
coft=False,
eps=6e-5,
block_share=False,
kernel_size=(0, 0),
use_cayley_neumann=True,
num_cayley_neumann_terms=5,
):
super().__init__()
self.r = r
self.n_elements = n_elements
self.block_size = block_size
self.in_features = in_features
self.weight = nn.Parameter(torch.empty(r, n_elements))
self.coft = coft
self.eps = eps
self.block_share = block_share
# Conv2d specific parameters
self.kernel_size = kernel_size
self.use_cayley_neumann = use_cayley_neumann
self.num_cayley_neumann_terms = num_cayley_neumann_terms
# Create indices for upper triangle (excluding diagonal)
self.rows, self.cols = torch.triu_indices(block_size, block_size, 1)
def _pytorch_skew_symmetric(self, vec, block_size):
batch_size = vec.shape[0]
matrix = torch.zeros(batch_size, block_size, block_size, device=vec.device, dtype=vec.dtype)
matrix[:, self.rows, self.cols] = vec
matrix = matrix - matrix.transpose(-2, -1)
return matrix
def _pytorch_skew_symmetric_inv(self, matrix, block_size):
batch_size = matrix.shape[0]
# Extract the upper triangular elements
vec = matrix[:, self.rows, self.cols]
return vec
def _cayley_batch(
self, Q: torch.Tensor, block_size: int, use_cayley_neumann: bool = True, num_neumann_terms: int = 5
) -> torch.Tensor:
"""
Perform the Cayley parametrization on a batch of skew-symmetric matrices.
Args:
data: A batch of skew-symmetric matrices of shape (b, r, c).
"""
b, _ = Q.shape
previous_dtype = Q.dtype
# Q_skew = SkewSymmetric.apply(Q, block_size)
Q_skew = self._pytorch_skew_symmetric(Q, block_size)
if use_cayley_neumann:
R = torch.eye(block_size, device=Q.device, dtype=Q.dtype).repeat(b, 1, 1)
if num_neumann_terms > 1:
R.add_(Q_skew, alpha=2.0)
if num_neumann_terms > 2:
Q_squared = torch.bmm(Q_skew, Q_skew)
R.add_(Q_squared, alpha=2.0)
Q_power = Q_squared
for i in range(3, num_neumann_terms):
Q_power = torch.bmm(Q_power, Q_skew)
R.add_(Q_power, alpha=2.0)
else:
id_mat = (
torch.eye(Q_skew.shape[-1], device=Q_skew.device)
.unsqueeze(0)
.expand(b, Q_skew.shape[-1], Q_skew.shape[-1])
)
R = torch.linalg.solve(id_mat + Q_skew, id_mat - Q_skew, left=False)
return R.to(previous_dtype)
# Copied from https://github.com/Zeju1997/oft/blob/84cebb965df69781e3d9c3c875f5980b421eaf24/oft-control/oft.py#L52
def _project_batch(self, Q, eps=1e-5):
oft_R = self._pytorch_skew_symmetric(Q, self.block_size)
# scaling factor for each of the smaller block matrix
eps = eps * 1 / torch.sqrt(torch.tensor(oft_R.shape[0]))
I = ( # noqa: E741
torch.zeros((oft_R.size(1), oft_R.size(1)), device=oft_R.device, dtype=oft_R.dtype)
.unsqueeze(0)
.expand_as(oft_R)
)
diff = oft_R - I
norm_diff = torch.norm(oft_R - I, dim=(1, 2), keepdim=True)
mask = (norm_diff <= eps).bool()
out = torch.where(mask, oft_R, I + eps * (diff / norm_diff))
return self._pytorch_skew_symmetric_inv(out, self.block_size)
# Copied from https://github.com/Zeju1997/oft/blob/84cebb965df69781e3d9c3c875f5980b421eaf24/oft-control/oft.py#L155
def _block_diagonal(self, oft_R: torch.Tensor, rank: int) -> torch.Tensor:
if oft_R.shape[0] == 1:
# block share
blocks = [oft_R[0, ...] for i in range(rank)]
else:
blocks = [oft_R[i, ...] for i in range(rank)]
# Use torch.block_diag to create the block diagonal matrix
A = torch.block_diag(*blocks)
return A
def _unfold(self, x):
"""
Unfold with stride=1, padding=0 to preserve spatial dimensions. Only use kernel_size from base layer to define
patch size.
"""
batch_size, in_channels, in_height, in_width = x.shape
if isinstance(self.kernel_size, int):
kernel_height, kernel_width = self.kernel_size, self.kernel_size
else:
kernel_height, kernel_width = self.kernel_size
stride_h = stride_w = 1
pad_h = pad_w = 0
# output dimensions
out_height = (in_height + 2 * pad_h - kernel_height) // stride_h + 1
out_width = (in_width + 2 * pad_w - kernel_width) // stride_w + 1
# Reshape input from [B, C, H, W] to [B, C, H_out, W_out, K_H, K_W]
x_unfolded = x.unfold(2, kernel_height, stride_h).unfold(3, kernel_width, stride_w)
x_unfolded = x_unfolded.permute(0, 2, 3, 1, 4, 5).contiguous()
x_unfolded = x_unfolded.view(batch_size * out_height * out_width, -1)
return x_unfolded
def _fold(self, x_unfolded, orig_shape):
"""
Fold back to preserve spatial dimensions.
"""
batch_size, in_channels, in_height, in_width = orig_shape
if isinstance(self.kernel_size, int):
kernel_height, kernel_width = self.kernel_size, self.kernel_size
else:
kernel_height, kernel_width = self.kernel_size
# With stride=1, padding=0:
out_height = in_height - kernel_height + 1
out_width = in_width - kernel_width + 1
# Reshape: [B*H_out*W_out, C*K_H*K_W] -> [B, H_out, W_out, C, K_H, K_W]
x_reshaped = x_unfolded.view(batch_size, out_height, out_width, in_channels, kernel_height, kernel_width)
# Permute to: [B, C, H_out, W_out, K_H, K_W]
x_reshaped = x_reshaped.permute(0, 3, 1, 2, 4, 5).contiguous()
# Use F.fold to reconstruct 4D tensor
x_folded = F.fold(
x_reshaped.view(batch_size, in_channels * kernel_height * kernel_width, out_height * out_width),
output_size=(in_height, in_width),
kernel_size=(kernel_height, kernel_width),
stride=(1, 1),
)
return x_folded
def forward(self, x):
# This module doesn't need to implement the orthogonal transform
# It's primarily a container for the parameter
# The actual transformation logic stays in your OFTLayer
required_dtype = x.dtype
if required_dtype != self.weight.dtype:
x = x.to(self.weight.dtype)
orig_shape = x.shape
if self.coft:
with torch.no_grad():
self.weight.copy_(self._project_batch(self.weight, eps=self.eps))
orth_rotate = self._cayley_batch(
self.weight, self.block_size, self.use_cayley_neumann, self.num_cayley_neumann_terms
)
# Unfold the input for Conv2d layer
if len(orig_shape) == 4:
x = self._unfold(x)
folded_shape = x.shape
rank = self.in_features // self.block_size if self.block_share else self.r
batch_dims = x.shape[:-1]
x_reshaped = x.reshape(*batch_dims, rank, self.block_size)
if self.block_share:
orth_rotate = orth_rotate.repeat(rank, 1, 1)
x_rotated_reshaped = torch.einsum("...rk,rkc->...rc", x_reshaped, orth_rotate)
else:
x_rotated_reshaped = torch.einsum("...rk,rkc->...rc", x_reshaped, orth_rotate)
x_rotated = x_rotated_reshaped.reshape(*folded_shape)
if len(orig_shape) == 4:
x_rotated = self._fold(x_rotated, orig_shape)
return x_rotated.to(required_dtype)
def get_weight(self):
"""
Compute the delta weight for the given adapter.
Args:
adapter (str):
The name of the adapter for which the delta weight should be computed.
"""
weight = self.weight
if self.coft:
with torch.no_grad():
weight = self._project_batch(weight, eps=self.eps)
self.weight.copy_(weight)
orth_rotate = self._cayley_batch(
weight, self.block_size, self.use_cayley_neumann, self.num_cayley_neumann_terms
)
rank = self.r if not self.block_share else self.in_features // self.block_size
return self._block_diagonal(orth_rotate, rank)
class OFTLayer(BaseTunerLayer):
"""
Implements the OFT layer.
"""
# All names of layers that may contain adapter weights
adapter_layer_names = ("oft_r", "oft_s")
# other_param_names is defined on parent class
other_param_names = ("r", "oft_block_size", "oft_dropout")
# All names of layers that may contain (trainable) adapter weights
adapter_layer_names: tuple[str, ...] = ("oft_R",)
# All names of other parameters that may contain adapter-related parameters
other_param_names: tuple[str, ...] = ("r", "oft_block_size", "oft_dropout")
def __init__(self, base_layer: nn.Module, **kwargs) -> None:
"""
@ -89,15 +324,11 @@ class OFTLayer(BaseTunerLayer):
base_layer: the pretrained model layer
"""
self.base_layer = base_layer
# OFT info
self.oft_r = nn.ParameterDict({})
self.oft_s = nn.ParameterDict({})
self.oft_R = nn.ModuleDict({})
self.oft_block_size = {}
self.r = {}
self.oft_block_size = {}
self.oft_dropout = nn.ModuleDict({})
self.coft = {}
self.eps = {}
self.block_share = {}
# Mark the weight as unmerged
self._disable_adapters = False
self.merged_adapters = []
@ -106,20 +337,44 @@ class OFTLayer(BaseTunerLayer):
self.kwargs = kwargs
base_layer = self.get_base_layer()
if isinstance(base_layer, nn.Linear):
in_features, out_features = base_layer.in_features, base_layer.out_features
elif isinstance(base_layer, nn.Conv2d):
in_features, out_features = base_layer.in_channels, base_layer.out_channels
elif hasattr(base_layer, "infeatures") and hasattr(base_layer, "outfeatures"):
# QuantLinear
in_features, out_features = base_layer.infeatures, base_layer.outfeatures
elif hasattr(base_layer, "input_size") and hasattr(base_layer, "output_size"):
# Megatron ColumnParallelLinear,RowParallelLinear
in_features, out_features = base_layer.input_size, base_layer.output_size
elif hasattr(base_layer, "codebooks") and base_layer.__class__.__name__ == "QuantizedLinear":
# AQLM QuantLinear
in_features, out_features = base_layer.in_features, base_layer.out_features
elif hasattr(base_layer, "w_bit") and base_layer.__class__.__name__ == "WQLinear_GEMM":
# Awq layers
in_features, out_features = base_layer.in_features, base_layer.out_features
elif base_layer.__class__.__name__ == "EetqLinear":
# Eetq layers
in_features, out_features = base_layer.in_features, base_layer.out_features
elif hasattr(base_layer, "W_q") and base_layer.__class__.__name__ == "HQQLinear":
# HQQ layers
in_features, out_features = base_layer.in_features, base_layer.out_features
else:
raise ValueError(f"Unsupported layer type {type(base_layer)}")
# possibly support user provided custom layer types using dynamic dispatch
if hasattr(base_layer, "in_features") and hasattr(base_layer, "out_features"):
in_features, out_features = base_layer.in_features, base_layer.out_features
else:
in_features, out_features = None, None
warnings.warn(
f"Unsupported layer type '{type(base_layer)}' encountered, proceed at your own risk.", UserWarning
)
self.in_features = in_features
self.out_features = out_features
@property
def _available_adapters(self) -> set[str]:
return {*self.oft_r}
return {*self.oft_R}
def set_scale(self, adapter, scale):
if adapter not in self.scaling:
@ -133,19 +388,31 @@ class OFTLayer(BaseTunerLayer):
return
for active_adapter in self.active_adapters:
if active_adapter not in self.oft_r.keys():
if active_adapter not in self.oft_R.keys():
continue
warnings.warn("Scaling operation for OFT not supported! Automatically set scale to 1.")
def unscale_layer(self, scale=None) -> None:
for active_adapter in self.active_adapters:
if active_adapter not in self.oft_r.keys():
if active_adapter not in self.oft_R.keys():
continue
warnings.warn("Unscaling operation for OFT not supported! Keeping scale to 1.")
def update_layer(self, adapter_name, r, oft_block_size, module_dropout, coft, eps, block_share, init_weights):
def update_layer(
self,
adapter_name,
r,
oft_block_size,
module_dropout,
coft,
eps,
block_share,
init_weights,
use_cayley_neumann,
num_cayley_neumann_terms,
):
"""
Update the linear layer with trainable OFT weights. Override for other layer types.
"""
@ -189,20 +456,19 @@ class OFTLayer(BaseTunerLayer):
"Something went wrong, please report this error: https://github.com/huggingface/peft/issues"
)
self.coft[adapter_name] = coft
self.block_share[adapter_name] = block_share
self.eps[adapter_name] = eps * math.ceil(self.out_features / r) * math.ceil(self.out_features / r)
# Create weights with provided shape
if block_share:
self.oft_r[adapter_name] = nn.Parameter(
torch.empty(1, math.ceil(self.in_features / r), math.ceil(self.in_features / r))
)
else:
self.oft_r[adapter_name] = nn.Parameter(
torch.empty(r, math.ceil(self.in_features / r), math.ceil(self.in_features / r))
)
self.oft_s[adapter_name] = nn.Parameter(torch.empty(int(self.out_features), 1))
n_elements = oft_block_size * (oft_block_size - 1) // 2
self.oft_R[adapter_name] = OFTRotationModule(
r if not block_share else 1,
n_elements,
oft_block_size,
self.in_features,
coft=coft,
eps=eps,
block_share=block_share,
use_cayley_neumann=use_cayley_neumann,
num_cayley_neumann_terms=num_cayley_neumann_terms,
)
# Initialize weights
self.reset_oft_parameters(adapter_name, init_weights)
@ -220,63 +486,16 @@ class OFTLayer(BaseTunerLayer):
Reset the OFT parameters.
"""
if init_weights is False:
nn.init.normal_(self.oft_r[adapter_name], mean=0.0, std=0.1)
nn.init.normal_(self.oft_s[adapter_name], mean=1.0, std=0.1)
nn.init.normal_(self.oft_R[adapter_name].weight, mean=0.0, std=0.1)
return
if adapter_name in self.oft_r.keys():
if adapter_name in self.oft_R.keys():
if init_weights is True:
# initialize oft_r to zero
nn.init.zeros_(self.oft_r[adapter_name])
nn.init.ones_(self.oft_s[adapter_name])
# initialize oft_R to zero
nn.init.zeros_(self.oft_R[adapter_name].weight)
else:
raise ValueError(f"Unknown initialization {init_weights=}")
def _cayley_batch(self, data: torch.Tensor) -> torch.Tensor:
"""
Perform the Cayley parametrization on a batch of skew-symmetric matrices.
Args:
data: A batch of skew-symmetric matrices of shape (b, r, c).
"""
b, r, c = data.shape
# Ensure the input matrix is skew-symmetric
skew_mat = 0.5 * (data - data.transpose(1, 2))
id_mat = torch.eye(r, device=data.device).unsqueeze(0).expand(b, r, c) # noqa: E741
# Perform the Cayley parametrization
Q = torch.linalg.solve(id_mat + skew_mat, id_mat - skew_mat, left=False)
return Q
# Copied from https://github.com/Zeju1997/oft/blob/84cebb965df69781e3d9c3c875f5980b421eaf24/oft-control/oft.py#L155
def _block_diagonal(self, oft_r: torch.Tensor, rank: int) -> torch.Tensor:
if oft_r.shape[0] == 1:
# block share
blocks = [oft_r[0, ...] for i in range(rank)]
else:
blocks = [oft_r[i, ...] for i in range(rank)]
# Use torch.block_diag to create the block diagonal matrix
A = torch.block_diag(*blocks)
return A
# Copied from https://github.com/Zeju1997/oft/blob/84cebb965df69781e3d9c3c875f5980b421eaf24/oft-control/oft.py#L52
def _project_batch(self, oft_r, eps=1e-5):
# scaling factor for each of the smaller block matrix
eps = eps * 1 / torch.sqrt(torch.tensor(oft_r.shape[0]))
I = ( # noqa: E741
torch.zeros((oft_r.size(1), oft_r.size(1)), device=oft_r.device, dtype=oft_r.dtype)
.unsqueeze(0)
.expand_as(oft_r)
)
diff = oft_r - I
norm_diff = torch.norm(oft_r - I, dim=(1, 2), keepdim=True)
mask = (norm_diff <= eps).bool()
out = torch.where(mask, oft_r, I + eps * (diff / norm_diff))
return out
def adjust_oft_parameters(self, in_features, params):
"""
Adjust the OFT parameters to be divisible by the in_features dimension.
@ -311,6 +530,8 @@ class Linear(nn.Module, OFTLayer):
coft: bool = False,
eps: float = 6e-5,
block_share: bool = False,
use_cayley_neumann: bool = False,
num_cayley_neumann_terms: int = 5,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
init_weights: Union[bool, str] = True,
is_target_conv_1d_layer: bool = False,
@ -322,7 +543,18 @@ class Linear(nn.Module, OFTLayer):
self._active_adapter = adapter_name
self.update_layer(adapter_name, r, oft_block_size, module_dropout, coft, eps, block_share, init_weights)
self.update_layer(
adapter_name,
r,
oft_block_size=oft_block_size,
module_dropout=module_dropout,
coft=coft,
eps=eps,
block_share=block_share,
init_weights=init_weights,
use_cayley_neumann=use_cayley_neumann,
num_cayley_neumann_terms=num_cayley_neumann_terms,
)
self.is_target_conv_1d_layer = is_target_conv_1d_layer
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
@ -349,13 +581,11 @@ class Linear(nn.Module, OFTLayer):
orig_dtype = base_layer.weight.dtype
if safe_merge:
# Note that safe_merge will be slower than the normal merge
# because of the copy operation.
orig_weights = base_layer.weight.data
oft_mat, oft_s = self.get_delta_weight(active_adapter)
oft_mat = self.get_delta_weight(active_adapter)
orig_weights = torch.transpose(orig_weights, 0, 1)
orig_weights = torch.mm(oft_mat, orig_weights.to(oft_mat.dtype))
orig_weights = torch.transpose(orig_weights, 0, 1)
orig_weights = orig_weights * oft_s
if not torch.isfinite(orig_weights).all():
raise ValueError(
@ -364,12 +594,11 @@ class Linear(nn.Module, OFTLayer):
base_layer.weight.data = orig_weights.contiguous().to(orig_dtype)
else:
oft_mat, oft_s = self.get_delta_weight(active_adapter)
orig_weights = base_layer.weight.data
oft_mat = self.get_delta_weight(active_adapter)
orig_weights = torch.transpose(orig_weights, 0, 1)
orig_weights = torch.mm(oft_mat, orig_weights.to(oft_mat.dtype))
orig_weights = torch.transpose(orig_weights, 0, 1)
orig_weights = orig_weights * oft_s
base_layer.weight.data = orig_weights.contiguous().to(orig_dtype)
@ -387,15 +616,15 @@ class Linear(nn.Module, OFTLayer):
orig_dtype = base_layer.weight.dtype
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter in self.oft_r.keys():
oft_mat, oft_s = self.get_delta_weight(active_adapter)
if active_adapter in self.oft_R.keys():
oft_mat = self.get_delta_weight(active_adapter)
orig_weights = self.get_base_layer().weight.data
orig_weights = torch.transpose(orig_weights, 0, 1)
orig_weights = torch.mm(oft_mat.t(), orig_weights.to(oft_mat.dtype))
orig_weights = torch.transpose(orig_weights, 0, 1)
base_layer.weight.data = (orig_weights * (1 / oft_s)).to(orig_dtype)
base_layer.weight.data = orig_weights.to(orig_dtype)
def get_delta_weight(self, adapter_name) -> tuple[torch.Tensor, torch.Tensor]:
"""
@ -405,21 +634,8 @@ class Linear(nn.Module, OFTLayer):
adapter (str):
The name of the adapter for which the delta weight should be computed.
"""
oft_r = self.oft_r[adapter_name]
oft_s = self.oft_s[adapter_name]
rank = self.r[adapter_name]
coft = self.coft[adapter_name]
eps = self.eps[adapter_name]
if coft:
with torch.no_grad():
oft_r.copy_(self._project_batch(oft_r, eps=eps))
orth_rotate = self._cayley_batch(oft_r)
weight = self._block_diagonal(orth_rotate, rank)
return weight, oft_s
return self.oft_R[adapter_name].get_weight()
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
previous_dtype = x.dtype
@ -431,42 +647,15 @@ class Linear(nn.Module, OFTLayer):
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
oft_rotation = torch.eye(self.in_features, device=x.device)
oft_scale = torch.ones((int(self.out_features), 1), device=x.device)
for active_adapter in self.active_adapters:
if active_adapter not in self.oft_r.keys():
if active_adapter not in self.oft_R.keys():
continue
oft_r = self.oft_r[active_adapter]
oft_s = self.oft_s[active_adapter]
dropout = self.oft_dropout[active_adapter]
oft_R = self.oft_R[active_adapter]
rank = self.r[active_adapter]
coft = self.coft[active_adapter]
eps = self.eps[active_adapter]
x = self._cast_input_dtype(x, oft_R.weight.dtype)
x = oft_R(x)
if coft:
with torch.no_grad():
oft_r.copy_(self._project_batch(oft_r, eps=eps))
orth_rotate = self._cayley_batch(oft_r)
orth_rotate = dropout(orth_rotate)
oft_mat = self._block_diagonal(orth_rotate, rank)
oft_rotation = oft_mat @ oft_rotation
oft_scale = oft_s * oft_scale
x = x.to(self.get_base_layer().weight.data.dtype)
orig_weight = self.get_base_layer().weight.data
orig_weight = torch.transpose(orig_weight, 0, 1)
rotated_weight = torch.mm(oft_rotation, orig_weight.to(oft_rotation.dtype))
rotated_weight = torch.transpose(rotated_weight, 0, 1)
scaled_rotated_weight = rotated_weight * oft_scale
x = self._cast_input_dtype(x, scaled_rotated_weight.dtype)
bias = self._cast_input_dtype(self.get_base_layer().bias, scaled_rotated_weight.dtype)
result = F.linear(input=x, weight=scaled_rotated_weight, bias=bias)
result = self.base_layer(x.to(previous_dtype), *args, **kwargs)
result = result.to(previous_dtype)
return result
@ -491,6 +680,8 @@ class Conv2d(nn.Module, OFTLayer):
eps: float = 6e-5,
block_share: bool = False,
init_weights: Union[bool, str] = True,
use_cayley_neumann: bool = False,
num_cayley_neumann_terms: int = 5,
**kwargs,
) -> None:
super().__init__()
@ -500,9 +691,32 @@ class Conv2d(nn.Module, OFTLayer):
self._active_adapter = adapter_name
# Create adapter and set it active
self.update_layer(adapter_name, r, oft_block_size, module_dropout, coft, eps, block_share, init_weights)
self.update_layer(
adapter_name,
r,
oft_block_size=oft_block_size,
module_dropout=module_dropout,
coft=coft,
eps=eps,
block_share=block_share,
init_weights=init_weights,
use_cayley_neumann=use_cayley_neumann,
num_cayley_neumann_terms=num_cayley_neumann_terms,
)
def update_layer(self, adapter_name, r, oft_block_size, module_dropout, coft, eps, block_share, init_weights):
def update_layer(
self,
adapter_name,
r,
oft_block_size,
module_dropout,
coft,
eps,
block_share,
init_weights,
use_cayley_neumann,
num_cayley_neumann_terms,
):
"""
Update the conv2d layer with trainable OFT weights.
"""
@ -515,6 +729,9 @@ class Conv2d(nn.Module, OFTLayer):
# layer information from the base layer
base_layer = self.get_base_layer()
if base_layer.dilation[0] > 1:
raise ValueError("Conv2d with dilation > 1 is not supported by OFT.")
conv_filter_dim = self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0]
if r == 0 and oft_block_size != 0:
@ -536,20 +753,20 @@ class Conv2d(nn.Module, OFTLayer):
"Something went wrong, please report this error: https://github.com/huggingface/peft/issues"
)
self.coft[adapter_name] = coft
self.block_share[adapter_name] = block_share
self.eps[adapter_name] = eps * math.ceil(self.out_features / r) * math.ceil(self.out_features / r)
# Create weights with provided shape
if block_share:
self.oft_r[adapter_name] = nn.Parameter(
torch.empty(1, math.ceil(conv_filter_dim / r), math.ceil(conv_filter_dim / r))
)
else:
self.oft_r[adapter_name] = nn.Parameter(
torch.empty(r, math.ceil(conv_filter_dim / r), math.ceil(conv_filter_dim / r))
)
self.oft_s[adapter_name] = nn.Parameter(torch.empty(int(self.out_features), 1))
n_elements = oft_block_size * (oft_block_size - 1) // 2
self.oft_R[adapter_name] = OFTRotationModule(
r if not block_share else 1,
n_elements,
oft_block_size,
conv_filter_dim,
coft=coft,
eps=eps,
block_share=block_share,
kernel_size=base_layer.kernel_size,
use_cayley_neumann=use_cayley_neumann,
num_cayley_neumann_terms=num_cayley_neumann_terms,
)
# Initialize weights
self.reset_oft_parameters(adapter_name, init_weights)
@ -581,14 +798,14 @@ class Conv2d(nn.Module, OFTLayer):
return
for active_adapter in adapter_names:
if active_adapter in self.oft_r.keys():
if active_adapter in self.oft_R.keys():
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.dtype
if safe_merge:
# Note that safe_merge will be slower than the normal merge
# because of the copy operation.
orig_weights = base_layer.weight.data.clone()
oft_mat, oft_s = self.get_delta_weight(active_adapter)
oft_mat = self.get_delta_weight(active_adapter)
orig_weights = orig_weights.view(
self.out_features, self.in_features * base_layer.kernel_size[0] * base_layer.kernel_size[0]
@ -596,14 +813,13 @@ class Conv2d(nn.Module, OFTLayer):
orig_weights = torch.transpose(orig_weights, 0, 1)
orig_weights = torch.mm(oft_mat, orig_weights.to(oft_mat.dtype))
orig_weights = torch.transpose(orig_weights, 0, 1)
orig_weights = orig_weights * oft_s
orig_weights = orig_weights.view(
self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0]
)
base_layer.weight.data = orig_weights.contiguous().to(orig_dtype)
else:
oft_mat, oft_s = self.get_delta_weight(active_adapter)
oft_mat = self.get_delta_weight(active_adapter)
orig_weights = base_layer.weight.data.clone()
orig_weights = orig_weights.view(
@ -612,7 +828,6 @@ class Conv2d(nn.Module, OFTLayer):
orig_weights = torch.transpose(orig_weights, 0, 1)
orig_weights = torch.mm(oft_mat, orig_weights.to(oft_mat.dtype))
orig_weights = torch.transpose(orig_weights, 0, 1)
orig_weights = orig_weights * oft_s
orig_weights = orig_weights.view(
self.out_features, self.in_features, base_layer.kernel_size[0], base_layer.kernel_size[0]
)
@ -633,8 +848,8 @@ class Conv2d(nn.Module, OFTLayer):
orig_dtype = base_layer.weight.dtype
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter in self.oft_r.keys():
oft_mat, oft_s = self.get_delta_weight(active_adapter)
if active_adapter in self.oft_R.keys():
oft_mat = self.get_delta_weight(active_adapter)
orig_weights = self.get_base_layer().weight.data.clone()
orig_weights = orig_weights.view(
@ -644,7 +859,6 @@ class Conv2d(nn.Module, OFTLayer):
orig_weights = torch.transpose(orig_weights, 0, 1)
orig_weights = torch.mm(oft_mat.t(), orig_weights.to(oft_mat.dtype))
orig_weights = torch.transpose(orig_weights, 0, 1)
orig_weights = orig_weights * (1 / oft_s)
orig_weights = orig_weights.view(
self.out_features,
self.in_features,
@ -662,21 +876,8 @@ class Conv2d(nn.Module, OFTLayer):
adapter (str):
The name of the adapter for which the delta weight should be computed.
"""
oft_r = self.oft_r[adapter_name]
oft_s = self.oft_s[adapter_name]
rank = self.r[adapter_name]
coft = self.coft[adapter_name]
eps = self.eps[adapter_name]
if coft:
with torch.no_grad():
oft_r.copy_(self._project_batch(oft_r, eps=eps))
orth_rotate = self._cayley_batch(oft_r)
weight = self._block_diagonal(orth_rotate, rank)
return weight, oft_s
return self.oft_R[adapter_name].get_weight()
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
previous_dtype = x.dtype
@ -688,62 +889,15 @@ class Conv2d(nn.Module, OFTLayer):
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
oft_rotation = torch.eye(
self.in_features * self.get_base_layer().kernel_size[0] * self.get_base_layer().kernel_size[0],
device=x.device,
)
oft_scale = torch.ones((int(self.out_features), 1), device=x.device)
for active_adapter in self.active_adapters:
if active_adapter not in self.oft_r.keys():
if active_adapter not in self.oft_R.keys():
continue
oft_r = self.oft_r[active_adapter]
oft_s = self.oft_s[active_adapter]
dropout = self.oft_dropout[active_adapter]
rank = self.r[active_adapter]
coft = self.coft[active_adapter]
eps = self.eps[active_adapter]
oft_R = self.oft_R[active_adapter]
x = self._cast_input_dtype(x, oft_R.weight.dtype)
x = oft_R(x)
if coft:
with torch.no_grad():
oft_r.copy_(self._project_batch(oft_r, eps=eps))
orth_rotate = self._cayley_batch(oft_r)
orth_rotate = dropout(orth_rotate)
oft_mat = self._block_diagonal(orth_rotate, rank)
oft_rotation = oft_mat @ oft_rotation
oft_scale = oft_s * oft_scale
orig_weights = self.base_layer.weight.data
orig_weights = orig_weights.view(
self.out_features,
self.in_features * self.get_base_layer().kernel_size[0] * self.get_base_layer().kernel_size[0],
)
orig_weights = torch.transpose(orig_weights, 0, 1)
oft_rotation = oft_rotation.to(previous_dtype)
orig_weights = orig_weights.to(previous_dtype)
rotated_weight = torch.mm(oft_rotation, orig_weights)
rotated_weight = torch.transpose(rotated_weight, 0, 1)
scaled_rotated_weight = rotated_weight * oft_scale
scaled_rotated_weight = scaled_rotated_weight.view(
self.out_features,
self.in_features,
self.get_base_layer().kernel_size[0],
self.get_base_layer().kernel_size[0],
)
x = self._cast_input_dtype(x, scaled_rotated_weight.dtype)
bias = self._cast_input_dtype(self.get_base_layer().bias, scaled_rotated_weight.dtype)
result = F.conv2d(
input=x,
weight=scaled_rotated_weight,
bias=bias,
padding=self.get_base_layer().padding[0],
stride=self.get_base_layer().stride[0],
)
result = self.base_layer(x.to(previous_dtype), *args, **kwargs)
result = result.to(previous_dtype)
return result
@ -751,3 +905,30 @@ class Conv2d(nn.Module, OFTLayer):
def __repr__(self) -> str:
rep = super().__repr__()
return "oft." + rep
def dispatch_default(
target: torch.nn.Module,
adapter_name: str,
oft_config: OFTConfig,
**kwargs,
) -> Optional[torch.nn.Module]:
new_module = None
if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target
if isinstance(target_base_layer, torch.nn.Conv2d):
new_module = Conv2d(target, adapter_name, **kwargs)
elif isinstance(target_base_layer, torch.nn.Linear):
if kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
"Setting fan_in_fan_out to False."
)
kwargs["fan_in_fan_out"] = oft_config.fan_in_fan_out = False
new_module = Linear(target, adapter_name, **kwargs)
return new_module

View File

@ -21,6 +21,7 @@ import torch
from torch import nn
from tqdm import tqdm
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
from peft.tuners.tuners_utils import (
BaseTuner,
BaseTunerLayer,
@ -31,10 +32,17 @@ from peft.utils import (
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
ModulesToSaveWrapper,
_get_submodules,
get_quantization_config,
)
from .aqlm import dispatch_aqlm
from .awq import dispatch_awq
from .config import OFTConfig
from .layer import Conv2d, Linear, OFTLayer
from .eetq import dispatch_eetq
from .gptq import dispatch_gptq
from .hqq import dispatch_hqq
from .inc import dispatch_inc
from .layer import OFTLayer, dispatch_default
class OFTModel(BaseTuner):
@ -126,7 +134,6 @@ class OFTModel(BaseTuner):
if current_key is None:
raise ValueError("Current Key shouldn't be `None`")
bias = hasattr(target, "bias") and target.bias is not None
kwargs = {
"r": oft_config.r,
"oft_block_size": oft_config.oft_block_size,
@ -134,14 +141,24 @@ class OFTModel(BaseTuner):
"coft": oft_config.coft,
"eps": oft_config.eps,
"block_share": oft_config.block_share,
"use_cayley_neumann": oft_config.use_cayley_neumann,
"num_cayley_neumann_terms": oft_config.num_cayley_neumann_terms,
"fan_in_fan_out": oft_config.fan_in_fan_out,
"init_weights": oft_config.init_weights,
"loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False),
"loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False),
}
kwargs["bias"] = bias
quant_methods = ["gptq", "aqlm", "awq"]
for quant_method in quant_methods:
quantization_config = get_quantization_config(self.model, method=quant_method)
if quantization_config is not None:
kwargs[f"{quant_method}_quantization_config"] = quantization_config
# If it is not a OFTLayer, create a new module, else update it with new adapters
if not isinstance(target, OFTLayer):
new_module = self._create_new_module(oft_config, adapter_name, target, **kwargs)
device_map = self.model.hf_device_map if hasattr(self.model, "hf_device_map") else None
new_module = self._create_new_module(oft_config, adapter_name, target, device_map=device_map, **kwargs)
if adapter_name not in self.active_adapters:
# adding an additional adapter: it is not automatically trainable
new_module.requires_grad_(False)
@ -155,6 +172,8 @@ class OFTModel(BaseTuner):
coft=oft_config.coft,
eps=oft_config.eps,
block_share=oft_config.block_share,
use_cayley_neumann=oft_config.use_cayley_neumann,
num_cayley_neumann_terms=oft_config.num_cayley_neumann_terms,
init_weights=oft_config.init_weights,
)
@ -167,24 +186,22 @@ class OFTModel(BaseTuner):
if hasattr(child, "base_layer"):
child = child.base_layer
if not hasattr(new_module, "base_layer"):
new_module.weight = child.weight
if hasattr(child, "bias"):
new_module.bias = child.bias
if getattr(child, "state", None) is not None:
if hasattr(new_module, "base_layer"):
new_module.base_layer.state = child.state
else:
new_module.state = child.state
new_module.to(child.weight.device)
meta = torch.device("meta")
# dispatch to correct device
for name, module in new_module.named_modules():
if self.prefix in name:
if (self.prefix in name) or ("ranknum" in name):
if hasattr(child, "qweight"):
weight = child.qweight
elif hasattr(child, "W_q"):
weight = child.W_q
elif hasattr(child, "weight"):
weight = child.weight
elif getattr(child, "in_proj_weight", None) is not None: # MHA
weight = child.in_proj_weight
else:
weight = next(child.parameters())
if not any(p.device == meta for p in module.parameters()):
module.to(child.weight.device)
module.to(weight.device)
def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None:
for n, p in model.named_parameters():
@ -209,25 +226,44 @@ class OFTModel(BaseTuner):
@staticmethod
def _create_new_module(oft_config, adapter_name, target, **kwargs):
if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target
# Collect dispatcher functions to decide what backend to use for the replaced OFT layer. The order matters,
# because the first match is always used. Therefore, the default layers should be checked last.
dispatchers = []
if isinstance(target_base_layer, torch.nn.Linear):
if kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
"Setting fan_in_fan_out to False."
)
kwargs["fan_in_fan_out"] = oft_config.fan_in_fan_out = False
new_module = Linear(target, adapter_name, **kwargs)
elif isinstance(target_base_layer, torch.nn.Conv2d):
new_module = Conv2d(target, adapter_name, **kwargs)
else:
# avoid eager bnb import
if is_bnb_available():
from .bnb import dispatch_bnb_8bit
dispatchers.append(dispatch_bnb_8bit)
if is_bnb_4bit_available():
from .bnb import dispatch_bnb_4bit
dispatchers.append(dispatch_bnb_4bit)
dispatchers.extend(
[
dispatch_eetq,
dispatch_aqlm,
dispatch_awq,
dispatch_gptq,
dispatch_hqq,
dispatch_inc,
dispatch_default,
]
)
new_module = None
for dispatcher in dispatchers:
new_module = dispatcher(target, adapter_name, oft_config=oft_config, **kwargs)
if new_module is not None: # first match wins
break
if new_module is None:
# no module could be matched
raise ValueError(
f"Target module {target} is not supported. "
"Currently, only `torch.nn.Linear` and `torch.nn.Conv2d` are supported."
f"Target module {target} is not supported. Currently, only the following modules are supported: "
"`torch.nn.Linear`, `torch.nn.Conv2d`."
)
return new_module
@ -255,7 +291,11 @@ class OFTModel(BaseTuner):
if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
module.enable_adapters(enabled)
def enable_adapter_layers(self):
def enable_adapter_layers(self) -> None:
"""Enable all adapters.
Call this if you have previously disabled all adapters and want to re-enable them.
"""
self._set_adapter_layers(enabled=True)
def disable_adapter_layers(self):
@ -270,6 +310,20 @@ class OFTModel(BaseTuner):
self._set_adapter_layers(enabled=False)
def set_adapter(self, adapter_name):
"""Set the active adapter(s).
Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is
not desired, use the following code.
```py
>>> for name, param in model_peft.named_parameters():
... if ...: # some check on name (ex. if 'lora' in name)
... param.requires_grad = False
```
Args:
adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated.
"""
for module in self.model.modules():
if isinstance(module, OFTLayer):
if module.merged:
@ -278,6 +332,17 @@ class OFTModel(BaseTuner):
module.set_adapter(adapter_name)
self.active_adapter = adapter_name
def _check_merge_allowed(self):
"""Verify that the configuration supports merging.
Currently gptq quantization and replicated layers do not support merging.
"""
super()._check_merge_allowed()
if getattr(self.model, "quantization_method", None) == "gptq":
raise ValueError("Cannot merge OFT layers when the model is gptq quantized")
if self.peft_config.get("layer_replication"):
raise ValueError("Cannot merge OFT layers when base model layers are replicated")
@staticmethod
def _prepare_adapter_config(peft_config, model_config):
if peft_config.target_modules is None:
@ -306,19 +371,16 @@ class OFTModel(BaseTuner):
except AttributeError:
continue
with onload_layer(target):
if hasattr(target, "base_layer"):
if hasattr(target, "unload_and_optionally_merge_module"):
# if layers have special unloading method, like MultiheadAttention, use that
unloaded_module = target.unload_and_optionally_merge_module(
merge=merge, safe_merge=safe_merge, adapter_names=adapter_names
)
self._replace_module(parent, target_name, unloaded_module, target)
elif hasattr(target, "base_layer"):
if merge:
target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
self._replace_module(parent, target_name, target.get_base_layer(), target)
elif isinstance(target, ModulesToSaveWrapper):
# save any additional trainable modules part of `modules_to_save`
new_module = target.modules_to_save[target.active_adapter]
if hasattr(new_module, "base_layer"):
# check if the module is itself a tuner layer
if merge:
new_module.merge(safe_merge=safe_merge, adapter_names=adapter_names)
new_module = new_module.get_base_layer()
setattr(parent, target_name, new_module)
return self.model

View File

@ -434,6 +434,11 @@ def set_peft_model_state_dict(
return k
peft_model_state_dict = {renamed_dora_weights(k): v for k, v in peft_model_state_dict.items()}
elif config.peft_type == PeftType.OFT:
if any(".oft_r." in key for key in peft_model_state_dict):
raise ValueError(
"Trying to load old OFT checkpoint, which is no longer supported. Please install PEFT <= v0.15.2 to load it or train a new OFT adapter."
)
else:
raise NotImplementedError

View File

@ -1884,18 +1884,18 @@ class TestSameAdapterDifferentDevices:
config = OFTConfig(target_modules=["lin0"])
model = get_peft_model(mlp, config)
model = model.to(self.device)
model.lin0.oft_r.cpu()
model.lin0.oft_R.default.weight.cpu()
# check that the adapter is indeed on CPU and the base model on GPU
assert model.lin0.oft_r.default.device.type == "cpu"
assert model.lin0.oft_R.default.weight.device.type == "cpu"
assert model.lin0.base_layer.weight.device.type == self.device
model.add_adapter("other", config)
# check that after adding a new adapter, the old adapter is still on CPU
assert model.lin0.oft_r.default.device.type == "cpu"
assert model.lin0.oft_R.default.weight.device.type == "cpu"
# the rest should be on GPU
assert model.lin0.base_layer.weight.device.type == self.device
assert model.lin0.oft_r.other.device.type == self.device
assert model.lin0.oft_R.other.weight.device.type == self.device
def test_vera_add_new_adapter_does_not_change_device(self, mlp):
# same as first test, but using VERA

View File

@ -300,26 +300,114 @@ TEST_CASES = [
########
# OFT #
########
("Vanilla MLP 1 OFT", "MLP", OFTConfig, {"r": 2, "target_modules": "lin0"}),
("Vanilla MLP 2 OFT", "MLP", OFTConfig, {"r": 2, "target_modules": ["lin0"]}),
("Vanilla MLP 5 OFT", "MLP", OFTConfig, {"r": 2, "target_modules": ["lin0"], "modules_to_save": ["lin1"]}),
(
"Vanilla MLP 1 OFT",
"MLP",
OFTConfig,
{"r": 2, "oft_block_size": 0, "target_modules": "lin0", "use_cayley_neumann": False},
),
(
"Vanilla MLP 2 OFT",
"MLP",
OFTConfig,
{"r": 2, "oft_block_size": 0, "target_modules": ["lin0"], "use_cayley_neumann": False},
),
(
"Vanilla MLP 5 OFT",
"MLP",
OFTConfig,
{
"r": 2,
"oft_block_size": 0,
"target_modules": ["lin0"],
"modules_to_save": ["lin1"],
"use_cayley_neumann": False,
},
),
(
"Vanilla MLP 6 OFT",
"MLP",
OFTConfig,
{
"r": 2,
"oft_block_size": 0,
"target_modules": ["lin0"],
"module_dropout": 0.1,
"use_cayley_neumann": False,
},
),
("Vanilla MLP 7 OFT", "MLP", OFTConfig, {"r": 2, "target_modules": ["lin0"], "coft": True}),
("Vanilla MLP 8 OFT", "MLP", OFTConfig, {"r": 2, "target_modules": ["lin0"], "block_share": True}),
("Vanilla MLP 9 OFT", "MLP", OFTConfig, {"r": 2, "target_modules": ["lin0"], "coft": True, "block_share": True}),
("Conv2d 1 OFT", "Conv2d", OFTConfig, {"r": 5, "target_modules": ["conv2d"]}),
("Conv2d 3 OFT", "Conv2d", OFTConfig, {"r": 5, "target_modules": ["conv2d"], "coft": True}),
("Conv2d 4 OFT", "Conv2d", OFTConfig, {"r": 5, "target_modules": ["conv2d"], "block_share": True}),
("Conv2d 5 OFT", "Conv2d", OFTConfig, {"r": 5, "target_modules": ["conv2d"], "coft": True, "block_share": True}),
(
"Vanilla MLP 7 OFT",
"MLP",
OFTConfig,
{"r": 2, "oft_block_size": 0, "target_modules": ["lin0"], "coft": True, "eps": 1e-2},
),
(
"Vanilla MLP 8 OFT",
"MLP",
OFTConfig,
{"r": 2, "oft_block_size": 0, "target_modules": ["lin0"], "block_share": True, "use_cayley_neumann": False},
),
(
"Vanilla MLP 9 OFT",
"MLP",
OFTConfig,
{"r": 2, "oft_block_size": 0, "target_modules": ["lin0"], "coft": True, "eps": 1e-2, "block_share": True},
),
(
"Vanilla MLP 10 OFT",
"MLP",
OFTConfig,
{"r": 0, "oft_block_size": 2, "target_modules": ["lin0"], "use_cayley_neumann": True},
),
(
"Vanilla MLP 11 OFT",
"MLP",
OFTConfig,
{"r": 0, "oft_block_size": 2, "target_modules": ["lin0"], "use_cayley_neumann": False},
),
(
"Vanilla MLP 12 OFT",
"MLP",
OFTConfig,
{
"r": 0,
"oft_block_size": 2,
"target_modules": ["lin0"],
"coft": True,
"eps": 1e-2,
"block_share": True,
"use_cayley_neumann": True,
},
),
(
"Vanilla MLP 13 OFT",
"MLP",
OFTConfig,
{
"r": 0,
"oft_block_size": 2,
"target_modules": ["lin0"],
"coft": True,
"eps": 1e-2,
"block_share": True,
"use_cayley_neumann": False,
},
),
("Conv2d 1 OFT", "Conv2d", OFTConfig, {"r": 5, "oft_block_size": 0, "target_modules": ["conv2d"]}),
("Conv2d 3 OFT", "Conv2d", OFTConfig, {"r": 5, "oft_block_size": 0, "target_modules": ["conv2d"], "coft": True}),
(
"Conv2d 4 OFT",
"Conv2d",
OFTConfig,
{"r": 5, "oft_block_size": 0, "target_modules": ["conv2d"], "block_share": True},
),
(
"Conv2d 5 OFT",
"Conv2d",
OFTConfig,
{"r": 5, "oft_block_size": 0, "target_modules": ["conv2d"], "coft": True, "block_share": True},
),
########
# HRA #
########
@ -1629,6 +1717,9 @@ class TestPeftCustomModel(PeftCommonTester):
if issubclass(config_cls, (IA3Config, LoraConfig)) and model_id in conv_ids: # more instability with Conv
atol, rtol = 1e-3, 1e-3
if issubclass(config_cls, OFTConfig):
atol, rtol = 1e-4, 1e-4
if config_kwargs.get("use_dora") and model_id == "EmbConv1D":
atol, rtol = 1e-4, 1e-4
@ -2214,7 +2305,7 @@ class TestPeftCustomModel(PeftCommonTester):
LoHaConfig(target_modules=["lin0"], init_weights=False),
AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False, total_step=1),
IA3Config(target_modules=["lin0"], feedforward_modules=["lin0"], init_ia3_weights=False),
OFTConfig(target_modules=["lin0"], init_weights=False, r=2),
OFTConfig(target_modules=["lin0"], init_weights=False, r=2, oft_block_size=0),
BOFTConfig(target_modules=["lin0"], init_weights=False, boft_block_size=2),
HRAConfig(target_modules=["lin0"], init_weights=False),
BoneConfig(target_modules=["lin0"], init_weights=False, r=2),
@ -3385,33 +3476,30 @@ class TestRequiresGrad:
def test_requires_grad_oft_different_targets(self):
# test two different OFT adapters that target different modules
config0 = OFTConfig(target_modules=["lin0"], r=2)
config0 = OFTConfig(target_modules=["lin0"], r=2, oft_block_size=0)
peft_model = get_peft_model(MLP(), config0)
config1 = OFTConfig(target_modules=["lin1"], r=2, inference_mode=True)
config1 = OFTConfig(target_modules=["lin1"], r=2, oft_block_size=0, inference_mode=True)
peft_model.add_adapter("adapter1", config1)
# active adapter is still "default"
self.check_requires_grad(
peft_model,
"base_model.model.lin0.oft_r.default",
"base_model.model.lin0.oft_s.default",
"base_model.model.lin0.oft_R.default.weight",
)
# set config0 as active, should not change anything
peft_model.set_adapter("default")
self.check_requires_grad(
peft_model,
"base_model.model.lin0.oft_r.default",
"base_model.model.lin0.oft_s.default",
"base_model.model.lin0.oft_R.default.weight",
)
# change activate pter to pter1
peft_model.set_adapter("adapter1")
self.check_requires_grad(
peft_model,
"base_model.model.lin1.oft_r.adapter1",
"base_model.model.lin1.oft_s.adapter1",
"base_model.model.lin1.oft_R.adapter1.weight",
)
# disable all pters
@ -3421,39 +3509,35 @@ class TestRequiresGrad:
# after context is exited, return to the previous state
self.check_requires_grad(
peft_model,
"base_model.model.lin1.oft_r.adapter1",
"base_model.model.lin1.oft_s.adapter1",
"base_model.model.lin1.oft_R.adapter1.weight",
)
def test_requires_grad_oft_same_targets(self):
# same as previous test, except that OFT adapters target the same layer
config0 = OFTConfig(target_modules=["lin0"], r=2)
config0 = OFTConfig(target_modules=["lin0"], r=2, oft_block_size=0)
peft_model = get_peft_model(MLP(), config0)
config1 = OFTConfig(target_modules=["lin0"], r=2, inference_mode=True)
config1 = OFTConfig(target_modules=["lin0"], r=2, oft_block_size=0, inference_mode=True)
peft_model.add_adapter("adapter1", config1)
# active adapter is still "default"
self.check_requires_grad(
peft_model,
"base_model.model.lin0.oft_r.default",
"base_model.model.lin0.oft_s.default",
"base_model.model.lin0.oft_R.default.weight",
)
# set config0 as active, should not change anything
peft_model.set_adapter("default")
self.check_requires_grad(
peft_model,
"base_model.model.lin0.oft_r.default",
"base_model.model.lin0.oft_s.default",
"base_model.model.lin0.oft_R.default.weight",
)
# change activate adapter to adapter1
peft_model.set_adapter("adapter1")
self.check_requires_grad(
peft_model,
"base_model.model.lin0.oft_r.adapter1",
"base_model.model.lin0.oft_s.adapter1",
"base_model.model.lin0.oft_R.adapter1.weight",
)
# disable all adapters
@ -3464,8 +3548,7 @@ class TestRequiresGrad:
peft_model.set_adapter("adapter1")
self.check_requires_grad(
peft_model,
"base_model.model.lin0.oft_r.adapter1",
"base_model.model.lin0.oft_s.adapter1",
"base_model.model.lin0.oft_R.adapter1.weight",
)
def test_requires_grad_hra_different_targets(self):

View File

@ -32,6 +32,7 @@ from transformers import (
from peft import (
AdaLoraConfig,
LoraConfig,
OFTConfig,
PeftModel,
get_peft_model,
prepare_model_for_kbit_training,
@ -103,6 +104,43 @@ class PeftGPTQModelCommonTests(unittest.TestCase):
assert "default" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.lora_A
assert "adapter2" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.lora_A
def test_oft_gptq_quantization_from_pretrained_safetensors(self):
r"""
Tests that the gptqmodel quantization using OFT works as expected with safetensors weights.
"""
from transformers import GPTQConfig
model_id = "marcsun13/opt-350m-gptq-4bit"
quantization_config = GPTQConfig(bits=4, use_exllama=False)
kwargs = {
"pretrained_model_name_or_path": model_id,
"torch_dtype": torch.float16,
"device_map": "auto",
"quantization_config": quantization_config,
}
model = AutoModelForCausalLM.from_pretrained(**kwargs)
model = prepare_model_for_kbit_training(model)
config = OFTConfig(task_type="CAUSAL_LM")
peft_model = get_peft_model(model, config)
peft_model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(peft_model.device))
with tempfile.TemporaryDirectory() as tmp_dir:
peft_model.save_pretrained(tmp_dir)
model = AutoModelForCausalLM.from_pretrained(**kwargs)
model = PeftModel.from_pretrained(model, tmp_dir)
model = prepare_model_for_kbit_training(model)
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(peft_model.device))
# loading a 2nd adapter works, #1239
model.load_adapter(tmp_dir, "adapter2")
model.set_adapter("adapter2")
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(peft_model.device))
# check that both adapters are in the same layer
assert "default" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.oft_R
assert "adapter2" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.oft_R
@require_gptqmodel
@require_optimum
@ -186,6 +224,58 @@ class PeftGPTQModelTests(unittest.TestCase):
# assert loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None
def test_oft_causal_lm_training(self):
r"""
Test the CausalLM training on a single GPU device. The test would simply fail if the adapters are not set
correctly.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
torch_dtype=torch.float16,
device_map="auto",
quantization_config=self.quantization_config,
)
model = prepare_model_for_kbit_training(model)
config = OFTConfig(
r=0,
oft_block_size=8,
target_modules=["q_proj", "v_proj"],
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
data = load_dataset_english_quotes()
data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
trainer = Trainer(
model=model,
train_dataset=data["train"],
args=TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_steps=2,
max_steps=3,
learning_rate=2e-4,
fp16=True,
logging_steps=1,
output_dir=tmp_dir,
),
data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
)
model.config.use_cache = False
trainer.train()
model.cpu().save_pretrained(tmp_dir)
assert "adapter_config.json" in os.listdir(tmp_dir)
assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
# assert loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None
@pytest.mark.single_gpu_tests
def test_adalora_causalLM(self):
r"""
@ -315,6 +405,68 @@ class PeftGPTQModelTests(unittest.TestCase):
# assert loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None
@pytest.mark.multi_gpu_tests
@require_torch_multi_accelerator
def test_oft_causal_lm_training_multi_accelerator(self):
r"""
Test the CausalLM training on a multi-accelerator device. The test would simply fail if the adapters are not
set correctly.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
torch_dtype=torch.float16,
device_map="auto",
quantization_config=self.quantization_config,
)
assert set(model.hf_device_map.values()) == set(range(device_count))
model = prepare_model_for_kbit_training(model)
setattr(model, "model_parallel", True)
setattr(model, "is_parallelizable", True)
config = OFTConfig(
r=0,
oft_block_size=8,
target_modules=["q_proj", "v_proj"],
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
data = load_dataset_english_quotes()
data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
trainer = Trainer(
model=model,
train_dataset=data["train"],
args=TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_steps=2,
max_steps=3,
learning_rate=2e-4,
fp16=True,
logging_steps=1,
output_dir=tmp_dir,
),
data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
)
model.config.use_cache = False
trainer.train()
model.cpu().save_pretrained(tmp_dir)
assert "adapter_config.json" in os.listdir(tmp_dir)
assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
# assert loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None
def test_non_default_adapter_name(self):
# See issue 1346
config = LoraConfig(
@ -350,6 +502,42 @@ class PeftGPTQModelTests(unittest.TestCase):
assert n_trainable_default == n_trainable_other
assert n_total_default == n_total_other
def test_oft_non_default_adapter_name(self):
# See issue 1346
config = OFTConfig(
r=0,
oft_block_size=8,
target_modules=["q_proj", "v_proj"],
task_type="CAUSAL_LM",
)
# default adapter name
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
torch_dtype=torch.float16,
device_map="auto",
quantization_config=self.quantization_config,
)
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, config)
n_trainable_default, n_total_default = model.get_nb_trainable_parameters()
# other adapter name
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
torch_dtype=torch.float16,
device_map="auto",
quantization_config=self.quantization_config,
)
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, config, adapter_name="other")
n_trainable_other, n_total_other = model.get_nb_trainable_parameters()
assert n_trainable_other > 0
# sanity check
assert n_trainable_default == n_trainable_other
assert n_total_default == n_total_other
def test_load_lora(self):
model_id = "ModelCloud/Llama-3.2-1B-gptqmodel-ci-4bit"
adapter_id = "ModelCloud/Llama-3.2-1B-gptqmodel-ci-4bit-lora"

View File

@ -134,12 +134,15 @@ DIFFUSERS_CONFIGS = [
{
"text_encoder": {
"r": 1,
"oft_block_size": 0,
"target_modules": ["k_proj", "q_proj", "v_proj", "out_proj", "fc1", "fc2"],
"module_dropout": 0.0,
"init_weights": False,
"use_cayley_neumann": False,
},
"unet": {
"r": 1,
"oft_block_size": 0,
"target_modules": [
"proj_in",
"proj_out",
@ -152,6 +155,7 @@ DIFFUSERS_CONFIGS = [
],
"module_dropout": 0.0,
"init_weights": False,
"use_cayley_neumann": False,
},
},
),

View File

@ -46,7 +46,9 @@ CONFIGS = {
"lora": LoraConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"]),
"loha": LoHaConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"]),
"lokr": LoKrConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"]),
"oft": OFTConfig(r=1, target_modules=["convolution"], modules_to_save=["classifier", "normalization"]),
"oft": OFTConfig(
r=1, oft_block_size=0, target_modules=["convolution"], modules_to_save=["classifier", "normalization"]
),
"hra": HRAConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"]),
# TODO: cannot use BOFT because some convolutional kernel dimensions are even (64) and others odd (147). There is no
# common denominator for the boft_block_size except 1, but using 1 results in an error in the fbd_cuda kernel: