mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 15:33:48 +08:00
Add Arrow + GenKnowSub to LoRA (#2644)
This PR adds support for Arrow, a modular routing mechanism for LoRA experts introduced here, as well as the refinement method GenKnowSub, proposed in our ACL 2025 Main Conference paper. GenKnowSub enhances Arrow by subtracting a general-domain LoRA from task-specific ones prior to routing, leading to improved generalisation and modularity.
This commit is contained in:
committed by
GitHub
parent
ed5c6eaa1a
commit
42db980676
@ -687,3 +687,148 @@ Using this feature has some drawbacks, namely:
|
||||
- Increase the batch size.
|
||||
- Try to avoid having a large number of different adapters in the same batch, prefer homogeneous batches. This can be achieved by buffering samples with the same adapter and only perform inference with a small handful of different adapters.
|
||||
- Take a look at alternative implementations such as [LoRAX](https://github.com/predibase/lorax), [punica](https://github.com/punica-ai/punica), or [S-LoRA](https://github.com/S-LoRA/S-LoRA), which are specialized to work with a large number of different adapters.
|
||||
|
||||
## Composing and Reusing LoRA Adapters
|
||||
### Arrow
|
||||
[Arrow](https://huggingface.co/papers/2405.11157) is a modular routing algorithm designed to combine multiple pre-trained task-specific LoRA adapters to solve a given task. Rather than merging all adapters naively, Arrow introduces a **gradient-free, token-wise mixture-of-experts (MoE) routing mechanism**. At inference time, it first computes a _prototype_ for each LoRA by extracting the top right singular vector from its SVD decomposition. Each token representation is then compared to these prototypes via cosine similarity to obtain routing coefficients. Tokens are assigned to the top-k most relevant LoRA adapters, with the coefficients normalized through softmax, and their outputs linearly combined. This allows effective reuse of existing LoRA modules for new tasks and leads to stronger zero-shot generalization.
|
||||
|
||||
In PEFT, Arrow is enabled through ```ArrowConfig``` and ```create_arrow_model```. You can also configure parameters such as ```top_k``` (the number of LoRA adapters combined per token), ```router_temperature``` (the softmax temperature applied to the routing coefficients), and ```rng_seed``` (for reproducibility).
|
||||
|
||||
```py
|
||||
from peft import create_arrow_model, ArrowConfig
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
# Loading the model
|
||||
base_model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
|
||||
|
||||
# Creating the Arrow config
|
||||
arrow_config = ArrowConfig(
|
||||
top_k=3,
|
||||
router_temperature=1.0,
|
||||
rng_seed=42,
|
||||
)
|
||||
|
||||
# The LoRA adapters below were trained on a clustered FLAN dataset.
|
||||
# Task clustering was performed using the Model-Based Clustering (MBC) method,
|
||||
# as described in the Arrow paper.
|
||||
# While one could train a separate LoRA for each task and let Arrow route tokens among them,
|
||||
# training LoRAs on clusters of tasks instead provides an indirect optimization for
|
||||
# transfer across the multi-task dataset.
|
||||
task_specific_adapter_paths = [
|
||||
f"TahaBa/phi3-mini-clustered-flan/ts_expert_{i}" for i in range(10)
|
||||
]
|
||||
|
||||
# Creating the Arrow model
|
||||
model = create_arrow_model(
|
||||
base_model=base_model,
|
||||
task_specific_adapter_paths=task_specific_adapter_paths,
|
||||
arrow_config=arrow_config,
|
||||
)
|
||||
|
||||
# Now the forward path could be called on this model, like a normal PeftModel.
|
||||
```
|
||||
|
||||
Furthermore, you can add or remove adapters after calling ```create_arrow_model```—for example, to fine-tune a new adapter or discard an unnecessary one. Once the adapters are in place, you can activate the ```"arrow_router"``` for inference to use Arrow. Note that if you add a new LoRA adapter after ```create_arrow_model``` and want to fine-tune it, you must explicitly set the new adapter as active, since ```"arrow_router"``` is activated by default in ```create_arrow_model```.
|
||||
|
||||
```py
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
|
||||
# Adding a new adapter and activating it
|
||||
model.add_adapter(adapter_name='new_adapter')
|
||||
model.set_adapter('new_adapter')
|
||||
|
||||
# Now the model could be trained along the `new_adapter`.
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=SFTConfig(...),
|
||||
...
|
||||
)
|
||||
|
||||
# Once the training is done, you can activate `arrow_router` and use it in inference
|
||||
model.set_adapter('arrow_router') # Model is ready to be used at inference time now
|
||||
```
|
||||
|
||||
### GenKnowSub
|
||||
[GenKnowSub](https://aclanthology.org/2025.acl-short.54/) augments Arrow by purifying task-specific LoRA adapters before routing. The key idea is to subtract general knowledge encoded in LoRA space—based on the [forgetting-via-negation principle](https://huggingface.co/papers/2212.04089)—so that task adapters become more isolated and focused on task-relevant signals. Concretely, GenKnowSub estimates a low-dimensional “general” subspace from a set of general (non task-specific) LoRA adapters and removes this component from each task adapter’s LoRA update prior to Arrow’s token-wise routing. This typically improves compositionality and reduces interference when combining many task adapters.
|
||||
|
||||
In PEFT, enable GenKnowSub by setting ```use_gks=True``` in ArrowConfig, and providing ```general_adapter_paths``` in ```create_arrow_model```:
|
||||
|
||||
```py
|
||||
from peft import create_arrow_model, ArrowConfig
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
# Loading the model
|
||||
base_model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
|
||||
|
||||
# Creating the Arrow config
|
||||
arrow_config = ArrowConfig(
|
||||
top_k=3,
|
||||
router_temperature=1.0,
|
||||
use_gks=True,
|
||||
rng_seed=42,
|
||||
)
|
||||
|
||||
# Path to task-specific, trained on flan clustered dataset (as we explained before.)
|
||||
task_specific_adapter_paths = [
|
||||
f"TahaBa/phi3-mini-clustered-flan/ts_expert_{i}" for i in range(10)
|
||||
]
|
||||
# These general adapters are trained on English, German, and French Wikipedia dataset,
|
||||
# with causal language modelling objective, each pair like: (507 token tsentence, 5 token completion), and the loss computed on the completion
|
||||
general_adapter_paths = [
|
||||
"TahaBa/phi3-mini-general-adapters/cluster0_batch16_prop1.0_langen/checkpoint-17",
|
||||
"TahaBa/phi3-mini-general-adapters/cluster0_batch16_prop1.0_langfr/checkpoint-35",
|
||||
"TahaBa/phi3-mini-general-adapters/cluster0_batch16_prop1.0_langger/checkpoint-17"
|
||||
]
|
||||
|
||||
# Creating the Arrow model
|
||||
model = create_arrow_model(
|
||||
base_model=base_model,
|
||||
task_specific_adapter_paths=task_specific_adapter_paths,
|
||||
general_adapter_paths=general_adapter_paths,
|
||||
arrow_config=arrow_config,
|
||||
)
|
||||
|
||||
# Now the forward path could be called on this model, like a normal PeftModel.
|
||||
```
|
||||
To encode general knowledge, GenKnowSub subtracts the average of the provided general adapters from each task-specific adapter once, before routing begins. Furthermore, the ability to add or remove adapters after calling ```create_arrow_model``` (as described in the Arrow section) is still supported in this case.
|
||||
|
||||
<Tip>
|
||||
|
||||
**Things to keep in mind when using Arrow + GenKnowSub:**
|
||||
|
||||
- All LoRA adapters (task-specific and general) must share the same ```rank``` and ```target_modules```.
|
||||
|
||||
- Any inconsistency in these settings will raise an error in ```create_arrow_model```.
|
||||
|
||||
- Having different scaling factors (```lora_alpha```) across task adapters is supported — Arrow handles them automatically.
|
||||
|
||||
- Merging the ```"arrow_router"``` is not supported, due to its dynamic routing behavior.
|
||||
|
||||
- In create_arrow_model, task adapters are loaded as ```task_i``` and general adapters as ```gks_j``` (where ```i``` and ```j``` are indices). The function ensures consistency of ```target_modules```, ```rank```, and whether adapters are applied to ```Linear``` or ```Linear4bit``` layers. It then adds the ```"arrow_router"``` module and activates it. Any customization of this process requires overriding ```create_arrow_model```.
|
||||
|
||||
- This implementation is compatible with 4-bit quantization (via bitsandbytes):
|
||||
|
||||
```py
|
||||
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
|
||||
import torch
|
||||
|
||||
# Quantisation config
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_use_double_quant=False,
|
||||
)
|
||||
|
||||
# Loading the model
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
"microsoft/Phi-3-mini-4k-instruct",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=bnb_config,
|
||||
)
|
||||
|
||||
# Now call create_arrow_model() as we explained before.
|
||||
```
|
||||
|
||||
</Tip>
|
@ -32,6 +32,10 @@ The abstract from the paper is:
|
||||
|
||||
## Utility
|
||||
|
||||
### ArrowConfig
|
||||
|
||||
[[autodoc]] tuners.lora.config.ArrowConfig
|
||||
|
||||
### LoftQ
|
||||
|
||||
[[autodoc]] utils.loftq_utils.replace_lora_weights_loftq
|
||||
|
375
examples/arrow_multitask/arrow_phi3_mini.py
Normal file
375
examples/arrow_multitask/arrow_phi3_mini.py
Normal file
@ -0,0 +1,375 @@
|
||||
# Copyright 2025-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script provides a simple evaluation pipeline for multiple-choice reasoning datasets
|
||||
(e.g., BoolQ, HellaSwag, ARC, OpenBookQA, Winogrande) with different composition strategies.
|
||||
|
||||
Usage examples:
|
||||
python arrow_phi3_mini.py --strategy base --ds_name arc-challenge
|
||||
python arrow_phi3_mini.py --strategy arrow --ds_name boolq
|
||||
python arrow_phi3_mini.py --strategy gks --ds_name hswag
|
||||
|
||||
Key features:
|
||||
- Supports three strategies:
|
||||
• "base" → Evaluate the quantized base model directly
|
||||
• "arrow" → Use Arrow modular routing with task-specific adapters
|
||||
• "gks" → Use Arrow + GenKnowSub (subtracting general-domain knowledge)
|
||||
- Loads evaluation datasets from the Hugging Face Hub
|
||||
- Implements a batched evaluation loop that computes per-option likelihoods and selects
|
||||
the answer with the lowest average loss
|
||||
- Reports simple accuracy
|
||||
|
||||
Implementation details:
|
||||
- The base model is quantized to 4-bit using `BitsAndBytesConfig` (nf4, bf16 compute).
|
||||
- For Arrow and GKS, task-specific adapters are loaded from the Hugging Face Hub:
|
||||
TahaBa/phi3-mini-clustered-flan/ts_expert_i
|
||||
- Task-specific adapters were trained on 10 clusters of FLAN tasks.
|
||||
- The clusters were created using Model-Based Clustering (MBC):
|
||||
1. Train a LoRA adapter for each individual task.
|
||||
2. Apply k-means clustering to group tasks based on these adapters.
|
||||
3. Train a LoRA adapter for each resulting cluster.
|
||||
For more details, see the Arrow paper: https://huggingface.co/papers/2405.11157
|
||||
|
||||
- For GKS, general adapters are loaded from:
|
||||
TahaBa/phi3-mini-general-adapters/...
|
||||
- These adapters were trained on English, French, and German Wikipedia data
|
||||
using a causal language modeling objective with (507-token context → 5-token completion) pairs.
|
||||
- This setup encodes general knowledge into the LoRA space, which can then be
|
||||
subtracted from task-specific adapters during inference to isolate and purify them.
|
||||
For more details, see the GenKnowSub paper: https://huggingface.co/papers/2505.10939
|
||||
|
||||
- `evaluate_on_multi_choice_batched` handles tokenization, masking context tokens,
|
||||
and computing per-choice log-likelihoods for fair comparison.
|
||||
- Accuracy is printed at the end for the selected dataset.
|
||||
|
||||
This script is mainly meant for demonstration purposes and lightweight evaluation,
|
||||
not full-scale benchmarking (batch size / max length can be tuned).
|
||||
|
||||
=======================================================================================
|
||||
|
||||
Results (evaluated with microsoft/Phi-3-mini-4k-instruct, 4-bit quantization):
|
||||
|
||||
| Dataset | Base Acc. | Arrow Acc. | Arrow+GKS Acc. |
|
||||
|--------------|-----------|------------|----------------|
|
||||
| ARC-Challenge| 0.4515 | 0.5418 | 0.5585 |
|
||||
| ARC-Easy | 0.6894 | 0.8404 | 0.8473 |
|
||||
| Winogrande | 0.5769 | 0.6550 | 0.6724 |
|
||||
| BoolQ | 0.8146 | 0.8030 | 0.8247 |
|
||||
| OpenBookQA | 0.43 | 0.448 | 0.472 |
|
||||
| HellaSwag | 0.7318 | 0.7150 | 0.7376 |
|
||||
|
||||
Observations:
|
||||
- Arrow generally improves over the base model by routing tokens to the most relevant task adapters.
|
||||
- Applying GKS (general knowledge subtraction) consistently gives further gains compared to Arrow and Base.
|
||||
|
||||
These numbers are not meant as leaderboard results, but as a sanity check
|
||||
to verify that the implementation works as expected and demonstrates
|
||||
the benefits of Arrow and GenKnowSub.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from sklearn.metrics import accuracy_score
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
||||
|
||||
from peft import ArrowConfig, create_arrow_model
|
||||
|
||||
|
||||
MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct"
|
||||
MODEL_MAX_LEN = 2048
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Training script with strategy selection")
|
||||
|
||||
parser.add_argument(
|
||||
"--strategy",
|
||||
type=str,
|
||||
choices=["base", "arrow", "gks"],
|
||||
default="base",
|
||||
help="Training strategy to use: base, arrow, or gks",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ds_name",
|
||||
type=str,
|
||||
choices=["boolq", "hswag", "arc-easy", "arc-challenge", "oqa", "wg"],
|
||||
default="arc-challenge",
|
||||
help="Dataset to use: boolq, hswag, arc-easy, arc-challenge, oqa, wg",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def read_test_dataset(ds_name):
|
||||
if ds_name == "boolq":
|
||||
ds = load_dataset("google/boolq", split="validation", trust_remote_code=True)
|
||||
elif ds_name == "hswag":
|
||||
ds = load_dataset("Rowan/hellaswag", split="validation", trust_remote_code=True)
|
||||
elif ds_name == "arc-challenge":
|
||||
ds = load_dataset("allenai/ai2_arc", "ARC-Challenge", split="validation", trust_remote_code=True)
|
||||
elif ds_name == "arc-easy":
|
||||
ds = load_dataset("allenai/ai2_arc", "ARC-Easy", split="validation", trust_remote_code=True)
|
||||
elif ds_name == "oqa":
|
||||
ds = load_dataset("allenai/openbookqa", split="validation", trust_remote_code=True)
|
||||
elif ds_name == "wg":
|
||||
ds = load_dataset("allenai/winogrande", "winogrande_xl", split="validation", trust_remote_code=True)
|
||||
else:
|
||||
raise f"Dataset {ds_name} is not supported yet."
|
||||
|
||||
return ds
|
||||
|
||||
|
||||
def extract_input_content(ds_name, row):
|
||||
if ds_name == "boolq":
|
||||
return f"[passage]{row['passage']}[question]{row['question']}"
|
||||
if ds_name == "hswag":
|
||||
return row["ctx"]
|
||||
if (ds_name == "arc-challenge") or (ds_name == "arc-easy"):
|
||||
return row["question"]
|
||||
if ds_name == "oqa":
|
||||
return row["question_stem"]
|
||||
if ds_name == "wg":
|
||||
return row["sentence"]
|
||||
|
||||
|
||||
def create_multi_choice_options(row, ds_name):
|
||||
options_texts = []
|
||||
content = extract_input_content(ds_name, row)
|
||||
if ds_name == "boolq":
|
||||
choices = ["true", "false"]
|
||||
if ds_name == "hswag":
|
||||
choices = row["endings"]
|
||||
if (ds_name == "arc-challenge") or (ds_name == "arc-easy"):
|
||||
choices = row["choices"]["text"]
|
||||
if ds_name == "wg":
|
||||
choices = [row["option1"], row["option2"]]
|
||||
if ds_name == "oqa":
|
||||
choices = row["choices"]["text"]
|
||||
|
||||
for choice in choices:
|
||||
options_texts.append(f"<|user|>\n{content}<|end|>\n<|assistant|>{choice}<|end|>\n")
|
||||
|
||||
return options_texts
|
||||
|
||||
|
||||
def extract_multi_choice_target_index(row, ds_name):
|
||||
if ds_name == "boolq":
|
||||
return 0 if row["answer"] is True else 1
|
||||
if ds_name == "hswag":
|
||||
return int(row["label"])
|
||||
if (ds_name == "arc-challenge") or (ds_name == "arc-easy"):
|
||||
return row["choices"]["label"].index(row["answerKey"])
|
||||
if ds_name == "wg":
|
||||
return int(row["answer"]) - 1
|
||||
if ds_name == "oqa":
|
||||
return row["choices"]["label"].index(row["answerKey"])
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def compute_loglike_loss(logits, labels, reduction="none"):
|
||||
bs = logits.size(0)
|
||||
vocab_size = logits.size(-1)
|
||||
labels = labels.squeeze(-1)
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
|
||||
# Flatten the tokens
|
||||
loss_fct = torch.nn.CrossEntropyLoss(reduction=reduction)
|
||||
shift_logits = shift_logits.view(-1, vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
# reshape back
|
||||
if reduction == "none":
|
||||
loss = loss.view((bs, -1))
|
||||
non_zero_loss = (loss != 0).sum(dim=-1)
|
||||
non_zero_loss[non_zero_loss == 0] = 1
|
||||
loss = loss.sum(dim=-1) / non_zero_loss
|
||||
|
||||
return loss.float() # Convert to float32 before returning
|
||||
|
||||
|
||||
def evaluate_on_multi_choice_batched(
|
||||
eval_dataset, model, tokenizer, ds_name, labels, predictions, args, batch_size=32, max_length=512, device="cuda"
|
||||
):
|
||||
# Local import to mirror your original function
|
||||
model.eval()
|
||||
|
||||
for start in tqdm(
|
||||
range(0, len(eval_dataset), batch_size), total=(len(eval_dataset) + batch_size - 1) // batch_size
|
||||
):
|
||||
rows = [eval_dataset[i] for i in range(start, min(start + batch_size, len(eval_dataset)))]
|
||||
|
||||
# Build the flattened option texts for this batch
|
||||
all_texts = []
|
||||
options_per_sample = [] # number of options for each sample
|
||||
ctx_lens_per_option = [] # context length replicated per option
|
||||
|
||||
for row in rows:
|
||||
# options: ["<|user|>...<|assistant|>choiceA<|end|>", ...]
|
||||
options = create_multi_choice_options(row, ds_name)
|
||||
options_per_sample.append(len(options))
|
||||
|
||||
# compute context length once per sample (align with your -1 shift)
|
||||
content = extract_input_content(ds_name, row)
|
||||
context_prompt = f"<|user|>\n{content}<|end|>\n<|assistant|>"
|
||||
ctx_len = len(tokenizer.encode(context_prompt)) - 1
|
||||
|
||||
all_texts.extend(options)
|
||||
ctx_lens_per_option.extend([ctx_len] * len(options))
|
||||
|
||||
# collect gold label
|
||||
labels.append(extract_multi_choice_target_index(row, ds_name))
|
||||
|
||||
# Tokenize all options in one go
|
||||
tokenized = tokenizer(
|
||||
all_texts,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=max_length,
|
||||
)
|
||||
tokenized = {k: v.to(device) for k, v in tokenized.items()}
|
||||
|
||||
# Create masked labels: ignore context and padding
|
||||
masked_labels = tokenized["input_ids"].clone()
|
||||
for i, ctx_len in enumerate(ctx_lens_per_option):
|
||||
masked_labels[i, :ctx_len] = -100
|
||||
masked_labels[tokenized["attention_mask"] == 0] = -100
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(input_ids=tokenized["input_ids"], attention_mask=tokenized["attention_mask"]).logits
|
||||
# per-sequence losses
|
||||
losses = compute_loglike_loss(logits, masked_labels, reduction="none").detach().cpu()
|
||||
|
||||
# Reduce per sample (argmin across its options)
|
||||
idx = 0
|
||||
for n_opt in options_per_sample:
|
||||
pred = torch.argmin(losses[idx : idx + n_opt]).item()
|
||||
predictions.append(pred)
|
||||
idx += n_opt
|
||||
|
||||
print(
|
||||
f"Accuracy for dataset {args.ds_name} and strategy {args.strategy} is: {accuracy_score(labels, predictions)}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
print(f"Selected strategy: {args.strategy}")
|
||||
print(f"Dataset name: {args.ds_name}")
|
||||
|
||||
# Loading the tokeniser
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
MODEL_NAME,
|
||||
use_fast=True,
|
||||
padding_side="right",
|
||||
model_max_length=MODEL_MAX_LEN,
|
||||
)
|
||||
|
||||
# Quantisation config
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_use_double_quant=False,
|
||||
)
|
||||
|
||||
# Loading the model
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_NAME,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=bnb_config,
|
||||
)
|
||||
|
||||
# Loading the test dataset
|
||||
test_dataset = read_test_dataset(args.ds_name)
|
||||
print(f"{args.ds_name} is loaded with size: {len(test_dataset)}.")
|
||||
|
||||
labels, predictions = [], []
|
||||
if args.strategy == "base":
|
||||
# Batch-wise inference
|
||||
with torch.no_grad():
|
||||
evaluate_on_multi_choice_batched(
|
||||
test_dataset,
|
||||
base_model,
|
||||
tokenizer,
|
||||
args.ds_name,
|
||||
labels,
|
||||
predictions,
|
||||
args,
|
||||
batch_size=64, # tune this
|
||||
max_length=512, # tune if options are long
|
||||
device="cuda",
|
||||
)
|
||||
else:
|
||||
general_adapter_paths = []
|
||||
if args.strategy == "gks":
|
||||
arrow_config = ArrowConfig(
|
||||
top_k=3,
|
||||
router_temperature=1.0,
|
||||
use_gks=True,
|
||||
)
|
||||
# General adapter paths from the hub
|
||||
general_adapter_paths = [
|
||||
"TahaBa/phi3-mini-general-adapters/cluster0_batch16_prop1.0_langen/checkpoint-17",
|
||||
"TahaBa/phi3-mini-general-adapters/cluster0_batch16_prop1.0_langfr/checkpoint-35",
|
||||
"TahaBa/phi3-mini-general-adapters/cluster0_batch16_prop1.0_langger/checkpoint-17",
|
||||
]
|
||||
else:
|
||||
arrow_config = ArrowConfig(
|
||||
top_k=3,
|
||||
router_temperature=1.0,
|
||||
)
|
||||
|
||||
# Task-specific adapter paths from the hub
|
||||
task_specific_adapter_paths = [f"TahaBa/phi3-mini-clustered-flan/ts_expert_{i}" for i in range(10)]
|
||||
|
||||
# Creating the Arrow model
|
||||
model = create_arrow_model(
|
||||
base_model=base_model,
|
||||
task_specific_adapter_paths=task_specific_adapter_paths,
|
||||
general_adapter_paths=general_adapter_paths,
|
||||
arrow_config=arrow_config,
|
||||
)
|
||||
|
||||
# Batch-wise inference
|
||||
with torch.no_grad():
|
||||
evaluate_on_multi_choice_batched(
|
||||
test_dataset,
|
||||
model,
|
||||
tokenizer,
|
||||
args.ds_name,
|
||||
labels,
|
||||
predictions,
|
||||
args,
|
||||
batch_size=32, # tune this
|
||||
max_length=512, # tune if options are long
|
||||
device="cuda",
|
||||
)
|
8
examples/arrow_multitask/requirements.txt
Normal file
8
examples/arrow_multitask/requirements.txt
Normal file
@ -0,0 +1,8 @@
|
||||
torch
|
||||
transformers
|
||||
accelerate
|
||||
datasets
|
||||
scikit-learn
|
||||
tqdm
|
||||
numpy
|
||||
bitsandbytes
|
@ -50,6 +50,7 @@ from .tuners import (
|
||||
AdaLoraModel,
|
||||
AdaptionPromptConfig,
|
||||
AdaptionPromptModel,
|
||||
ArrowConfig,
|
||||
BOFTConfig,
|
||||
BOFTModel,
|
||||
BoneConfig,
|
||||
@ -105,6 +106,7 @@ from .tuners import (
|
||||
VeraModel,
|
||||
XLoraConfig,
|
||||
XLoraModel,
|
||||
create_arrow_model,
|
||||
get_eva_state_dict,
|
||||
initialize_lora_eva_weights,
|
||||
)
|
||||
@ -134,6 +136,7 @@ __all__ = [
|
||||
"AdaLoraModel",
|
||||
"AdaptionPromptConfig",
|
||||
"AdaptionPromptModel",
|
||||
"ArrowConfig",
|
||||
"AutoPeftModel",
|
||||
"AutoPeftModelForCausalLM",
|
||||
"AutoPeftModelForFeatureExtraction",
|
||||
@ -212,6 +215,7 @@ __all__ = [
|
||||
"XLoraModel",
|
||||
"bloom_model_postprocess_past_key_value",
|
||||
"cast_mixed_precision_params",
|
||||
"create_arrow_model",
|
||||
"get_eva_state_dict",
|
||||
"get_layer_status",
|
||||
"get_model_status",
|
||||
|
@ -25,11 +25,13 @@ from .ln_tuning import LNTuningConfig, LNTuningModel
|
||||
from .loha import LoHaConfig, LoHaModel
|
||||
from .lokr import LoKrConfig, LoKrModel
|
||||
from .lora import (
|
||||
ArrowConfig,
|
||||
EvaConfig,
|
||||
LoftQConfig,
|
||||
LoraConfig,
|
||||
LoraModel,
|
||||
LoraRuntimeConfig,
|
||||
create_arrow_model,
|
||||
get_eva_state_dict,
|
||||
initialize_lora_eva_weights,
|
||||
)
|
||||
@ -55,6 +57,7 @@ __all__ = [
|
||||
"AdaLoraModel",
|
||||
"AdaptionPromptConfig",
|
||||
"AdaptionPromptModel",
|
||||
"ArrowConfig",
|
||||
"BOFTConfig",
|
||||
"BOFTModel",
|
||||
"BoneConfig",
|
||||
@ -112,6 +115,7 @@ __all__ = [
|
||||
"VeraModel",
|
||||
"XLoraConfig",
|
||||
"XLoraModel",
|
||||
"create_arrow_model",
|
||||
"get_eva_state_dict",
|
||||
"initialize_lora_eva_weights",
|
||||
]
|
||||
|
@ -15,7 +15,8 @@
|
||||
from peft.import_utils import is_bnb_4bit_available, is_bnb_available, is_eetq_available
|
||||
from peft.utils import register_peft_method
|
||||
|
||||
from .config import EvaConfig, LoftQConfig, LoraConfig, LoraRuntimeConfig
|
||||
from .arrow import create_arrow_model
|
||||
from .config import ArrowConfig, EvaConfig, LoftQConfig, LoraConfig, LoraRuntimeConfig
|
||||
from .eva import get_eva_state_dict, initialize_lora_eva_weights
|
||||
from .gptq import GPTQLoraLinear
|
||||
from .layer import Conv2d, Conv3d, Embedding, Linear, LoraLayer, ParamWrapper
|
||||
@ -23,6 +24,7 @@ from .model import LoraModel
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ArrowConfig",
|
||||
"Conv2d",
|
||||
"Conv3d",
|
||||
"Embedding",
|
||||
@ -35,6 +37,7 @@ __all__ = [
|
||||
"LoraModel",
|
||||
"LoraRuntimeConfig",
|
||||
"ParamWrapper",
|
||||
"create_arrow_model",
|
||||
"get_eva_state_dict",
|
||||
"initialize_lora_eva_weights",
|
||||
]
|
||||
|
476
src/peft/tuners/lora/arrow.py
Normal file
476
src/peft/tuners/lora/arrow.py
Normal file
@ -0,0 +1,476 @@
|
||||
# Copyright 2025-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from .config import ArrowConfig
|
||||
|
||||
|
||||
TASK_ADAPTER_PREFIX = "task_"
|
||||
GKS_ADAPTER_PREFIX = "gks_"
|
||||
|
||||
|
||||
class ArrowLoraLinearLayer(nn.Module):
|
||||
"""
|
||||
This class represent the main logic of the arrow routing algorithm for linear layers.
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, arrow_config):
|
||||
super().__init__()
|
||||
# extra parameters needed for arrow
|
||||
self.in_features = in_features
|
||||
self._protos_ready = False
|
||||
self.top_k = arrow_config.top_k
|
||||
self.temperature = arrow_config.router_temperature
|
||||
self.rng_seed = arrow_config.rng_seed
|
||||
self.task_adapter_names = (
|
||||
arrow_config.task_adapter_names.copy()
|
||||
) # Set in create_arrow_model() with this format: task_0, task_1, ...
|
||||
self.gks_adapter_names = (
|
||||
arrow_config.gks_adapter_names
|
||||
) # Set in create_arrow_model() with this format: gks_0, gks_1, ...
|
||||
self.use_gks = arrow_config.use_gks
|
||||
self.gks_done = False
|
||||
self.gks_added_adapter_names = []
|
||||
self.in_features = in_features
|
||||
self.cast_input_dtype_enabled = True
|
||||
|
||||
@torch.no_grad()
|
||||
def on_adapter_change(self, lora_A, lora_B):
|
||||
"""
|
||||
Called when adapters are added/removed/renamed so Arrow can refresh its internal state before the next forward
|
||||
pass.
|
||||
"""
|
||||
all_ts_adapter_names = [
|
||||
k
|
||||
for k in lora_A.keys()
|
||||
if k in lora_B and k != "arrow_router" and not (k.startswith("gks_") and k[len("gks_") :].isdigit())
|
||||
]
|
||||
|
||||
if sorted(self.task_adapter_names) == sorted(all_ts_adapter_names): # No changes in the ts_adapters
|
||||
return
|
||||
|
||||
# Getting the name(s) of added adapter(s)
|
||||
if len(self.task_adapter_names) < len(all_ts_adapter_names): # Adapter(s) are added.
|
||||
self.gks_added_adapter_names = [x for x in all_ts_adapter_names if x not in self.task_adapter_names]
|
||||
|
||||
# Updating the task_adapter_names
|
||||
self.task_adapter_names = all_ts_adapter_names.copy()
|
||||
# Invalidate caches so they’ll be rebuilt lazily on next forward()
|
||||
self._protos_ready = False
|
||||
# GKS will be handled by self.gks_added_adapter_names
|
||||
|
||||
def top_right_singular_vec_from_BA(self, A, B, iters=15, eps=1e-8):
|
||||
"""
|
||||
Computes the top *right* singular vector of ΔW = B @ A without forming ΔW.
|
||||
|
||||
Theory:
|
||||
For any matrix M, the right singular vectors are the eigenvectors of Mᵀ M. If ΔW = B @ A (with A ∈
|
||||
ℝ^{r×in}, B ∈ ℝ^{out×r}), then
|
||||
ΔWᵀ ΔW = (B @ A)ᵀ (B @ A) = Aᵀ (Bᵀ B) A ∈ ℝ^{in×in}.
|
||||
Therefore, the dominant right singular vector of ΔW is the dominant eigenvector of M := Aᵀ (Bᵀ B) A. We
|
||||
find it by *power iteration* on the linear operator
|
||||
v ↦ Aᵀ (Bᵀ B) (A v),
|
||||
which avoids materializing ΔW (out×in) or M (in×in). The result lives in the input/token space (size =
|
||||
in_features), which is exactly what Arrow needs. (Right singular vectors ≡ eigenvectors of MᵀM; power
|
||||
iteration converges to the dominant eigenvector under mild conditions.)
|
||||
=============================== Practical notes:
|
||||
- We perform all iteration in float32 for numerical stability, then cast back
|
||||
to the LoRA dtype/device before storing/using the prototype.
|
||||
- Convergence is checked with a simple fixed-iter cap (`iters`) and/or
|
||||
`allclose` tolerance (`tol`).
|
||||
- The returned vector is unique up to sign (±), as with any singular vector.
|
||||
Downstream code should be sign-invariant.
|
||||
"""
|
||||
|
||||
# A: (r, in), B: (out, r)
|
||||
A32 = A.to(torch.float32)
|
||||
B32 = B.to(torch.float32)
|
||||
C = B32.T @ B32 # (r, r)
|
||||
|
||||
# Private RNG on A's device
|
||||
gen = None
|
||||
if self.rng_seed is not None:
|
||||
gen = torch.Generator(device=A32.device.type)
|
||||
gen.manual_seed(int(self.rng_seed))
|
||||
|
||||
# init vector in input space
|
||||
v = torch.randn(A32.size(1), dtype=A32.dtype, device=A32.device, generator=gen)
|
||||
v = v / (v.norm() + eps)
|
||||
|
||||
for _ in range(iters):
|
||||
# w = (ΔWᵀΔW) v = Aᵀ (BᵀB) (A v)
|
||||
w = A32.T @ (C @ (A32 @ v))
|
||||
v = w / (w.norm() + eps)
|
||||
|
||||
return v # fp32
|
||||
|
||||
@torch.no_grad()
|
||||
def build_prototypes(self, lora_A, lora_B):
|
||||
"""
|
||||
Computes a prototype vector for each LoRA module in every layer by applying Singular Value Decomposition (SVD)
|
||||
to the `lora_A` matrix and extracting the top right singular vector.
|
||||
|
||||
These prototypes are later used to calculate the cosine similarity between each input token and each expert.
|
||||
The resulting similarity scores serve as coefficients to compute a weighted average of the corresponding LoRA
|
||||
modules, effectively routing each token through its most relevant experts.
|
||||
|
||||
** This prototype computation is done is done once for all experts and is re-done on newly added adapters.**
|
||||
|
||||
Args:
|
||||
lora_A : Matrices A in LoRA layer.
|
||||
lora_B (optional): Matrices B in LoRA layer. Defaults to None.
|
||||
"""
|
||||
|
||||
if self._protos_ready:
|
||||
return
|
||||
protos = []
|
||||
for name in self.task_adapter_names:
|
||||
A = lora_A[name].weight # (r, in_features)
|
||||
B = lora_B[name].weight # (out_features, r)
|
||||
|
||||
# Efficiently computing right singular vector of A @ B
|
||||
proto32 = self.top_right_singular_vec_from_BA(A, B)
|
||||
|
||||
proto = proto32.to(dtype=A.dtype, device=A.device)
|
||||
protos.append(proto)
|
||||
|
||||
proto_stack = torch.stack(protos, dim=0) # (E, in_features)
|
||||
|
||||
# Register the prototypes buffer with correct dtype/device consistent with A and B weights
|
||||
self.register_buffer("prototypes", proto_stack, persistent=False)
|
||||
self._protos_ready = True
|
||||
|
||||
@torch.no_grad()
|
||||
def gen_know_sub(self, lora_A, lora_B):
|
||||
"""
|
||||
This function performs General Knowledge Subtraction. It takes an average of provided general_adapters, and
|
||||
subtract it from each task_adapter. This subtraction tries to purify the task adapters, based on
|
||||
"forgetting-via-negation" principle. Forgetting-via-negation is a task-arithmetic operation, explained in:
|
||||
https://arxiv.org/abs/2212.04089 The task adapters will be more focused and isolated, enhancing the performance
|
||||
on new tasks.
|
||||
|
||||
Args:
|
||||
lora_A : Matrices A in LoRA layer.
|
||||
lora_B : Matrices A in LoRA layer.
|
||||
"""
|
||||
if not self.use_gks:
|
||||
return
|
||||
elif self.gks_done and not self.gks_added_adapter_names:
|
||||
return
|
||||
else:
|
||||
# 1) compute average A/B over gks_adapter_names
|
||||
avg_A = torch.stack([lora_A[n].weight for n in self.gks_adapter_names], dim=0).mean(
|
||||
0
|
||||
) # shape (r, in_features)
|
||||
avg_B = torch.stack([lora_B[n].weight for n in self.gks_adapter_names], dim=0).mean(
|
||||
0
|
||||
) # shape (out_features, r)
|
||||
|
||||
# 2) Subtract the average from task-specific experts
|
||||
if self.gks_done is False: # GKS is done for all the experts, since it hasn't been done yet.
|
||||
for name in self.task_adapter_names:
|
||||
lora_A[name].weight.data.sub_(avg_A)
|
||||
lora_B[name].weight.data.sub_(avg_B)
|
||||
else: # GKS is only done on new added experts, since GKS has been done previously.
|
||||
for name in self.gks_added_adapter_names:
|
||||
lora_A[name].weight.data.sub_(avg_A)
|
||||
lora_B[name].weight.data.sub_(avg_B)
|
||||
|
||||
# 3) Set gks_done flag as true, so we won't do it again in ArrowLinearVariant.forward().
|
||||
self.gks_done = True
|
||||
# Clearing the self.gks_added_adapter_names
|
||||
self.gks_added_adapter_names = []
|
||||
|
||||
def _cast_input_dtype(self, x, dtype: torch.dtype):
|
||||
"""
|
||||
Whether to cast the dtype of the input of the forward method.
|
||||
|
||||
Usually, we want to enable this to align the input dtype with the dtype of the weight, but by setting
|
||||
layer.cast_input_dtype=False, this can be disabled if necessary.
|
||||
|
||||
Enabling or disabling can be managed via the peft.helpers.disable_lora_input_dtype_casting context manager.
|
||||
"""
|
||||
if x is None: # useful e.g. if x is the bias, which can be None
|
||||
return None
|
||||
|
||||
cast_input_dtype_enabled = getattr(self, "cast_input_dtype_enabled", True)
|
||||
if (not cast_input_dtype_enabled) or (x.dtype == dtype):
|
||||
return x
|
||||
return x.to(dtype=dtype)
|
||||
|
||||
def forward(self, x, lora_A, lora_B, dropout, scaling):
|
||||
"""
|
||||
Applies Arrow routing inside a LoRA layer.
|
||||
|
||||
Steps:
|
||||
1. Compute cosine similarity between each token representation and all adapter prototypes.
|
||||
2. Select the top-k experts per token and normalize their scores with a softmax.
|
||||
3. Project tokens into each selected expert’s low-rank space (A weights).
|
||||
4. Map back to the output space (B weights).
|
||||
5. Aggregate expert outputs via the weighted sum of their contributions.
|
||||
6. Apply dropout, scaling, and return the reshaped delta.
|
||||
|
||||
- Conceptually, this is a Mixture-of-Experts (MoE) over LoRA adapters,
|
||||
where coefficients are derived from prototype similarity.
|
||||
|
||||
Returns:
|
||||
delta: LoRA output adjustment computed by Arrow routing.
|
||||
"""
|
||||
x = self._cast_input_dtype(x, lora_A[self.task_adapter_names[0]].weight.dtype)
|
||||
B, *rest, F_in = x.shape
|
||||
tok = x.view(-1, F_in) # (t, F_in)
|
||||
t, E = tok.size(0), self.prototypes.size(0)
|
||||
|
||||
# We now turn scaling, which is a dict, to tensors in order to use them later
|
||||
scales_tens = torch.tensor(
|
||||
[scaling[n] for n in self.task_adapter_names],
|
||||
device=tok.device,
|
||||
dtype=tok.dtype,
|
||||
) # shape (E,)
|
||||
|
||||
# 1) similarity — sign-agnostic
|
||||
sim = torch.abs(tok @ self.prototypes.T) # (t, E)
|
||||
|
||||
# 2) top-k + softmax over full E (non-top-k = -inf)
|
||||
top_v, idx = torch.topk(sim, self.top_k, dim=1)
|
||||
full_score = tok.new_full((t, E), float("-inf"))
|
||||
full_score.scatter_(1, idx, top_v)
|
||||
coeff = torch.softmax(full_score / self.temperature, dim=1) # (t, E)
|
||||
|
||||
# 3) stack all A and B weights once
|
||||
# A_stack: (E, r, in_features), B_stack: (E, out_features, r)
|
||||
A_stack = torch.stack([lora_A[n].weight for n in self.task_adapter_names], dim=0)
|
||||
B_stack = torch.stack([lora_B[n].weight for n in self.task_adapter_names], dim=0)
|
||||
|
||||
# 4) project tokens into each expert’s low‑rank space:
|
||||
# z[e] = tok @ A_e.T → shape (t, E, r)
|
||||
z = torch.einsum("tf, erf -> ter", tok, A_stack)
|
||||
|
||||
# 5) lift back each expert’s output:
|
||||
# y[e] = z[e] @ B_e.T → shape (t, E, out_features)
|
||||
y = torch.einsum("ter, eor -> teo", z, B_stack)
|
||||
|
||||
# 6) apply per-expert scaling before the weighted sum
|
||||
# y_scaled[t, e, o] = scales[e] * y[t, e, o]
|
||||
y = y * scales_tens.view(1, -1, 1)
|
||||
|
||||
# 6) weighted sum over experts:
|
||||
# delta_flat[t,o] = Σ_e coeff[t,e] * y[t,e,o]
|
||||
delta_flat = torch.einsum("te, teo -> to", coeff, y) # (t, out_features)
|
||||
|
||||
# 7) dropout, scale, and reshape
|
||||
delta = dropout(delta_flat)
|
||||
out_dim = delta_flat.size(-1)
|
||||
return delta.view(B, *rest, out_dim)
|
||||
|
||||
|
||||
def check_loaded_lora_compatibility_arrow(model, adapter_names: list[str]):
|
||||
"""
|
||||
After loading all adapters into `model`, check they share:
|
||||
- the same LoRA rank (r)
|
||||
- identical weight shapes
|
||||
- identical sets of target_modules
|
||||
Returns (sorted list of target module names, agreed rank r).
|
||||
"""
|
||||
reference = None # {'r':…, 'shapes':(Ashape,Bshape), 'modules':set([...])}
|
||||
|
||||
for name in adapter_names:
|
||||
curr_modules = set()
|
||||
curr_r = None
|
||||
curr_shapes = None
|
||||
|
||||
for full_name, module in model.named_modules():
|
||||
if hasattr(module, "lora_A") and name in module.lora_A:
|
||||
A = module.lora_A[name].weight
|
||||
B = module.lora_B[name].weight
|
||||
mod_name = full_name.split(".")[-1]
|
||||
curr_modules.add(mod_name)
|
||||
# A has shape (r, in_features); B has shape (out_features, r)
|
||||
curr_r = A.shape[0]
|
||||
curr_shapes = (A.shape, B.shape)
|
||||
|
||||
if reference is None:
|
||||
reference = {"r": curr_r, "shapes": curr_shapes, "modules": curr_modules}
|
||||
else:
|
||||
if curr_r != reference["r"]:
|
||||
raise ValueError(f"[{name}] rank mismatch: {curr_r} != {reference['r']}")
|
||||
if curr_shapes != reference["shapes"]:
|
||||
raise ValueError(f"[{name}] shape mismatch: {curr_shapes} != {reference['shapes']}")
|
||||
if curr_modules != reference["modules"]:
|
||||
raise ValueError(
|
||||
f"[{name}] target_modules mismatch:\n"
|
||||
f" this adapter -> {sorted(curr_modules)}\n"
|
||||
f" reference -> {sorted(reference['modules'])}"
|
||||
)
|
||||
|
||||
agreed_modules = sorted(reference["modules"])
|
||||
return agreed_modules, int(reference["r"])
|
||||
|
||||
|
||||
def ensure_adapters_target_linear_layers_only(model, adapter_names: list[str]):
|
||||
"""
|
||||
Validate that every module holding LoRA weights for any of `adapter_names` is Linear-like: nn.Linear,
|
||||
bitsandbytes.nn.Linear4bit, nn.Conv1d, or transformers.models.gpt2.modeling_gpt2.Conv1D. If not, raise.
|
||||
"""
|
||||
import torch.nn as nn
|
||||
|
||||
Linear4bit = None
|
||||
try:
|
||||
import bitsandbytes as bnb # type: ignore
|
||||
|
||||
Linear4bit = bnb.nn.Linear4bit
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
HFConv1D = None
|
||||
try:
|
||||
from transformers.models.gpt2.modeling_gpt2 import Conv1D as HFConv1D # type: ignore
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
allowed_types = (nn.Linear, nn.Conv1d)
|
||||
if Linear4bit is not None:
|
||||
allowed_types = allowed_types + (Linear4bit,)
|
||||
if HFConv1D is not None:
|
||||
allowed_types = allowed_types + (HFConv1D,)
|
||||
|
||||
offenders = []
|
||||
|
||||
for full_name, module in model.named_modules():
|
||||
if hasattr(module, "lora_A"):
|
||||
for name in adapter_names:
|
||||
if name in getattr(module, "lora_A", {}):
|
||||
base = getattr(module, "base_layer", None) or getattr(module, "original_module", None)
|
||||
layer_to_check = base if base is not None else module
|
||||
|
||||
if not isinstance(layer_to_check, allowed_types):
|
||||
offenders.append((name, full_name, type(layer_to_check).__name__))
|
||||
|
||||
if offenders:
|
||||
lines = [
|
||||
"LoRA adapters must only target Linear-like layers "
|
||||
"(nn.Linear, nn.Conv1d, HF Conv1D, or bitsandbytes.nn.Linear4bit). Found:"
|
||||
]
|
||||
for name, full_name, tname in offenders:
|
||||
lines.append(f" - adapter '{name}' on module '{full_name}' of type {tname}")
|
||||
raise TypeError("\n".join(lines))
|
||||
|
||||
|
||||
def _resolve_adapter_source(path: str) -> tuple[str, str | None]:
|
||||
"""
|
||||
Resolve a user-provided adapter `path` into (model_id, subfolder).
|
||||
|
||||
Supports:
|
||||
- Local path to a folder that contains `adapter_config.json`
|
||||
- Hub path with subfolder, e.g. "user/repo/ts_expert_0[/more/...]", which becomes:
|
||||
model_id="user/repo", subfolder="ts_expert_0[/more/...]"
|
||||
- Plain Hub repo id "user/repo" (no subfolder)
|
||||
"""
|
||||
if os.path.isdir(path):
|
||||
if not os.path.isfile(os.path.join(path, "adapter_config.json")):
|
||||
raise ValueError(f"Local adapter path '{path}' does not contain 'adapter_config.json'.")
|
||||
return path, None
|
||||
|
||||
parts = path.strip("/").split("/")
|
||||
if len(parts) >= 2:
|
||||
model_id = "/".join(parts[:2])
|
||||
if len(parts) > 2:
|
||||
subfolder = "/".join(parts[2:])
|
||||
return model_id, subfolder
|
||||
return model_id, None
|
||||
|
||||
return path, None
|
||||
|
||||
|
||||
def create_arrow_model(
|
||||
base_model: PreTrainedModel,
|
||||
task_specific_adapter_paths: list[str],
|
||||
arrow_config: ArrowConfig,
|
||||
general_adapter_paths: list[str] | None = None,
|
||||
**adapter_kwargs: Any,
|
||||
):
|
||||
if task_specific_adapter_paths is None or len(task_specific_adapter_paths) == 0:
|
||||
raise ValueError("`task_specific_adapter_paths` should contain at least one adapter path")
|
||||
|
||||
from peft import LoraConfig, PeftModel
|
||||
|
||||
model_id0, sub0 = _resolve_adapter_source(task_specific_adapter_paths[0])
|
||||
initial_ts_expert_name = f"{TASK_ADAPTER_PREFIX}0"
|
||||
|
||||
first_kwargs = dict(adapter_kwargs)
|
||||
if sub0 is not None and "subfolder" not in first_kwargs:
|
||||
first_kwargs["subfolder"] = sub0
|
||||
|
||||
model = PeftModel.from_pretrained(
|
||||
base_model,
|
||||
model_id=model_id0,
|
||||
adapter_name=initial_ts_expert_name,
|
||||
**first_kwargs,
|
||||
)
|
||||
|
||||
for i in range(1, len(task_specific_adapter_paths)):
|
||||
ts_expert_name = f"{TASK_ADAPTER_PREFIX}{i}"
|
||||
mid, sub = _resolve_adapter_source(task_specific_adapter_paths[i])
|
||||
more_kwargs = dict(adapter_kwargs)
|
||||
if sub is not None and "subfolder" not in more_kwargs:
|
||||
more_kwargs["subfolder"] = sub
|
||||
model.load_adapter(
|
||||
model_id=mid,
|
||||
adapter_name=ts_expert_name,
|
||||
**more_kwargs,
|
||||
)
|
||||
arrow_config.task_adapter_names = [f"{TASK_ADAPTER_PREFIX}{i}" for i in range(len(task_specific_adapter_paths))]
|
||||
|
||||
if arrow_config.use_gks:
|
||||
if general_adapter_paths is None or len(general_adapter_paths) == 0:
|
||||
raise ValueError("You should provide general LoRA paths if you want to use GenKnowSub.")
|
||||
for i in range(len(general_adapter_paths)):
|
||||
gen_expert_name = f"{GKS_ADAPTER_PREFIX}{i}"
|
||||
mid, sub = _resolve_adapter_source(general_adapter_paths[i])
|
||||
gks_kwargs = dict(adapter_kwargs)
|
||||
if sub is not None and "subfolder" not in gks_kwargs:
|
||||
gks_kwargs["subfolder"] = sub
|
||||
model.load_adapter(
|
||||
model_id=mid,
|
||||
adapter_name=gen_expert_name,
|
||||
**gks_kwargs,
|
||||
)
|
||||
arrow_config.gks_adapter_names = [f"{GKS_ADAPTER_PREFIX}{i}" for i in range(len(general_adapter_paths))]
|
||||
else:
|
||||
arrow_config.gks_adapter_names = []
|
||||
|
||||
target_modules, r = check_loaded_lora_compatibility_arrow(
|
||||
model, adapter_names=arrow_config.task_adapter_names + arrow_config.gks_adapter_names
|
||||
)
|
||||
|
||||
ensure_adapters_target_linear_layers_only(
|
||||
model, adapter_names=arrow_config.task_adapter_names + arrow_config.gks_adapter_names
|
||||
)
|
||||
|
||||
router_cfg = LoraConfig(
|
||||
arrow_config=arrow_config,
|
||||
target_modules=target_modules,
|
||||
r=r,
|
||||
)
|
||||
model.add_adapter(adapter_name="arrow_router", peft_config=router_cfg)
|
||||
model.set_adapter("arrow_router")
|
||||
|
||||
return model
|
@ -24,6 +24,7 @@ from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
|
||||
from peft.utils.integrations import dequantize_bnb_weight
|
||||
from peft.utils.other import transpose
|
||||
|
||||
from .config import ArrowConfig
|
||||
from .layer import LoraLayer, LoraVariant
|
||||
|
||||
|
||||
@ -44,6 +45,7 @@ if is_bnb_available():
|
||||
use_rslora: bool = False,
|
||||
use_alora: bool = False,
|
||||
use_dora: bool = False,
|
||||
arrow_config: ArrowConfig = None,
|
||||
lora_bias: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
@ -62,9 +64,17 @@ if is_bnb_available():
|
||||
use_dora=use_dora,
|
||||
use_alora=use_alora,
|
||||
lora_bias=lora_bias,
|
||||
arrow_config=arrow_config,
|
||||
)
|
||||
|
||||
def resolve_lora_variant(self, *, use_dora: bool, use_alora: bool, **kwargs) -> Optional[LoraVariant]:
|
||||
def resolve_lora_variant(
|
||||
self, *, arrow_config: ArrowConfig, use_dora: bool, use_alora: bool, **kwargs
|
||||
) -> Optional[LoraVariant]:
|
||||
if arrow_config is not None:
|
||||
from .variants import ArrowLinearVariant
|
||||
|
||||
return ArrowLinearVariant()
|
||||
|
||||
if not use_dora and not use_alora:
|
||||
return None
|
||||
|
||||
@ -323,6 +333,7 @@ if is_bnb_4bit_available():
|
||||
init_lora_weights: bool = True,
|
||||
use_rslora: bool = False,
|
||||
use_dora: bool = False,
|
||||
arrow_config: ArrowConfig = None,
|
||||
lora_bias: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
@ -340,9 +351,17 @@ if is_bnb_4bit_available():
|
||||
use_rslora=use_rslora,
|
||||
use_dora=use_dora,
|
||||
lora_bias=lora_bias,
|
||||
arrow_config=arrow_config,
|
||||
)
|
||||
|
||||
def resolve_lora_variant(self, *, use_dora: bool, use_alora: bool, **kwargs) -> Optional[LoraVariant]:
|
||||
def resolve_lora_variant(
|
||||
self, *, arrow_config: ArrowConfig, use_dora: bool, use_alora: bool, **kwargs
|
||||
) -> Optional[LoraVariant]:
|
||||
if arrow_config is not None:
|
||||
from .variants import ArrowLinearVariant
|
||||
|
||||
return ArrowLinearVariant()
|
||||
|
||||
if not use_dora and not use_alora:
|
||||
return None
|
||||
|
||||
|
@ -69,6 +69,56 @@ class LoftQConfig:
|
||||
loftq_iter: int = field(default=1, metadata={"help": "Alternating iterations for LoftQ"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArrowConfig:
|
||||
"""
|
||||
This is the sub-configuration class to store the configuration for Arrow and GenKnowSub algorithm. Arrow is a
|
||||
routing algorithm to combine the trained LoRA modules to solve new tasks, proposed in
|
||||
'https://arxiv.org/pdf/2405.11157'. GenKnowSub is a refinement on the trained modules before being combined via
|
||||
Arrow, introduced in 'https://aclanthology.org/2025.acl-short.54/'
|
||||
"""
|
||||
|
||||
top_k: int = field(
|
||||
default=3,
|
||||
metadata={"help": "Number of top LoRA modules to combine in Arrow routing."},
|
||||
)
|
||||
|
||||
router_temperature: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Softmax temperature for computing Arrow expert coefficients."},
|
||||
)
|
||||
|
||||
use_gks: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable GenKnowSub."},
|
||||
)
|
||||
|
||||
task_adapter_names: Optional[list[str]] = field(
|
||||
default=None,
|
||||
init=False,
|
||||
metadata={"help": "list of task-specific LoRA adapter names. It will be set in create_arrow_model()."},
|
||||
)
|
||||
|
||||
gks_adapter_names: Optional[list[str]] = field(
|
||||
default=None,
|
||||
init=False,
|
||||
metadata={
|
||||
"help": "list of general LoRA adapter names for GenKnowSub. It will be set in create_arrow_model()."
|
||||
},
|
||||
)
|
||||
|
||||
rng_seed: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Optional RNG seed for reproducibility. If None, sampling is non-deterministic."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.top_k <= 0:
|
||||
raise ValueError("top_k cannot be negative.")
|
||||
if self.router_temperature <= 0:
|
||||
raise ValueError("router_temperature must be greater than 0.")
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvaConfig:
|
||||
"""
|
||||
@ -610,6 +660,9 @@ class LoraConfig(PeftConfig):
|
||||
)
|
||||
},
|
||||
)
|
||||
arrow_config: Optional[ArrowConfig] = field(
|
||||
default=None, metadata={"help": "The necessary config to apply arrow routing on the model."}
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
@ -628,7 +681,6 @@ class LoraConfig(PeftConfig):
|
||||
self.exclude_modules = (
|
||||
set(self.exclude_modules) if isinstance(self.exclude_modules, list) else self.exclude_modules
|
||||
)
|
||||
|
||||
if isinstance(self.target_parameters, str):
|
||||
raise TypeError("`target_parameters` must be a list of strings or None.")
|
||||
|
||||
|
@ -36,7 +36,7 @@ from peft.utils.integrations import (
|
||||
from peft.utils.other import transpose
|
||||
from peft.utils.warning import PeftWarning
|
||||
|
||||
from .config import LoraConfig
|
||||
from .config import ArrowConfig, LoraConfig
|
||||
|
||||
|
||||
VARIANT_KWARG_KEYS = ["alora_offsets"]
|
||||
@ -206,6 +206,7 @@ class LoraLayer(BaseTunerLayer):
|
||||
use_alora: bool = False,
|
||||
use_qalora: bool = False,
|
||||
lora_bias: bool = False,
|
||||
arrow_config: ArrowConfig = None,
|
||||
qalora_group_size: int = 32,
|
||||
**kwargs,
|
||||
):
|
||||
@ -225,7 +226,11 @@ class LoraLayer(BaseTunerLayer):
|
||||
)
|
||||
|
||||
lora_variant = self.resolve_lora_variant(
|
||||
use_dora=use_dora, use_alora=use_alora, use_qalora=use_qalora, qalora_group_size=qalora_group_size
|
||||
use_dora=use_dora,
|
||||
use_alora=use_alora,
|
||||
use_qalora=use_qalora,
|
||||
qalora_group_size=qalora_group_size,
|
||||
arrow_config=arrow_config,
|
||||
)
|
||||
if lora_variant is not None:
|
||||
self.lora_variant[adapter_name] = lora_variant
|
||||
@ -279,6 +284,17 @@ class LoraLayer(BaseTunerLayer):
|
||||
|
||||
self.set_adapter(self.active_adapters)
|
||||
|
||||
# Check for adapters that were added or removed from the arrow_model.
|
||||
# The arrow model may be modified after creation by adding new experts
|
||||
# (pre-trained or trainable) or by removing existing ones. Whenever such
|
||||
# a change occurs, on_adapter_change() is called to update the set of
|
||||
# active task-specific experts and, if needed, to handle recomputing prototypes
|
||||
# and doing general knowledge subtraction (GKS) again.
|
||||
if hasattr(self, "lora_arrow"):
|
||||
for adapter in self.lora_variant:
|
||||
if adapter in self.lora_arrow:
|
||||
self.lora_arrow[adapter].on_adapter_change(self.lora_A, self.lora_B)
|
||||
|
||||
def reset_lora_parameters(self, adapter_name, init_lora_weights):
|
||||
if init_lora_weights is False:
|
||||
return
|
||||
@ -639,6 +655,7 @@ class Linear(nn.Module, LoraLayer):
|
||||
use_rslora: bool = False,
|
||||
use_dora: bool = False,
|
||||
use_alora: bool = False,
|
||||
arrow_config: ArrowConfig = None,
|
||||
lora_bias: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
@ -657,10 +674,18 @@ class Linear(nn.Module, LoraLayer):
|
||||
use_dora=use_dora,
|
||||
use_alora=use_alora,
|
||||
lora_bias=lora_bias,
|
||||
arrow_config=arrow_config,
|
||||
)
|
||||
self.is_target_conv_1d_layer = is_target_conv_1d_layer
|
||||
|
||||
def resolve_lora_variant(self, *, use_dora: bool, use_alora: bool, **kwargs) -> Optional[LoraVariant]:
|
||||
def resolve_lora_variant(
|
||||
self, *, arrow_config: ArrowConfig, use_dora: bool, use_alora: bool, **kwargs
|
||||
) -> Optional[LoraVariant]:
|
||||
if arrow_config is not None:
|
||||
from .variants import ArrowLinearVariant
|
||||
|
||||
return ArrowLinearVariant()
|
||||
|
||||
if not use_dora and not use_alora:
|
||||
return None
|
||||
|
||||
@ -742,6 +767,7 @@ class Linear(nn.Module, LoraLayer):
|
||||
"""
|
||||
This method unmerges all merged adapter layers from the base weights.
|
||||
"""
|
||||
|
||||
if not self.merged:
|
||||
warnings.warn("Already unmerged. Nothing to do.")
|
||||
return
|
||||
@ -855,6 +881,7 @@ class Embedding(nn.Module, LoraLayer):
|
||||
init_lora_weights: Union[bool, str] = True,
|
||||
use_rslora: bool = False,
|
||||
use_dora: bool = False,
|
||||
arrow_config: ArrowConfig = None,
|
||||
lora_bias: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
@ -876,6 +903,7 @@ class Embedding(nn.Module, LoraLayer):
|
||||
use_rslora=use_rslora,
|
||||
use_dora=use_dora,
|
||||
lora_bias=lora_bias,
|
||||
arrow_config=arrow_config,
|
||||
)
|
||||
|
||||
def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
|
||||
@ -887,7 +915,17 @@ class Embedding(nn.Module, LoraLayer):
|
||||
return DoraEmbeddingVariant()
|
||||
|
||||
def update_layer(
|
||||
self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora, lora_bias, **kwargs
|
||||
self,
|
||||
adapter_name,
|
||||
r,
|
||||
lora_alpha,
|
||||
lora_dropout,
|
||||
init_lora_weights,
|
||||
use_rslora,
|
||||
use_dora,
|
||||
lora_bias,
|
||||
arrow_config: ArrowConfig = None,
|
||||
**kwargs,
|
||||
):
|
||||
# collect the kwargs
|
||||
kwargs = locals().copy()
|
||||
@ -896,7 +934,7 @@ class Embedding(nn.Module, LoraLayer):
|
||||
if r <= 0:
|
||||
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
|
||||
|
||||
lora_variant = self.resolve_lora_variant(use_dora=use_dora)
|
||||
lora_variant = self.resolve_lora_variant(use_dora=use_dora, arrow_config=arrow_config)
|
||||
if lora_variant is not None:
|
||||
self.lora_variant[adapter_name] = lora_variant
|
||||
|
||||
@ -1129,6 +1167,7 @@ class _ConvNd(nn.Module, LoraLayer):
|
||||
init_lora_weights: Union[bool, str] = True,
|
||||
use_rslora: bool = False,
|
||||
use_dora: bool = False,
|
||||
arrow_config: ArrowConfig = None,
|
||||
lora_bias: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
@ -1158,10 +1197,21 @@ class _ConvNd(nn.Module, LoraLayer):
|
||||
use_rslora=use_rslora,
|
||||
use_dora=use_dora,
|
||||
lora_bias=lora_bias,
|
||||
arrow_config=arrow_config,
|
||||
)
|
||||
|
||||
def update_layer(
|
||||
self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora, lora_bias, **kwargs
|
||||
self,
|
||||
adapter_name,
|
||||
r,
|
||||
lora_alpha,
|
||||
lora_dropout,
|
||||
init_lora_weights,
|
||||
use_rslora,
|
||||
use_dora,
|
||||
lora_bias,
|
||||
arrow_config: ArrowConfig = None,
|
||||
**kwargs,
|
||||
):
|
||||
# collect the kwargs
|
||||
kwargs = locals().copy()
|
||||
@ -1177,7 +1227,7 @@ class _ConvNd(nn.Module, LoraLayer):
|
||||
PeftWarning,
|
||||
)
|
||||
|
||||
lora_variant = self.resolve_lora_variant(use_dora=use_dora)
|
||||
lora_variant = self.resolve_lora_variant(use_dora=use_dora, arrow_config=arrow_config)
|
||||
if lora_variant is not None:
|
||||
self.lora_variant[adapter_name] = lora_variant
|
||||
|
||||
@ -1840,7 +1890,6 @@ class MultiheadAttention(nn.Module, LoraLayer):
|
||||
|
||||
class _LoraParameterProxy(nn.Module):
|
||||
"""This proxies an `nn.Parameter` that is targeted with LoRA.
|
||||
|
||||
Intended to be used in conjunction with `nn.utils.parametrize`, see `ParamWrapper`.
|
||||
"""
|
||||
|
||||
@ -1865,13 +1914,12 @@ def _register_parameter_or_buffer(module, name, X):
|
||||
class ParamWrapper(nn.Module, LoraLayer):
|
||||
"""A LoRA wrapper for `nn.Parameter`. This layer is dispatched if users target a parameter directly with
|
||||
`lora_config.target_parameters`
|
||||
|
||||
Note:
|
||||
|
||||
- When accessing the wrapped nn.Parameter directly, e.g. via `module.weight`, the LoRA weights are *not* applied.
|
||||
- It is currently not implemented to target multiple parameters on the same module. To achieve this, it is
|
||||
currently required to create a separate LoRA adapter (with another adapter name) and activate both at the same
|
||||
time.
|
||||
Note:
|
||||
- When accessing the wrapped nn.Parameter directly, e.g. via `module.weight`, the LoRA weights are *not*
|
||||
applied.
|
||||
- It is currently not implemented to target multiple parameters on the same module. To achieve this, it is
|
||||
currently required to create a separate LoRA adapter (with another adapter name) and activate both at the
|
||||
same time.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -221,10 +221,12 @@ class LoraModel(BaseTuner):
|
||||
"qalora_group_size": lora_config.qalora_group_size,
|
||||
"ephemeral_gpu_offload": lora_config.runtime_config.ephemeral_gpu_offload,
|
||||
"lora_bias": lora_config.lora_bias,
|
||||
"arrow_config": lora_config.arrow_config,
|
||||
"loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False),
|
||||
"loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False),
|
||||
"parameter_name": parameter_name,
|
||||
}
|
||||
|
||||
# for torchao merging, we need the get_apply_tensor_subclass from the quantization config
|
||||
try:
|
||||
kwargs["get_apply_tensor_subclass"] = operator.attrgetter(
|
||||
@ -254,6 +256,7 @@ class LoraModel(BaseTuner):
|
||||
use_rslora=lora_config.use_rslora,
|
||||
use_dora=lora_config.use_dora,
|
||||
lora_bias=lora_config.lora_bias,
|
||||
arrow_config=lora_config.arrow_config,
|
||||
)
|
||||
else:
|
||||
if isinstance(target, ParamWrapper) and (parameter_name == target.parameter_name):
|
||||
|
@ -23,11 +23,112 @@ from torch import nn
|
||||
|
||||
from peft.utils.other import transpose
|
||||
|
||||
from .arrow import ArrowLoraLinearLayer
|
||||
from .config import PeftConfig
|
||||
from .dora import DoraConv1dLayer, DoraConv2dLayer, DoraConv3dLayer, DoraEmbeddingLayer, DoraLinearLayer
|
||||
from .layer import Conv1d, Conv2d, Conv3d, Embedding, Linear, LoraVariant, _ConvNd
|
||||
|
||||
|
||||
class ArrowLinearVariant(LoraVariant):
|
||||
@staticmethod
|
||||
def init(module: Linear, adapter_name: str, **kwargs):
|
||||
"""
|
||||
Initialise the ArrowLoraLinearLayer() inside lora_arrow. lora_arrow is nn.ModuleDict(), serving as a container
|
||||
for ArrowLoraLinearLayer(). A layer of the base model with LoRA adapter loaded on it will be like:
|
||||
----------------------------------------------------
|
||||
(qkv_proj): lora.Linear4bit or lora.Linear(
|
||||
(base_layer): Linear4bit or Linear (lora_dropout): ModuleDict( ... ) (lora_A): ModuleDict( ... )
|
||||
(lora_B): ModuleDict( ... ) (lora_embedding_A): ParameterDict( ... ) (lora_embedding_B): ParameterDict(
|
||||
... ) (lora_magnitude_vector): ModuleDict( ... ) (lora_arrow): ModuleDict(
|
||||
(arrow_router): ArrowLoraLinearLayer() )
|
||||
)
|
||||
----------------------------------------------------
|
||||
|
||||
Args:
|
||||
module (Linear): LoRA Layer of the model, containing base_layer, lora_A, lora_B, etc.
|
||||
adapter_name (str): name of the adapter that will be put in lora_arrow.
|
||||
The adapter_name is "arrow_router" by default, set in create_arrow_model() in ./arrow.py
|
||||
"""
|
||||
# Checking for arrow necessary config
|
||||
arrow_config = kwargs.get("arrow_config")
|
||||
if arrow_config is None:
|
||||
raise ValueError("ArrowLinearVariant.init() did not receive an arrow_config")
|
||||
|
||||
# 1-a) build the ArrowLoRALayer
|
||||
arrow_layer = ArrowLoraLinearLayer(
|
||||
in_features=module.in_features,
|
||||
arrow_config=arrow_config,
|
||||
).to(module.weight.device)
|
||||
|
||||
# 1-b) register a container if it doesn’t exist yet
|
||||
if not hasattr(module, "lora_arrow"):
|
||||
module.lora_arrow = nn.ModuleDict()
|
||||
|
||||
module.lora_arrow[adapter_name] = arrow_layer
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
module: Linear,
|
||||
*,
|
||||
active_adapter: str,
|
||||
x: torch.Tensor,
|
||||
result: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Parameters mirror those in PEFT’s `LoraVariant.forward`. Called every time the host Linear does a fwd pass.
|
||||
|
||||
build_prototypes() and gen_know_sub() should run only once before routing. Both are implemented in
|
||||
ArrowLoraLinearLayer (see ./arrow.py). They are lazily invoked in the forward pass below. Attributes of
|
||||
ArrowLoraLinearLayer() class ensure they execute only a single time.
|
||||
|
||||
Args:
|
||||
module (Linear): LoRA Layer of the model
|
||||
active_adapter (str): name of the arrow route, which should be active to perform arrow.
|
||||
x (torch.Tensor): input to the layer
|
||||
result (torch.Tensor): output of the base layer.
|
||||
|
||||
Return value:
|
||||
output of the base model + delta weight computed by arrow layer.
|
||||
"""
|
||||
arrow = module.lora_arrow[active_adapter] # ArrowLoraLinearLayer
|
||||
# Apply GenKnowSub the 1st time if applcable. By calling arrow/on_adapter_change(),
|
||||
# gen_know_sub() is redone for newly added adapters after arrow.create_arrow_model().
|
||||
arrow.gen_know_sub(module.lora_A, module.lora_B)
|
||||
# lazily build prototypes the 1st time after GenKnowSub. By calling arrow/on_adapter_change(),
|
||||
# build_prototypes() is redone for newly added adapters after arrow.create_arrow_model().
|
||||
arrow.build_prototypes(module.lora_A, module.lora_B)
|
||||
|
||||
# A forward path of ArrowLoraLinearLayer is called so routing performs.
|
||||
# Accept and ignore extra variant kwargs (e.g., 'alora_offsets') for compatibility
|
||||
delta = arrow(
|
||||
x,
|
||||
lora_A=module.lora_A,
|
||||
lora_B=module.lora_B,
|
||||
dropout=module.lora_dropout[active_adapter],
|
||||
scaling=module.scaling,
|
||||
)
|
||||
return result + delta
|
||||
|
||||
"""
|
||||
Since Arrow is a Mixture-of-Experts (MoE) approach, merging adapters is not meaningful or even possible: for each
|
||||
token, the top-k LoRA experts are dynamically selected and routed. Because of this per-token routing, there is no
|
||||
single set of weights that can represent a merged adapter.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
raise RuntimeError("Cannot merge an active Arrow router adapter. Remove it first.")
|
||||
|
||||
@staticmethod
|
||||
def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None:
|
||||
raise RuntimeError("Cannot merge an active Arrow router adapter. Remove it first.")
|
||||
|
||||
@staticmethod
|
||||
def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
raise RuntimeError("Cannot unmerge an active Arrow router adapter. Remove it first.")
|
||||
|
||||
|
||||
class DoraLinearVariant(LoraVariant):
|
||||
@staticmethod
|
||||
def init(module: Linear, adapter_name: str, **kwargs: Any) -> None:
|
||||
|
509
tests/test_arrow.py
Normal file
509
tests/test_arrow.py
Normal file
@ -0,0 +1,509 @@
|
||||
# Copyright 2025-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoModelForImageClassification
|
||||
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from peft.tuners.lora import ArrowConfig, create_arrow_model
|
||||
from peft.tuners.lora.arrow import _resolve_adapter_source
|
||||
from tests.testing_utils import hub_online_once
|
||||
|
||||
|
||||
# ─── Fixtures ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def workdir(tmp_path_factory):
|
||||
"""
|
||||
Create a temp directory and chdir into it for the duration of the module.
|
||||
"""
|
||||
wd = tmp_path_factory.mktemp("arrow_workdir")
|
||||
old_cwd = os.getcwd()
|
||||
os.chdir(wd)
|
||||
yield Path(wd)
|
||||
os.chdir(old_cwd)
|
||||
# (pytest will auto-delete wd)
|
||||
|
||||
|
||||
def _create_and_save_adapter(out_dir: Path, rank: int = 4):
|
||||
"""Helper: build a LoRA adapter around `model` and save into `out_dir`."""
|
||||
# fan_in_fan_out is set to True because of GPT2 model that we use to avoid warning
|
||||
cfg = LoraConfig(r=rank, target_modules=["c_attn"], fan_in_fan_out=True, init_lora_weights=False)
|
||||
model_id = "hf-internal-testing/tiny-random-gpt2"
|
||||
with hub_online_once(model_id):
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
peft_model = get_peft_model(model, cfg)
|
||||
peft_model.save_pretrained(out_dir)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ts_adapters(workdir: Path):
|
||||
"""
|
||||
Build 3 task-specific adapters and return their absolute paths
|
||||
"""
|
||||
abs_paths = []
|
||||
for i in range(3):
|
||||
sub = f"{workdir}/ts{i}"
|
||||
_create_and_save_adapter(sub)
|
||||
abs_paths.append(sub)
|
||||
return abs_paths
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gen_adapter(workdir: Path):
|
||||
"""Build 1 general-knowledge adapter and return its absolute path list."""
|
||||
sub = f"{workdir}/gen0"
|
||||
_create_and_save_adapter(sub)
|
||||
return [sub] # list because create_arrow_model expects list
|
||||
|
||||
|
||||
class TestArrowRouting:
|
||||
def test_incompatible_rank_raises(self, workdir: Path):
|
||||
"""
|
||||
Adding adapters with different ranks must raise a ValueError.
|
||||
"""
|
||||
# Create two adapters with different ranks targeting the same modules
|
||||
sub_r4 = workdir / "rank4"
|
||||
sub_r8 = workdir / "rank8"
|
||||
_create_and_save_adapter(sub_r4, rank=4)
|
||||
_create_and_save_adapter(sub_r8, rank=8)
|
||||
|
||||
model_id = "hf-internal-testing/tiny-random-gpt2"
|
||||
with hub_online_once(model_id):
|
||||
base = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
|
||||
# Expect create_arrow_model to raise due to rank mismatch
|
||||
with pytest.raises(ValueError, match=r"rank mismatch"):
|
||||
_ = create_arrow_model(
|
||||
base_model=base,
|
||||
task_specific_adapter_paths=[str(sub_r4), str(sub_r8)],
|
||||
arrow_config=ArrowConfig(top_k=1),
|
||||
)
|
||||
|
||||
def test_arrow_differs_with_extra_expert(self, ts_adapters):
|
||||
"""
|
||||
Arrow with 2 experts vs Arrow with 3 experts must produce different logits.
|
||||
"""
|
||||
# Arrow over first 2 experts
|
||||
model_id = "hf-internal-testing/tiny-random-gpt2"
|
||||
with hub_online_once(model_id):
|
||||
base_model_1 = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
base_model_2 = copy.deepcopy(base_model_1)
|
||||
cfg_small = ArrowConfig(top_k=2)
|
||||
m_small = create_arrow_model(
|
||||
base_model=base_model_1,
|
||||
task_specific_adapter_paths=ts_adapters[:2],
|
||||
arrow_config=cfg_small,
|
||||
).eval()
|
||||
|
||||
# Arrow over all 3 experts
|
||||
cfg_big = ArrowConfig(top_k=2)
|
||||
m_big = create_arrow_model(
|
||||
base_model=base_model_2,
|
||||
task_specific_adapter_paths=ts_adapters,
|
||||
arrow_config=cfg_big,
|
||||
).eval()
|
||||
|
||||
x = torch.ones(1, 4, dtype=torch.long)
|
||||
assert not torch.allclose(m_small(x).logits, m_big(x).logits)
|
||||
|
||||
def test_arrow_gks_with_load_adapter_later_with_forward(self, ts_adapters, gen_adapter):
|
||||
"""
|
||||
Loading the last expert after creating the arrow model should produce the same result as loading all the
|
||||
experts at once in create_arrow_model(), when forward path is called before adding the new adapter.
|
||||
"""
|
||||
# Arrow over all three experts
|
||||
model_id = "hf-internal-testing/tiny-random-gpt2"
|
||||
with hub_online_once(model_id):
|
||||
base_model_1 = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
base_model_2 = copy.deepcopy(base_model_1)
|
||||
cfg_big = ArrowConfig(top_k=2, use_gks=True, rng_seed=42)
|
||||
m_big = create_arrow_model(
|
||||
base_model=base_model_1,
|
||||
task_specific_adapter_paths=ts_adapters,
|
||||
general_adapter_paths=gen_adapter,
|
||||
arrow_config=cfg_big,
|
||||
).eval()
|
||||
|
||||
# Arrow over all 2 experts + loading the third expert later
|
||||
cfg_small_later_big = ArrowConfig(top_k=2, use_gks=True, rng_seed=42)
|
||||
m_small_later_big = create_arrow_model(
|
||||
base_model=base_model_2,
|
||||
task_specific_adapter_paths=ts_adapters[:2],
|
||||
general_adapter_paths=gen_adapter,
|
||||
arrow_config=cfg_small_later_big,
|
||||
)
|
||||
|
||||
# Ensuring that the prototypes and gks are done one time by running a forward path
|
||||
x = torch.ones(1, 4, dtype=torch.long)
|
||||
m_small_later_big(x)
|
||||
|
||||
# Now loading the third expert
|
||||
m_small_later_big.load_adapter(
|
||||
model_id=ts_adapters[-1],
|
||||
adapter_name="new_added_ts_expert",
|
||||
)
|
||||
# Activating the new adapter and run forward path on it
|
||||
m_small_later_big.set_adapter("new_added_ts_expert")
|
||||
x = torch.ones(3, 5, dtype=torch.long)
|
||||
m_small_later_big(x)
|
||||
|
||||
# Now we switch back to the arrow_router
|
||||
m_small_later_big.set_adapter("arrow_router")
|
||||
m_small_later_big.eval()
|
||||
|
||||
x = torch.ones(1, 4, dtype=torch.long)
|
||||
assert torch.allclose(m_big(x).logits, m_small_later_big(x).logits)
|
||||
|
||||
def test_arrow_with_load_adapter_later_with_forward_activate_new(self, ts_adapters, gen_adapter):
|
||||
"""
|
||||
Loading the last expert after creating the arrow model and activate it should produce different result compared
|
||||
to the case where arrow_router is activate, and the model's using arrow.
|
||||
"""
|
||||
# Arrow over all three experts
|
||||
model_id = "hf-internal-testing/tiny-random-gpt2"
|
||||
with hub_online_once(model_id):
|
||||
base_model_1 = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
base_model_2 = copy.deepcopy(base_model_1)
|
||||
cfg_big = ArrowConfig(top_k=2, use_gks=True, rng_seed=42)
|
||||
m_big = create_arrow_model(
|
||||
base_model=base_model_1,
|
||||
task_specific_adapter_paths=ts_adapters,
|
||||
general_adapter_paths=gen_adapter,
|
||||
arrow_config=cfg_big,
|
||||
).eval()
|
||||
|
||||
# Arrow over all 2 experts + loading the third expert later
|
||||
cfg_small_later_big = ArrowConfig(top_k=2, use_gks=True, rng_seed=42)
|
||||
m_small_later_big = create_arrow_model(
|
||||
base_model=base_model_2,
|
||||
task_specific_adapter_paths=ts_adapters[:2],
|
||||
general_adapter_paths=gen_adapter,
|
||||
arrow_config=cfg_small_later_big,
|
||||
)
|
||||
|
||||
# Ensuring that the prototypes and gks are done one time by running a forward path
|
||||
x = torch.ones(1, 4, dtype=torch.long)
|
||||
m_small_later_big(x)
|
||||
|
||||
# Now loading the third expert
|
||||
m_small_later_big.load_adapter(
|
||||
model_id=ts_adapters[-1],
|
||||
adapter_name="new_added_ts_expert",
|
||||
)
|
||||
# The new adapter is activated
|
||||
m_small_later_big.set_adapter("new_added_ts_expert")
|
||||
m_small_later_big.eval()
|
||||
|
||||
x = torch.ones(1, 4, dtype=torch.long)
|
||||
assert not torch.allclose(m_big(x).logits, m_small_later_big(x).logits)
|
||||
|
||||
def test_arrow_gks_with_load_adapter_later_without_forward(self, ts_adapters, gen_adapter):
|
||||
"""
|
||||
Loading the last expert after creating the arrow model should produce the same result as loading all the
|
||||
experts at once in create_arrow_model()
|
||||
"""
|
||||
# Arrow over all three experts
|
||||
model_id = "hf-internal-testing/tiny-random-gpt2"
|
||||
with hub_online_once(model_id):
|
||||
base_model_1 = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
base_model_2 = copy.deepcopy(base_model_1)
|
||||
cfg_big = ArrowConfig(top_k=2, use_gks=True, rng_seed=42)
|
||||
m_big = create_arrow_model(
|
||||
base_model=base_model_1,
|
||||
task_specific_adapter_paths=ts_adapters,
|
||||
general_adapter_paths=gen_adapter,
|
||||
arrow_config=cfg_big,
|
||||
).eval()
|
||||
|
||||
# Arrow over all 2 experts + loading the third expert later
|
||||
cfg_small_later_big = ArrowConfig(top_k=2, use_gks=True, rng_seed=42)
|
||||
m_small_later_big = create_arrow_model(
|
||||
base_model=base_model_2,
|
||||
task_specific_adapter_paths=ts_adapters[:2],
|
||||
general_adapter_paths=gen_adapter,
|
||||
arrow_config=cfg_small_later_big,
|
||||
)
|
||||
|
||||
# Now loading the third expert
|
||||
m_small_later_big.load_adapter(
|
||||
model_id=ts_adapters[-1],
|
||||
adapter_name="new_added_ts_expert",
|
||||
)
|
||||
m_small_later_big.eval()
|
||||
|
||||
x = torch.ones(1, 4, dtype=torch.long)
|
||||
assert torch.allclose(m_big(x).logits, m_small_later_big(x).logits)
|
||||
|
||||
def test_genknowsub_changes_output(self, ts_adapters, gen_adapter):
|
||||
"""
|
||||
Arrow+GenKnowSub vs plain Arrow must change logits.
|
||||
"""
|
||||
# Plain Arrow
|
||||
model_id = "hf-internal-testing/tiny-random-gpt2"
|
||||
with hub_online_once(model_id):
|
||||
base_model_1 = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
base_model_2 = copy.deepcopy(base_model_1)
|
||||
cfg_plain = ArrowConfig(top_k=2)
|
||||
m_plain = create_arrow_model(
|
||||
base_model=base_model_1,
|
||||
task_specific_adapter_paths=ts_adapters,
|
||||
arrow_config=cfg_plain,
|
||||
).eval()
|
||||
|
||||
# Arrow + GenKnowSub
|
||||
cfg_gks = ArrowConfig(top_k=2, use_gks=True)
|
||||
m_gks = create_arrow_model(
|
||||
base_model=base_model_2,
|
||||
task_specific_adapter_paths=ts_adapters,
|
||||
general_adapter_paths=gen_adapter,
|
||||
arrow_config=cfg_gks,
|
||||
).eval()
|
||||
|
||||
x = torch.ones(1, 4, dtype=torch.long)
|
||||
assert not torch.allclose(m_plain(x).logits, m_gks(x).logits)
|
||||
|
||||
def test_merging_adapters_raise_error_in_arrow(self, ts_adapters):
|
||||
"""
|
||||
Merging/unmerging is not allowed while an ArrowLinearLayer is loaded on the model and active.
|
||||
"""
|
||||
# Arrow over first 2 experts
|
||||
model_id = "hf-internal-testing/tiny-random-gpt2"
|
||||
with hub_online_once(model_id):
|
||||
base_model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
cfg_small = ArrowConfig(top_k=2)
|
||||
m_small = create_arrow_model(
|
||||
base_model=base_model,
|
||||
task_specific_adapter_paths=ts_adapters[:2],
|
||||
arrow_config=cfg_small,
|
||||
).eval()
|
||||
|
||||
with pytest.raises(RuntimeError, match=r"Cannot merge an active Arrow router adapter"):
|
||||
m_small.merge_and_unload()
|
||||
|
||||
def test_conv2d_targets_raise_typeerror_in_arrow(self, workdir):
|
||||
"""
|
||||
Adapters applied to Conv2d must be rejected by create_arrow_model() which enforces Linear/Linear4bit-only
|
||||
targets.
|
||||
"""
|
||||
|
||||
model_id = "hf-internal-testing/tiny-random-ResNetForImageClassification"
|
||||
with hub_online_once(model_id):
|
||||
base = AutoModelForImageClassification.from_pretrained(model_id)
|
||||
|
||||
# Build a LoRA adapter targeting a Conv2d
|
||||
cfg = LoraConfig(r=4, target_modules=["convolution"], init_lora_weights=False)
|
||||
peft_model = get_peft_model(copy.deepcopy(base), cfg)
|
||||
|
||||
conv_dir = workdir / "cv0"
|
||||
peft_model.save_pretrained(conv_dir)
|
||||
|
||||
# Expect create_arrow_model to raise TypeError
|
||||
with pytest.raises(TypeError, match=r"LoRA adapters must only target Linear"):
|
||||
_ = create_arrow_model(
|
||||
base_model=base,
|
||||
task_specific_adapter_paths=[str(conv_dir)],
|
||||
arrow_config=ArrowConfig(top_k=1),
|
||||
)
|
||||
|
||||
def test_arrow_forward_float16_no_autocast_with_merging(self, ts_adapters):
|
||||
"""
|
||||
Run Arrow in float16 with autocast disabled; forward should work, while merge/unmerge operations must raise for
|
||||
Arrow models.
|
||||
"""
|
||||
import platform
|
||||
|
||||
try:
|
||||
_ = torch.zeros(1, dtype=torch.float16)
|
||||
except Exception:
|
||||
pytest.skip(reason="Test requires float16 support")
|
||||
|
||||
if platform.system() == "Darwin":
|
||||
pytest.skip(reason="MacOS does not support multiple ops in float16")
|
||||
|
||||
model_id = "hf-internal-testing/tiny-random-gpt2"
|
||||
|
||||
# Create base in fp16 (no manual assignment to .dtype)
|
||||
with hub_online_once(model_id):
|
||||
base = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16)
|
||||
|
||||
cfg = ArrowConfig(top_k=2)
|
||||
|
||||
# Build Arrow model and disable adapter dtype autocast
|
||||
model = create_arrow_model(
|
||||
base_model=base,
|
||||
task_specific_adapter_paths=ts_adapters,
|
||||
arrow_config=cfg,
|
||||
autocast_adapter_dtype=False,
|
||||
torch_dtype=torch.float16,
|
||||
).eval()
|
||||
|
||||
X = {
|
||||
"input_ids": torch.ones(1, 4, dtype=torch.long),
|
||||
"attention_mask": torch.ones(1, 4, dtype=torch.long),
|
||||
}
|
||||
|
||||
# Forward should work in fp16
|
||||
_ = model(**X)
|
||||
|
||||
# Merge must fail on Arrow models
|
||||
with pytest.raises(RuntimeError, match=r"Cannot merge an active Arrow router adapter"):
|
||||
model.merge_adapter(safe_merge=False)
|
||||
|
||||
with pytest.raises(RuntimeError, match=r"Cannot merge an active Arrow router adapter"):
|
||||
_ = model.merge_and_unload()
|
||||
|
||||
def test_prototypes_not_recomputed_on_repeated_forward(self, ts_adapters):
|
||||
"""
|
||||
Repeated calls to forward should not recompute prototypes. We verify by spying on
|
||||
ArrowLoraLinearLayer.top_right_singular_vec_from_BA(), which is only called when prototypes are (re)built.
|
||||
"""
|
||||
model_id = "hf-internal-testing/tiny-random-gpt2"
|
||||
with hub_online_once(model_id):
|
||||
base = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
|
||||
cfg = ArrowConfig(top_k=2)
|
||||
model = create_arrow_model(
|
||||
base_model=base,
|
||||
task_specific_adapter_paths=ts_adapters,
|
||||
arrow_config=cfg,
|
||||
).eval()
|
||||
|
||||
# Find one Arrow layer instance on the model
|
||||
arrow_layer = None
|
||||
for _, module in model.named_modules():
|
||||
if hasattr(module, "lora_arrow") and "arrow_router" in module.lora_arrow:
|
||||
arrow_layer = module.lora_arrow["arrow_router"]
|
||||
break
|
||||
assert arrow_layer is not None, "Arrow router layer not found on model"
|
||||
|
||||
x = torch.ones(1, 4, dtype=torch.long)
|
||||
|
||||
# Spy on the internal proto computation; should run once (E calls for E experts)
|
||||
with patch.object(
|
||||
arrow_layer,
|
||||
"top_right_singular_vec_from_BA",
|
||||
wraps=arrow_layer.top_right_singular_vec_from_BA,
|
||||
) as spy:
|
||||
_ = model(x)
|
||||
first_calls = spy.call_count
|
||||
assert first_calls == len(arrow_layer.task_adapter_names)
|
||||
|
||||
# Call forward again; prototypes should be cached, so no extra calls
|
||||
_ = model(x)
|
||||
assert spy.call_count == first_calls
|
||||
|
||||
|
||||
def test_training_updates_when_task_adapter_active(ts_adapters):
|
||||
"""
|
||||
Ensure a simple training step works: compute a dummy loss, backward, and take an optimizer step. Verify that
|
||||
task-adapter parameters update.
|
||||
"""
|
||||
model_id = "hf-internal-testing/tiny-random-gpt2"
|
||||
with hub_online_once(model_id):
|
||||
base = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
|
||||
# Build Arrow model over two experts
|
||||
cfg = ArrowConfig(top_k=2)
|
||||
model = create_arrow_model(
|
||||
base_model=base,
|
||||
task_specific_adapter_paths=ts_adapters[:2],
|
||||
arrow_config=cfg,
|
||||
)
|
||||
model.train()
|
||||
|
||||
# Switch to a specific task adapter for training (vanilla LoRA)
|
||||
model.set_adapter("task_0")
|
||||
|
||||
# Choose a representative parameter to check updates (task_0 A weight)
|
||||
rep_name = None
|
||||
for n, _ in model.named_parameters():
|
||||
if ".lora_A.task_0.weight" in n:
|
||||
rep_name = n
|
||||
break
|
||||
assert rep_name is not None, "task_0 LoRA A weight not found"
|
||||
rep_param = dict(model.named_parameters())[rep_name]
|
||||
before = rep_param.detach().clone()
|
||||
|
||||
# Optimizer over trainable params (task_0 now active and trainable)
|
||||
opt = torch.optim.SGD([p for p in model.parameters() if p.requires_grad], lr=1e-2)
|
||||
|
||||
# Dummy batch
|
||||
vocab = model.config.vocab_size
|
||||
input_ids = torch.randint(0, vocab, (2, 8))
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
# Compute loss and update
|
||||
opt.zero_grad()
|
||||
out = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
|
||||
assert hasattr(out, "loss") and out.loss is not None
|
||||
out.loss.backward()
|
||||
opt.step()
|
||||
|
||||
after = rep_param.detach().clone()
|
||||
assert not torch.allclose(before, after), "Active task adapter parameters did not update after optimizer step"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
[
|
||||
"local_root",
|
||||
"local_nested",
|
||||
"hub_repo",
|
||||
"hub_with_sub",
|
||||
],
|
||||
)
|
||||
def test_resolve_adapter_source_variants(tmp_path: Path, case: str):
|
||||
"""
|
||||
Ensure `_resolve_adapter_source` correctly handles:
|
||||
- Local dir (containing adapter_config.json)
|
||||
- Local nested subfolder
|
||||
- Hub repo id "user/repo"
|
||||
- Hub repo with subfolder "user/repo/sub/folder"
|
||||
"""
|
||||
if case == "local_root":
|
||||
d = tmp_path / "adapter_local_root"
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
(d / "adapter_config.json").write_text("{}")
|
||||
model_id, sub = _resolve_adapter_source(str(d))
|
||||
assert model_id == str(d)
|
||||
assert sub is None
|
||||
|
||||
elif case == "local_nested":
|
||||
d = tmp_path / "repo_like" / "sub" / "folder"
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
(d / "adapter_config.json").write_text("{}")
|
||||
model_id, sub = _resolve_adapter_source(str(d))
|
||||
assert model_id == str(d)
|
||||
assert sub is None
|
||||
|
||||
elif case == "hub_repo":
|
||||
model_id, sub = _resolve_adapter_source("user/repo")
|
||||
assert model_id == "user/repo"
|
||||
assert sub is None
|
||||
|
||||
elif case == "hub_with_sub":
|
||||
model_id, sub = _resolve_adapter_source("user/repo/sub/folder")
|
||||
assert model_id == "user/repo"
|
||||
assert sub == "sub/folder"
|
||||
|
||||
else:
|
||||
raise AssertionError(f"unknown case: {case}")
|
@ -20,6 +20,7 @@ import unittest
|
||||
from collections import Counter, defaultdict
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
|
||||
import numpy as np
|
||||
@ -57,6 +58,7 @@ from transformers.pytorch_utils import Conv1D
|
||||
|
||||
from peft import (
|
||||
AdaLoraConfig,
|
||||
ArrowConfig,
|
||||
EvaConfig,
|
||||
LoftQConfig,
|
||||
LoraConfig,
|
||||
@ -67,6 +69,7 @@ from peft import (
|
||||
RoadConfig,
|
||||
TaskType,
|
||||
VeraConfig,
|
||||
create_arrow_model,
|
||||
get_peft_model,
|
||||
get_peft_model_state_dict,
|
||||
initialize_lora_eva_weights,
|
||||
@ -82,6 +85,7 @@ from peft.utils import SAFETENSORS_WEIGHTS_NAME, infer_device
|
||||
from peft.utils.hotswap import hotswap_adapter, prepare_model_for_compiled_hotswap
|
||||
from peft.utils.loftq_utils import NFQuantizer
|
||||
from peft.utils.other import fsdp_auto_wrap_policy
|
||||
from tests.testing_utils import hub_online_once
|
||||
|
||||
from .testing_utils import (
|
||||
device_count,
|
||||
@ -5271,3 +5275,88 @@ class TestHotSwapping:
|
||||
inductor_config_ctx = torch._inductor.config.patch("triton.cudagraph_support_input_mutation", False)
|
||||
with dynamo_config_ctx, inductor_config_ctx:
|
||||
self.check_hotswap_diffusion(ranks=ranks, alpha_scalings=ranks, target_modules=target_modules)
|
||||
|
||||
|
||||
# Test: 4-bit load + Arrow + generate
|
||||
class TestArrowQuantized:
|
||||
@pytest.fixture(scope="class")
|
||||
def workdir(self, tmp_path_factory):
|
||||
"""Create and return a temp directory path for this class (no chdir)."""
|
||||
wd = tmp_path_factory.mktemp("arrow_workdir")
|
||||
return Path(wd)
|
||||
|
||||
def _create_and_save_adapter_opt(self, out_dir: Path, rank: int = 4):
|
||||
"""
|
||||
Build a randomly initialized LoRA adapter for OPT-125M and save into `out_dir`. We construct a model from
|
||||
CONFIG (no pretrained weights) to avoid slow downloads here.
|
||||
"""
|
||||
model_id = "facebook/opt-125m"
|
||||
# Target all linear layers so the adapter matches whatever we later quantize/load.
|
||||
lora_cfg = LoraConfig(
|
||||
r=rank,
|
||||
target_modules="all-linear",
|
||||
task_type="CAUSAL_LM",
|
||||
init_lora_weights=False,
|
||||
)
|
||||
# Load the adapter on the model and save it
|
||||
with hub_online_once(model_id):
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
peft_model = get_peft_model(model, lora_cfg)
|
||||
peft_model.save_pretrained(out_dir)
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def ts_adapters_opt(self, workdir: Path):
|
||||
"""
|
||||
Build 3 locally-saved task-specific adapters for OPT-125M and return their absolute paths.
|
||||
"""
|
||||
paths = []
|
||||
for i in range(3):
|
||||
sub = workdir / f"ts_expert_{i}"
|
||||
self._create_and_save_adapter_opt(sub)
|
||||
paths.append(str(sub))
|
||||
return paths
|
||||
|
||||
@require_bitsandbytes
|
||||
@pytest.mark.single_gpu_tests
|
||||
def test_arrow_4bit_opt125m_load_and_generate_with_local_adapters(self, ts_adapters_opt):
|
||||
# Skip if CUDA or bitsandbytes isn’t available
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA required for 4-bit bitsandbytes test.")
|
||||
|
||||
model_id = "facebook/opt-125m"
|
||||
|
||||
# Quantization config (nf4, bf16 compute)
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_use_double_quant=False,
|
||||
)
|
||||
|
||||
with hub_online_once(model_id):
|
||||
# Load quantized base model
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
quantization_config=bnb_config,
|
||||
)
|
||||
with hub_online_once(model_id + "tokenizer"):
|
||||
tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
||||
|
||||
# Build Arrow model from the locally created adapters
|
||||
arrow_cfg = ArrowConfig(top_k=2, router_temperature=1.0, rng_seed=42)
|
||||
model = create_arrow_model(
|
||||
base_model=base_model,
|
||||
task_specific_adapter_paths=ts_adapters_opt, # local dirs (each has adapter_config.json)
|
||||
arrow_config=arrow_cfg,
|
||||
).eval()
|
||||
|
||||
# Quick generate smoke test
|
||||
inputs = tok("Hello world", return_tensors="pt")
|
||||
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
||||
with torch.no_grad():
|
||||
out = model.generate(**inputs, max_new_tokens=8)
|
||||
|
||||
assert out is not None
|
||||
assert out.shape[0] == 1 # batch size 1
|
||||
|
Reference in New Issue
Block a user