FEAT RoAd: 2D Rotary Adaptation (#2678)

Implements RoAd from https://arxiv.org/pdf/2409.00119

Supports mixed adapter batches.
This commit is contained in:
ppetrushkov
2025-08-19 15:45:38 +02:00
committed by GitHub
parent b5ace6a8c4
commit ce5c2044f1
23 changed files with 2443 additions and 50 deletions

View File

@ -132,6 +132,8 @@
title: C3A
- local: package_reference/miss
title: MiSS
- local: package_reference/road
title: RoAd
title: Adapters
- sections:

View File

@ -0,0 +1,31 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# RoAd
[RoAd](https://arxiv.org/pdf/2409.00119) is a parameterefficient finetuning technique that adapts large language models by learning a small set of 2×2 rotation matrices (and optional scaling factors) applied to pairs of hidden dimensions. RoAd achieves competitive or superior performance compared to other PEFT methods with under 0.1% trainable parameters. Unlike LoRAs batched lowrank updates, RoAds sparse rotations reformulate to simple elementwise operations, yielding significantly higher serving throughput when handling heterogeneous requests in the same batch, i.e. serving multiple adapters simulatenously. Moreover, RoAd integrates seamlessly into a distributed interchange intervention framework, interpreting its sparse 2D rotations as task-specific interventions within learned subspaces of hidden representations. These orthogonal subspaces can be composed to merge multiple task-specific behaviors—like multilingual capabilities or instruction following—without additional fine-tuning, enabling modular, interpretable adaptations in LLMs.
Finetuning with RoAd typically requires higher learning rate compared to LoRA or similar methods, around 1e-3. Currently RoAd only supports linear layers and it can be used on models quantized with bitsandbytes (4-bit or 8-bit).
For running inference with different RoAd adapters in the same batch see [Inference with different LoRA adapters in the same batch](../developer_guides/lora#inference-with-different-lora-adapters-in-the-same-batch).
## RoadConfig
[[autodoc]] tuners.road.config.RoadConfig
## RoadModel
[[autodoc]] tuners.road.model.RoadModel

View File

@ -0,0 +1,88 @@
# RoAd: 3-in-1: 2D Rotary Adaptation for Efficient Finetuning, Efficient Batching and Composability
## Introduction
[RoAd](https://arxiv.org/pdf/2409.00119) is a novel method that adapts LLMs using simple 2D rotations. It is highly parameter-efficient,
achieving strong performance with less than 0.1% trainable parameters.
RoAd also supports efficient serving of mixed-adapter requests within a batch, incurring only element-wise computation overhead rather than costly batch matrix multiplications.
Additionally, it improves model interpretability through structured and composable transformations.
## Quick start
```python
import torch
from peft import RoadConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer
from datasets import load_dataset
model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b", device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")
road_config = RoadConfig(
variant="1",
)
peft_model = get_peft_model(model, road_config)
trainer = transformers.Trainer(
model=peft_model,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=2048,
tokenizer=tokenizer,
)
trainer.train()
peft_model.save_pretrained("road-llama-3-8b")
```
RoAd requires a higher learning rate compared to LoRa and similar approaches, set it to around 1e-3.
Run the finetuning script simply by running:
```bash
python examples/road_finetuning/road_finetuning.py --base_model meta-llama/Meta-Llama-3-8B --data_path timdettmers/openassistant-guanaco
```
RoAd also supports quantization. To use 4-bit quantization try:
```bash
python examples/road_finetuning/road_finetuning.py --base_model meta-llama/Meta-Llama-3-8B --quantize
```
### Full example of the script
```bash
python road_finetuning.py \
--base_model "PATH_TO_MODEL" \
--data_path "PATH_TO_DATASET" \
--output_dir "PATH_TO_OUTPUT_DIR" \
--batch_size 1 \
--num_epochs 3 \
--learning_rate 1e-3 \
--cutoff_len 512 \
--val_set_size 500 \
--quantize \
--eval_step 10 \
--save_step 100 \
--device "cuda:0" \
--variant 1 \
--road_target_modules "q_proj,k_proj,v_proj,o_proj" \
--hub_model_id "YOUR_HF_REPO" \
--push_to_hub
```
## Use the model on 🤗
You can load and use the model as any other 🤗 models.
```python
from transformers import AutoModel
model = AutoModel.from_pretrained("ppetrushkov/llama-2-7b-sql-road-test")
```
## Citation
```
@inproceedings{
liao2024in,
title={3-in-1: 2D Rotary Adaptation for Efficient Finetuning, Efficient Batching and Composability},
author={Baohao Liao and Christof Monz},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
year={2024},
url={https://openreview.net/forum?id=rYjYwuM6yH}
}
```

View File

@ -0,0 +1,203 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)
from peft import RoadConfig, get_peft_model, prepare_model_for_kbit_training
def train_model(
base_model: str,
data_path: str,
output_dir: str,
batch_size: int,
num_epochs: int,
learning_rate: float,
cutoff_len: int,
val_set_size: int,
quantize: bool,
eval_step: int,
save_step: int,
device: str,
variant: str,
road_target_modules: str,
hub_model_id: str,
push_to_hub: bool,
):
os.environ["TOKENIZERS_PARALLELISM"] = "false"
hf_token = os.getenv("HF_TOKEN")
# Setup device
device = torch.device(device)
print(f"Using device: {device}")
# load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model, token=hf_token)
# IF YOU WANNA QUANTIZE THE MODEL
if quantize:
model = AutoModelForCausalLM.from_pretrained(
base_model,
token=hf_token,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=(
torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
),
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
),
)
# setup for quantized training
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
else:
model = AutoModelForCausalLM.from_pretrained(base_model, token=hf_token, device_map="auto")
# RoAd config for the PEFT model
road_config = RoadConfig(
variant=variant, # Rank of matrix
target_modules=(
road_target_modules.split(",")
if road_target_modules
else ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
),
)
# get the peft model with RoAd config
model = get_peft_model(model, road_config)
model.to(device) # MODEL TO GPU/CUDA
tokenizer.pad_token = tokenizer.eos_token
# Load the dataset
dataset = load_dataset(data_path)
def tokenize_function(examples):
inputs = tokenizer(examples["text"], padding="max_length", truncation=True, max_length=cutoff_len)
inputs["labels"] = inputs["input_ids"].copy() # setting labels for a language modeling task
return inputs
# Tokenize the dataset and prepare for training
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)
# Data collator to dynamically pad the batched examples
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
# Define training arguments
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
warmup_steps=100,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=eval_step,
save_steps=save_step,
save_total_limit=2,
push_to_hub=push_to_hub,
hub_model_id=hub_model_id,
gradient_accumulation_steps=16,
fp16=True,
learning_rate=learning_rate,
hub_token=hf_token,
)
# Clear CUDA cache to free memory
torch.cuda.empty_cache()
# Initialize the Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["test"],
data_collator=data_collator,
)
# Start model training
trainer.train()
# Save and push the trained model and tokenizer
if push_to_hub:
# Push the main model to the hub
trainer.push_to_hub(commit_message="Fine-tuned model")
# Save the model and tokenizer locally
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Fine-tune LLaMA with DoRA and PEFT")
parser.add_argument("--base_model", type=str, default="huggyllama/llama-7b", help="Base model path or name")
parser.add_argument(
"--data_path", type=str, default="timdettmers/openassistant-guanaco", help="Dataset path or name"
)
parser.add_argument(
"--output_dir", type=str, default="path/to/output", help="Output directory for the fine-tuned model"
)
parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
parser.add_argument("--learning_rate", type=float, default=3e-3, help="Learning rate")
parser.add_argument("--cutoff_len", type=int, default=512, help="Cutoff length for tokenization")
parser.add_argument("--val_set_size", type=int, default=500, help="Validation set size")
parser.add_argument("--quantize", action="store_true", help="Use quantization")
parser.add_argument("--eval_step", type=int, default=10, help="Evaluation step interval")
parser.add_argument("--save_step", type=int, default=100, help="Save step interval")
parser.add_argument("--device", type=str, default="cuda:0", help="Device to use for training")
parser.add_argument(
"--variant", type=str, default="road_1", choices=["road_1", "road_2", "road_4"], help="RoAD variant"
)
parser.add_argument(
"--road_target_modules", type=str, default=None, help="Comma-separated list of target modules for RoAd"
)
parser.add_argument(
"--hub_model_id",
type=str,
default="path/to/repo",
help="Repository name to push the model on the Hugging Face Hub",
)
parser.add_argument("--push_to_hub", action="store_true", help="Whether to push the model to Hugging Face Hub")
args = parser.parse_args()
train_model(
base_model=args.base_model,
data_path=args.data_path,
output_dir=args.output_dir,
batch_size=args.batch_size,
num_epochs=args.num_epochs,
learning_rate=args.learning_rate,
cutoff_len=args.cutoff_len,
val_set_size=args.val_set_size,
quantize=args.quantize,
eval_step=args.eval_step,
save_step=args.save_step,
device=args.device,
variant=args.variant,
road_target_modules=args.road_target_modules,
hub_model_id=args.hub_model_id,
push_to_hub=args.push_to_hub,
)

View File

@ -0,0 +1,12 @@
{
"auto_mapping": null,
"base_model_name_or_path": null,
"group_size": 64,
"inference_mode": false,
"init_weights": true,
"peft_type": "ROAD",
"revision": null,
"target_modules": null,
"task_type": null,
"variant": "road_2"
}

View File

@ -0,0 +1,5 @@
{
"optimizer_kwargs": {
"lr": 1e-3
}
}

View File

@ -93,6 +93,8 @@ from .tuners import (
PromptTuningInit,
RandLoraConfig,
RandLoraModel,
RoadConfig,
RoadModel,
ShiraConfig,
ShiraModel,
TrainableTokensConfig,
@ -194,6 +196,8 @@ __all__ = [
"PromptTuningInit",
"RandLoraConfig",
"RandLoraModel",
"RoadConfig",
"RoadModel",
"ShiraConfig",
"ShiraModel",
"TaskType",

View File

@ -42,6 +42,7 @@ from .poly import PolyConfig, PolyModel
from .prefix_tuning import PrefixEncoder, PrefixTuningConfig
from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit
from .randlora import RandLoraConfig, RandLoraModel
from .road import RoadConfig, RoadModel
from .shira import ShiraConfig, ShiraModel
from .trainable_tokens import TrainableTokensConfig, TrainableTokensModel
from .vblora import VBLoRAConfig, VBLoRAModel
@ -99,6 +100,8 @@ __all__ = [
"PromptTuningInit",
"RandLoraConfig",
"RandLoraModel",
"RoadConfig",
"RoadModel",
"ShiraConfig",
"ShiraModel",
"TrainableTokensConfig",

View File

@ -0,0 +1,47 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Based on implementation made available in https://github.com/ppetrushkov/peft/tree/road (not from paper authors)
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
from peft.utils import register_peft_method
from .config import RoadConfig
from .layer import Linear, RoadLayer
from .model import RoadModel
__all__ = [
"Linear",
"RoadConfig",
"RoadLayer",
"RoadModel",
]
register_peft_method(name="road", config_cls=RoadConfig, model_cls=RoadModel, is_mixed_compatible=True)
def __getattr__(name):
if (name == "Linear8bitLt") and is_bnb_available():
from .bnb import Linear8bitLt
return Linear8bitLt
if (name == "Linear4bit") and is_bnb_4bit_available():
from .bnb import Linear4bit
return Linear4bit
raise AttributeError(f"module {__name__} has no attribute {name}")

407
src/peft/tuners/road/bnb.py Normal file
View File

@ -0,0 +1,407 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import warnings
from typing import Any, Optional
import bitsandbytes as bnb
import torch
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
from peft.utils.integrations import dequantize_bnb_weight
from .config import RoadVariant
from .layer import RoadLayer, _apply_road, _get_delta_weight
if is_bnb_available():
class Linear8bitLt(torch.nn.Module, RoadLayer):
# Road implemented in a dense layer
def __init__(
self,
base_layer: torch.nn.Module,
adapter_name: str,
variant: RoadVariant = "road_1",
group_size: int = 64,
init_weights: bool = True,
**kwargs,
) -> None:
super().__init__()
RoadLayer.__init__(self, base_layer)
self._active_adapter = adapter_name
self.update_layer(
adapter_name,
variant=variant,
group_size=group_size,
init_weights=init_weights,
)
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Args:
safe_merge (`bool`, *optional*):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`list[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged.
Defaults to `None`.
"""
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
# no adapter to merge
return
for active_adapter in adapter_names:
if active_adapter in self._available_adapters:
warnings.warn(
"Merge road module to 8-bit linear may get different generations due to rounding errors."
)
weight = self.get_base_layer().weight
state = self.get_base_layer().state
if state.SCB is None:
state.SCB = weight.SCB
# Dequantize the result of identity matrix and int8 weight because bitsandbytes does not support int8
# dequantization directly
output = dequantize_bnb_weight(weight, state=state)
road_R = _get_delta_weight(
self.variant[active_adapter],
self.group_size[active_adapter],
self.road_theta[active_adapter].data,
self.road_alpha[active_adapter].data,
)
w_data = torch.matmul(road_R, output.to(road_R.dtype))
w_data = w_data.to(road_R.dtype).to(road_R.device).contiguous()
if safe_merge and not torch.isfinite(w_data).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
self.get_base_layer().weight = bnb.nn.Int8Params(
w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights
).to(weight.device)
if self.get_base_layer().bias is not None:
bias = self.get_base_layer().bias
orig_dtype = bias.dtype
bias_data = bias.data
new_bias = torch.matmul(road_R, bias_data.to(road_R.dtype))
bias.data = new_bias.to(orig_dtype)
state.reset_grads()
self.merged_adapters.append(active_adapter)
def unmerge(self) -> None:
"""
This method unmerges all merged adapter layers from the base weights.
"""
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter in self._available_adapters:
warnings.warn(
"Unmerge road module to 8-bit linear may get different generations due to rounding errors."
)
weight = self.get_base_layer().weight
state = self.get_base_layer().state
if state.SCB is None:
state.SCB = weight.SCB
output = dequantize_bnb_weight(weight, state=state)
road_R = _get_delta_weight(
self.variant[active_adapter],
self.group_size[active_adapter],
self.road_theta[active_adapter].data,
self.road_alpha[active_adapter].data,
)
inv_road_R = torch.linalg.inv(road_R.to(torch.float32)).to(road_R.dtype)
w_data = torch.matmul(inv_road_R, output.to(road_R.dtype))
w_data = w_data.to(road_R.dtype).to(road_R.device).contiguous()
self.get_base_layer().weight = bnb.nn.Int8Params(
w_data.to("cpu"), requires_grad=False, has_fp16_weights=weight.has_fp16_weights
).to(weight.device)
if self.get_base_layer().bias is not None:
bias = self.get_base_layer().bias
orig_dtype = bias.dtype
bias_data = bias.data
new_bias = torch.matmul(inv_road_R, bias_data)
bias.data = new_bias.to(orig_dtype)
state.reset_grads()
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
for active_adapter in self.active_adapters:
if active_adapter not in self._available_adapters:
continue
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
result = self._cast_input_dtype(result, self.road_theta[active_adapter].dtype)
result = _apply_road(
self.variant[active_adapter],
self.group_size[active_adapter],
self.road_theta[active_adapter],
self.road_alpha[active_adapter],
result,
)
if requires_conversion:
x = x.to(expected_dtype)
return result
def __repr__(self) -> str:
rep = super().__repr__()
return "road." + rep
def dispatch_bnb_8bit(target: torch.nn.Module, adapter_name: str, **kwargs):
new_module = None
if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target
loaded_in_8bit = kwargs.get("loaded_in_8bit", False)
if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt):
eightbit_kwargs = kwargs.copy()
eightbit_kwargs.update(
{
"has_fp16_weights": target.state.has_fp16_weights,
"threshold": target.state.threshold,
"index": target.index,
}
)
new_module = Linear8bitLt(target, adapter_name, **eightbit_kwargs)
return new_module
if is_bnb_4bit_available():
class Linear4bit(torch.nn.Module, RoadLayer):
# OFT implemented in a dense layer
def __init__(
self,
base_layer: torch.nn.Module,
adapter_name: str,
variant: RoadVariant = "road_1",
group_size: int = 64,
init_weights: bool = True,
**kwargs,
) -> None:
super().__init__()
RoadLayer.__init__(self, base_layer)
self._active_adapter = adapter_name
self.update_layer(
adapter_name,
variant=variant,
group_size=group_size,
init_weights=init_weights,
)
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Args:
safe_merge (`bool`, *optional*):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`list[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged.
Defaults to `None`.
"""
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
# no adapter to merge
return
for active_adapter in adapter_names:
if active_adapter in self._available_adapters:
warnings.warn(
"Merge oft module to 4-bit linear may get different generations due to rounding errors."
)
# Refer to https://gist.github.com/ChrisHayduk/1a53463331f52dca205e55982baf9930
weight = self.get_base_layer().weight
kwargs = weight.__dict__
output = dequantize_bnb_weight(weight, state=weight.quant_state)
road_R = _get_delta_weight(
self.variant[active_adapter],
self.group_size[active_adapter],
self.road_theta[active_adapter].data,
self.road_alpha[active_adapter].data,
)
w_data = torch.matmul(road_R, output.to(road_R.dtype))
w_data = w_data.to(road_R.dtype).to(road_R.device)
if safe_merge and not torch.isfinite(w_data).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
if "bnb_quantized" in kwargs:
kwargs["bnb_quantized"] = False
kwargs["requires_grad"] = False
kwargs.pop("data", None)
# torch.compile can introduce attributes preceded by '_', remove them
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), **kwargs).to(weight.device)
if self.get_base_layer().bias is not None:
bias = self.get_base_layer().bias
orig_dtype = bias.dtype
bias_data = bias.data
new_bias = torch.matmul(road_R, bias_data.to(road_R.dtype))
bias.data = new_bias.to(orig_dtype)
self.merged_adapters.append(active_adapter)
def unmerge(self) -> None:
"""
This method unmerges all merged adapter layers from the base weights.
"""
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter in self._available_adapters:
warnings.warn(
"Unmerge oft module to 4-bit linear may get different generations due to rounding errors."
)
weight = self.get_base_layer().weight
kwargs = weight.__dict__
output = dequantize_bnb_weight(weight, state=weight.quant_state)
road_R = _get_delta_weight(
self.variant[active_adapter],
self.group_size[active_adapter],
self.road_theta[active_adapter].data,
self.road_alpha[active_adapter].data,
)
inv_road_R = torch.linalg.inv(road_R.to(torch.float32)).to(road_R.dtype)
w_data = torch.matmul(inv_road_R, output.to(road_R.dtype))
w_data = w_data.to(road_R.dtype).to(road_R.device)
if "bnb_quantized" in kwargs:
kwargs["bnb_quantized"] = False
kwargs["requires_grad"] = False
kwargs.pop("data", None)
self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), **kwargs).to(weight.device)
if self.get_base_layer().bias is not None:
bias = self.get_base_layer().bias
orig_dtype = bias.dtype
bias_data = bias.data
new_bias = torch.matmul(inv_road_R, bias_data)
bias.data = new_bias.to(orig_dtype)
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
# As per Tim Dettmers, for 4bit, we need to defensively clone here.
# The reason is that in some cases, an error can occur that backprop
# does not work on a manipulated view. This issue may be solved with
# newer PyTorch versions but this would need extensive testing to be
# sure.
# result = result.clone()
for active_adapter in self.active_adapters:
if active_adapter not in self._available_adapters:
continue
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
result = self._cast_input_dtype(result, self.road_theta[active_adapter].dtype)
result = _apply_road(
self.variant[active_adapter],
self.group_size[active_adapter],
self.road_theta[active_adapter],
self.road_alpha[active_adapter],
result,
)
if requires_conversion:
x = x.to(expected_dtype)
return result
def __repr__(self) -> str:
rep = super().__repr__()
return "oft." + rep
def dispatch_bnb_4bit(target: torch.nn.Module, adapter_name: str, **kwargs):
new_module = None
if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target
loaded_in_4bit = kwargs.get("loaded_in_4bit", False)
if loaded_in_4bit and is_bnb_4bit_available() and isinstance(target_base_layer, bnb.nn.Linear4bit):
fourbit_kwargs = kwargs.copy()
fourbit_kwargs.update(
{
"compute_dtype": target_base_layer.compute_dtype,
"compress_statistics": target_base_layer.weight.compress_statistics,
"quant_type": target_base_layer.weight.quant_type,
}
)
new_module = Linear4bit(target, adapter_name, **fourbit_kwargs)
return new_module

View File

@ -0,0 +1,126 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Literal, Optional, Union
from peft.config import PeftConfig
from peft.utils import PeftType
RoadVariant = Literal["road_1", "road_2", "road_4"]
@dataclass
class RoadConfig(PeftConfig):
"""
This is the configuration class to store the configuration of a [`RoadModel`]. RoAd adapter is proposed in
https://arxiv.org/pdf/2409.00119.
Args:
variant (Union[`RoadVariant`, `str`]):
The variant of the Road model to use. It can be one of road_1, road_2, or road_4. Refer to the paper for
more details.
- road_1: Uses the same scale and angle for all pairs of elements.
This variant has lowest number of parameters, it stores a number equal to the output hidden size of
parameters for each layer that RoAd is applied to.
- road_2: Uses the same scale and angle for each element.
This variant has 2x the number of parameters compared to road_1.
- road_4: Uses two different scales and angles for each ellement.
This variant has 4x the number of parameters compared to road_1.
group_size (`int`):
Group size defines how elements are grouped together into 2D vectors for rotation. Within each group
element 0 is paired with element group_size/2, then element 1 is paired with element group_size/2+1 and so
on. This has no effect on the model performance, since elements are unordered, however it has some effect
on inference speed when used in e.g. VLLM. For best speed group size of at least 32 or 64 (the default) is
recommended. Note that model hidden size (or hidden size per partition when used with tensor parallelism)
must be divisible by group_size, so for very small models you might need to reduce this parameter.
init_weights (`bool`):
Whether to perform initialization of RoAd weights.
target_modules (`Optional[Union[List[str], str]]`):
The names of the modules to apply the adapter to. If this is specified, only the modules with the specified
names will be replaced. When passing a string, a regex match will be performed. When passing a list of
strings, either an exact match will be performed or it is checked if the name of the module ends with any
of the passed strings. If this is specified as 'all-linear', then all linear/Conv1D modules are chosen (if
the model is a PreTrainedModel, the output layer excluded). If this is not specified, modules will be
chosen according to the model architecture. If the architecture is not known, an error will be raised -- in
this case, you should specify the target modules manually.
modules_to_save (`List[str]`):
List of modules apart from Road layers to be set as trainable and saved in the final checkpoint.
"""
variant: Union[str, RoadVariant] = field(
default="road_1",
metadata={"help": ("Variant of the Road model to use.")},
)
group_size: int = field(
default=64,
metadata={
"help": (
"Group size defines how elements are grouped together into 2D vectors for rotation. "
"Within each group element 0 is paired with element group_size/2, "
"then element 1 is paired with element group_size/2+1 and so on. "
"This has no effect on the model performance, since elements are unordered, "
"however it has some effect on inference speed when used in e.g. VLLM. "
"For best speed group size of at least 64 is recommended. "
"Note that model hidden size (or hidden size per partition when used with tensor parallelism) "
"must be divisible by group_size, so for very small models you might need to reduce this parameter."
)
},
)
init_weights: bool = field(
default=True,
metadata={
"help": (
"Whether to initialize the weights of the RoAd layers with their default initialization. Don't change "
"this setting, except if you know exactly what you're doing."
),
},
)
target_modules: Optional[Union[list[str], str]] = field(
default=None,
metadata={
"help": (
"List of module names or regex expression of the module names to replace with Road."
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'."
"This can also be a wildcard 'all-linear' which matches all linear/Conv1D "
"(if the model is a PreTrainedModel, the output layer excluded)."
"If not specified, modules will be chosen according to the model architecture, If the architecture is "
"not known, an error will be raised -- in this case, you should specify the target modules manually."
),
},
)
modules_to_save: Optional[list[str]] = field(
default=None,
metadata={
"help": (
"List of modules apart from RoAd layers to be set as trainable and saved in the final checkpoint. For"
" example, in Sequence Classification or Token Classification tasks, the final layer"
" `classifier/score` are randomly initialized and as such need to be trainable and saved."
)
},
)
def __post_init__(self):
super().__post_init__()
self.peft_type = PeftType.ROAD
self.target_modules = (
set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules
)
if self.variant not in ["road_1", "road_2", "road_4"]:
raise ValueError(f"Invalid variant {self.variant} specified. Please choose from road_1, road_2 or road_4")
if self.group_size <= 0 or self.group_size % 2 != 0:
raise ValueError(f"The group_size must be divisible by 2 when using RoadLayer, but got {self.group_size}.")

View File

@ -0,0 +1,417 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Any, Optional, Union
import torch
import torch.nn as nn
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
from .config import RoadConfig, RoadVariant
class RoadLayer(BaseTunerLayer):
"""
Road layer.
Generally the idea of RoAD is to split the input vector into many 2D vectors and rotate each 2D vector with its own
2D rotation matrix. For additional flexibility, each rotation matrix is multiplied by a trainable scale.
when applied to vector R @ x each pair of elements of x is transformed like this: `y₀ = x₀ * α * cosθ - xₙ * α *
sinθ` and `yₙ = x₀ * α * sinθ + xₙ * α * cosθ`
The scales α and angles θ are learned for each pair of elements and, moreover, each of the 4 instances in the
rotation matrix may actually be different (when using variant 2 or 4).
Note that instead of using two consecutive elements x₀ x₁ we first split the whole vector into groups and pair
elements from the first with the second half of the same group, which allows for more efficient inference
implementation.
The adapter needs to only store the angles θ and scales α, rather than the full matrix R and the inference
implementation only needs to do elementwise vector multiplications.
For merging the weights, we make use of the following formula: R @ (W @ x + b) = (R @ W) @ x + R @ b. The lhs part
is how it is used in unmerged state (using efficient elementwise implementation instead of matrix multiplication)
and the rhs part is how it is used in merged state where (R @ W) becomes the new weight matrix and R @ b becomes
the new bias.
"""
adapter_layer_names: tuple[str, ...] = ("road_theta", "road_alpha")
other_param_names: tuple[str, ...] = ("variant", "group_size")
def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, **kwargs) -> None:
self.base_layer = base_layer
self.variant = {}
self.group_size = {}
self.road_theta = nn.ParameterDict({})
self.road_alpha = nn.ParameterDict({})
self._disable_adapters = False
self.merged_adapters = []
base_layer = self.get_base_layer()
if isinstance(base_layer, nn.Linear):
in_features, out_features = base_layer.in_features, base_layer.out_features
else:
raise ValueError(f"Unsupported layer type '{type(base_layer)}' encountered, cannot apply RoAd adapter.")
self.in_features = in_features
self.out_features = out_features
@property
def _available_adapters(self) -> set[str]:
return {*self.road_theta}
def update_layer(
self,
adapter_name,
variant,
group_size,
init_weights,
):
self.variant[adapter_name] = variant
self.group_size[adapter_name] = group_size
if self.out_features % group_size != 0:
raise ValueError(
f"The out_features of the base layer must be divisible by group_size ({group_size}) when using RoadLayer."
)
# Actual trainable parameters
if variant == "road_1":
size = self.out_features // 2
elif variant == "road_2":
size = self.out_features
elif variant == "road_4":
size = self.out_features * 2
else:
raise ValueError(
f"Unsupported variant {variant} for RoadLayer. Supported variants are road_1, road_2, and road_4."
)
self.road_theta[adapter_name] = nn.Parameter(torch.empty(size))
self.road_alpha[adapter_name] = nn.Parameter(torch.empty(size))
self.reset_parameters(adapter_name, init_weights)
self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters)
def reset_parameters(self, adapter_name, init_weights):
if init_weights is False:
nn.init.normal_(self.road_theta[adapter_name].data, mean=0.0, std=0.5)
nn.init.normal_(self.road_alpha[adapter_name].data, mean=1.0, std=0.5)
return
nn.init.zeros_(self.road_theta[adapter_name].data)
nn.init.ones_(self.road_alpha[adapter_name].data)
class Linear(nn.Module, RoadLayer):
# Road implemented in a dense layer
def __init__(
self,
base_layer,
adapter_name: str,
variant: RoadVariant = "road_1",
group_size: int = 64,
init_weights: Union[bool, str] = True,
**kwargs,
) -> None:
super().__init__()
RoadLayer.__init__(self, base_layer, **kwargs)
self._active_adapter = adapter_name
self.update_layer(
adapter_name,
variant,
group_size,
init_weights=init_weights,
)
def _check_forward_args(self, x, *args, **kwargs):
"""Check if the arguments are compatible with the configs and state of the model"""
adapter_names = kwargs.get("adapter_names", None)
if adapter_names is None:
return
if len(x) != len(adapter_names):
msg = (
"Length of `adapter_names` should be the same as the number of inputs, but got "
f"{len(adapter_names)} and {len(x)} respectively."
)
raise ValueError(msg)
if self.merged:
# It is unclear what would be the right thing to do if users pass adapter_names and there are merged
# adapters. Therefore, it is better to raise an error in this case.
msg = "Cannot pass `adapter_names` when there are merged adapters, please call `unmerge_adapter` first."
raise ValueError(msg)
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
self._check_forward_args(x, *args, **kwargs)
adapter_names = kwargs.pop("adapter_names", None)
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
elif adapter_names is not None:
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
torch_result_dtype = result.dtype
for active_adapter in self.active_adapters:
if active_adapter not in self._available_adapters:
continue
result = self._cast_input_dtype(result, self.road_theta[active_adapter].dtype)
result = _apply_road(
self.variant[active_adapter],
self.group_size[active_adapter],
self.road_theta[active_adapter],
self.road_alpha[active_adapter],
result,
)
result = result.to(torch_result_dtype)
return result
def _mixed_batch_forward(
self, x: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any
) -> torch.Tensor:
# This is a special method that handles the case when users pass the argument `adapter_names`. This is an
# extra argument that allows mixing different adapters in the same batch at inference time.
result = self.base_layer(x, *args, **kwargs)
unique_adapters = set(adapter_names)
sub_batch_indices_list = []
for adapter in unique_adapters:
sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter])
for i, active_adapter in enumerate(unique_adapters):
if active_adapter == "__base__":
continue
if active_adapter not in self._available_adapters:
continue
dtype = self.road_theta[active_adapter].data.dtype
# getting the sub-batch, passing it to Road layers and updating the corresponding indices of the linear
# layer output
sub_batch = result[sub_batch_indices_list[i]].to(dtype)
result[sub_batch_indices_list[i]] = _apply_road(
self.variant[active_adapter],
self.group_size[active_adapter],
self.road_theta[active_adapter],
self.road_alpha[active_adapter],
sub_batch,
)
return result
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights
Args:
safe_merge (`bool`, *optional*):
If `True`, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If `None`, all active adapters will be merged.
Defaults to `None`.
"""
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
# no adapter to merge
return
for active_adapter in adapter_names:
if active_adapter in self._available_adapters:
base_layer = self.get_base_layer()
orig_dtype = base_layer.weight.dtype
road_R = _get_delta_weight(
self.variant[active_adapter],
self.group_size[active_adapter],
self.road_theta[active_adapter].data,
self.road_alpha[active_adapter].data,
)
if safe_merge:
# Note that safe_merge will be slower than the normal merge
# because of the copy operation.
orig_weight = base_layer.weight.data.clone()
orig_weight = torch.matmul(road_R.to(orig_dtype), orig_weight)
if not torch.isfinite(orig_weight).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)
base_layer.weight.data = orig_weight.contiguous().to(orig_dtype)
if base_layer.bias is not None:
orig_bias = base_layer.bias.clone()
orig_bias = torch.matmul(road_R.to(orig_dtype), orig_bias)
if not torch.isfinite(orig_bias).all():
raise ValueError(
f"NaNs detected in the merged bias. The adapter {active_adapter} seems to be broken"
)
base_layer.bias.data = orig_bias.contiguous().to(orig_dtype)
else:
orig_weight = base_layer.weight.data
orig_weight = torch.matmul(road_R.to(orig_dtype), orig_weight)
base_layer.weight.data = orig_weight.contiguous().to(orig_dtype)
if base_layer.bias is not None:
orig_bias = base_layer.bias.data
orig_bias = torch.matmul(road_R.to(orig_dtype), orig_bias)
base_layer.bias.data = orig_bias.contiguous().to(orig_dtype)
self.merged_adapters.append(active_adapter)
def unmerge(self) -> None:
"""
This method unmerges all merged adapter layers from the base weights.
"""
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
while len(self.merged_adapters) > 0:
# Going in reverse order
active_adapter = self.merged_adapters.pop()
if active_adapter in self._available_adapters:
weight = self.get_base_layer().weight
orig_dtype = weight.dtype
road_R = _get_delta_weight(
self.variant[active_adapter],
self.group_size[active_adapter],
self.road_theta[active_adapter].data,
self.road_alpha[active_adapter].data,
)
# Since our matrix are not necessarily orthogonal we need inverse instead of transpose.
# In practice we expect this to basically always work since we start from block diagonal rotation matrix.
inv_road_R = torch.linalg.inv(road_R.to(torch.float32)).to(orig_dtype)
orig_weight = torch.matmul(inv_road_R, weight.data)
weight.data = orig_weight.contiguous()
if self.get_base_layer().bias is not None:
orig_bias = torch.matmul(inv_road_R, self.get_base_layer().bias.data)
self.get_base_layer().bias.data = orig_bias.contiguous()
def __repr__(self) -> str:
rep = super().__repr__()
return "road." + rep
def _get_delta_weight(variant: RoadVariant, group_size: int, road_theta: torch.Tensor, road_alpha: torch.Tensor):
first_col, second_col = _prepare_cols(variant, group_size, road_theta, road_alpha)
# To help understand the logic below consider how rope embeddings work
# here it is similar, but done in groups.
# https://discuss.huggingface.co/t/is-llama-rotary-embedding-implementation-correct/44509/3
# First column is simply put on the main diagonal
output_tensor = torch.diag(first_col)
# For second column we need to swap each half groups and add minus sign
size = second_col.shape[0]
swapped_second_col = second_col.reshape(-1, 2, group_size // 2)[:, [1, 0], :].flatten()
rotated_diag_second_col = torch.diag(swapped_second_col).reshape(-1, 2, group_size // 2, size)[:, [1, 0], :, :]
rotated_diag_second_col[:, 0, :, :] *= -1
rotated_diag_second_col = rotated_diag_second_col.reshape(size, size)
output_tensor += rotated_diag_second_col
return output_tensor
def _prepare_cols(
variant: RoadVariant, group_size: int, road_theta: torch.Tensor, road_alpha: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
# In inference mode, this can be cached
if variant == "road_1":
# In each group there are only group_size // 2 parameters that are reused
road_theta = road_theta.reshape(-1, group_size // 2).repeat_interleave(2, dim=0).flatten()
road_alpha = road_alpha.reshape(-1, group_size // 2).repeat_interleave(2, dim=0).flatten()
theta_cos = road_theta.cos()
theta_sin = road_theta.sin()
first_col = road_alpha * theta_cos
second_col = road_alpha * theta_sin
elif variant == "road_2":
# Each group has exactly group_size parameters
theta_cos = road_theta.cos()
theta_sin = road_theta.sin()
first_col = road_alpha * theta_cos
second_col = road_alpha * theta_sin
elif variant == "road_4":
# Each group has 2*group_size parameters, first half used for first column, second half for second column
road_theta = road_theta.reshape(-1, 2, group_size)
theta_cos = road_theta[:, 0, :].cos().flatten()
theta_sin = road_theta[:, 1, :].sin().flatten()
road_alpha = road_alpha.reshape(-1, 2, group_size)
alpha_1 = road_alpha[:, 0, :].flatten()
alpha_2 = road_alpha[:, 1, :].flatten()
first_col = alpha_1 * theta_cos
second_col = alpha_2 * theta_sin
else:
raise ValueError(
f"Unsupported variant {variant} for RoadLayer. Supported variants are road_1, road_2, and road_4."
)
return first_col, second_col
def _apply_road(
variant: RoadVariant, group_size: int, road_theta: torch.Tensor, road_alpha: torch.Tensor, x: torch.Tensor
):
first_col, second_col = _prepare_cols(variant, group_size, road_theta, road_alpha)
# Split in half groups and join back
# See equation 4 in the RoAD paper
x_grouped = x.reshape(-1, 2, group_size // 2)
x1 = x_grouped[:, 0, :]
x2 = x_grouped[:, 1, :]
rotate_half_x = torch.stack((-x2, x1), dim=1).reshape(x.shape)
result = x * first_col + rotate_half_x * second_col
return result
def dispatch_default(
target: torch.nn.Module,
adapter_name: str,
road_config: RoadConfig,
**kwargs,
) -> Optional[torch.nn.Module]:
new_module = None
if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target
if isinstance(target_base_layer, torch.nn.Linear):
new_module = Linear(target, adapter_name, **kwargs)
return new_module

View File

@ -0,0 +1,342 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import operator
from contextlib import contextmanager
from functools import partial
from typing import Optional
import torch
from torch import nn
from tqdm import tqdm
from peft.import_utils import is_bnb_4bit_available, is_bnb_available
from peft.tuners.road.config import RoadConfig
from peft.tuners.tuners_utils import (
BaseTuner,
BaseTunerLayer,
check_target_module_exists,
onload_layer,
)
from peft.utils import (
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
ModulesToSaveWrapper,
_get_submodules,
)
from .layer import RoadLayer, dispatch_default
def _adapter_names_pre_forward_hook(target, args, kwargs, adapter_names):
# pre-forward hook to inject the adapter_names argument when using mixed adapter batches inference
kwargs["adapter_names"] = adapter_names
return args, kwargs
class RoadModel(BaseTuner):
""" """
prefix: str = "road_"
@staticmethod
def _prepare_adapter_config(road_config: RoadConfig, model_config: dict) -> RoadConfig:
if road_config.target_modules is None:
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING:
raise ValueError("Please specify `target_modules` in `peft_config`")
road_config.target_modules = set(
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]]
)
return road_config
@staticmethod
def _check_target_module_exists(road_config, key):
return check_target_module_exists(road_config, key)
def _create_and_replace(
self,
road_config: RoadConfig,
adapter_name: str,
target: nn.Module,
target_name: str,
parent: nn.Module,
current_key,
) -> None:
if current_key is None:
raise ValueError("Current Key shouldn't be `None`")
# Regexp matching - Find key which matches current target_name in patterns provided
variant = road_config.variant
group_size = road_config.group_size
kwargs = {
"variant": variant,
"group_size": group_size,
"init_weights": road_config.init_weights,
"loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False),
"loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False),
}
# for torchao merging, we need the get_apply_tensor_subclass from the quantization config
try:
kwargs["get_apply_tensor_subclass"] = operator.attrgetter(
"hf_quantizer.quantization_config.get_apply_tensor_subclass"
)(self.model)
except AttributeError:
pass
if isinstance(target, RoadLayer):
target.update_layer(
adapter_name,
variant,
group_size,
init_weights=road_config.init_weights,
)
else:
device_map = self.model.hf_device_map if hasattr(self.model, "hf_device_map") else None
new_module = self._create_new_module(road_config, adapter_name, target, device_map=device_map, **kwargs)
if adapter_name not in self.active_adapters:
# adding an additional adapter: it is not automatically trainable
new_module.requires_grad_(False)
self._replace_module(parent, target_name, new_module, target)
def _replace_module(self, parent, child_name, new_module, child):
setattr(parent, child_name, new_module)
# It's not necessary to set requires_grad here, as that is handled by
# _mark_only_adapters_as_trainable
# child layer wraps the original module, unpack it
if hasattr(child, "base_layer"):
child = child.base_layer
meta = torch.device("meta")
# dispatch to correct device
for name, module in new_module.named_modules():
if (self.prefix in name) or ("ranknum" in name):
if hasattr(child, "qweight"):
weight = child.qweight
elif hasattr(child, "W_q"):
weight = child.W_q
elif hasattr(child, "weight"):
weight = child.weight
elif getattr(child, "in_proj_weight", None) is not None: # MHA
weight = child.in_proj_weight
else:
weight = next(child.parameters())
if not any(p.device == meta for p in module.parameters()):
module.to(weight.device)
@staticmethod
def _create_new_module(road_config: RoadConfig, adapter_name, target, **kwargs):
dispatchers = []
# avoid eager bnb import
if is_bnb_available():
from .bnb import dispatch_bnb_8bit
dispatchers.append(dispatch_bnb_8bit)
if is_bnb_4bit_available():
from .bnb import dispatch_bnb_4bit
dispatchers.append(dispatch_bnb_4bit)
dispatchers.extend(
[
dispatch_default,
]
)
new_module = None
for dispatcher in dispatchers:
new_module = dispatcher(target, adapter_name, road_config=road_config, **kwargs)
if new_module is not None: # first match wins
break
if new_module is None:
# no module could be matched
raise ValueError(
f"Target module {target} is not supported. Currently, only the following modules are supported: "
"`torch.nn.Linear`."
)
return new_module
def _mark_only_adapters_as_trainable(self, model: nn.Module):
for n, p in model.named_parameters():
if self.prefix not in n:
p.requires_grad = False
def _set_adapter_layers(self, enabled: bool = True) -> None:
for module in self.model.modules():
if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
module.enable_adapters(enabled)
def disable_adapter_layers(self) -> None:
self._set_adapter_layers(enabled=False)
def enable_adapter_layers(self) -> None:
self._set_adapter_layers(enabled=True)
def set_adapter(self, adapter_name: str | list[str]) -> None:
"""Set the active adapter(s).
Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is
not desired, use the following code.
```py
>>> for name, param in model_peft.named_parameters():
... if ...: # some check on name (ex. if 'lora' in name)
... param.requires_grad = False
```
Args:
adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated.
"""
for module in self.model.modules():
if isinstance(module, RoadLayer):
module.set_adapter(adapter_name)
self.active_adapter = adapter_name
def __getattr__(self, name: str):
"""Forward missing attributes to the wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
if name == "model": # see #1892: prevent infinite recursion if class is not initialized
raise
return getattr(self.model, name)
@contextmanager
def _enable_peft_forward_hooks(self, *args, **kwargs):
# If adapter_names is passed as an argument, we inject it into the forward arguments.
adapter_names = kwargs.pop("adapter_names", None)
if adapter_names is None:
# nothing to do
yield
return
if self.training:
raise ValueError("Cannot pass `adapter_names` when the model is in training mode.")
# Check that users only passed actually existing adapters.
# Note: We cannot do this on the layer level, as each individual layer may not have each adapter. Still, we want
# to check that there is at least one layer with the given name, or else something like typos can easily slip.
expected_adapters = set()
for layer in self.modules():
if isinstance(layer, RoadLayer):
expected_adapters |= layer.road_theta.keys()
unique_adapters = {name for name in adapter_names if name != "__base__"}
unexpected_adapters = unique_adapters - expected_adapters
if unexpected_adapters:
raise ValueError(f"Trying to infer with non-existing adapter(s): {', '.join(sorted(unexpected_adapters))}")
hook_handles = []
for module in self.modules():
if isinstance(module, RoadLayer):
pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=adapter_names)
handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
hook_handles.append(handle)
# TODO LoRA also has hooks for beam search, ignore this for now
yield
for handle in hook_handles:
handle.remove()
def _unload_and_optionally_merge(
self,
merge=True,
progressbar: bool = False,
safe_merge: bool = False,
adapter_names: Optional[list[str]] = None,
):
if merge:
self._check_merge_allowed()
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
desc = "Unloading " + ("and merging " if merge else "") + "model"
for key in tqdm(key_list, disable=not progressbar, desc=desc):
try:
parent, target, target_name = _get_submodules(self.model, key)
except AttributeError:
continue
with onload_layer(target):
if hasattr(target, "base_layer"):
if merge:
target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
self._replace_module(parent, target_name, target.get_base_layer(), target)
elif isinstance(target, ModulesToSaveWrapper):
# save any additional trainable modules part of `modules_to_save`
new_module = target.modules_to_save[target.active_adapter]
if hasattr(new_module, "base_layer"):
# check if the module is itself a tuner layer
if merge:
new_module.merge(safe_merge=safe_merge, adapter_names=adapter_names)
new_module = new_module.get_base_layer()
setattr(parent, target_name, new_module)
return self.model
def delete_adapter(self, adapter_name: str) -> None:
"""
Deletes an existing adapter.
Args:
adapter_name (str): Name of the adapter to be deleted.
"""
if adapter_name not in list(self.peft_config.keys()):
raise ValueError(f"Adapter {adapter_name} does not exist")
del self.peft_config[adapter_name]
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
new_adapter = None
for key in key_list:
_, target, _ = _get_submodules(self.model, key)
if isinstance(target, RoadLayer):
target.delete_adapter(adapter_name)
if new_adapter is None:
new_adapter = target.active_adapters[:]
self.active_adapter = new_adapter or []
self._delete_auxiliary_adapter(adapter_name, new_active_adapters=new_adapter)
def merge_and_unload(
self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
) -> torch.nn.Module:
r"""
This method merges the RoAd layers into the base model. This is needed if someone wants to use the base model
as a standalone model.
Args:
progressbar (`bool`):
whether to show a progressbar indicating the unload and merge process
safe_merge (`bool`):
whether to activate the safe merging check to check if there is any potential Nan in the adapter
weights
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
"""
return self._unload_and_optionally_merge(
progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names
)
def unload(self) -> torch.nn.Module:
"""
Gets back the base model by removing all the road modules without merging. This gives back the original base
model.
"""
return self._unload_and_optionally_merge(merge=False)

View File

@ -44,6 +44,7 @@ class PeftType(str, enum.Enum):
- RANDLORA
- SHIRA
- C3A
- ROAD
"""
PROMPT_TUNING = "PROMPT_TUNING"
@ -69,6 +70,7 @@ class PeftType(str, enum.Enum):
BONE = "BONE"
MISS = "MISS"
RANDLORA = "RANDLORA"
ROAD = "ROAD"
TRAINABLE_TOKENS = "TRAINABLE_TOKENS"
SHIRA = "SHIRA"
C3A = "C3A"

View File

@ -48,6 +48,7 @@ from peft import (
OFTConfig,
PeftModel,
RandLoraConfig,
RoadConfig,
TaskType,
VBLoRAConfig,
VeraConfig,
@ -74,12 +75,14 @@ if is_bnb_available():
from peft.tuners.ia3 import Linear8bitLt as IA3Linear8bitLt
from peft.tuners.lora import Linear8bitLt as LoraLinear8bitLt
from peft.tuners.randlora import Linear8bitLt as RandLoraLinear8bitLt
from peft.tuners.road import Linear8bitLt as RoadLinear8bitLt
from peft.tuners.vera import Linear8bitLt as VeraLinear8bitLt
if is_bnb_4bit_available():
from peft.tuners.ia3 import Linear4bit as IA3Linear4bit
from peft.tuners.lora import Linear4bit as LoraLinear4bit
from peft.tuners.randlora import Linear4bit as RandLoraLinear4bit
from peft.tuners.road import Linear4bit as RoadLinear4bit
from peft.tuners.vera import Linear4bit as VeraLinear4bit
@ -292,6 +295,49 @@ class PeftGPUCommonTests(unittest.TestCase):
whisper_8bit = get_peft_model(whisper_8bit, config)
assert isinstance(whisper_8bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, IA3Linear8bitLt)
@require_bitsandbytes
@pytest.mark.multi_gpu_tests
@pytest.mark.single_gpu_tests
def test_road_bnb_8bit_quantization(self):
r"""
Test that tests if the 8bit quantization using Road works as expected
"""
whisper_8bit = WhisperForConditionalGeneration.from_pretrained(
self.audio_model_id,
device_map="auto",
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
)
opt_8bit = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
device_map="auto",
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
)
flan_8bit = AutoModelForSeq2SeqLM.from_pretrained(
self.seq2seq_model_id,
device_map="auto",
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
)
flan_road_config = RoadConfig(target_modules=["q", "v"], task_type="SEQ_2_SEQ_LM")
opt_road_config = RoadConfig(
target_modules=["q_proj", "v_proj", "fc2"],
task_type="CAUSAL_LM",
)
config = RoadConfig(target_modules=["q_proj", "v_proj", "fc2"])
flan_8bit = get_peft_model(flan_8bit, flan_road_config)
assert isinstance(flan_8bit.base_model.model.encoder.block[0].layer[0].SelfAttention.q, RoadLinear8bitLt)
opt_8bit = get_peft_model(opt_8bit, opt_road_config)
assert isinstance(opt_8bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, RoadLinear8bitLt)
whisper_8bit = get_peft_model(whisper_8bit, config)
assert isinstance(whisper_8bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, RoadLinear8bitLt)
@require_bitsandbytes
@pytest.mark.multi_gpu_tests
@pytest.mark.single_gpu_tests
@ -697,6 +743,49 @@ class PeftGPUCommonTests(unittest.TestCase):
whisper_4bit = get_peft_model(whisper_4bit, config)
assert isinstance(whisper_4bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, IA3Linear4bit)
@require_bitsandbytes
@pytest.mark.multi_gpu_tests
@pytest.mark.single_gpu_tests
def test_road_bnb_4bit_quantization(self):
r"""
Test that tests if the 4bit quantization using IA3 works as expected
"""
whisper_4bit = WhisperForConditionalGeneration.from_pretrained(
self.audio_model_id,
device_map="auto",
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
)
opt_4bit = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
device_map="auto",
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
)
flan_4bit = AutoModelForSeq2SeqLM.from_pretrained(
self.seq2seq_model_id,
device_map="auto",
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
)
flan_road_config = RoadConfig(target_modules=["q", "v"], task_type="SEQ_2_SEQ_LM")
opt_road_config = RoadConfig(
target_modules=["q_proj", "v_proj", "fc2"],
task_type="CAUSAL_LM",
)
config = RoadConfig(target_modules=["q_proj", "v_proj", "fc2"])
flan_4bit = get_peft_model(flan_4bit, flan_road_config)
assert isinstance(flan_4bit.base_model.model.encoder.block[0].layer[0].SelfAttention.q, RoadLinear4bit)
opt_4bit = get_peft_model(opt_4bit, opt_road_config)
assert isinstance(opt_4bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, RoadLinear4bit)
whisper_4bit = get_peft_model(whisper_4bit, config)
assert isinstance(whisper_4bit.base_model.model.model.decoder.layers[0].self_attn.v_proj, RoadLinear4bit)
@pytest.mark.multi_gpu_tests
@require_torch_multi_accelerator
def test_lora_causal_lm_multi_gpu_inference(self):
@ -1520,6 +1609,98 @@ class PeftGPUCommonTests(unittest.TestCase):
layer.lora_A, layer.lora_B = la, lb
layer.lora_variant[adapter_name].init(layer, adapter_name=adapter_name) # should not raise an error
@require_non_cpu
@pytest.mark.single_gpu_tests
@require_bitsandbytes
def test_8bit_road_merging(self):
# Check results for merging, unmerging, unloading
torch.manual_seed(0)
model = AutoModelForCausalLM.from_pretrained(
"facebook/opt-125m",
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
torch_dtype=torch.float32,
).eval()
random_input = torch.LongTensor([[1, 0, 1, 0, 1, 0]]).to(model.device)
# compare outputs in probability space, because logits can have outliers
# and token ids are not precise enough
out_base = F.softmax(model(random_input).logits, dim=-1)
config = RoadConfig(
init_weights=False,
)
model = get_peft_model(model, config).eval()
with torch.inference_mode():
out_road = F.softmax(model(random_input).logits, dim=-1)
model.merge_adapter()
out_merged = F.softmax(model(random_input).logits, dim=-1)
model.unmerge_adapter()
out_unmerged = F.softmax(model(random_input).logits, dim=-1)
model = model.merge_and_unload()
out_unloaded = F.softmax(model(random_input).logits, dim=-1)
atol = 1e-3
rtol = 1
# sanity check that using DoRA changes the results
assert not torch.allclose(out_base, out_road, atol=atol, rtol=rtol)
assert torch.allclose(out_road, out_merged, atol=atol, rtol=rtol)
assert torch.allclose(out_road, out_unmerged, atol=atol, rtol=rtol)
assert torch.allclose(out_road, out_unloaded, atol=atol, rtol=rtol)
@require_non_cpu
@pytest.mark.single_gpu_tests
@require_bitsandbytes
def test_4bit_road_merging(self):
# Check results for merging, unmerging, unloading
torch.manual_seed(0)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
bnb_4bit_compute_dtype=torch.float32,
)
model = AutoModelForCausalLM.from_pretrained(
"trl-internal-testing/tiny-random-LlamaForCausalLM",
quantization_config=bnb_config,
torch_dtype=torch.float32,
).eval()
random_input = torch.LongTensor([[1, 0, 1, 0, 1, 0]]).to(model.device)
# compare outputs in probability space, because logits can have outliers
# and token ids are not precise enough
out_base = model(random_input).logits
probs_base = F.softmax(out_base, dim=-1)
config = RoadConfig(
init_weights=False,
group_size=4,
)
model = get_peft_model(model, config).eval()
with torch.inference_mode():
out_road = model(random_input).logits
probs_road = F.softmax(out_road, dim=-1)
model.merge_adapter()
probs_merged = F.softmax(model(random_input).logits, dim=-1)
model.unmerge_adapter()
probs_unmerged = F.softmax(model(random_input).logits, dim=-1)
model = model.merge_and_unload()
probs_unloaded = F.softmax(model(random_input).logits, dim=-1)
atol = 1e-5
rtol = 1e-3
# sanity check that using DoRA changes the results
# we compare outputs instead of logits because they may not be sensitive enough
assert not torch.allclose(out_base, out_road, atol=atol, rtol=rtol)
assert torch.allclose(probs_road, probs_merged, atol=atol, rtol=rtol)
assert torch.allclose(probs_road, probs_unmerged, atol=atol, rtol=rtol)
assert torch.allclose(probs_road, probs_unloaded, atol=atol, rtol=rtol)
def test_apply_GS_hra_inference(self):
# check for different result with and without apply_GS
model = AutoModelForCausalLM.from_pretrained(
@ -1984,3 +2165,21 @@ class TestSameAdapterDifferentDevices:
# the rest should be on GPU
assert model.lin0.base_layer.weight.device.type == self.device
assert model.lin0.hra_u.other.device.type == self.device
def test_road_add_new_adapter_does_not_change_device(self, mlp):
# same as first test, but using HRA
config = RoadConfig(target_modules=["lin0"], group_size=2)
model = get_peft_model(mlp, config)
model = model.to(self.device)
model.lin0.road_theta.cpu()
# check that the adapter is indeed on CPU and the base model on GPU
assert model.lin0.road_theta.default.device.type == "cpu"
assert model.lin0.base_layer.weight.device.type == self.device
model.add_adapter("other", config)
# check that after adding a new adapter, the old adapter is still on CPU
assert model.lin0.road_theta.default.device.type == "cpu"
# the rest should be on GPU
assert model.lin0.base_layer.weight.device.type == self.device
assert model.lin0.road_theta.other.device.type == self.device

View File

@ -40,6 +40,7 @@ from peft import (
PromptEncoder,
PromptEncoderConfig,
PromptTuningConfig,
RoadConfig,
TaskType,
VBLoRAConfig,
VeraConfig,
@ -65,6 +66,7 @@ ALL_CONFIG_CLASSES = (
(PrefixTuningConfig, {}),
(PromptEncoderConfig, {}),
(PromptTuningConfig, {}),
(RoadConfig, {}),
(VeraConfig, {}),
(VBLoRAConfig, {}),
)

View File

@ -48,6 +48,7 @@ from peft import (
PeftModel,
PeftWarning,
RandLoraConfig,
RoadConfig,
ShiraConfig,
TaskType,
TrainableTokensConfig,
@ -720,6 +721,15 @@ TEST_CASES = [
"init_weights": True,
},
),
########
# RoAd #
########
("Vanilla MLP 1 RoAd", "MLP", RoadConfig, {"target_modules": "lin0", "group_size": 2}),
("Vanilla MLP 2 RoAd", "MLP", RoadConfig, {"target_modules": ["lin0"], "group_size": 2}),
("Vanilla MLP 3 RoAd", "MLP", RoadConfig, {"target_modules": ["lin1"], "group_size": 2}),
("Vanilla MLP 4 RoAd", "MLP", RoadConfig, {"target_modules": ["lin0", "lin1"], "group_size": 2}),
("Vanilla MLP 5 RoAd", "MLP", RoadConfig, {"target_modules": ["lin0"], "variant": "road_2", "group_size": 2}),
("Vanilla MLP 6 RoAd", "MLP", RoadConfig, {"target_modules": ["lin0"], "variant": "road_4", "group_size": 2}),
]
# For this test matrix, each tuple consists of:
@ -933,6 +943,34 @@ MULTIPLE_ACTIVE_ADAPTERS_TEST_CASES = [
{"target_modules": ["lin0"], "init_weights": False, "boft_block_size": 2},
{"target_modules": ["lin1"], "init_weights": False, "boft_block_size": 2},
),
(
"RoAd Same",
"road",
RoadConfig,
{"target_modules": ["lin0"], "init_weights": False, "group_size": 2},
{"target_modules": ["lin0"], "init_weights": False, "group_size": 2},
),
(
"RoAd Different",
"road",
RoadConfig,
{"target_modules": ["lin0"], "init_weights": False, "group_size": 2},
{"target_modules": ["lin1"], "init_weights": False, "group_size": 2},
),
(
"RoAd 2 Different",
"road",
RoadConfig,
{"target_modules": ["lin0"], "init_weights": False, "variant": "road_1", "group_size": 2},
{"target_modules": ["lin1"], "init_weights": False, "variant": "road_2", "group_size": 2},
),
(
"RoAd 4 Different",
"road",
RoadConfig,
{"target_modules": ["lin0"], "init_weights": False, "variant": "road_1", "group_size": 2},
{"target_modules": ["lin1"], "init_weights": False, "variant": "road_4", "group_size": 2},
),
]
PREFIXES = {
@ -951,6 +989,7 @@ PREFIXES = {
ShiraConfig: "shira_",
VBLoRAConfig: "vblora_",
BoneConfig: "bone_",
RoadConfig: "road_",
MissConfig: "miss_",
TrainableTokensConfig: "trainable_tokens_",
}
@ -4665,17 +4704,29 @@ class TestRequiresGrad:
)
# this is for PEFT methods that support mixed adapter batches.
MIXED_ADAPTER_TEST_CASES = [
(
"LoRA mixed adapter",
LoraConfig(target_modules=["lin0"], init_lora_weights=False),
LoraConfig(target_modules=["lin0"], r=16, init_lora_weights=False),
),
(
"RoAd mixed adapter",
RoadConfig(target_modules=["lin0"], group_size=2, init_weights=False),
RoadConfig(target_modules=["lin0"], group_size=2, variant="road_2", init_weights=False),
),
]
class TestMixedAdapterBatches:
torch_device = infer_device()
@pytest.fixture
def mlp_lora(self):
def get_mlp_peft(self, config0, config1):
"""A simple MLP with 2 LoRA adapters"""
torch.manual_seed(0)
base_model = MLP().to(self.torch_device).eval()
config0 = LoraConfig(target_modules=["lin0"], init_lora_weights=False)
config1 = LoraConfig(target_modules=["lin0"], r=16, init_lora_weights=False)
peft_model = get_peft_model(base_model, config0, "adapter0").eval()
peft_model.add_adapter("adapter1", config1)
return peft_model
@ -4714,32 +4765,68 @@ class TestMixedAdapterBatches:
assert torch.allclose(output0[1::3], output_mixed[1::3])
assert torch.allclose(output1[2::3], output_mixed[2::3])
def test_mixed_adapter_batches_lora_mlp(self, mlp_lora):
@pytest.mark.parametrize("test_name, config0, config1", MIXED_ADAPTER_TEST_CASES)
def test_mixed_adapter_batches_mlp(self, test_name, config0, config1):
mlp_peft = self.get_mlp_peft(config0, config1)
inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)}
self.run_checks(mlp_lora, inputs)
self.run_checks(mlp_peft, inputs)
def test_mixed_adapter_batches_lora_different_target_layers(self, mlp_lora):
@pytest.mark.parametrize(
"test_name, config0, config1",
[
(
"LoRA mixed adapter with different target layers",
LoraConfig(target_modules=["lin0"], init_lora_weights=False),
LoraConfig(target_modules=["lin1"], init_lora_weights=False),
),
(
"RoAd mixed adapter with different target layers",
RoadConfig(target_modules=["lin0"], group_size=2, init_weights=False),
RoadConfig(target_modules=["lin1"], group_size=2, init_weights=False),
),
],
)
def test_mixed_adapter_batches_different_target_layers(self, test_name, config0, config1):
base_model = MLP().to(self.torch_device).eval()
config0 = LoraConfig(target_modules=["lin0"], init_lora_weights=False)
config1 = LoraConfig(target_modules=["lin1"], init_lora_weights=False)
peft_model = get_peft_model(base_model, config0, "adapter0").eval()
peft_model.add_adapter("adapter1", config1)
inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)}
self.run_checks(peft_model, inputs)
def test_mixed_adapter_batches_lora_multiple_modules_to_save(self, mlp_lora):
@pytest.mark.parametrize(
"test_name, config0, config1",
[
(
"LoRA mixed adapter with modules to save",
LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"], init_lora_weights=False),
LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"], init_lora_weights=False),
),
(
"RoAd mixed adapter with modules to save",
RoadConfig(target_modules=["lin0"], modules_to_save=["lin1"], group_size=2, init_weights=False),
RoadConfig(target_modules=["lin0"], modules_to_save=["lin1"], group_size=2, init_weights=False),
),
],
)
def test_mixed_adapter_batches_multiple_modules_to_save(self, test_name, config0, config1):
base_model = MLP().to(self.torch_device).eval()
config0 = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"], init_lora_weights=False)
config1 = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"], init_lora_weights=False)
peft_model = get_peft_model(base_model, config0, "adapter0").eval()
peft_model.add_adapter("adapter1", config1)
inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)}
self.run_checks(peft_model, inputs)
def test_mixed_adapter_batches_lora_unsupported_layer_raises(self, mlp_lora):
@pytest.mark.parametrize(
"test_name, config0, config1",
[
(
"LoRA mixed adapter with unsupported layer",
LoraConfig(target_modules=["lin0"], modules_to_save=["gru"], init_lora_weights=False),
LoraConfig(target_modules=["lin0"], modules_to_save=["gru"], init_lora_weights=False),
),
],
)
def test_mixed_adapter_batches_unsupported_layer_raises(self, test_name, config0, config1):
base_model = MLPWithGRU().to(self.torch_device).eval()
config0 = LoraConfig(target_modules=["lin0"], modules_to_save=["gru"], init_lora_weights=False)
config1 = LoraConfig(target_modules=["lin0"], modules_to_save=["gru"], init_lora_weights=False)
peft_model = get_peft_model(base_model, config0, "adapter0").eval()
peft_model.add_adapter("adapter1", config1)
inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)}
@ -4750,50 +4837,95 @@ class TestMixedAdapterBatches:
):
self.run_checks(peft_model, inputs)
def test_mixed_adapter_batches_lora_partly_overlapping_target_layers(self, mlp_lora):
@pytest.mark.parametrize(
"test_name, config0, config1",
[
(
"LoRA mixed adapter with overlapping layers",
LoraConfig(target_modules=["lin0"], init_lora_weights=False),
LoraConfig(target_modules=["lin0", "lin1"], init_lora_weights=False),
),
(
"RoAd mixed adapter with overlapping layers",
RoadConfig(target_modules=["lin0"], group_size=2, init_weights=False),
RoadConfig(target_modules=["lin0", "lin1"], group_size=2, init_weights=False),
),
],
)
def test_mixed_adapter_batches_partly_overlapping_target_layers(self, test_name, config0, config1):
base_model = MLP().to(self.torch_device).eval()
# target different lora layers
config0 = LoraConfig(target_modules=["lin0"], init_lora_weights=False)
config1 = LoraConfig(target_modules=["lin0", "lin1"], init_lora_weights=False)
peft_model = get_peft_model(base_model, config0, "adapter0").eval()
peft_model.add_adapter("adapter1", config1)
inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)}
self.run_checks(peft_model, inputs)
def test_mixed_adapter_batches_lora_conv1d_emb(self):
@pytest.mark.parametrize(
"test_name, config0, config1",
[
(
"LoRA mixed adapter with conv1d",
LoraConfig(target_modules=["emb", "conv1d"], init_lora_weights=False),
LoraConfig(target_modules=["emb", "conv1d"], r=16, init_lora_weights=False),
),
],
)
def test_mixed_adapter_batches_lora_conv1d_emb(self, test_name, config0, config1):
base_model = ModelEmbConv1D().to(self.torch_device).eval()
config0 = LoraConfig(target_modules=["emb", "conv1d"], init_lora_weights=False)
config1 = LoraConfig(target_modules=["emb", "conv1d"], r=16, init_lora_weights=False)
peft_model = get_peft_model(base_model, config0, "adapter0").eval()
peft_model.add_adapter("adapter1", config1)
inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)}
self.run_checks(peft_model, inputs)
def test_mixed_adapter_batches_lora_conv1d_emb_multiple_modules_to_save(self):
@pytest.mark.parametrize(
"test_name, config0, config1",
[
(
"LoRA mixed adapter with conv1d and emb and modules to save",
LoraConfig(target_modules=["emb", "conv1d"], modules_to_save=["lin0"], init_lora_weights=False),
LoraConfig(target_modules=["emb", "conv1d"], modules_to_save=["lin0"], init_lora_weights=False),
),
],
)
def test_mixed_adapter_batches_lora_conv1d_emb_multiple_modules_to_save(self, test_name, config0, config1):
base_model = ModelEmbConv1D().to(self.torch_device).eval()
config0 = LoraConfig(target_modules=["emb", "conv1d"], modules_to_save=["lin0"], init_lora_weights=False)
config1 = LoraConfig(target_modules=["emb", "conv1d"], modules_to_save=["lin0"], init_lora_weights=False)
peft_model = get_peft_model(base_model, config0, "adapter0").eval()
peft_model.add_adapter("adapter1", config1)
inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)}
self.run_checks(peft_model, inputs)
def test_mixed_adapter_batches_lora_conv2d(self):
@pytest.mark.parametrize(
"test_name, config0, config1",
[
(
"LoRA mixed adapter with conv2d",
LoraConfig(target_modules=["conv2d"], init_lora_weights=False),
LoraConfig(target_modules=["conv2d"], r=16, init_lora_weights=False),
),
],
)
def test_mixed_adapter_batches_lora_conv2d(self, test_name, config0, config1):
base_model = ModelConv2D().to(self.torch_device).eval()
config0 = LoraConfig(target_modules=["conv2d"], init_lora_weights=False)
config1 = LoraConfig(target_modules=["conv2d"], r=16, init_lora_weights=False)
peft_model = get_peft_model(base_model, config0, "adapter0").eval()
peft_model.add_adapter("adapter1", config1)
inputs = {"X": torch.arange(270).view(6, 5, 3, 3).to(self.torch_device)}
self.run_checks(peft_model, inputs)
def test_mixed_adapter_batches_mha_raises(self):
@pytest.mark.parametrize(
"test_name, config0, config1",
[
(
"LoRA mixed adapter with mha",
LoraConfig(target_modules=["mha"], init_lora_weights=False),
LoraConfig(target_modules=["mha"], r=16, init_lora_weights=False),
),
],
)
def test_mixed_adapter_batches_mha_raises(self, test_name, config0, config1):
base_model = ModelMha().to(self.torch_device).eval()
config0 = LoraConfig(target_modules=["mha"], init_lora_weights=False)
config1 = LoraConfig(target_modules=["mha"], r=16, init_lora_weights=False)
peft_model = get_peft_model(base_model, config0, "adapter0").eval()
peft_model.add_adapter("adapter1", config1)
@ -4802,56 +4934,76 @@ class TestMixedAdapterBatches:
with pytest.raises(TypeError, match=msg):
self.run_checks(peft_model, inputs)
def test_mixed_adapter_batches_lora_length_mismatch_raises(self, mlp_lora):
@pytest.mark.parametrize("test_name, config0, config1", MIXED_ADAPTER_TEST_CASES)
def test_mixed_adapter_batches_length_mismatch_raises(self, test_name, config0, config1):
mlp_peft = self.get_mlp_peft(config0, config1)
inputs = {
"X": torch.arange(90).view(-1, 10).to(self.torch_device),
"adapter_names": ["__base__"] * 5, # wrong length!
}
msg = r"Length of `adapter_names` should be the same as the number of inputs, but got "
with pytest.raises(ValueError, match=msg):
mlp_lora.forward(**inputs)
mlp_peft.forward(**inputs)
def test_mixed_adapter_batches_lora_training_mode_raises(self, mlp_lora):
@pytest.mark.parametrize("test_name, config0, config1", MIXED_ADAPTER_TEST_CASES)
def test_mixed_adapter_batches_training_mode_raises(self, test_name, config0, config1):
mlp_peft = self.get_mlp_peft(config0, config1)
inputs = {
"X": torch.arange(90).view(-1, 10).to(self.torch_device),
"adapter_names": ["__base__"] * 9,
}
mlp_lora = mlp_lora.train()
mlp_peft = mlp_peft.train()
msg = r"Cannot pass `adapter_names` when the model is in training mode."
with pytest.raises(ValueError, match=msg):
mlp_lora.forward(**inputs)
mlp_peft.forward(**inputs)
def test_mixed_adapter_batches_lora_disabled(self, mlp_lora):
@pytest.mark.parametrize("test_name, config0, config1", MIXED_ADAPTER_TEST_CASES)
def test_mixed_adapter_batches_disabled(self, test_name, config0, config1):
# Disabling adapters should have precedence over passing adapter names
mlp_peft = self.get_mlp_peft(config0, config1)
inputs = {"X": torch.arange(90).view(-1, 10).to(self.torch_device)}
with mlp_lora.disable_adapter():
output_disabled = mlp_lora(**inputs)
with mlp_peft.disable_adapter():
output_disabled = mlp_peft(**inputs)
adapters = ["__base__", "adapter0", "adapter1"]
inputs["adapter_names"] = [adapters[i % 3] for i in (range(len(inputs["X"])))]
with mlp_lora.disable_adapter():
output_mixed = mlp_lora.forward(**inputs)
with mlp_peft.disable_adapter():
output_mixed = mlp_peft.forward(**inputs)
assert torch.allclose(output_disabled, output_mixed)
def test_mixed_adapter_batches_lora_merged_raises(self, mlp_lora):
@pytest.mark.parametrize("test_name, config0, config1", MIXED_ADAPTER_TEST_CASES)
def test_mixed_adapter_batches_merged_raises(self, test_name, config0, config1):
# When there are merged adapters, passing adapter names should raise an error
mlp_peft = self.get_mlp_peft(config0, config1)
inputs = {
"X": torch.arange(90).view(-1, 10).to(self.torch_device),
"adapter_names": ["adapter0"] * 9,
}
mlp_lora.merge_adapter(["adapter0"])
mlp_peft.merge_adapter(["adapter0"])
msg = r"Cannot pass `adapter_names` when there are merged adapters, please call `unmerge_adapter` first."
with pytest.raises(ValueError, match=msg):
mlp_lora.forward(**inputs)
mlp_peft.forward(**inputs)
def test_mixed_adapter_batches_lora_wrong_adapter_name_raises(self):
@pytest.mark.parametrize(
"test_name, config",
[
(
"LoRA mixed batch wrong adapter name",
LoraConfig(target_modules=["lin0"], init_lora_weights=False),
),
(
"RoAD mixed batch wrong adapter name",
RoadConfig(target_modules=["lin0"], group_size=2, init_weights=False),
),
],
)
def test_mixed_adapter_batches_lora_wrong_adapter_name_raises(self, test_name, config):
# Ensure that all of the adapter names that are being passed actually exist
torch.manual_seed(0)
x = torch.arange(90).view(-1, 10).to(self.torch_device)
base_model = MLP().to(self.torch_device).eval()
config = LoraConfig(target_modules=["lin0"], init_lora_weights=False)
peft_model = get_peft_model(base_model, config).eval()
peft_model.add_adapter(adapter_name="other", peft_config=config)
@ -4906,8 +5058,25 @@ class TestMixedAdapterBatches:
}
peft_model.forward(**inputs)
@pytest.mark.parametrize(
"test_name, config0, config1, factor",
[
(
"LoRA mixed adapter timing",
LoraConfig(task_type="CAUSAL_LM", init_lora_weights=False),
LoraConfig(task_type="CAUSAL_LM", r=16, init_lora_weights=False),
2.0,
),
(
"RoAd mixed adapter timing",
RoadConfig(task_type="CAUSAL_LM", init_weights=False),
RoadConfig(task_type="CAUSAL_LM", variant="road_2", init_weights=False),
3.0,
),
],
)
@require_non_cpu
def test_mixed_adapter_batches_lora_opt_timing(self):
def test_mixed_adapter_batches_lora_opt_timing(self, test_name, config0, config1, factor):
# Use a more realistic model (opt-125m) and do a simple runtime check to ensure that mixed adapter batches
# don't add too much overhead. These types of tests are inherently flaky, so we try to add in some robustness.
logs = [] # store the time it takes to run each forward pass here
@ -4924,7 +5093,6 @@ class TestMixedAdapterBatches:
with timed():
output_base = base_model(**inputs).logits
config0 = LoraConfig(task_type="CAUSAL_LM", init_lora_weights=False)
peft_model = get_peft_model(base_model, config0, "adapter1").eval()
with timed():
output0 = peft_model(**inputs).logits
@ -4932,7 +5100,6 @@ class TestMixedAdapterBatches:
# sanity check, outputs are not the same
assert not torch.allclose(output_base, output0)
config1 = LoraConfig(task_type="CAUSAL_LM", r=16, init_lora_weights=False)
peft_model.add_adapter("adapter2", config1)
peft_model.set_adapter("adapter2")
with timed():
@ -4964,7 +5131,6 @@ class TestMixedAdapterBatches:
time_non_mixed = (time_base + time0 + time1) / 3
time_mixed = min(time_mixed)
factor = 2.0
assert time_mixed < factor * time_non_mixed
# Measure timing of running base and adapter separately vs using a mixed batch. Note that on CPU, the

View File

@ -42,6 +42,7 @@ from peft import (
PromptEncoderConfig,
PromptTuningConfig,
PromptTuningInit,
RoadConfig,
ShiraConfig,
VBLoRAConfig,
VeraConfig,
@ -193,6 +194,14 @@ ALL_CONFIGS = [
"num_virtual_tokens": 10,
},
),
(
RoadConfig,
{
"task_type": "CAUSAL_LM",
"variant": "road_1",
"group_size": 2,
},
),
(
ShiraConfig,
{
@ -242,11 +251,12 @@ def _skip_if_not_conv1d_supported(model_id, config_cls):
BoneConfig,
HRAConfig,
OFTConfig,
RoadConfig,
ShiraConfig,
C3AConfig,
MissConfig,
]:
pytest.skip("Skipping BOFT/HRA/OFT/Bone/SHiRA/C3A/MiSS for GPT2LMHeadModel")
pytest.skip("Skipping BOFT/HRA/OFT/Bone/Road/SHiRA/C3A/MiSS for GPT2LMHeadModel")
def _skip_adalora_oft_hra_bone_for_gpt2(model_id, config_cls):
@ -257,6 +267,7 @@ def _skip_adalora_oft_hra_bone_for_gpt2(model_id, config_cls):
OFTConfig,
BoneConfig,
C3AConfig,
RoadConfig,
MissConfig,
]:
pytest.skip("Skipping AdaLora/BOFT/HRA/OFT/Bone/MiSS for GPT2LMHeadModel")

View File

@ -31,6 +31,7 @@ from peft import (
PrefixTuningConfig,
PromptEncoderConfig,
PromptTuningConfig,
RoadConfig,
ShiraConfig,
TaskType,
VBLoRAConfig,
@ -155,6 +156,14 @@ ALL_CONFIGS = [
"task_type": "SEQ_2_SEQ_LM",
},
),
(
RoadConfig,
{
"task_type": "SEQ_2_SEQ_LM",
"variant": "road_1",
"group_size": 2,
},
),
(
ShiraConfig,
{

View File

@ -30,6 +30,7 @@ from peft import (
PromptEncoderConfig,
PromptLearningConfig,
PromptTuningConfig,
RoadConfig,
ShiraConfig,
VBLoRAConfig,
VeraConfig,
@ -155,6 +156,14 @@ ALL_CONFIGS = [
"num_virtual_tokens": 10,
},
),
(
RoadConfig,
{
"task_type": "FEATURE_EXTRACTION",
"variant": "road_1",
"group_size": 2,
},
),
(
ShiraConfig,
{

View File

@ -64,6 +64,7 @@ from peft import (
PrefixTuningConfig,
PromptEncoderConfig,
RandLoraConfig,
RoadConfig,
TaskType,
VeraConfig,
get_peft_model,
@ -1717,6 +1718,226 @@ class PeftBnbGPUExampleTests(unittest.TestCase):
# assert loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None
@pytest.mark.single_gpu_tests
def test_causal_lm_training_8bit_road(self):
r"""
Same as test_causal_lm_training but with RoAd
"""
with tempfile.TemporaryDirectory() as tmp_dir:
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
model = prepare_model_for_kbit_training(model)
config = RoadConfig(
variant="road_1",
target_modules=["q_proj", "v_proj"],
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
data = load_dataset("ybelkada/english_quotes_copy")
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
trainer = Trainer(
model=model,
train_dataset=data["train"],
args=TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_steps=2,
max_steps=3,
learning_rate=1e-3,
fp16=True,
logging_steps=1,
output_dir=tmp_dir,
),
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False
trainer.train()
model.cpu().save_pretrained(tmp_dir)
assert "adapter_config.json" in os.listdir(tmp_dir)
assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
# assert loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None
@pytest.mark.single_gpu_tests
def test_causal_lm_training_4bit_road(self):
r"""
Same as test_causal_lm_training_4bit but with RoAd
"""
with tempfile.TemporaryDirectory() as tmp_dir:
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
model = prepare_model_for_kbit_training(model)
config = RoadConfig(
variant="road_1",
target_modules=["q_proj", "v_proj"],
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
data = load_dataset("ybelkada/english_quotes_copy")
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
trainer = Trainer(
model=model,
train_dataset=data["train"],
args=TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_steps=2,
max_steps=3,
learning_rate=1e-3,
fp16=True,
logging_steps=1,
output_dir=tmp_dir,
),
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False
trainer.train()
model.cpu().save_pretrained(tmp_dir)
assert "adapter_config.json" in os.listdir(tmp_dir)
assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
# assert loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None
@pytest.mark.multi_gpu_tests
def test_causal_lm_training_multi_gpu_8bit_road(self):
r"""
Same as test_causal_lm_training_multi_gpu but with RoAd
"""
with tempfile.TemporaryDirectory() as tmp_dir:
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
device_map=DEVICE_MAP_MAP[self.causal_lm_model_id],
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
)
assert set(model.hf_device_map.values()) == set(range(device_count))
assert {p.device.index for p in model.parameters()} == set(range(device_count))
model = prepare_model_for_kbit_training(model)
setattr(model, "model_parallel", True)
setattr(model, "is_parallelizable", True)
config = RoadConfig(
variant="road_1",
target_modules=["q_proj", "v_proj"],
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
data = load_dataset("Abirate/english_quotes")
data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
trainer = Trainer(
model=model,
train_dataset=data["train"],
args=TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_steps=2,
max_steps=3,
learning_rate=1e-3,
fp16=True,
logging_steps=1,
output_dir=tmp_dir,
),
data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
)
model.config.use_cache = False
trainer.train()
model.cpu().save_pretrained(tmp_dir)
assert "adapter_config.json" in os.listdir(tmp_dir)
assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
# assert loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None
@pytest.mark.multi_gpu_tests
def test_causal_lm_training_multi_gpu_4bit_road(self):
r"""
Same as test_causal_lm_training_multi_gpu_4bit but with RoAd
"""
with tempfile.TemporaryDirectory() as tmp_dir:
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
device_map=DEVICE_MAP_MAP[self.causal_lm_model_id],
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
)
assert set(model.hf_device_map.values()) == set(range(device_count))
assert {p.device.index for p in model.parameters()} == set(range(device_count))
model = prepare_model_for_kbit_training(model)
setattr(model, "model_parallel", True)
setattr(model, "is_parallelizable", True)
config = RoadConfig(
variant="road_1",
target_modules=["q_proj", "v_proj"],
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
data = load_dataset("Abirate/english_quotes")
data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
trainer = Trainer(
model=model,
train_dataset=data["train"],
args=TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_steps=2,
max_steps=3,
learning_rate=1e-3,
fp16=True,
logging_steps=1,
output_dir=tmp_dir,
),
data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
)
model.config.use_cache = False
trainer.train()
model.cpu().save_pretrained(tmp_dir)
assert "adapter_config.json" in os.listdir(tmp_dir)
assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
# assert loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None
@pytest.mark.single_gpu_tests
def test_causal_lm_training_lora_resize_embeddings_trainable_tokens(self):
r"""

View File

@ -53,6 +53,7 @@ from peft import (
PeftWarning,
PrefixTuningConfig,
PromptTuningConfig,
RoadConfig,
VBLoRAConfig,
VeraConfig,
get_eva_state_dict,
@ -1768,6 +1769,83 @@ class TestC3AInitialization:
get_peft_model(model, config)
class TestRoadInitialization:
torch_device = infer_device()
def get_model(self):
class MLP(nn.Module):
def __init__(self, bias=True):
super().__init__()
self.lin0 = nn.Linear(10, 30, bias=bias)
self.lin1 = nn.Linear(30, 2, bias=bias)
def forward(self, X):
X = self.lin0(X)
X = self.lin1(X)
return X
return MLP().to(self.torch_device)
def get_conv2d_model(self):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
# choose a large weight so that averages are close to expected values
self.linear = nn.Linear(1000, 1000)
self.embed = nn.Embedding(1000, 1000)
self.conv2d = nn.Conv2d(100, 100, 3)
def forward(self, x):
x_int = (100 * x).int()
x_4d = x.flatten().reshape(1, 100, 10, 10)
return self.linear(x), self.embed(x_int), self.conv2d(x_4d)
return MyModule().eval().to(self.torch_device)
def test_road_default_initialization(self):
torch.manual_seed(0)
model = self.get_model()
config = RoadConfig(target_modules=["lin0"], group_size=2)
model = get_peft_model(model, config)
weight_alpha = model.lin0.road_alpha["default"].data
weight_theta = model.lin0.road_theta["default"].data
torch.allclose(weight_alpha, torch.ones_like(weight_alpha))
torch.allclose(weight_theta, torch.zeros_like(weight_theta))
def test_road_with_odd_group_size(self):
group_size = 3 # odd values are not allowed
msg = f"The group_size must be divisible by 2 when using RoadLayer, but got {group_size}."
with pytest.raises(ValueError, match=re.escape(msg)):
RoadConfig(group_size=group_size)
def test_road_with_too_large_group_size(self):
group_size = 64 # larger than out_features
msg = (
f"The out_features of the base layer must be divisible by group_size ({group_size}) when using RoadLayer."
)
model = self.get_model()
config = RoadConfig(target_modules=["lin0"], group_size=group_size)
with pytest.raises(ValueError, match=re.escape(msg)):
get_peft_model(model, config)
def test_road_with_incompatible_group_size_with_out_features(self):
group_size = 4 # even, but 30 does not divide by 4
model = self.get_model()
config = RoadConfig(target_modules=["lin0"], group_size=group_size)
msg = (
f"The out_features of the base layer must be divisible by group_size ({group_size}) when using RoadLayer."
)
with pytest.raises(ValueError, match=re.escape(msg)):
get_peft_model(model, config)
def test_road_with_conv2d_layer(self):
model = self.get_conv2d_model()
config = RoadConfig(target_modules=["conv2d"], group_size=2)
msg = "Target module Conv2d(100, 100, kernel_size=(3, 3), stride=(1, 1)) is not supported. Currently, only the following modules are supported: `torch.nn.Linear`."
with pytest.raises(ValueError, match=re.escape(msg)):
get_peft_model(model, config)
class TestNoInfiniteRecursionDeepspeed:
# see #1892 for details
classes = [

View File

@ -30,6 +30,7 @@ from peft import (
PromptEncoderConfig,
PromptTuningConfig,
PromptTuningInit,
RoadConfig,
ShiraConfig,
VBLoRAConfig,
VeraConfig,
@ -156,6 +157,14 @@ ALL_CONFIGS = [
"num_virtual_tokens": 10,
},
),
(
RoadConfig,
{
"task_type": "SEQ_CLS",
"variant": "road_1",
"group_size": 2,
},
),
(
ShiraConfig,
{