Add Arrow + GenKnowSub to LoRA (#2644)

This PR adds support for Arrow, a modular routing mechanism for LoRA experts introduced here, as well as the refinement method GenKnowSub, proposed in our ACL 2025 Main Conference paper. GenKnowSub enhances Arrow by subtracting a general-domain LoRA from task-specific ones prior to routing, leading to improved generalisation and modularity.
This commit is contained in:
Mohammadtaha Bagherifard
2025-09-08 15:51:37 +03:30
committed by GitHub
parent ed5c6eaa1a
commit 42db980676
15 changed files with 1859 additions and 19 deletions

View File

@ -687,3 +687,148 @@ Using this feature has some drawbacks, namely:
- Increase the batch size.
- Try to avoid having a large number of different adapters in the same batch, prefer homogeneous batches. This can be achieved by buffering samples with the same adapter and only perform inference with a small handful of different adapters.
- Take a look at alternative implementations such as [LoRAX](https://github.com/predibase/lorax), [punica](https://github.com/punica-ai/punica), or [S-LoRA](https://github.com/S-LoRA/S-LoRA), which are specialized to work with a large number of different adapters.
## Composing and Reusing LoRA Adapters
### Arrow
[Arrow](https://huggingface.co/papers/2405.11157) is a modular routing algorithm designed to combine multiple pre-trained task-specific LoRA adapters to solve a given task. Rather than merging all adapters naively, Arrow introduces a **gradient-free, token-wise mixture-of-experts (MoE) routing mechanism**. At inference time, it first computes a _prototype_ for each LoRA by extracting the top right singular vector from its SVD decomposition. Each token representation is then compared to these prototypes via cosine similarity to obtain routing coefficients. Tokens are assigned to the top-k most relevant LoRA adapters, with the coefficients normalized through softmax, and their outputs linearly combined. This allows effective reuse of existing LoRA modules for new tasks and leads to stronger zero-shot generalization.
In PEFT, Arrow is enabled through ```ArrowConfig``` and ```create_arrow_model```. You can also configure parameters such as ```top_k``` (the number of LoRA adapters combined per token), ```router_temperature``` (the softmax temperature applied to the routing coefficients), and ```rng_seed``` (for reproducibility).
```py
from peft import create_arrow_model, ArrowConfig
from transformers import AutoModelForCausalLM
# Loading the model
base_model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
# Creating the Arrow config
arrow_config = ArrowConfig(
top_k=3,
router_temperature=1.0,
rng_seed=42,
)
# The LoRA adapters below were trained on a clustered FLAN dataset.
# Task clustering was performed using the Model-Based Clustering (MBC) method,
# as described in the Arrow paper.
# While one could train a separate LoRA for each task and let Arrow route tokens among them,
# training LoRAs on clusters of tasks instead provides an indirect optimization for
# transfer across the multi-task dataset.
task_specific_adapter_paths = [
f"TahaBa/phi3-mini-clustered-flan/ts_expert_{i}" for i in range(10)
]
# Creating the Arrow model
model = create_arrow_model(
base_model=base_model,
task_specific_adapter_paths=task_specific_adapter_paths,
arrow_config=arrow_config,
)
# Now the forward path could be called on this model, like a normal PeftModel.
```
Furthermore, you can add or remove adapters after calling ```create_arrow_model```—for example, to fine-tune a new adapter or discard an unnecessary one. Once the adapters are in place, you can activate the ```"arrow_router"``` for inference to use Arrow. Note that if you add a new LoRA adapter after ```create_arrow_model``` and want to fine-tune it, you must explicitly set the new adapter as active, since ```"arrow_router"``` is activated by default in ```create_arrow_model```.
```py
from trl import SFTTrainer, SFTConfig
# Adding a new adapter and activating it
model.add_adapter(adapter_name='new_adapter')
model.set_adapter('new_adapter')
# Now the model could be trained along the `new_adapter`.
trainer = SFTTrainer(
model=model,
args=SFTConfig(...),
...
)
# Once the training is done, you can activate `arrow_router` and use it in inference
model.set_adapter('arrow_router') # Model is ready to be used at inference time now
```
### GenKnowSub
[GenKnowSub](https://aclanthology.org/2025.acl-short.54/) augments Arrow by purifying task-specific LoRA adapters before routing. The key idea is to subtract general knowledge encoded in LoRA space—based on the [forgetting-via-negation principle](https://huggingface.co/papers/2212.04089)—so that task adapters become more isolated and focused on task-relevant signals. Concretely, GenKnowSub estimates a low-dimensional “general” subspace from a set of general (non task-specific) LoRA adapters and removes this component from each task adapters LoRA update prior to Arrows token-wise routing. This typically improves compositionality and reduces interference when combining many task adapters.
In PEFT, enable GenKnowSub by setting ```use_gks=True``` in ArrowConfig, and providing ```general_adapter_paths``` in ```create_arrow_model```:
```py
from peft import create_arrow_model, ArrowConfig
from transformers import AutoModelForCausalLM
# Loading the model
base_model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
# Creating the Arrow config
arrow_config = ArrowConfig(
top_k=3,
router_temperature=1.0,
use_gks=True,
rng_seed=42,
)
# Path to task-specific, trained on flan clustered dataset (as we explained before.)
task_specific_adapter_paths = [
f"TahaBa/phi3-mini-clustered-flan/ts_expert_{i}" for i in range(10)
]
# These general adapters are trained on English, German, and French Wikipedia dataset,
# with causal language modelling objective, each pair like: (507 token tsentence, 5 token completion), and the loss computed on the completion
general_adapter_paths = [
"TahaBa/phi3-mini-general-adapters/cluster0_batch16_prop1.0_langen/checkpoint-17",
"TahaBa/phi3-mini-general-adapters/cluster0_batch16_prop1.0_langfr/checkpoint-35",
"TahaBa/phi3-mini-general-adapters/cluster0_batch16_prop1.0_langger/checkpoint-17"
]
# Creating the Arrow model
model = create_arrow_model(
base_model=base_model,
task_specific_adapter_paths=task_specific_adapter_paths,
general_adapter_paths=general_adapter_paths,
arrow_config=arrow_config,
)
# Now the forward path could be called on this model, like a normal PeftModel.
```
To encode general knowledge, GenKnowSub subtracts the average of the provided general adapters from each task-specific adapter once, before routing begins. Furthermore, the ability to add or remove adapters after calling ```create_arrow_model``` (as described in the Arrow section) is still supported in this case.
<Tip>
**Things to keep in mind when using Arrow + GenKnowSub:**
- All LoRA adapters (task-specific and general) must share the same ```rank``` and ```target_modules```.
- Any inconsistency in these settings will raise an error in ```create_arrow_model```.
- Having different scaling factors (```lora_alpha```) across task adapters is supported — Arrow handles them automatically.
- Merging the ```"arrow_router"``` is not supported, due to its dynamic routing behavior.
- In create_arrow_model, task adapters are loaded as ```task_i``` and general adapters as ```gks_j``` (where ```i``` and ```j``` are indices). The function ensures consistency of ```target_modules```, ```rank```, and whether adapters are applied to ```Linear``` or ```Linear4bit``` layers. It then adds the ```"arrow_router"``` module and activates it. Any customization of this process requires overriding ```create_arrow_model```.
- This implementation is compatible with 4-bit quantization (via bitsandbytes):
```py
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import torch
# Quantisation config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=False,
)
# Loading the model
base_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Phi-3-mini-4k-instruct",
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=bnb_config,
)
# Now call create_arrow_model() as we explained before.
```
</Tip>

View File

@ -32,6 +32,10 @@ The abstract from the paper is:
## Utility
### ArrowConfig
[[autodoc]] tuners.lora.config.ArrowConfig
### LoftQ
[[autodoc]] utils.loftq_utils.replace_lora_weights_loftq

View File

@ -0,0 +1,375 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script provides a simple evaluation pipeline for multiple-choice reasoning datasets
(e.g., BoolQ, HellaSwag, ARC, OpenBookQA, Winogrande) with different composition strategies.
Usage examples:
python arrow_phi3_mini.py --strategy base --ds_name arc-challenge
python arrow_phi3_mini.py --strategy arrow --ds_name boolq
python arrow_phi3_mini.py --strategy gks --ds_name hswag
Key features:
- Supports three strategies:
"base" → Evaluate the quantized base model directly
"arrow" → Use Arrow modular routing with task-specific adapters
"gks" → Use Arrow + GenKnowSub (subtracting general-domain knowledge)
- Loads evaluation datasets from the Hugging Face Hub
- Implements a batched evaluation loop that computes per-option likelihoods and selects
the answer with the lowest average loss
- Reports simple accuracy
Implementation details:
- The base model is quantized to 4-bit using `BitsAndBytesConfig` (nf4, bf16 compute).
- For Arrow and GKS, task-specific adapters are loaded from the Hugging Face Hub:
TahaBa/phi3-mini-clustered-flan/ts_expert_i
- Task-specific adapters were trained on 10 clusters of FLAN tasks.
- The clusters were created using Model-Based Clustering (MBC):
1. Train a LoRA adapter for each individual task.
2. Apply k-means clustering to group tasks based on these adapters.
3. Train a LoRA adapter for each resulting cluster.
For more details, see the Arrow paper: https://huggingface.co/papers/2405.11157
- For GKS, general adapters are loaded from:
TahaBa/phi3-mini-general-adapters/...
- These adapters were trained on English, French, and German Wikipedia data
using a causal language modeling objective with (507-token context → 5-token completion) pairs.
- This setup encodes general knowledge into the LoRA space, which can then be
subtracted from task-specific adapters during inference to isolate and purify them.
For more details, see the GenKnowSub paper: https://huggingface.co/papers/2505.10939
- `evaluate_on_multi_choice_batched` handles tokenization, masking context tokens,
and computing per-choice log-likelihoods for fair comparison.
- Accuracy is printed at the end for the selected dataset.
This script is mainly meant for demonstration purposes and lightweight evaluation,
not full-scale benchmarking (batch size / max length can be tuned).
=======================================================================================
Results (evaluated with microsoft/Phi-3-mini-4k-instruct, 4-bit quantization):
| Dataset | Base Acc. | Arrow Acc. | Arrow+GKS Acc. |
|--------------|-----------|------------|----------------|
| ARC-Challenge| 0.4515 | 0.5418 | 0.5585 |
| ARC-Easy | 0.6894 | 0.8404 | 0.8473 |
| Winogrande | 0.5769 | 0.6550 | 0.6724 |
| BoolQ | 0.8146 | 0.8030 | 0.8247 |
| OpenBookQA | 0.43 | 0.448 | 0.472 |
| HellaSwag | 0.7318 | 0.7150 | 0.7376 |
Observations:
- Arrow generally improves over the base model by routing tokens to the most relevant task adapters.
- Applying GKS (general knowledge subtraction) consistently gives further gains compared to Arrow and Base.
These numbers are not meant as leaderboard results, but as a sanity check
to verify that the implementation works as expected and demonstrates
the benefits of Arrow and GenKnowSub.
"""
import argparse
import random
import numpy as np
import torch
from datasets import load_dataset
from sklearn.metrics import accuracy_score
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import ArrowConfig, create_arrow_model
MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct"
MODEL_MAX_LEN = 2048
def parse_args():
parser = argparse.ArgumentParser(description="Training script with strategy selection")
parser.add_argument(
"--strategy",
type=str,
choices=["base", "arrow", "gks"],
default="base",
help="Training strategy to use: base, arrow, or gks",
)
parser.add_argument(
"--ds_name",
type=str,
choices=["boolq", "hswag", "arc-easy", "arc-challenge", "oqa", "wg"],
default="arc-challenge",
help="Dataset to use: boolq, hswag, arc-easy, arc-challenge, oqa, wg",
)
return parser.parse_args()
def read_test_dataset(ds_name):
if ds_name == "boolq":
ds = load_dataset("google/boolq", split="validation", trust_remote_code=True)
elif ds_name == "hswag":
ds = load_dataset("Rowan/hellaswag", split="validation", trust_remote_code=True)
elif ds_name == "arc-challenge":
ds = load_dataset("allenai/ai2_arc", "ARC-Challenge", split="validation", trust_remote_code=True)
elif ds_name == "arc-easy":
ds = load_dataset("allenai/ai2_arc", "ARC-Easy", split="validation", trust_remote_code=True)
elif ds_name == "oqa":
ds = load_dataset("allenai/openbookqa", split="validation", trust_remote_code=True)
elif ds_name == "wg":
ds = load_dataset("allenai/winogrande", "winogrande_xl", split="validation", trust_remote_code=True)
else:
raise f"Dataset {ds_name} is not supported yet."
return ds
def extract_input_content(ds_name, row):
if ds_name == "boolq":
return f"[passage]{row['passage']}[question]{row['question']}"
if ds_name == "hswag":
return row["ctx"]
if (ds_name == "arc-challenge") or (ds_name == "arc-easy"):
return row["question"]
if ds_name == "oqa":
return row["question_stem"]
if ds_name == "wg":
return row["sentence"]
def create_multi_choice_options(row, ds_name):
options_texts = []
content = extract_input_content(ds_name, row)
if ds_name == "boolq":
choices = ["true", "false"]
if ds_name == "hswag":
choices = row["endings"]
if (ds_name == "arc-challenge") or (ds_name == "arc-easy"):
choices = row["choices"]["text"]
if ds_name == "wg":
choices = [row["option1"], row["option2"]]
if ds_name == "oqa":
choices = row["choices"]["text"]
for choice in choices:
options_texts.append(f"<|user|>\n{content}<|end|>\n<|assistant|>{choice}<|end|>\n")
return options_texts
def extract_multi_choice_target_index(row, ds_name):
if ds_name == "boolq":
return 0 if row["answer"] is True else 1
if ds_name == "hswag":
return int(row["label"])
if (ds_name == "arc-challenge") or (ds_name == "arc-easy"):
return row["choices"]["label"].index(row["answerKey"])
if ds_name == "wg":
return int(row["answer"]) - 1
if ds_name == "oqa":
return row["choices"]["label"].index(row["answerKey"])
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def compute_loglike_loss(logits, labels, reduction="none"):
bs = logits.size(0)
vocab_size = logits.size(-1)
labels = labels.squeeze(-1)
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss(reduction=reduction)
shift_logits = shift_logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
# reshape back
if reduction == "none":
loss = loss.view((bs, -1))
non_zero_loss = (loss != 0).sum(dim=-1)
non_zero_loss[non_zero_loss == 0] = 1
loss = loss.sum(dim=-1) / non_zero_loss
return loss.float() # Convert to float32 before returning
def evaluate_on_multi_choice_batched(
eval_dataset, model, tokenizer, ds_name, labels, predictions, args, batch_size=32, max_length=512, device="cuda"
):
# Local import to mirror your original function
model.eval()
for start in tqdm(
range(0, len(eval_dataset), batch_size), total=(len(eval_dataset) + batch_size - 1) // batch_size
):
rows = [eval_dataset[i] for i in range(start, min(start + batch_size, len(eval_dataset)))]
# Build the flattened option texts for this batch
all_texts = []
options_per_sample = [] # number of options for each sample
ctx_lens_per_option = [] # context length replicated per option
for row in rows:
# options: ["<|user|>...<|assistant|>choiceA<|end|>", ...]
options = create_multi_choice_options(row, ds_name)
options_per_sample.append(len(options))
# compute context length once per sample (align with your -1 shift)
content = extract_input_content(ds_name, row)
context_prompt = f"<|user|>\n{content}<|end|>\n<|assistant|>"
ctx_len = len(tokenizer.encode(context_prompt)) - 1
all_texts.extend(options)
ctx_lens_per_option.extend([ctx_len] * len(options))
# collect gold label
labels.append(extract_multi_choice_target_index(row, ds_name))
# Tokenize all options in one go
tokenized = tokenizer(
all_texts,
return_tensors="pt",
padding=True,
truncation=True,
max_length=max_length,
)
tokenized = {k: v.to(device) for k, v in tokenized.items()}
# Create masked labels: ignore context and padding
masked_labels = tokenized["input_ids"].clone()
for i, ctx_len in enumerate(ctx_lens_per_option):
masked_labels[i, :ctx_len] = -100
masked_labels[tokenized["attention_mask"] == 0] = -100
with torch.no_grad():
logits = model(input_ids=tokenized["input_ids"], attention_mask=tokenized["attention_mask"]).logits
# per-sequence losses
losses = compute_loglike_loss(logits, masked_labels, reduction="none").detach().cpu()
# Reduce per sample (argmin across its options)
idx = 0
for n_opt in options_per_sample:
pred = torch.argmin(losses[idx : idx + n_opt]).item()
predictions.append(pred)
idx += n_opt
print(
f"Accuracy for dataset {args.ds_name} and strategy {args.strategy} is: {accuracy_score(labels, predictions)}"
)
if __name__ == "__main__":
args = parse_args()
print(f"Selected strategy: {args.strategy}")
print(f"Dataset name: {args.ds_name}")
# Loading the tokeniser
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
use_fast=True,
padding_side="right",
model_max_length=MODEL_MAX_LEN,
)
# Quantisation config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=False,
)
# Loading the model
base_model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=bnb_config,
)
# Loading the test dataset
test_dataset = read_test_dataset(args.ds_name)
print(f"{args.ds_name} is loaded with size: {len(test_dataset)}.")
labels, predictions = [], []
if args.strategy == "base":
# Batch-wise inference
with torch.no_grad():
evaluate_on_multi_choice_batched(
test_dataset,
base_model,
tokenizer,
args.ds_name,
labels,
predictions,
args,
batch_size=64, # tune this
max_length=512, # tune if options are long
device="cuda",
)
else:
general_adapter_paths = []
if args.strategy == "gks":
arrow_config = ArrowConfig(
top_k=3,
router_temperature=1.0,
use_gks=True,
)
# General adapter paths from the hub
general_adapter_paths = [
"TahaBa/phi3-mini-general-adapters/cluster0_batch16_prop1.0_langen/checkpoint-17",
"TahaBa/phi3-mini-general-adapters/cluster0_batch16_prop1.0_langfr/checkpoint-35",
"TahaBa/phi3-mini-general-adapters/cluster0_batch16_prop1.0_langger/checkpoint-17",
]
else:
arrow_config = ArrowConfig(
top_k=3,
router_temperature=1.0,
)
# Task-specific adapter paths from the hub
task_specific_adapter_paths = [f"TahaBa/phi3-mini-clustered-flan/ts_expert_{i}" for i in range(10)]
# Creating the Arrow model
model = create_arrow_model(
base_model=base_model,
task_specific_adapter_paths=task_specific_adapter_paths,
general_adapter_paths=general_adapter_paths,
arrow_config=arrow_config,
)
# Batch-wise inference
with torch.no_grad():
evaluate_on_multi_choice_batched(
test_dataset,
model,
tokenizer,
args.ds_name,
labels,
predictions,
args,
batch_size=32, # tune this
max_length=512, # tune if options are long
device="cuda",
)

View File

@ -0,0 +1,8 @@
torch
transformers
accelerate
datasets
scikit-learn
tqdm
numpy
bitsandbytes

View File

@ -50,6 +50,7 @@ from .tuners import (
AdaLoraModel,
AdaptionPromptConfig,
AdaptionPromptModel,
ArrowConfig,
BOFTConfig,
BOFTModel,
BoneConfig,
@ -105,6 +106,7 @@ from .tuners import (
VeraModel,
XLoraConfig,
XLoraModel,
create_arrow_model,
get_eva_state_dict,
initialize_lora_eva_weights,
)
@ -134,6 +136,7 @@ __all__ = [
"AdaLoraModel",
"AdaptionPromptConfig",
"AdaptionPromptModel",
"ArrowConfig",
"AutoPeftModel",
"AutoPeftModelForCausalLM",
"AutoPeftModelForFeatureExtraction",
@ -212,6 +215,7 @@ __all__ = [
"XLoraModel",
"bloom_model_postprocess_past_key_value",
"cast_mixed_precision_params",
"create_arrow_model",
"get_eva_state_dict",
"get_layer_status",
"get_model_status",

View File

@ -25,11 +25,13 @@ from .ln_tuning import LNTuningConfig, LNTuningModel
from .loha import LoHaConfig, LoHaModel
from .lokr import LoKrConfig, LoKrModel
from .lora import (
ArrowConfig,
EvaConfig,
LoftQConfig,
LoraConfig,
LoraModel,
LoraRuntimeConfig,
create_arrow_model,
get_eva_state_dict,
initialize_lora_eva_weights,
)
@ -55,6 +57,7 @@ __all__ = [
"AdaLoraModel",
"AdaptionPromptConfig",
"AdaptionPromptModel",
"ArrowConfig",
"BOFTConfig",
"BOFTModel",
"BoneConfig",
@ -112,6 +115,7 @@ __all__ = [
"VeraModel",
"XLoraConfig",
"XLoraModel",
"create_arrow_model",
"get_eva_state_dict",
"initialize_lora_eva_weights",
]

View File

@ -15,7 +15,8 @@
from peft.import_utils import is_bnb_4bit_available, is_bnb_available, is_eetq_available
from peft.utils import register_peft_method
from .config import EvaConfig, LoftQConfig, LoraConfig, LoraRuntimeConfig
from .arrow import create_arrow_model
from .config import ArrowConfig, EvaConfig, LoftQConfig, LoraConfig, LoraRuntimeConfig
from .eva import get_eva_state_dict, initialize_lora_eva_weights
from .gptq import GPTQLoraLinear
from .layer import Conv2d, Conv3d, Embedding, Linear, LoraLayer, ParamWrapper
@ -23,6 +24,7 @@ from .model import LoraModel
__all__ = [
"ArrowConfig",
"Conv2d",
"Conv3d",
"Embedding",
@ -35,6 +37,7 @@ __all__ = [
"LoraModel",
"LoraRuntimeConfig",
"ParamWrapper",
"create_arrow_model",
"get_eva_state_dict",
"initialize_lora_eva_weights",
]

View File

@ -0,0 +1,476 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
from typing import Any
import torch
from torch import nn
from transformers import PreTrainedModel
from .config import ArrowConfig
TASK_ADAPTER_PREFIX = "task_"
GKS_ADAPTER_PREFIX = "gks_"
class ArrowLoraLinearLayer(nn.Module):
"""
This class represent the main logic of the arrow routing algorithm for linear layers.
"""
def __init__(self, in_features, arrow_config):
super().__init__()
# extra parameters needed for arrow
self.in_features = in_features
self._protos_ready = False
self.top_k = arrow_config.top_k
self.temperature = arrow_config.router_temperature
self.rng_seed = arrow_config.rng_seed
self.task_adapter_names = (
arrow_config.task_adapter_names.copy()
) # Set in create_arrow_model() with this format: task_0, task_1, ...
self.gks_adapter_names = (
arrow_config.gks_adapter_names
) # Set in create_arrow_model() with this format: gks_0, gks_1, ...
self.use_gks = arrow_config.use_gks
self.gks_done = False
self.gks_added_adapter_names = []
self.in_features = in_features
self.cast_input_dtype_enabled = True
@torch.no_grad()
def on_adapter_change(self, lora_A, lora_B):
"""
Called when adapters are added/removed/renamed so Arrow can refresh its internal state before the next forward
pass.
"""
all_ts_adapter_names = [
k
for k in lora_A.keys()
if k in lora_B and k != "arrow_router" and not (k.startswith("gks_") and k[len("gks_") :].isdigit())
]
if sorted(self.task_adapter_names) == sorted(all_ts_adapter_names): # No changes in the ts_adapters
return
# Getting the name(s) of added adapter(s)
if len(self.task_adapter_names) < len(all_ts_adapter_names): # Adapter(s) are added.
self.gks_added_adapter_names = [x for x in all_ts_adapter_names if x not in self.task_adapter_names]
# Updating the task_adapter_names
self.task_adapter_names = all_ts_adapter_names.copy()
# Invalidate caches so theyll be rebuilt lazily on next forward()
self._protos_ready = False
# GKS will be handled by self.gks_added_adapter_names
def top_right_singular_vec_from_BA(self, A, B, iters=15, eps=1e-8):
"""
Computes the top *right* singular vector of ΔW = B @ A without forming ΔW.
Theory:
For any matrix M, the right singular vectors are the eigenvectors of Mᵀ M. If ΔW = B @ A (with A ∈
^{r×in}, B ∈ ^{out×r}), then
ΔWᵀ ΔW = (B @ A)ᵀ (B @ A) = Aᵀ (Bᵀ B) A ∈ ^{in×in}.
Therefore, the dominant right singular vector of ΔW is the dominant eigenvector of M := Aᵀ (Bᵀ B) A. We
find it by *power iteration* on the linear operator
v ↦ Aᵀ (Bᵀ B) (A v),
which avoids materializing ΔW (out×in) or M (in×in). The result lives in the input/token space (size =
in_features), which is exactly what Arrow needs. (Right singular vectors ≡ eigenvectors of MᵀM; power
iteration converges to the dominant eigenvector under mild conditions.)
=============================== Practical notes:
- We perform all iteration in float32 for numerical stability, then cast back
to the LoRA dtype/device before storing/using the prototype.
- Convergence is checked with a simple fixed-iter cap (`iters`) and/or
`allclose` tolerance (`tol`).
- The returned vector is unique up to sign (±), as with any singular vector.
Downstream code should be sign-invariant.
"""
# A: (r, in), B: (out, r)
A32 = A.to(torch.float32)
B32 = B.to(torch.float32)
C = B32.T @ B32 # (r, r)
# Private RNG on A's device
gen = None
if self.rng_seed is not None:
gen = torch.Generator(device=A32.device.type)
gen.manual_seed(int(self.rng_seed))
# init vector in input space
v = torch.randn(A32.size(1), dtype=A32.dtype, device=A32.device, generator=gen)
v = v / (v.norm() + eps)
for _ in range(iters):
# w = (ΔWᵀΔW) v = Aᵀ (BᵀB) (A v)
w = A32.T @ (C @ (A32 @ v))
v = w / (w.norm() + eps)
return v # fp32
@torch.no_grad()
def build_prototypes(self, lora_A, lora_B):
"""
Computes a prototype vector for each LoRA module in every layer by applying Singular Value Decomposition (SVD)
to the `lora_A` matrix and extracting the top right singular vector.
These prototypes are later used to calculate the cosine similarity between each input token and each expert.
The resulting similarity scores serve as coefficients to compute a weighted average of the corresponding LoRA
modules, effectively routing each token through its most relevant experts.
** This prototype computation is done is done once for all experts and is re-done on newly added adapters.**
Args:
lora_A : Matrices A in LoRA layer.
lora_B (optional): Matrices B in LoRA layer. Defaults to None.
"""
if self._protos_ready:
return
protos = []
for name in self.task_adapter_names:
A = lora_A[name].weight # (r, in_features)
B = lora_B[name].weight # (out_features, r)
# Efficiently computing right singular vector of A @ B
proto32 = self.top_right_singular_vec_from_BA(A, B)
proto = proto32.to(dtype=A.dtype, device=A.device)
protos.append(proto)
proto_stack = torch.stack(protos, dim=0) # (E, in_features)
# Register the prototypes buffer with correct dtype/device consistent with A and B weights
self.register_buffer("prototypes", proto_stack, persistent=False)
self._protos_ready = True
@torch.no_grad()
def gen_know_sub(self, lora_A, lora_B):
"""
This function performs General Knowledge Subtraction. It takes an average of provided general_adapters, and
subtract it from each task_adapter. This subtraction tries to purify the task adapters, based on
"forgetting-via-negation" principle. Forgetting-via-negation is a task-arithmetic operation, explained in:
https://arxiv.org/abs/2212.04089 The task adapters will be more focused and isolated, enhancing the performance
on new tasks.
Args:
lora_A : Matrices A in LoRA layer.
lora_B : Matrices A in LoRA layer.
"""
if not self.use_gks:
return
elif self.gks_done and not self.gks_added_adapter_names:
return
else:
# 1) compute average A/B over gks_adapter_names
avg_A = torch.stack([lora_A[n].weight for n in self.gks_adapter_names], dim=0).mean(
0
) # shape (r, in_features)
avg_B = torch.stack([lora_B[n].weight for n in self.gks_adapter_names], dim=0).mean(
0
) # shape (out_features, r)
# 2) Subtract the average from task-specific experts
if self.gks_done is False: # GKS is done for all the experts, since it hasn't been done yet.
for name in self.task_adapter_names:
lora_A[name].weight.data.sub_(avg_A)
lora_B[name].weight.data.sub_(avg_B)
else: # GKS is only done on new added experts, since GKS has been done previously.
for name in self.gks_added_adapter_names:
lora_A[name].weight.data.sub_(avg_A)
lora_B[name].weight.data.sub_(avg_B)
# 3) Set gks_done flag as true, so we won't do it again in ArrowLinearVariant.forward().
self.gks_done = True
# Clearing the self.gks_added_adapter_names
self.gks_added_adapter_names = []
def _cast_input_dtype(self, x, dtype: torch.dtype):
"""
Whether to cast the dtype of the input of the forward method.
Usually, we want to enable this to align the input dtype with the dtype of the weight, but by setting
layer.cast_input_dtype=False, this can be disabled if necessary.
Enabling or disabling can be managed via the peft.helpers.disable_lora_input_dtype_casting context manager.
"""
if x is None: # useful e.g. if x is the bias, which can be None
return None
cast_input_dtype_enabled = getattr(self, "cast_input_dtype_enabled", True)
if (not cast_input_dtype_enabled) or (x.dtype == dtype):
return x
return x.to(dtype=dtype)
def forward(self, x, lora_A, lora_B, dropout, scaling):
"""
Applies Arrow routing inside a LoRA layer.
Steps:
1. Compute cosine similarity between each token representation and all adapter prototypes.
2. Select the top-k experts per token and normalize their scores with a softmax.
3. Project tokens into each selected experts low-rank space (A weights).
4. Map back to the output space (B weights).
5. Aggregate expert outputs via the weighted sum of their contributions.
6. Apply dropout, scaling, and return the reshaped delta.
- Conceptually, this is a Mixture-of-Experts (MoE) over LoRA adapters,
where coefficients are derived from prototype similarity.
Returns:
delta: LoRA output adjustment computed by Arrow routing.
"""
x = self._cast_input_dtype(x, lora_A[self.task_adapter_names[0]].weight.dtype)
B, *rest, F_in = x.shape
tok = x.view(-1, F_in) # (t, F_in)
t, E = tok.size(0), self.prototypes.size(0)
# We now turn scaling, which is a dict, to tensors in order to use them later
scales_tens = torch.tensor(
[scaling[n] for n in self.task_adapter_names],
device=tok.device,
dtype=tok.dtype,
) # shape (E,)
# 1) similarity — sign-agnostic
sim = torch.abs(tok @ self.prototypes.T) # (t, E)
# 2) top-k + softmax over full E (non-top-k = -inf)
top_v, idx = torch.topk(sim, self.top_k, dim=1)
full_score = tok.new_full((t, E), float("-inf"))
full_score.scatter_(1, idx, top_v)
coeff = torch.softmax(full_score / self.temperature, dim=1) # (t, E)
# 3) stack all A and B weights once
# A_stack: (E, r, in_features), B_stack: (E, out_features, r)
A_stack = torch.stack([lora_A[n].weight for n in self.task_adapter_names], dim=0)
B_stack = torch.stack([lora_B[n].weight for n in self.task_adapter_names], dim=0)
# 4) project tokens into each experts lowrank space:
# z[e] = tok @ A_e.T → shape (t, E, r)
z = torch.einsum("tf, erf -> ter", tok, A_stack)
# 5) lift back each experts output:
# y[e] = z[e] @ B_e.T → shape (t, E, out_features)
y = torch.einsum("ter, eor -> teo", z, B_stack)
# 6) apply per-expert scaling before the weighted sum
# y_scaled[t, e, o] = scales[e] * y[t, e, o]
y = y * scales_tens.view(1, -1, 1)
# 6) weighted sum over experts:
# delta_flat[t,o] = Σ_e coeff[t,e] * y[t,e,o]
delta_flat = torch.einsum("te, teo -> to", coeff, y) # (t, out_features)
# 7) dropout, scale, and reshape
delta = dropout(delta_flat)
out_dim = delta_flat.size(-1)
return delta.view(B, *rest, out_dim)
def check_loaded_lora_compatibility_arrow(model, adapter_names: list[str]):
"""
After loading all adapters into `model`, check they share:
- the same LoRA rank (r)
- identical weight shapes
- identical sets of target_modules
Returns (sorted list of target module names, agreed rank r).
"""
reference = None # {'r':…, 'shapes':(Ashape,Bshape), 'modules':set([...])}
for name in adapter_names:
curr_modules = set()
curr_r = None
curr_shapes = None
for full_name, module in model.named_modules():
if hasattr(module, "lora_A") and name in module.lora_A:
A = module.lora_A[name].weight
B = module.lora_B[name].weight
mod_name = full_name.split(".")[-1]
curr_modules.add(mod_name)
# A has shape (r, in_features); B has shape (out_features, r)
curr_r = A.shape[0]
curr_shapes = (A.shape, B.shape)
if reference is None:
reference = {"r": curr_r, "shapes": curr_shapes, "modules": curr_modules}
else:
if curr_r != reference["r"]:
raise ValueError(f"[{name}] rank mismatch: {curr_r} != {reference['r']}")
if curr_shapes != reference["shapes"]:
raise ValueError(f"[{name}] shape mismatch: {curr_shapes} != {reference['shapes']}")
if curr_modules != reference["modules"]:
raise ValueError(
f"[{name}] target_modules mismatch:\n"
f" this adapter -> {sorted(curr_modules)}\n"
f" reference -> {sorted(reference['modules'])}"
)
agreed_modules = sorted(reference["modules"])
return agreed_modules, int(reference["r"])
def ensure_adapters_target_linear_layers_only(model, adapter_names: list[str]):
"""
Validate that every module holding LoRA weights for any of `adapter_names` is Linear-like: nn.Linear,
bitsandbytes.nn.Linear4bit, nn.Conv1d, or transformers.models.gpt2.modeling_gpt2.Conv1D. If not, raise.
"""
import torch.nn as nn
Linear4bit = None
try:
import bitsandbytes as bnb # type: ignore
Linear4bit = bnb.nn.Linear4bit
except ImportError:
pass
HFConv1D = None
try:
from transformers.models.gpt2.modeling_gpt2 import Conv1D as HFConv1D # type: ignore
except ImportError:
pass
allowed_types = (nn.Linear, nn.Conv1d)
if Linear4bit is not None:
allowed_types = allowed_types + (Linear4bit,)
if HFConv1D is not None:
allowed_types = allowed_types + (HFConv1D,)
offenders = []
for full_name, module in model.named_modules():
if hasattr(module, "lora_A"):
for name in adapter_names:
if name in getattr(module, "lora_A", {}):
base = getattr(module, "base_layer", None) or getattr(module, "original_module", None)
layer_to_check = base if base is not None else module
if not isinstance(layer_to_check, allowed_types):
offenders.append((name, full_name, type(layer_to_check).__name__))
if offenders:
lines = [
"LoRA adapters must only target Linear-like layers "
"(nn.Linear, nn.Conv1d, HF Conv1D, or bitsandbytes.nn.Linear4bit). Found:"
]
for name, full_name, tname in offenders:
lines.append(f" - adapter '{name}' on module '{full_name}' of type {tname}")
raise TypeError("\n".join(lines))
def _resolve_adapter_source(path: str) -> tuple[str, str | None]:
"""
Resolve a user-provided adapter `path` into (model_id, subfolder).
Supports:
- Local path to a folder that contains `adapter_config.json`
- Hub path with subfolder, e.g. "user/repo/ts_expert_0[/more/...]", which becomes:
model_id="user/repo", subfolder="ts_expert_0[/more/...]"
- Plain Hub repo id "user/repo" (no subfolder)
"""
if os.path.isdir(path):
if not os.path.isfile(os.path.join(path, "adapter_config.json")):
raise ValueError(f"Local adapter path '{path}' does not contain 'adapter_config.json'.")
return path, None
parts = path.strip("/").split("/")
if len(parts) >= 2:
model_id = "/".join(parts[:2])
if len(parts) > 2:
subfolder = "/".join(parts[2:])
return model_id, subfolder
return model_id, None
return path, None
def create_arrow_model(
base_model: PreTrainedModel,
task_specific_adapter_paths: list[str],
arrow_config: ArrowConfig,
general_adapter_paths: list[str] | None = None,
**adapter_kwargs: Any,
):
if task_specific_adapter_paths is None or len(task_specific_adapter_paths) == 0:
raise ValueError("`task_specific_adapter_paths` should contain at least one adapter path")
from peft import LoraConfig, PeftModel
model_id0, sub0 = _resolve_adapter_source(task_specific_adapter_paths[0])
initial_ts_expert_name = f"{TASK_ADAPTER_PREFIX}0"
first_kwargs = dict(adapter_kwargs)
if sub0 is not None and "subfolder" not in first_kwargs:
first_kwargs["subfolder"] = sub0
model = PeftModel.from_pretrained(
base_model,
model_id=model_id0,
adapter_name=initial_ts_expert_name,
**first_kwargs,
)
for i in range(1, len(task_specific_adapter_paths)):
ts_expert_name = f"{TASK_ADAPTER_PREFIX}{i}"
mid, sub = _resolve_adapter_source(task_specific_adapter_paths[i])
more_kwargs = dict(adapter_kwargs)
if sub is not None and "subfolder" not in more_kwargs:
more_kwargs["subfolder"] = sub
model.load_adapter(
model_id=mid,
adapter_name=ts_expert_name,
**more_kwargs,
)
arrow_config.task_adapter_names = [f"{TASK_ADAPTER_PREFIX}{i}" for i in range(len(task_specific_adapter_paths))]
if arrow_config.use_gks:
if general_adapter_paths is None or len(general_adapter_paths) == 0:
raise ValueError("You should provide general LoRA paths if you want to use GenKnowSub.")
for i in range(len(general_adapter_paths)):
gen_expert_name = f"{GKS_ADAPTER_PREFIX}{i}"
mid, sub = _resolve_adapter_source(general_adapter_paths[i])
gks_kwargs = dict(adapter_kwargs)
if sub is not None and "subfolder" not in gks_kwargs:
gks_kwargs["subfolder"] = sub
model.load_adapter(
model_id=mid,
adapter_name=gen_expert_name,
**gks_kwargs,
)
arrow_config.gks_adapter_names = [f"{GKS_ADAPTER_PREFIX}{i}" for i in range(len(general_adapter_paths))]
else:
arrow_config.gks_adapter_names = []
target_modules, r = check_loaded_lora_compatibility_arrow(
model, adapter_names=arrow_config.task_adapter_names + arrow_config.gks_adapter_names
)
ensure_adapters_target_linear_layers_only(
model, adapter_names=arrow_config.task_adapter_names + arrow_config.gks_adapter_names
)
router_cfg = LoraConfig(
arrow_config=arrow_config,
target_modules=target_modules,
r=r,
)
model.add_adapter(adapter_name="arrow_router", peft_config=router_cfg)
model.set_adapter("arrow_router")
return model

View File

@ -24,6 +24,7 @@ from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
from peft.utils.integrations import dequantize_bnb_weight
from peft.utils.other import transpose
from .config import ArrowConfig
from .layer import LoraLayer, LoraVariant
@ -44,6 +45,7 @@ if is_bnb_available():
use_rslora: bool = False,
use_alora: bool = False,
use_dora: bool = False,
arrow_config: ArrowConfig = None,
lora_bias: bool = False,
**kwargs,
) -> None:
@ -62,9 +64,17 @@ if is_bnb_available():
use_dora=use_dora,
use_alora=use_alora,
lora_bias=lora_bias,
arrow_config=arrow_config,
)
def resolve_lora_variant(self, *, use_dora: bool, use_alora: bool, **kwargs) -> Optional[LoraVariant]:
def resolve_lora_variant(
self, *, arrow_config: ArrowConfig, use_dora: bool, use_alora: bool, **kwargs
) -> Optional[LoraVariant]:
if arrow_config is not None:
from .variants import ArrowLinearVariant
return ArrowLinearVariant()
if not use_dora and not use_alora:
return None
@ -323,6 +333,7 @@ if is_bnb_4bit_available():
init_lora_weights: bool = True,
use_rslora: bool = False,
use_dora: bool = False,
arrow_config: ArrowConfig = None,
lora_bias: bool = False,
**kwargs,
) -> None:
@ -340,9 +351,17 @@ if is_bnb_4bit_available():
use_rslora=use_rslora,
use_dora=use_dora,
lora_bias=lora_bias,
arrow_config=arrow_config,
)
def resolve_lora_variant(self, *, use_dora: bool, use_alora: bool, **kwargs) -> Optional[LoraVariant]:
def resolve_lora_variant(
self, *, arrow_config: ArrowConfig, use_dora: bool, use_alora: bool, **kwargs
) -> Optional[LoraVariant]:
if arrow_config is not None:
from .variants import ArrowLinearVariant
return ArrowLinearVariant()
if not use_dora and not use_alora:
return None

View File

@ -69,6 +69,56 @@ class LoftQConfig:
loftq_iter: int = field(default=1, metadata={"help": "Alternating iterations for LoftQ"})
@dataclass
class ArrowConfig:
"""
This is the sub-configuration class to store the configuration for Arrow and GenKnowSub algorithm. Arrow is a
routing algorithm to combine the trained LoRA modules to solve new tasks, proposed in
'https://arxiv.org/pdf/2405.11157'. GenKnowSub is a refinement on the trained modules before being combined via
Arrow, introduced in 'https://aclanthology.org/2025.acl-short.54/'
"""
top_k: int = field(
default=3,
metadata={"help": "Number of top LoRA modules to combine in Arrow routing."},
)
router_temperature: float = field(
default=1.0,
metadata={"help": "Softmax temperature for computing Arrow expert coefficients."},
)
use_gks: bool = field(
default=False,
metadata={"help": "Enable GenKnowSub."},
)
task_adapter_names: Optional[list[str]] = field(
default=None,
init=False,
metadata={"help": "list of task-specific LoRA adapter names. It will be set in create_arrow_model()."},
)
gks_adapter_names: Optional[list[str]] = field(
default=None,
init=False,
metadata={
"help": "list of general LoRA adapter names for GenKnowSub. It will be set in create_arrow_model()."
},
)
rng_seed: Optional[int] = field(
default=None,
metadata={"help": "Optional RNG seed for reproducibility. If None, sampling is non-deterministic."},
)
def __post_init__(self):
if self.top_k <= 0:
raise ValueError("top_k cannot be negative.")
if self.router_temperature <= 0:
raise ValueError("router_temperature must be greater than 0.")
@dataclass
class EvaConfig:
"""
@ -610,6 +660,9 @@ class LoraConfig(PeftConfig):
)
},
)
arrow_config: Optional[ArrowConfig] = field(
default=None, metadata={"help": "The necessary config to apply arrow routing on the model."}
)
def to_dict(self):
"""
@ -628,7 +681,6 @@ class LoraConfig(PeftConfig):
self.exclude_modules = (
set(self.exclude_modules) if isinstance(self.exclude_modules, list) else self.exclude_modules
)
if isinstance(self.target_parameters, str):
raise TypeError("`target_parameters` must be a list of strings or None.")

View File

@ -36,7 +36,7 @@ from peft.utils.integrations import (
from peft.utils.other import transpose
from peft.utils.warning import PeftWarning
from .config import LoraConfig
from .config import ArrowConfig, LoraConfig
VARIANT_KWARG_KEYS = ["alora_offsets"]
@ -206,6 +206,7 @@ class LoraLayer(BaseTunerLayer):
use_alora: bool = False,
use_qalora: bool = False,
lora_bias: bool = False,
arrow_config: ArrowConfig = None,
qalora_group_size: int = 32,
**kwargs,
):
@ -225,7 +226,11 @@ class LoraLayer(BaseTunerLayer):
)
lora_variant = self.resolve_lora_variant(
use_dora=use_dora, use_alora=use_alora, use_qalora=use_qalora, qalora_group_size=qalora_group_size
use_dora=use_dora,
use_alora=use_alora,
use_qalora=use_qalora,
qalora_group_size=qalora_group_size,
arrow_config=arrow_config,
)
if lora_variant is not None:
self.lora_variant[adapter_name] = lora_variant
@ -279,6 +284,17 @@ class LoraLayer(BaseTunerLayer):
self.set_adapter(self.active_adapters)
# Check for adapters that were added or removed from the arrow_model.
# The arrow model may be modified after creation by adding new experts
# (pre-trained or trainable) or by removing existing ones. Whenever such
# a change occurs, on_adapter_change() is called to update the set of
# active task-specific experts and, if needed, to handle recomputing prototypes
# and doing general knowledge subtraction (GKS) again.
if hasattr(self, "lora_arrow"):
for adapter in self.lora_variant:
if adapter in self.lora_arrow:
self.lora_arrow[adapter].on_adapter_change(self.lora_A, self.lora_B)
def reset_lora_parameters(self, adapter_name, init_lora_weights):
if init_lora_weights is False:
return
@ -639,6 +655,7 @@ class Linear(nn.Module, LoraLayer):
use_rslora: bool = False,
use_dora: bool = False,
use_alora: bool = False,
arrow_config: ArrowConfig = None,
lora_bias: bool = False,
**kwargs,
) -> None:
@ -657,10 +674,18 @@ class Linear(nn.Module, LoraLayer):
use_dora=use_dora,
use_alora=use_alora,
lora_bias=lora_bias,
arrow_config=arrow_config,
)
self.is_target_conv_1d_layer = is_target_conv_1d_layer
def resolve_lora_variant(self, *, use_dora: bool, use_alora: bool, **kwargs) -> Optional[LoraVariant]:
def resolve_lora_variant(
self, *, arrow_config: ArrowConfig, use_dora: bool, use_alora: bool, **kwargs
) -> Optional[LoraVariant]:
if arrow_config is not None:
from .variants import ArrowLinearVariant
return ArrowLinearVariant()
if not use_dora and not use_alora:
return None
@ -742,6 +767,7 @@ class Linear(nn.Module, LoraLayer):
"""
This method unmerges all merged adapter layers from the base weights.
"""
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
@ -855,6 +881,7 @@ class Embedding(nn.Module, LoraLayer):
init_lora_weights: Union[bool, str] = True,
use_rslora: bool = False,
use_dora: bool = False,
arrow_config: ArrowConfig = None,
lora_bias: bool = False,
**kwargs,
) -> None:
@ -876,6 +903,7 @@ class Embedding(nn.Module, LoraLayer):
use_rslora=use_rslora,
use_dora=use_dora,
lora_bias=lora_bias,
arrow_config=arrow_config,
)
def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
@ -887,7 +915,17 @@ class Embedding(nn.Module, LoraLayer):
return DoraEmbeddingVariant()
def update_layer(
self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora, lora_bias, **kwargs
self,
adapter_name,
r,
lora_alpha,
lora_dropout,
init_lora_weights,
use_rslora,
use_dora,
lora_bias,
arrow_config: ArrowConfig = None,
**kwargs,
):
# collect the kwargs
kwargs = locals().copy()
@ -896,7 +934,7 @@ class Embedding(nn.Module, LoraLayer):
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
lora_variant = self.resolve_lora_variant(use_dora=use_dora)
lora_variant = self.resolve_lora_variant(use_dora=use_dora, arrow_config=arrow_config)
if lora_variant is not None:
self.lora_variant[adapter_name] = lora_variant
@ -1129,6 +1167,7 @@ class _ConvNd(nn.Module, LoraLayer):
init_lora_weights: Union[bool, str] = True,
use_rslora: bool = False,
use_dora: bool = False,
arrow_config: ArrowConfig = None,
lora_bias: bool = False,
**kwargs,
) -> None:
@ -1158,10 +1197,21 @@ class _ConvNd(nn.Module, LoraLayer):
use_rslora=use_rslora,
use_dora=use_dora,
lora_bias=lora_bias,
arrow_config=arrow_config,
)
def update_layer(
self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora, lora_bias, **kwargs
self,
adapter_name,
r,
lora_alpha,
lora_dropout,
init_lora_weights,
use_rslora,
use_dora,
lora_bias,
arrow_config: ArrowConfig = None,
**kwargs,
):
# collect the kwargs
kwargs = locals().copy()
@ -1177,7 +1227,7 @@ class _ConvNd(nn.Module, LoraLayer):
PeftWarning,
)
lora_variant = self.resolve_lora_variant(use_dora=use_dora)
lora_variant = self.resolve_lora_variant(use_dora=use_dora, arrow_config=arrow_config)
if lora_variant is not None:
self.lora_variant[adapter_name] = lora_variant
@ -1840,7 +1890,6 @@ class MultiheadAttention(nn.Module, LoraLayer):
class _LoraParameterProxy(nn.Module):
"""This proxies an `nn.Parameter` that is targeted with LoRA.
Intended to be used in conjunction with `nn.utils.parametrize`, see `ParamWrapper`.
"""
@ -1865,13 +1914,12 @@ def _register_parameter_or_buffer(module, name, X):
class ParamWrapper(nn.Module, LoraLayer):
"""A LoRA wrapper for `nn.Parameter`. This layer is dispatched if users target a parameter directly with
`lora_config.target_parameters`
Note:
- When accessing the wrapped nn.Parameter directly, e.g. via `module.weight`, the LoRA weights are *not* applied.
- It is currently not implemented to target multiple parameters on the same module. To achieve this, it is
currently required to create a separate LoRA adapter (with another adapter name) and activate both at the same
time.
Note:
- When accessing the wrapped nn.Parameter directly, e.g. via `module.weight`, the LoRA weights are *not*
applied.
- It is currently not implemented to target multiple parameters on the same module. To achieve this, it is
currently required to create a separate LoRA adapter (with another adapter name) and activate both at the
same time.
"""
def __init__(

View File

@ -221,10 +221,12 @@ class LoraModel(BaseTuner):
"qalora_group_size": lora_config.qalora_group_size,
"ephemeral_gpu_offload": lora_config.runtime_config.ephemeral_gpu_offload,
"lora_bias": lora_config.lora_bias,
"arrow_config": lora_config.arrow_config,
"loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False),
"loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False),
"parameter_name": parameter_name,
}
# for torchao merging, we need the get_apply_tensor_subclass from the quantization config
try:
kwargs["get_apply_tensor_subclass"] = operator.attrgetter(
@ -254,6 +256,7 @@ class LoraModel(BaseTuner):
use_rslora=lora_config.use_rslora,
use_dora=lora_config.use_dora,
lora_bias=lora_config.lora_bias,
arrow_config=lora_config.arrow_config,
)
else:
if isinstance(target, ParamWrapper) and (parameter_name == target.parameter_name):

View File

@ -23,11 +23,112 @@ from torch import nn
from peft.utils.other import transpose
from .arrow import ArrowLoraLinearLayer
from .config import PeftConfig
from .dora import DoraConv1dLayer, DoraConv2dLayer, DoraConv3dLayer, DoraEmbeddingLayer, DoraLinearLayer
from .layer import Conv1d, Conv2d, Conv3d, Embedding, Linear, LoraVariant, _ConvNd
class ArrowLinearVariant(LoraVariant):
@staticmethod
def init(module: Linear, adapter_name: str, **kwargs):
"""
Initialise the ArrowLoraLinearLayer() inside lora_arrow. lora_arrow is nn.ModuleDict(), serving as a container
for ArrowLoraLinearLayer(). A layer of the base model with LoRA adapter loaded on it will be like:
----------------------------------------------------
(qkv_proj): lora.Linear4bit or lora.Linear(
(base_layer): Linear4bit or Linear (lora_dropout): ModuleDict( ... ) (lora_A): ModuleDict( ... )
(lora_B): ModuleDict( ... ) (lora_embedding_A): ParameterDict( ... ) (lora_embedding_B): ParameterDict(
... ) (lora_magnitude_vector): ModuleDict( ... ) (lora_arrow): ModuleDict(
(arrow_router): ArrowLoraLinearLayer() )
)
----------------------------------------------------
Args:
module (Linear): LoRA Layer of the model, containing base_layer, lora_A, lora_B, etc.
adapter_name (str): name of the adapter that will be put in lora_arrow.
The adapter_name is "arrow_router" by default, set in create_arrow_model() in ./arrow.py
"""
# Checking for arrow necessary config
arrow_config = kwargs.get("arrow_config")
if arrow_config is None:
raise ValueError("ArrowLinearVariant.init() did not receive an arrow_config")
# 1-a) build the ArrowLoRALayer
arrow_layer = ArrowLoraLinearLayer(
in_features=module.in_features,
arrow_config=arrow_config,
).to(module.weight.device)
# 1-b) register a container if it doesnt exist yet
if not hasattr(module, "lora_arrow"):
module.lora_arrow = nn.ModuleDict()
module.lora_arrow[adapter_name] = arrow_layer
@staticmethod
def forward(
module: Linear,
*,
active_adapter: str,
x: torch.Tensor,
result: torch.Tensor,
**kwargs,
) -> torch.Tensor:
"""
Parameters mirror those in PEFTs `LoraVariant.forward`. Called every time the host Linear does a fwd pass.
build_prototypes() and gen_know_sub() should run only once before routing. Both are implemented in
ArrowLoraLinearLayer (see ./arrow.py). They are lazily invoked in the forward pass below. Attributes of
ArrowLoraLinearLayer() class ensure they execute only a single time.
Args:
module (Linear): LoRA Layer of the model
active_adapter (str): name of the arrow route, which should be active to perform arrow.
x (torch.Tensor): input to the layer
result (torch.Tensor): output of the base layer.
Return value:
output of the base model + delta weight computed by arrow layer.
"""
arrow = module.lora_arrow[active_adapter] # ArrowLoraLinearLayer
# Apply GenKnowSub the 1st time if applcable. By calling arrow/on_adapter_change(),
# gen_know_sub() is redone for newly added adapters after arrow.create_arrow_model().
arrow.gen_know_sub(module.lora_A, module.lora_B)
# lazily build prototypes the 1st time after GenKnowSub. By calling arrow/on_adapter_change(),
# build_prototypes() is redone for newly added adapters after arrow.create_arrow_model().
arrow.build_prototypes(module.lora_A, module.lora_B)
# A forward path of ArrowLoraLinearLayer is called so routing performs.
# Accept and ignore extra variant kwargs (e.g., 'alora_offsets') for compatibility
delta = arrow(
x,
lora_A=module.lora_A,
lora_B=module.lora_B,
dropout=module.lora_dropout[active_adapter],
scaling=module.scaling,
)
return result + delta
"""
Since Arrow is a Mixture-of-Experts (MoE) approach, merging adapters is not meaningful or even possible: for each
token, the top-k LoRA experts are dynamically selected and routed. Because of this per-token routing, there is no
single set of weights that can represent a merged adapter.
"""
@staticmethod
def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
raise RuntimeError("Cannot merge an active Arrow router adapter. Remove it first.")
@staticmethod
def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None:
raise RuntimeError("Cannot merge an active Arrow router adapter. Remove it first.")
@staticmethod
def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
raise RuntimeError("Cannot unmerge an active Arrow router adapter. Remove it first.")
class DoraLinearVariant(LoraVariant):
@staticmethod
def init(module: Linear, adapter_name: str, **kwargs: Any) -> None:

509
tests/test_arrow.py Normal file
View File

@ -0,0 +1,509 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
from pathlib import Path
from unittest.mock import patch
import pytest
import torch
from transformers import AutoModelForCausalLM, AutoModelForImageClassification
from peft import LoraConfig, get_peft_model
from peft.tuners.lora import ArrowConfig, create_arrow_model
from peft.tuners.lora.arrow import _resolve_adapter_source
from tests.testing_utils import hub_online_once
# ─── Fixtures ──────────────────────────────────────────────────────────
@pytest.fixture(scope="module")
def workdir(tmp_path_factory):
"""
Create a temp directory and chdir into it for the duration of the module.
"""
wd = tmp_path_factory.mktemp("arrow_workdir")
old_cwd = os.getcwd()
os.chdir(wd)
yield Path(wd)
os.chdir(old_cwd)
# (pytest will auto-delete wd)
def _create_and_save_adapter(out_dir: Path, rank: int = 4):
"""Helper: build a LoRA adapter around `model` and save into `out_dir`."""
# fan_in_fan_out is set to True because of GPT2 model that we use to avoid warning
cfg = LoraConfig(r=rank, target_modules=["c_attn"], fan_in_fan_out=True, init_lora_weights=False)
model_id = "hf-internal-testing/tiny-random-gpt2"
with hub_online_once(model_id):
model = AutoModelForCausalLM.from_pretrained(model_id)
peft_model = get_peft_model(model, cfg)
peft_model.save_pretrained(out_dir)
@pytest.fixture(scope="module")
def ts_adapters(workdir: Path):
"""
Build 3 task-specific adapters and return their absolute paths
"""
abs_paths = []
for i in range(3):
sub = f"{workdir}/ts{i}"
_create_and_save_adapter(sub)
abs_paths.append(sub)
return abs_paths
@pytest.fixture(scope="module")
def gen_adapter(workdir: Path):
"""Build 1 general-knowledge adapter and return its absolute path list."""
sub = f"{workdir}/gen0"
_create_and_save_adapter(sub)
return [sub] # list because create_arrow_model expects list
class TestArrowRouting:
def test_incompatible_rank_raises(self, workdir: Path):
"""
Adding adapters with different ranks must raise a ValueError.
"""
# Create two adapters with different ranks targeting the same modules
sub_r4 = workdir / "rank4"
sub_r8 = workdir / "rank8"
_create_and_save_adapter(sub_r4, rank=4)
_create_and_save_adapter(sub_r8, rank=8)
model_id = "hf-internal-testing/tiny-random-gpt2"
with hub_online_once(model_id):
base = AutoModelForCausalLM.from_pretrained(model_id)
# Expect create_arrow_model to raise due to rank mismatch
with pytest.raises(ValueError, match=r"rank mismatch"):
_ = create_arrow_model(
base_model=base,
task_specific_adapter_paths=[str(sub_r4), str(sub_r8)],
arrow_config=ArrowConfig(top_k=1),
)
def test_arrow_differs_with_extra_expert(self, ts_adapters):
"""
Arrow with 2 experts vs Arrow with 3 experts must produce different logits.
"""
# Arrow over first 2 experts
model_id = "hf-internal-testing/tiny-random-gpt2"
with hub_online_once(model_id):
base_model_1 = AutoModelForCausalLM.from_pretrained(model_id)
base_model_2 = copy.deepcopy(base_model_1)
cfg_small = ArrowConfig(top_k=2)
m_small = create_arrow_model(
base_model=base_model_1,
task_specific_adapter_paths=ts_adapters[:2],
arrow_config=cfg_small,
).eval()
# Arrow over all 3 experts
cfg_big = ArrowConfig(top_k=2)
m_big = create_arrow_model(
base_model=base_model_2,
task_specific_adapter_paths=ts_adapters,
arrow_config=cfg_big,
).eval()
x = torch.ones(1, 4, dtype=torch.long)
assert not torch.allclose(m_small(x).logits, m_big(x).logits)
def test_arrow_gks_with_load_adapter_later_with_forward(self, ts_adapters, gen_adapter):
"""
Loading the last expert after creating the arrow model should produce the same result as loading all the
experts at once in create_arrow_model(), when forward path is called before adding the new adapter.
"""
# Arrow over all three experts
model_id = "hf-internal-testing/tiny-random-gpt2"
with hub_online_once(model_id):
base_model_1 = AutoModelForCausalLM.from_pretrained(model_id)
base_model_2 = copy.deepcopy(base_model_1)
cfg_big = ArrowConfig(top_k=2, use_gks=True, rng_seed=42)
m_big = create_arrow_model(
base_model=base_model_1,
task_specific_adapter_paths=ts_adapters,
general_adapter_paths=gen_adapter,
arrow_config=cfg_big,
).eval()
# Arrow over all 2 experts + loading the third expert later
cfg_small_later_big = ArrowConfig(top_k=2, use_gks=True, rng_seed=42)
m_small_later_big = create_arrow_model(
base_model=base_model_2,
task_specific_adapter_paths=ts_adapters[:2],
general_adapter_paths=gen_adapter,
arrow_config=cfg_small_later_big,
)
# Ensuring that the prototypes and gks are done one time by running a forward path
x = torch.ones(1, 4, dtype=torch.long)
m_small_later_big(x)
# Now loading the third expert
m_small_later_big.load_adapter(
model_id=ts_adapters[-1],
adapter_name="new_added_ts_expert",
)
# Activating the new adapter and run forward path on it
m_small_later_big.set_adapter("new_added_ts_expert")
x = torch.ones(3, 5, dtype=torch.long)
m_small_later_big(x)
# Now we switch back to the arrow_router
m_small_later_big.set_adapter("arrow_router")
m_small_later_big.eval()
x = torch.ones(1, 4, dtype=torch.long)
assert torch.allclose(m_big(x).logits, m_small_later_big(x).logits)
def test_arrow_with_load_adapter_later_with_forward_activate_new(self, ts_adapters, gen_adapter):
"""
Loading the last expert after creating the arrow model and activate it should produce different result compared
to the case where arrow_router is activate, and the model's using arrow.
"""
# Arrow over all three experts
model_id = "hf-internal-testing/tiny-random-gpt2"
with hub_online_once(model_id):
base_model_1 = AutoModelForCausalLM.from_pretrained(model_id)
base_model_2 = copy.deepcopy(base_model_1)
cfg_big = ArrowConfig(top_k=2, use_gks=True, rng_seed=42)
m_big = create_arrow_model(
base_model=base_model_1,
task_specific_adapter_paths=ts_adapters,
general_adapter_paths=gen_adapter,
arrow_config=cfg_big,
).eval()
# Arrow over all 2 experts + loading the third expert later
cfg_small_later_big = ArrowConfig(top_k=2, use_gks=True, rng_seed=42)
m_small_later_big = create_arrow_model(
base_model=base_model_2,
task_specific_adapter_paths=ts_adapters[:2],
general_adapter_paths=gen_adapter,
arrow_config=cfg_small_later_big,
)
# Ensuring that the prototypes and gks are done one time by running a forward path
x = torch.ones(1, 4, dtype=torch.long)
m_small_later_big(x)
# Now loading the third expert
m_small_later_big.load_adapter(
model_id=ts_adapters[-1],
adapter_name="new_added_ts_expert",
)
# The new adapter is activated
m_small_later_big.set_adapter("new_added_ts_expert")
m_small_later_big.eval()
x = torch.ones(1, 4, dtype=torch.long)
assert not torch.allclose(m_big(x).logits, m_small_later_big(x).logits)
def test_arrow_gks_with_load_adapter_later_without_forward(self, ts_adapters, gen_adapter):
"""
Loading the last expert after creating the arrow model should produce the same result as loading all the
experts at once in create_arrow_model()
"""
# Arrow over all three experts
model_id = "hf-internal-testing/tiny-random-gpt2"
with hub_online_once(model_id):
base_model_1 = AutoModelForCausalLM.from_pretrained(model_id)
base_model_2 = copy.deepcopy(base_model_1)
cfg_big = ArrowConfig(top_k=2, use_gks=True, rng_seed=42)
m_big = create_arrow_model(
base_model=base_model_1,
task_specific_adapter_paths=ts_adapters,
general_adapter_paths=gen_adapter,
arrow_config=cfg_big,
).eval()
# Arrow over all 2 experts + loading the third expert later
cfg_small_later_big = ArrowConfig(top_k=2, use_gks=True, rng_seed=42)
m_small_later_big = create_arrow_model(
base_model=base_model_2,
task_specific_adapter_paths=ts_adapters[:2],
general_adapter_paths=gen_adapter,
arrow_config=cfg_small_later_big,
)
# Now loading the third expert
m_small_later_big.load_adapter(
model_id=ts_adapters[-1],
adapter_name="new_added_ts_expert",
)
m_small_later_big.eval()
x = torch.ones(1, 4, dtype=torch.long)
assert torch.allclose(m_big(x).logits, m_small_later_big(x).logits)
def test_genknowsub_changes_output(self, ts_adapters, gen_adapter):
"""
Arrow+GenKnowSub vs plain Arrow must change logits.
"""
# Plain Arrow
model_id = "hf-internal-testing/tiny-random-gpt2"
with hub_online_once(model_id):
base_model_1 = AutoModelForCausalLM.from_pretrained(model_id)
base_model_2 = copy.deepcopy(base_model_1)
cfg_plain = ArrowConfig(top_k=2)
m_plain = create_arrow_model(
base_model=base_model_1,
task_specific_adapter_paths=ts_adapters,
arrow_config=cfg_plain,
).eval()
# Arrow + GenKnowSub
cfg_gks = ArrowConfig(top_k=2, use_gks=True)
m_gks = create_arrow_model(
base_model=base_model_2,
task_specific_adapter_paths=ts_adapters,
general_adapter_paths=gen_adapter,
arrow_config=cfg_gks,
).eval()
x = torch.ones(1, 4, dtype=torch.long)
assert not torch.allclose(m_plain(x).logits, m_gks(x).logits)
def test_merging_adapters_raise_error_in_arrow(self, ts_adapters):
"""
Merging/unmerging is not allowed while an ArrowLinearLayer is loaded on the model and active.
"""
# Arrow over first 2 experts
model_id = "hf-internal-testing/tiny-random-gpt2"
with hub_online_once(model_id):
base_model = AutoModelForCausalLM.from_pretrained(model_id)
cfg_small = ArrowConfig(top_k=2)
m_small = create_arrow_model(
base_model=base_model,
task_specific_adapter_paths=ts_adapters[:2],
arrow_config=cfg_small,
).eval()
with pytest.raises(RuntimeError, match=r"Cannot merge an active Arrow router adapter"):
m_small.merge_and_unload()
def test_conv2d_targets_raise_typeerror_in_arrow(self, workdir):
"""
Adapters applied to Conv2d must be rejected by create_arrow_model() which enforces Linear/Linear4bit-only
targets.
"""
model_id = "hf-internal-testing/tiny-random-ResNetForImageClassification"
with hub_online_once(model_id):
base = AutoModelForImageClassification.from_pretrained(model_id)
# Build a LoRA adapter targeting a Conv2d
cfg = LoraConfig(r=4, target_modules=["convolution"], init_lora_weights=False)
peft_model = get_peft_model(copy.deepcopy(base), cfg)
conv_dir = workdir / "cv0"
peft_model.save_pretrained(conv_dir)
# Expect create_arrow_model to raise TypeError
with pytest.raises(TypeError, match=r"LoRA adapters must only target Linear"):
_ = create_arrow_model(
base_model=base,
task_specific_adapter_paths=[str(conv_dir)],
arrow_config=ArrowConfig(top_k=1),
)
def test_arrow_forward_float16_no_autocast_with_merging(self, ts_adapters):
"""
Run Arrow in float16 with autocast disabled; forward should work, while merge/unmerge operations must raise for
Arrow models.
"""
import platform
try:
_ = torch.zeros(1, dtype=torch.float16)
except Exception:
pytest.skip(reason="Test requires float16 support")
if platform.system() == "Darwin":
pytest.skip(reason="MacOS does not support multiple ops in float16")
model_id = "hf-internal-testing/tiny-random-gpt2"
# Create base in fp16 (no manual assignment to .dtype)
with hub_online_once(model_id):
base = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16)
cfg = ArrowConfig(top_k=2)
# Build Arrow model and disable adapter dtype autocast
model = create_arrow_model(
base_model=base,
task_specific_adapter_paths=ts_adapters,
arrow_config=cfg,
autocast_adapter_dtype=False,
torch_dtype=torch.float16,
).eval()
X = {
"input_ids": torch.ones(1, 4, dtype=torch.long),
"attention_mask": torch.ones(1, 4, dtype=torch.long),
}
# Forward should work in fp16
_ = model(**X)
# Merge must fail on Arrow models
with pytest.raises(RuntimeError, match=r"Cannot merge an active Arrow router adapter"):
model.merge_adapter(safe_merge=False)
with pytest.raises(RuntimeError, match=r"Cannot merge an active Arrow router adapter"):
_ = model.merge_and_unload()
def test_prototypes_not_recomputed_on_repeated_forward(self, ts_adapters):
"""
Repeated calls to forward should not recompute prototypes. We verify by spying on
ArrowLoraLinearLayer.top_right_singular_vec_from_BA(), which is only called when prototypes are (re)built.
"""
model_id = "hf-internal-testing/tiny-random-gpt2"
with hub_online_once(model_id):
base = AutoModelForCausalLM.from_pretrained(model_id)
cfg = ArrowConfig(top_k=2)
model = create_arrow_model(
base_model=base,
task_specific_adapter_paths=ts_adapters,
arrow_config=cfg,
).eval()
# Find one Arrow layer instance on the model
arrow_layer = None
for _, module in model.named_modules():
if hasattr(module, "lora_arrow") and "arrow_router" in module.lora_arrow:
arrow_layer = module.lora_arrow["arrow_router"]
break
assert arrow_layer is not None, "Arrow router layer not found on model"
x = torch.ones(1, 4, dtype=torch.long)
# Spy on the internal proto computation; should run once (E calls for E experts)
with patch.object(
arrow_layer,
"top_right_singular_vec_from_BA",
wraps=arrow_layer.top_right_singular_vec_from_BA,
) as spy:
_ = model(x)
first_calls = spy.call_count
assert first_calls == len(arrow_layer.task_adapter_names)
# Call forward again; prototypes should be cached, so no extra calls
_ = model(x)
assert spy.call_count == first_calls
def test_training_updates_when_task_adapter_active(ts_adapters):
"""
Ensure a simple training step works: compute a dummy loss, backward, and take an optimizer step. Verify that
task-adapter parameters update.
"""
model_id = "hf-internal-testing/tiny-random-gpt2"
with hub_online_once(model_id):
base = AutoModelForCausalLM.from_pretrained(model_id)
# Build Arrow model over two experts
cfg = ArrowConfig(top_k=2)
model = create_arrow_model(
base_model=base,
task_specific_adapter_paths=ts_adapters[:2],
arrow_config=cfg,
)
model.train()
# Switch to a specific task adapter for training (vanilla LoRA)
model.set_adapter("task_0")
# Choose a representative parameter to check updates (task_0 A weight)
rep_name = None
for n, _ in model.named_parameters():
if ".lora_A.task_0.weight" in n:
rep_name = n
break
assert rep_name is not None, "task_0 LoRA A weight not found"
rep_param = dict(model.named_parameters())[rep_name]
before = rep_param.detach().clone()
# Optimizer over trainable params (task_0 now active and trainable)
opt = torch.optim.SGD([p for p in model.parameters() if p.requires_grad], lr=1e-2)
# Dummy batch
vocab = model.config.vocab_size
input_ids = torch.randint(0, vocab, (2, 8))
attention_mask = torch.ones_like(input_ids)
# Compute loss and update
opt.zero_grad()
out = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
assert hasattr(out, "loss") and out.loss is not None
out.loss.backward()
opt.step()
after = rep_param.detach().clone()
assert not torch.allclose(before, after), "Active task adapter parameters did not update after optimizer step"
@pytest.mark.parametrize(
"case",
[
"local_root",
"local_nested",
"hub_repo",
"hub_with_sub",
],
)
def test_resolve_adapter_source_variants(tmp_path: Path, case: str):
"""
Ensure `_resolve_adapter_source` correctly handles:
- Local dir (containing adapter_config.json)
- Local nested subfolder
- Hub repo id "user/repo"
- Hub repo with subfolder "user/repo/sub/folder"
"""
if case == "local_root":
d = tmp_path / "adapter_local_root"
d.mkdir(parents=True, exist_ok=True)
(d / "adapter_config.json").write_text("{}")
model_id, sub = _resolve_adapter_source(str(d))
assert model_id == str(d)
assert sub is None
elif case == "local_nested":
d = tmp_path / "repo_like" / "sub" / "folder"
d.mkdir(parents=True, exist_ok=True)
(d / "adapter_config.json").write_text("{}")
model_id, sub = _resolve_adapter_source(str(d))
assert model_id == str(d)
assert sub is None
elif case == "hub_repo":
model_id, sub = _resolve_adapter_source("user/repo")
assert model_id == "user/repo"
assert sub is None
elif case == "hub_with_sub":
model_id, sub = _resolve_adapter_source("user/repo/sub/folder")
assert model_id == "user/repo"
assert sub == "sub/folder"
else:
raise AssertionError(f"unknown case: {case}")

View File

@ -20,6 +20,7 @@ import unittest
from collections import Counter, defaultdict
from copy import deepcopy
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Union
import numpy as np
@ -57,6 +58,7 @@ from transformers.pytorch_utils import Conv1D
from peft import (
AdaLoraConfig,
ArrowConfig,
EvaConfig,
LoftQConfig,
LoraConfig,
@ -67,6 +69,7 @@ from peft import (
RoadConfig,
TaskType,
VeraConfig,
create_arrow_model,
get_peft_model,
get_peft_model_state_dict,
initialize_lora_eva_weights,
@ -82,6 +85,7 @@ from peft.utils import SAFETENSORS_WEIGHTS_NAME, infer_device
from peft.utils.hotswap import hotswap_adapter, prepare_model_for_compiled_hotswap
from peft.utils.loftq_utils import NFQuantizer
from peft.utils.other import fsdp_auto_wrap_policy
from tests.testing_utils import hub_online_once
from .testing_utils import (
device_count,
@ -5271,3 +5275,88 @@ class TestHotSwapping:
inductor_config_ctx = torch._inductor.config.patch("triton.cudagraph_support_input_mutation", False)
with dynamo_config_ctx, inductor_config_ctx:
self.check_hotswap_diffusion(ranks=ranks, alpha_scalings=ranks, target_modules=target_modules)
# Test: 4-bit load + Arrow + generate
class TestArrowQuantized:
@pytest.fixture(scope="class")
def workdir(self, tmp_path_factory):
"""Create and return a temp directory path for this class (no chdir)."""
wd = tmp_path_factory.mktemp("arrow_workdir")
return Path(wd)
def _create_and_save_adapter_opt(self, out_dir: Path, rank: int = 4):
"""
Build a randomly initialized LoRA adapter for OPT-125M and save into `out_dir`. We construct a model from
CONFIG (no pretrained weights) to avoid slow downloads here.
"""
model_id = "facebook/opt-125m"
# Target all linear layers so the adapter matches whatever we later quantize/load.
lora_cfg = LoraConfig(
r=rank,
target_modules="all-linear",
task_type="CAUSAL_LM",
init_lora_weights=False,
)
# Load the adapter on the model and save it
with hub_online_once(model_id):
model = AutoModelForCausalLM.from_pretrained(model_id)
peft_model = get_peft_model(model, lora_cfg)
peft_model.save_pretrained(out_dir)
@pytest.fixture(scope="class")
def ts_adapters_opt(self, workdir: Path):
"""
Build 3 locally-saved task-specific adapters for OPT-125M and return their absolute paths.
"""
paths = []
for i in range(3):
sub = workdir / f"ts_expert_{i}"
self._create_and_save_adapter_opt(sub)
paths.append(str(sub))
return paths
@require_bitsandbytes
@pytest.mark.single_gpu_tests
def test_arrow_4bit_opt125m_load_and_generate_with_local_adapters(self, ts_adapters_opt):
# Skip if CUDA or bitsandbytes isnt available
if not torch.cuda.is_available():
pytest.skip("CUDA required for 4-bit bitsandbytes test.")
model_id = "facebook/opt-125m"
# Quantization config (nf4, bf16 compute)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=False,
)
with hub_online_once(model_id):
# Load quantized base model
base_model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=bnb_config,
)
with hub_online_once(model_id + "tokenizer"):
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
# Build Arrow model from the locally created adapters
arrow_cfg = ArrowConfig(top_k=2, router_temperature=1.0, rng_seed=42)
model = create_arrow_model(
base_model=base_model,
task_specific_adapter_paths=ts_adapters_opt, # local dirs (each has adapter_config.json)
arrow_config=arrow_cfg,
).eval()
# Quick generate smoke test
inputs = tok("Hello world", return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
out = model.generate(**inputs, max_new_tokens=8)
assert out is not None
assert out.shape[0] == 1 # batch size 1