mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 15:33:48 +08:00
Support for Activated LoRA (#2609)
This PR migrates Activated LoRA (aLoRA) support from a standalone Github (see above) to PEFT itself. Note there is also an active PR for vLLM inference support for Activated LoRA: vllm-project/vllm#19710 . There are also collections of aLoRA models on huggingface (in the ibm-granite org), note that these preexisting models run off of the standalone github repo and will be updated to work with this new PEFT feature if merged. Description of changes: Activated LoRA is a modification of the LoRA architecture to "activate" the adapter weights only on tokens coming after a specified invocation_string. This fact makes it so that KV values for the string coming before the activation matches KV values for the base model. This allows KV cache for the input to be interchangeable between the base model and adapter model, and allows for major speedups in inference pipelines (e.g. agentic pipelines) that want to use both base models and adapter models. See the paper for detailed exploration of use cases and further elaboration. Other notes: The crux of the changes are really in layer.py. Everything else is simply managing the alora_offsets quantity which defines where the weights start to be activated. This is determined by scanning input strings for the invocation_string defined in the aLoraConfig. I believe that aLoRA really only makes sense for CausalLMs, hence I've only implemented this for that model type. Merging doesn't make sense for aLoRA adapters since the weights are not universally applied to all tokens. I used the LoRA code as a starting point, but did not implement various seemingly extra features in that code. As of now, invocation_string should probably start and end with special tokens, to avoid tokenizer issues at the boundary. Open to suggestions on how to make this more general if needed. --------- Co-authored-by: githubnemo <githubnemo@users.noreply.github.com>
This commit is contained in:
@ -173,6 +173,108 @@ from peft import LoraConfig
|
||||
|
||||
config = LoraConfig(use_rslora=True, ...)
|
||||
```
|
||||
### Activated LoRA (aLoRA)
|
||||
|
||||
Activated LoRA (aLoRA) is a low rank adapter architecture for Causal LMs that allows for reusing existing base model KV cache for more efficient inference. This approach is best suited for inference pipelines which rely on the base model for most tasks/generations, but use aLoRA adapter(s) to perform specialized task(s) within the chain. For example, checking or correcting generated outputs of the base model. In these settings, inference times can be sped up by an order of magnitude or more. For more information on aLoRA and many example use cases, see https://huggingface.co/papers/2504.12397.
|
||||
|
||||
This technique scans for the last occurence of an invocation sequence (`alora_invocation_tokens`) in each input (this can be as short as 1 token), and activates the adapter weights on tokens starting with the beginning of the invocation sequence (any inputs after the invocation sequence are also adapted, and all generated tokens will use the adapted weights). Weights on prior tokens are left un-adapted -- making the cache for those tokens interchangeable with base model cache due to the causal attention mask in Causal LMs. Usage is very similar to standard LoRA, with the key difference that this invocation sequence must be specified when the adapter is created:
|
||||
|
||||
```py
|
||||
from peft import LoraConfig
|
||||
|
||||
config = LoraConfig(alora_invocation_tokens=alora_invocation_tokens, task_type="CAUSAL_LM", ...)
|
||||
```
|
||||
|
||||
where `alora_invocation_tokens` is a list of integer token ids. Given a desired invocation string, this can be obtained as
|
||||
```
|
||||
invocation_string = "placeholder"
|
||||
alora_invocation_tokens = tokenizer.encode(invocation_string, add_special_tokens=False).
|
||||
```
|
||||
where the tokenizer is the tokenizer for the base model. Note that we have `add_special_tokens=False` to avoid adding SOS/EOS tokens in our search string (which will most likely cause failure to find).
|
||||
|
||||
**Notes**
|
||||
* aLoRA is only supported for `task_type=CAUSAL_LM` tasks due to its focus on cache reuse.
|
||||
* Since the weights are adapted on fewer tokens, often (not always) aLoRA requires higher rank (`r`) than LoRA. `r=32` can be a good starting point.
|
||||
* aLoRA weights cannot be merged into the base model by definition, since the adapter weights are selectively applied to a subset of tokens. Attempts to merge will throw errors.
|
||||
* Beam search is not yet supported.
|
||||
* It is generally not recommended to add new tokens to the tokenizer that are not present in the base model, as this can complicate the target use case of both the base model and adapter model operating on overlapping context. That said, there is a possible workaround by first efficiently adding [trainable tokens](https://huggingface.co/docs/peft/en/package_reference/trainable_tokens) to the base model prior to training the adapter.
|
||||
|
||||
#### Choice of invocation sequence and SFT design
|
||||
|
||||
Each input must have the `alora_invocation_tokens` sequence present, it is not added automatically. To maximize model performance without compromising cache reuse, it is recommended to have the adapter weights activated early, i.e. at the start of any adapter-specific prompting, but after any long inputs such as prior generations or documents. As with any model,
|
||||
formatting should be consistent between train and test.
|
||||
|
||||
Consider the following example, where the base model has a chat template,
|
||||
and the goal is to train the adapter to generate a desired output.
|
||||
|
||||
* Option 1: If there is no task-specific prompt, i.e. the input is a chat history with the `assistant` prompt, then the chat template's `assistant` prompt (e.g. `<|start_of_role|>assistant<|end_of_role|>`) is a natural choice for the invocation string. See the model's chat template to find the prompt for the model.
|
||||
* Option 2: If there is a task-specific prompt for the adapter that describes the task the adapter is learning, and that prompt is put as a `user` turn immediately prior to the generation, then the chat template's `user` prompt (e.g. `<|start_of_role|>user<|end_of_role|>`) is a natural choice for the invocation string.
|
||||
|
||||
Once deciding on an invocation string, get the model tokenizer and obtain `alora_invocation_tokens` as
|
||||
```
|
||||
alora_invocation_tokens = tokenizer.encode(invocation_string, add_special_tokens=False).
|
||||
```
|
||||
|
||||
An example inference setup is at [alora finetuning](https://github.com/huggingface/peft/blob/main/examples/alora_finetuning/alora_finetuning.py).
|
||||
|
||||
**Note** If using custom strings for the invocation string, make sure that the start and end of the string are special tokens to avoid issues with tokenization at the boundaries.
|
||||
|
||||
To see why, imagine that 'a', 'b', 'c', and 'ab' are tokens in your tokenizer (numbers 1, 2, 3, 4 respectively). Suppose that your alora_invocation_tokens = [2, 3]. Now imagine your input string is "abc". Because "ab" is a token, this will get tokenized as [4,3]. So the alora_invocation_tokens will fail to be found, despite the string "bc" being in it. If the start and end of the invocation string are special tokens, however, this failure case will never happen since special tokens are never tokenized into the same token with other characters.
|
||||
|
||||
#### Using (and reusing) cache for generation
|
||||
The main purpose of Activated LoRA is to make KV cache interchangeable between the base model and aLoRA adapter models **prior to the invocation sequence** since base and adapted KV values are not compatible. Specifically, keys and values stored during one model generation can be used in subsequent generations to avoid expensive prefill operations for context tokens. When sharing cache between the base model and aLoRA adapters, there are 2 main patterns:
|
||||
1. The base model has generated something, and an aLoRA adapter is then called to do a followup generation. Example: the base model answers a question, and an aLoRA trained to detect hallucinations checks the base model response.
|
||||
2. An aLoRA adapter has generated something, and the base model or a different aLoRA adapter is called to do a followup generation where there is partial context overlap with the original aLoRA. Example: The user provides a query, and an aLoRA rewrites the query to be more self-contained and improve retrieval in a RAG system. Then, documents are retrieved and loaded into context, an aLoRA checks if these documents are indeed relevant to the question, and then the base model generates an answer.
|
||||
|
||||
|
||||
To demonstrate the above behaviors when using caching, we're using [DynamicCache](https://huggingface.co/docs/transformers/en/kv_cache) from `transformers`. Care must be taken to ensure that adapted cache values are not mixed with base cache values. In particular, an extra step is required for sharing the cache when there is partial context overlap (pattern 2).
|
||||
|
||||
**Pattern 1: Base model followed by aLoRA** Here, the entire input and generation from the base model is input into the aLoRA adapter, along with the invocation sequence:
|
||||
```
|
||||
from transformers import DynamicCache
|
||||
...
|
||||
cache = DynamicCache()
|
||||
inputs_base = tokenizer(prompt_base, return_tensors="pt")
|
||||
# Generate from base model and save cache
|
||||
with model_alora.disable_adapter():
|
||||
output = model_alora.generate(inputs_base["input_ids"].to(device),attention_mask=inputs_base["attention_mask"].to(device),past_key_values = cache,return_dict_in_generate=True)
|
||||
output_text_base = tokenizer.decode(output.sequences[0])
|
||||
cache = output.past_key_values
|
||||
|
||||
# Generate with aLoRA adapter from cache
|
||||
prompt_alora = output_text + INVOCATION_STRING
|
||||
inputs_alora = tokenizer(prompt_alora, return_tensors="pt").to(device)
|
||||
output = model_alora.generate(**inputs_alora, past_key_values=cache)
|
||||
output_text_alora = tokenizer.decode(output[0])
|
||||
|
||||
# Note: cache is now tainted with adapter values and cannot be used in base model from here on!
|
||||
**Pattern 2: aLoRA generation followed by base model (or another aLoRA) with partial context overlap** Here, we prefill the shared context using the base model, and then generate.
|
||||
```
|
||||
from transformers import DynamicCache
|
||||
import copy
|
||||
...
|
||||
cache = DynamicCache()
|
||||
inputs_shared = tokenizer(prompt_shared, return_tensors="pt").to(device)
|
||||
|
||||
# Prefill from base model and save cache
|
||||
with model_alora.disable_adapter():
|
||||
with torch.no_grad():
|
||||
model_alora(**inputs_shared, past_key_values=cache)
|
||||
cache_copy = copy.deepcopy(cache)
|
||||
|
||||
# Generate from aLoRA using prefilled cache
|
||||
prompt_alora = prompt_shared + INVOCATION_STRING
|
||||
inputs_alora = tokenizer(prompt_alora, return_tensors="pt").to(device)
|
||||
output = model_alora.generate(**inputs_alora, past_key_values=cache)
|
||||
output_text_alora = tokenizer.decode(output[0])
|
||||
|
||||
# Generate from base model using saved cache not tainted by aLoRA KV values
|
||||
prompt_base = prompt_shared
|
||||
inputs_base = tokenizer(prompt_base, return_tensors="pt").to(device)
|
||||
with model_alora.disable_adapter():
|
||||
output = model_alora.generate(**inputs_base, past_key_values=cache_copy)
|
||||
output_text_base = tokenizer.decode(output[0])
|
||||
```
|
||||
|
||||
### Weight-Decomposed Low-Rank Adaptation (DoRA)
|
||||
|
||||
|
76
examples/alora_finetuning/README.md
Normal file
76
examples/alora_finetuning/README.md
Normal file
@ -0,0 +1,76 @@
|
||||
# Activated LoRA (aLoRA)
|
||||
|
||||
## Introduction
|
||||
Activated LoRA (aLoRA) is an adapter that selectively activates its weights only after a given invocation sequence, ensuring that hidden states match the base model prior to this point. This allows reusing the base model KVs (stored in the KV cache) for tokens before the invocation,
|
||||
enabling much faster real-world inference (e.g. vLLM) when switching between generation with the base model and generation with adapters.
|
||||
See the [paper](https://huggingface.co/papers/2504.12397) for more details.
|
||||
|
||||
## Quick start (shown for Mistral 7B)
|
||||
```python
|
||||
import torch
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, DataCollatorForLanguageModeling
|
||||
from datasets import load_dataset
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", device_map="cuda")
|
||||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
|
||||
dataset = load_dataset("Lots-of-LoRAs/task1660_super_glue_question_generation", split="train")
|
||||
|
||||
invocation_string = "[/INST]" # End of user turn in Mistral chat template
|
||||
invocation_tokens = tokenizer.encode(invocation_string, add_special_tokens=False)
|
||||
|
||||
lora_config = LoraConfig(
|
||||
task_type="CAUSAL_LM",
|
||||
alora_invocation_tokens=invocation_tokens,
|
||||
r=32,
|
||||
target_modules=["q_proj", "k_proj", "v_proj"],
|
||||
)
|
||||
|
||||
peft_model = get_peft_model(model, lora_config)
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
|
||||
trainer = Trainer(
|
||||
model=peft_model,
|
||||
train_dataset=dataset,
|
||||
dataset_text_field="text",
|
||||
max_seq_length=2048,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
)
|
||||
trainer.train()
|
||||
peft_model.save_pretrained("alora-mistral-7b")
|
||||
```
|
||||
|
||||
### Use the training example script directly
|
||||
Pass the invocation string with `--invocation_string` when running the training example
|
||||
script. For Mistral 7B, do:
|
||||
```bash
|
||||
python examples/alora_finetuning/alora_finetuning.py --base_model mistralai/Mistral-7B-Instruct-v0.3 --data_path Lots-of-LoRAs/task1660_super_glue_question_generation --invocation_string "[/INST]"
|
||||
```
|
||||
and similarly for Llama-3.2-3B-Instruct:
|
||||
```bash
|
||||
python examples/alora_finetuning/alora_finetuning.py --base_model meta-llama/Llama-3.2-3B-Instruct --data_path Lots-of-LoRAs/task1660_super_glue_question_generation --invocation_string "<|start_header_id|>assistant<|end_header_id|>"
|
||||
```
|
||||
|
||||
### Full example of the script
|
||||
```bash
|
||||
python alora_finetuning.py \
|
||||
--base_model "PATH_TO_MODEL" \
|
||||
--data_path "PATH_TO_DATASET" \
|
||||
--output_dir "PATH_TO_OUTPUT_DIR" \
|
||||
--batch_size 1 \
|
||||
--num_epochs 3 \
|
||||
--learning_rate 3e-4 \
|
||||
--cutoff_len 512 \
|
||||
--val_set_size 500 \
|
||||
--invocation_string "[/INST]" \
|
||||
--quantize \
|
||||
--eval_step 10 \
|
||||
--save_step 100 \
|
||||
--device "cuda:0" \
|
||||
--lora_r 32 \
|
||||
--lora_alpha 32 \
|
||||
--lora_dropout 0.05 \
|
||||
--lora_target_modules "q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj" \
|
||||
--hub_model_id "YOUR_HF_REPO" \
|
||||
--push_to_hub
|
||||
```
|
251
examples/alora_finetuning/alora_finetuning.py
Normal file
251
examples/alora_finetuning/alora_finetuning.py
Normal file
@ -0,0 +1,251 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
DataCollatorForLanguageModeling,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
|
||||
|
||||
|
||||
def train_model(
|
||||
base_model: str,
|
||||
data_path: str,
|
||||
output_dir: str,
|
||||
batch_size: int,
|
||||
num_epochs: int,
|
||||
learning_rate: float,
|
||||
cutoff_len: int,
|
||||
val_set_size: int,
|
||||
invocation_string: str,
|
||||
quantize: bool,
|
||||
eval_step: int,
|
||||
save_step: int,
|
||||
device: str,
|
||||
lora_r: int,
|
||||
lora_alpha: int,
|
||||
lora_dropout: float,
|
||||
lora_target_modules: str,
|
||||
hub_model_id: str,
|
||||
push_to_hub: bool,
|
||||
):
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
hf_token = os.getenv("HF_TOKEN")
|
||||
|
||||
device = torch.device(device)
|
||||
print(f"Using device: {device}")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model, token=hf_token)
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
invocation_tokens = tokenizer.encode(invocation_string, add_special_tokens=False)
|
||||
|
||||
if quantize:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
token=hf_token,
|
||||
quantization_config=BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=(
|
||||
torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
|
||||
),
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
),
|
||||
)
|
||||
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(base_model, token=hf_token)
|
||||
|
||||
lora_config = LoraConfig(
|
||||
task_type="CAUSAL_LM",
|
||||
alora_invocation_tokens=invocation_tokens,
|
||||
r=lora_r,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=(lora_target_modules.split(",") if lora_target_modules else ["q_proj", "k_proj", "v_proj"]),
|
||||
lora_dropout=lora_dropout,
|
||||
bias="none",
|
||||
)
|
||||
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
model.to(device)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
dataset = load_dataset(data_path)
|
||||
|
||||
def tokenize_function(examples):
|
||||
formatted_texts = [
|
||||
tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": user_msg},
|
||||
{"role": "assistant", "content": assistant_msg},
|
||||
],
|
||||
tokenize=False, # get plain text first
|
||||
add_generation_prompt=False,
|
||||
)
|
||||
for user_msg, assistant_msg in zip(examples["input"], examples["output"])
|
||||
]
|
||||
|
||||
# 2) Tokenize those texts
|
||||
model_inputs = tokenizer(
|
||||
formatted_texts,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=cutoff_len,
|
||||
)
|
||||
|
||||
labels = []
|
||||
for ids in model_inputs["input_ids"]:
|
||||
labels.append([(token_id if token_id != tokenizer.pad_token_id else -100) for token_id in ids])
|
||||
model_inputs["labels"] = labels
|
||||
|
||||
return model_inputs
|
||||
|
||||
# Tokenize the dataset and prepare for training
|
||||
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)
|
||||
|
||||
# Data collator to dynamically pad the batched examples
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir=output_dir,
|
||||
num_train_epochs=num_epochs,
|
||||
per_device_train_batch_size=batch_size,
|
||||
per_device_eval_batch_size=batch_size,
|
||||
warmup_steps=100,
|
||||
weight_decay=0.01,
|
||||
logging_dir="./logs",
|
||||
logging_steps=eval_step,
|
||||
save_steps=save_step,
|
||||
save_total_limit=2,
|
||||
push_to_hub=push_to_hub,
|
||||
hub_model_id=hub_model_id,
|
||||
gradient_accumulation_steps=16,
|
||||
fp16=True,
|
||||
learning_rate=learning_rate,
|
||||
hub_token=hf_token,
|
||||
)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=tokenized_datasets["train"],
|
||||
eval_dataset=tokenized_datasets["test"],
|
||||
data_collator=data_collator,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
if push_to_hub:
|
||||
trainer.push_to_hub(commit_message="Fine-tuned model")
|
||||
|
||||
model.save_pretrained(output_dir)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
|
||||
def model_inference(model_path: str, adapter_path: str, prompt: str = None, data_path: str = None):
|
||||
"""
|
||||
Simple inference with the tuned aLoRA adapter. Optionally (reuse_cache = True) demonstrates
|
||||
that the aLoRA adapter can (but does not need to) use KV cache created by the base model,
|
||||
perhaps during a prior generation turn.
|
||||
|
||||
Purely for demonstration purposes. See the [paper](https://huggingface.co/papers/2504.12397)
|
||||
for realistic multiturn cache reuse examples.
|
||||
"""
|
||||
if prompt is None:
|
||||
# Use first row of test data
|
||||
dataset = load_dataset(data_path)
|
||||
prompt = dataset["test"][0]["input"]
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
base_model = AutoModelForCausalLM.from_pretrained(model_path)
|
||||
alora_model = PeftModel.from_pretrained(base_model, adapter_path)
|
||||
chat = [{"role": "user", "content": prompt}]
|
||||
text = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
|
||||
inputs = tokenizer(text, return_tensors="pt").to(base_model.device)
|
||||
|
||||
# Generate answer with adapter
|
||||
output_dict = alora_model.generate(**inputs, return_dict_in_generate=True, max_new_tokens=20)
|
||||
alora_outputs = output_dict.sequences
|
||||
|
||||
# Print results
|
||||
print(f"Prompt: {text}")
|
||||
response = tokenizer.decode(alora_outputs[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True)
|
||||
print(f"Trained adapter response: {response}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Fine-tune Mistral with Activated LoRA")
|
||||
parser.add_argument(
|
||||
"--base_model", type=str, default="mistralai/Mistral-7B-Instruct-v0.3", help="Base model path or name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_path",
|
||||
type=str,
|
||||
default="Lots-of-LoRAs/task1660_super_glue_question_generation",
|
||||
help="Dataset path or name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir", type=str, default="path/to/output", help="Output directory for the fine-tuned model"
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=2, help="Batch size")
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate")
|
||||
parser.add_argument("--cutoff_len", type=int, default=2048, help="Cutoff length for tokenization")
|
||||
parser.add_argument("--val_set_size", type=int, default=500, help="Validation set size")
|
||||
parser.add_argument(
|
||||
"--invocation_string",
|
||||
type=str,
|
||||
default="[/INST]",
|
||||
help="String that activates the aLoRA adapter. Model dependent.",
|
||||
)
|
||||
parser.add_argument("--quantize", action="store_true", help="Use quantization")
|
||||
parser.add_argument("--eval_step", type=int, default=10, help="Evaluation step interval")
|
||||
parser.add_argument("--save_step", type=int, default=100, help="Save step interval")
|
||||
parser.add_argument("--device", type=str, default="cuda:0", help="Device to use for training")
|
||||
parser.add_argument("--lora_r", type=int, default=32, help="LoRA rank")
|
||||
parser.add_argument("--lora_alpha", type=int, default=32, help="LoRA alpha")
|
||||
parser.add_argument("--lora_dropout", type=float, default=0.05, help="LoRA dropout rate")
|
||||
parser.add_argument(
|
||||
"--lora_target_modules", type=str, default=None, help="Comma-separated list of target modules for LoRA"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hub_model_id",
|
||||
type=str,
|
||||
default="path/to/repo",
|
||||
help="Repository name to push the model on the Hugging Face Hub",
|
||||
)
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether to push the model to Hugging Face Hub")
|
||||
args = parser.parse_args()
|
||||
train_model(
|
||||
base_model=args.base_model,
|
||||
data_path=args.data_path,
|
||||
output_dir=args.output_dir,
|
||||
batch_size=args.batch_size,
|
||||
num_epochs=args.num_epochs,
|
||||
learning_rate=args.learning_rate,
|
||||
cutoff_len=args.cutoff_len,
|
||||
val_set_size=args.val_set_size,
|
||||
invocation_string=args.invocation_string,
|
||||
quantize=args.quantize,
|
||||
eval_step=args.eval_step,
|
||||
save_step=args.save_step,
|
||||
device=args.device,
|
||||
lora_r=args.lora_r,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout,
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
hub_model_id=args.hub_model_id,
|
||||
push_to_hub=args.push_to_hub,
|
||||
)
|
||||
print("Model trained. Running test inference.")
|
||||
model_inference(model_path=args.base_model, adapter_path=args.output_dir, data_path=args.data_path)
|
@ -38,6 +38,7 @@ from transformers import Cache, DynamicCache, EncoderDecoderCache, PreTrainedMod
|
||||
from transformers.modeling_outputs import QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput
|
||||
from transformers.utils import PushToHubMixin
|
||||
|
||||
from peft.tuners.lora.variants import get_alora_offsets_for_forward, get_alora_offsets_for_generate
|
||||
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer
|
||||
from peft.utils.constants import DUMMY_MODEL_CONFIG
|
||||
from peft.utils.integrations import init_empty_weights
|
||||
@ -115,7 +116,7 @@ class PeftModel(PushToHubMixin, torch.nn.Module):
|
||||
self.peft_type = peft_config.peft_type
|
||||
# These args are special PEFT arguments that users can pass. They need to be removed before passing them to
|
||||
# forward.
|
||||
self.special_peft_forward_args = {"adapter_names"}
|
||||
self.special_peft_forward_args = {"adapter_names", "alora_offsets"}
|
||||
|
||||
self._is_prompt_learning = peft_config.is_prompt_learning
|
||||
if self._is_prompt_learning:
|
||||
@ -1838,7 +1839,10 @@ class PeftModelForCausalLM(PeftModel):
|
||||
**kwargs,
|
||||
):
|
||||
peft_config = self.active_peft_config
|
||||
|
||||
if not peft_config.is_prompt_learning:
|
||||
# Adds alora_offsets to kwargs if relevant. No other modifications.
|
||||
kwargs = get_alora_offsets_for_forward(self, input_ids, inputs_embeds, **kwargs)
|
||||
if self.base_model.config.model_type == "mpt":
|
||||
if inputs_embeds is not None:
|
||||
raise AssertionError("forward in MPTForCausalLM does not support inputs_embeds")
|
||||
@ -1978,6 +1982,8 @@ class PeftModelForCausalLM(PeftModel):
|
||||
self.base_model.generation_config = self.generation_config
|
||||
try:
|
||||
if not peft_config.is_prompt_learning:
|
||||
# Adds alora_offsets to kwargs if relevant. No other changes.
|
||||
kwargs = get_alora_offsets_for_generate(self, *args, **kwargs)
|
||||
with self._enable_peft_forward_hooks(*args, **kwargs):
|
||||
kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
|
||||
outputs = self.base_model.generate(*args, **kwargs)
|
||||
|
@ -27,6 +27,8 @@ from peft.utils.other import transpose
|
||||
from .layer import LoraLayer, LoraVariant
|
||||
|
||||
|
||||
VARIANT_KWARG_KEYS = ["alora_offsets"]
|
||||
|
||||
if is_bnb_available():
|
||||
|
||||
class Linear8bitLt(torch.nn.Module, LoraLayer):
|
||||
@ -40,6 +42,7 @@ if is_bnb_available():
|
||||
lora_dropout: float = 0.0,
|
||||
init_lora_weights: bool = True,
|
||||
use_rslora: bool = False,
|
||||
use_alora: bool = False,
|
||||
use_dora: bool = False,
|
||||
lora_bias: bool = False,
|
||||
**kwargs,
|
||||
@ -57,16 +60,20 @@ if is_bnb_available():
|
||||
init_lora_weights=init_lora_weights,
|
||||
use_rslora=use_rslora,
|
||||
use_dora=use_dora,
|
||||
use_alora=use_alora,
|
||||
lora_bias=lora_bias,
|
||||
)
|
||||
|
||||
def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
|
||||
if not use_dora:
|
||||
def resolve_lora_variant(self, *, use_dora: bool, use_alora: bool, **kwargs) -> Optional[LoraVariant]:
|
||||
if not use_dora and not use_alora:
|
||||
return None
|
||||
|
||||
from .variants import DoraLinearVariant
|
||||
from .variants import ALoraLinearVariant, DoraLinearVariant
|
||||
|
||||
return DoraLinearVariant()
|
||||
if use_alora:
|
||||
return ALoraLinearVariant()
|
||||
else:
|
||||
return DoraLinearVariant()
|
||||
|
||||
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
|
||||
"""
|
||||
@ -178,6 +185,7 @@ if is_bnb_available():
|
||||
) -> torch.Tensor:
|
||||
# This is a special method that handles the case when users pass the argument `adapter_names`. This is an
|
||||
# extra argument that allows mixing different adapters in the same batch at inference time.
|
||||
variant_kwargs = {k: kwargs.pop(k, None) for k in VARIANT_KWARG_KEYS} # don't pass these to base_layer
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
|
||||
unique_adapters = set(adapter_names)
|
||||
@ -204,23 +212,40 @@ if is_bnb_available():
|
||||
# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
|
||||
# layer output
|
||||
sub_batch = x[sub_batch_indices_list[i]]
|
||||
output = lora_B(lora_A(dropout(sub_batch))) * scaling
|
||||
if requires_conversion:
|
||||
output = output.to(expected_dtype)
|
||||
result[sub_batch_indices_list[i]] += output
|
||||
if active_adapter not in self.lora_variant: # vanilla LoRA:
|
||||
output = lora_B(lora_A(dropout(sub_batch))) * scaling
|
||||
if requires_conversion:
|
||||
output = output.to(expected_dtype)
|
||||
result[sub_batch_indices_list[i]] += output
|
||||
else:
|
||||
alora_offsets = variant_kwargs.get("alora_offsets", None)
|
||||
if alora_offsets is not None:
|
||||
variant_kwargs["alora_offsets"] = [alora_offsets[j] for j in sub_batch_indices_list[i]]
|
||||
output = self.lora_variant[active_adapter].forward(
|
||||
self,
|
||||
active_adapter=active_adapter,
|
||||
x=sub_batch,
|
||||
result=result[sub_batch_indices_list[i]],
|
||||
**variant_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
if requires_conversion:
|
||||
output = output.to(expected_dtype)
|
||||
result[sub_batch_indices_list[i]] = output
|
||||
|
||||
return result
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
self._check_forward_args(x, *args, **kwargs)
|
||||
adapter_names = kwargs.pop("adapter_names", None)
|
||||
variant_kwargs = {k: kwargs.pop(k, None) for k in VARIANT_KWARG_KEYS} # don't pass these to base_layer
|
||||
|
||||
if self.disable_adapters:
|
||||
if self.merged:
|
||||
self.unmerge()
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
elif adapter_names is not None:
|
||||
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
|
||||
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **variant_kwargs, **kwargs)
|
||||
elif self.merged:
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
else:
|
||||
@ -249,6 +274,8 @@ if is_bnb_available():
|
||||
active_adapter=active_adapter,
|
||||
x=x,
|
||||
result=result,
|
||||
**variant_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
if requires_conversion:
|
||||
result = result.to(expected_dtype)
|
||||
@ -315,13 +342,16 @@ if is_bnb_4bit_available():
|
||||
lora_bias=lora_bias,
|
||||
)
|
||||
|
||||
def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
|
||||
if not use_dora:
|
||||
def resolve_lora_variant(self, *, use_dora: bool, use_alora: bool, **kwargs) -> Optional[LoraVariant]:
|
||||
if not use_dora and not use_alora:
|
||||
return None
|
||||
|
||||
from .variants import DoraLinearVariant
|
||||
from .variants import ALoraLinearVariant, DoraLinearVariant
|
||||
|
||||
return DoraLinearVariant()
|
||||
if use_alora:
|
||||
return ALoraLinearVariant()
|
||||
else:
|
||||
return DoraLinearVariant()
|
||||
|
||||
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
|
||||
"""
|
||||
@ -431,6 +461,7 @@ if is_bnb_4bit_available():
|
||||
) -> torch.Tensor:
|
||||
# This is a special method that handles the case when users pass the argument `adapter_names`. This is an
|
||||
# extra argument that allows mixing different adapters in the same batch at inference time.
|
||||
variant_kwargs = {k: kwargs.pop(k, None) for k in VARIANT_KWARG_KEYS} # don't pass these to base_layer
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
|
||||
unique_adapters = set(adapter_names)
|
||||
@ -457,23 +488,40 @@ if is_bnb_4bit_available():
|
||||
# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
|
||||
# layer output
|
||||
sub_batch = x[sub_batch_indices_list[i]]
|
||||
output = lora_B(lora_A(dropout(sub_batch))) * scaling
|
||||
if requires_conversion:
|
||||
output = output.to(expected_dtype)
|
||||
result[sub_batch_indices_list[i]] += output
|
||||
if active_adapter not in self.lora_variant: # vanilla LoRA
|
||||
output = lora_B(lora_A(dropout(sub_batch))) * scaling
|
||||
if requires_conversion:
|
||||
output = output.to(expected_dtype)
|
||||
result[sub_batch_indices_list[i]] += output
|
||||
else:
|
||||
alora_offsets = variant_kwargs.get("alora_offsets", None)
|
||||
if alora_offsets is not None:
|
||||
variant_kwargs["alora_offsets"] = [alora_offsets[j] for j in sub_batch_indices_list[i]]
|
||||
output = self.lora_variant[active_adapter].forward(
|
||||
self,
|
||||
active_adapter=active_adapter,
|
||||
x=sub_batch,
|
||||
result=result[sub_batch_indices_list[i]],
|
||||
**variant_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
if requires_conversion:
|
||||
output = output.to(expected_dtype)
|
||||
result[sub_batch_indices_list[i]] = output
|
||||
|
||||
return result
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
self._check_forward_args(x, *args, **kwargs)
|
||||
adapter_names = kwargs.pop("adapter_names", None)
|
||||
variant_kwargs = {k: kwargs.pop(k, None) for k in VARIANT_KWARG_KEYS} # don't pass these to base_layer
|
||||
|
||||
if self.disable_adapters:
|
||||
if self.merged:
|
||||
self.unmerge()
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
elif adapter_names is not None:
|
||||
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
|
||||
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **variant_kwargs, **kwargs)
|
||||
elif self.merged:
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
else:
|
||||
@ -509,6 +557,8 @@ if is_bnb_4bit_available():
|
||||
active_adapter=active_adapter,
|
||||
x=x,
|
||||
result=result,
|
||||
**variant_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
if requires_conversion:
|
||||
result = result.to(expected_dtype)
|
||||
|
@ -301,6 +301,17 @@ class LoraConfig(PeftConfig):
|
||||
ranks. Right now, DoRA only supports linear and Conv2D layers. DoRA introduces a bigger overhead than pure
|
||||
LoRA, so it is recommended to merge weights for inference. For more information, see
|
||||
https://huggingface.co/papers/2402.09353.
|
||||
alora_invocation_tokens (`List[int]`):
|
||||
If not None, enable <a href='https://huggingface.co/papers/2504.12397'>'Activated LoRA' (aLoRA)</a>, with
|
||||
alora_invocation_tokens being the tokenized invocation string for the adapter (must be present in all model
|
||||
input strings). This technique selectively activates the adapter weights only on tokens during and after
|
||||
the alora_invocation_tokens. When used in a CausalLM, this means that the KV cache prior to invocation is
|
||||
interchangeable with that of the base model (and other aLoRA adapters operating this way). As a result, in
|
||||
inference pipelines involving switching between base model inference and adapter inference (e.g. agentic
|
||||
pipelines, see paper for examples), significant savings are realized (relative to LoRA) by saving prefill
|
||||
operations. Overall adapter inference speedups of an order of magnitude or more can occur on vLLM,
|
||||
depending on the length of the shared context. Note that merging is not possible due to the selective
|
||||
application of the weights.
|
||||
layer_replication (`List[Tuple[int, int]]`):
|
||||
Build a new stack of layers by stacking the original model layers according to the ranges specified. This
|
||||
allows expanding (or shrinking) the model without duplicating the base model weights. The new layers will
|
||||
@ -510,6 +521,23 @@ class LoraConfig(PeftConfig):
|
||||
)
|
||||
},
|
||||
)
|
||||
alora_invocation_tokens: Optional[list[int]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"If not None, enable <a href='https://huggingface.co/papers/2504.12397'>'Activated LoRA' (aLoRA)</a>, with "
|
||||
"alora_invocation_tokens being the tokenized invocation string for the adapter (must be present in all model "
|
||||
"input strings). This technique selectively activates the adapter weights only on tokens during and after "
|
||||
"the alora_invocation_tokens. When used in a CausalLM, this means that the KV cache prior to invocation is "
|
||||
"interchangeable with that of the base model (and other aLoRA adapters operating this way). As a result, in "
|
||||
"inference pipelines involving switching between base model inference and adapter inference (e.g. agentic "
|
||||
"pipelines, see paper for examples), significant savings are realized (relative to LoRA) by saving prefill "
|
||||
"operations. Overall adapter inference speedups of an order of magnitude or more can occur on vLLM, "
|
||||
"depending on the length of the shared context. Note that merging is not possible due to the selective "
|
||||
"application of the weights."
|
||||
)
|
||||
},
|
||||
)
|
||||
use_qalora: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
@ -657,6 +685,9 @@ class LoraConfig(PeftConfig):
|
||||
if self.use_dora:
|
||||
raise ValueError("The argument lora_bias=True is not supported for DoRA, please pass use_dora=False")
|
||||
|
||||
if self.alora_invocation_tokens is not None and self.task_type != "CAUSAL_LM":
|
||||
warnings.warn("aLoRA is currently only supported for CAUSAL_LM task.")
|
||||
|
||||
# Using post training conversion of modified base weights to restore their initial values PiSSA/CorDA/OLoRA cannot
|
||||
# be correctly done when using rslora + rank_pattern/alpha_pattern. We can't really know if the user intends
|
||||
# this when they'll eventually call save_pretrained (i.e. if they'll pass
|
||||
|
@ -39,6 +39,9 @@ from peft.utils.warning import PeftWarning
|
||||
from .config import LoraConfig
|
||||
|
||||
|
||||
VARIANT_KWARG_KEYS = ["alora_offsets"]
|
||||
|
||||
|
||||
class LoraVariant:
|
||||
"""
|
||||
Base class for LoRA variants, e.g. DoRA.
|
||||
@ -69,7 +72,13 @@ class LoraVariant:
|
||||
"""Remove the adapter weights from the original weights, then return them"""
|
||||
|
||||
@staticmethod
|
||||
def forward(module: LoraLayer, active_adapter: str, x: torch.Tensor, result: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
module: LoraLayer,
|
||||
active_adapter: str,
|
||||
x: torch.Tensor,
|
||||
result: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
The forward pass of the LoRA variant, should return the overall result (not just the diff)
|
||||
|
||||
@ -78,6 +87,7 @@ class LoraVariant:
|
||||
active_adapter (str): The name of the active adapter
|
||||
x (torch.Tensor): The input to the forward call
|
||||
result (torch.Tensor): The result from the base model
|
||||
**kwargs: Additional arguments passed from [`LoraLayer.forward`].
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -174,7 +184,7 @@ class LoraLayer(BaseTunerLayer):
|
||||
"""Return a matching LoRA variant for this layer type.
|
||||
|
||||
Given the init arguments of this layer, return the correct LoRA variant, if any. E.g., if `use_dora=True`, this
|
||||
method should return the DoRA variant for the given layer.
|
||||
method should return the DoRA variant for the given layer. If `use_alora=True`, same for aLoRA.
|
||||
|
||||
If there is no fitting variant, return None.
|
||||
|
||||
@ -193,6 +203,7 @@ class LoraLayer(BaseTunerLayer):
|
||||
init_lora_weights,
|
||||
use_rslora,
|
||||
use_dora: bool = False,
|
||||
use_alora: bool = False,
|
||||
use_qalora: bool = False,
|
||||
lora_bias: bool = False,
|
||||
qalora_group_size: int = 32,
|
||||
@ -214,7 +225,7 @@ class LoraLayer(BaseTunerLayer):
|
||||
)
|
||||
|
||||
lora_variant = self.resolve_lora_variant(
|
||||
use_dora=use_dora, use_qalora=use_qalora, qalora_group_size=qalora_group_size
|
||||
use_dora=use_dora, use_alora=use_alora, use_qalora=use_qalora, qalora_group_size=qalora_group_size
|
||||
)
|
||||
if lora_variant is not None:
|
||||
self.lora_variant[adapter_name] = lora_variant
|
||||
@ -561,6 +572,7 @@ class LoraLayer(BaseTunerLayer):
|
||||
) -> torch.Tensor:
|
||||
# This is a special method that handles the case when users pass the argument `adapter_names`. This is an
|
||||
# extra argument that allows mixing different adapters in the same batch at inference time.
|
||||
variant_kwargs = {k: kwargs.pop(k, None) for k in VARIANT_KWARG_KEYS} # don't pass these to base_layer
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
torch_result_dtype = result.dtype
|
||||
|
||||
@ -568,7 +580,7 @@ class LoraLayer(BaseTunerLayer):
|
||||
sub_batch_indices_list = []
|
||||
for adapter in unique_adapters:
|
||||
sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter])
|
||||
|
||||
alora_offsets = variant_kwargs.get("alora_offsets", None)
|
||||
for i, active_adapter in enumerate(unique_adapters):
|
||||
if active_adapter == "__base__":
|
||||
continue
|
||||
@ -583,8 +595,21 @@ class LoraLayer(BaseTunerLayer):
|
||||
# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
|
||||
# layer output
|
||||
sub_batch = x[sub_batch_indices_list[i]].to(lora_A.weight.dtype)
|
||||
lora_output = lora_B(lora_A(dropout(sub_batch))) * scaling
|
||||
result[sub_batch_indices_list[i]] += lora_output.to(torch_result_dtype)
|
||||
if active_adapter not in self.lora_variant: # vanilla LoRA
|
||||
lora_output = lora_B(lora_A(dropout(sub_batch))) * scaling
|
||||
result[sub_batch_indices_list[i]] += lora_output.to(torch_result_dtype)
|
||||
else:
|
||||
if alora_offsets is not None:
|
||||
variant_kwargs["alora_offsets"] = [alora_offsets[j] for j in sub_batch_indices_list[i]]
|
||||
lora_output = self.lora_variant[active_adapter].forward(
|
||||
self,
|
||||
active_adapter=active_adapter,
|
||||
x=sub_batch,
|
||||
result=result[sub_batch_indices_list[i]],
|
||||
**variant_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
result[sub_batch_indices_list[i]] = lora_output.to(torch_result_dtype)
|
||||
|
||||
return result
|
||||
|
||||
@ -613,6 +638,7 @@ class Linear(nn.Module, LoraLayer):
|
||||
init_lora_weights: Union[bool, str] = True,
|
||||
use_rslora: bool = False,
|
||||
use_dora: bool = False,
|
||||
use_alora: bool = False,
|
||||
lora_bias: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
@ -629,17 +655,21 @@ class Linear(nn.Module, LoraLayer):
|
||||
init_lora_weights=init_lora_weights,
|
||||
use_rslora=use_rslora,
|
||||
use_dora=use_dora,
|
||||
use_alora=use_alora,
|
||||
lora_bias=lora_bias,
|
||||
)
|
||||
self.is_target_conv_1d_layer = is_target_conv_1d_layer
|
||||
|
||||
def resolve_lora_variant(self, *, use_dora: bool, **kwargs) -> Optional[LoraVariant]:
|
||||
if not use_dora:
|
||||
def resolve_lora_variant(self, *, use_dora: bool, use_alora: bool, **kwargs) -> Optional[LoraVariant]:
|
||||
if not use_dora and not use_alora:
|
||||
return None
|
||||
|
||||
from .variants import DoraLinearVariant
|
||||
from .variants import ALoraLinearVariant, DoraLinearVariant
|
||||
|
||||
return DoraLinearVariant()
|
||||
if use_alora:
|
||||
return ALoraLinearVariant()
|
||||
else:
|
||||
return DoraLinearVariant()
|
||||
|
||||
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
|
||||
"""
|
||||
@ -767,13 +797,14 @@ class Linear(nn.Module, LoraLayer):
|
||||
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
self._check_forward_args(x, *args, **kwargs)
|
||||
adapter_names = kwargs.pop("adapter_names", None)
|
||||
variant_kwargs = {k: kwargs.pop(k, None) for k in VARIANT_KWARG_KEYS} # don't pass these to base_layer
|
||||
|
||||
if self.disable_adapters:
|
||||
if self.merged:
|
||||
self.unmerge()
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
elif adapter_names is not None:
|
||||
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
|
||||
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **variant_kwargs, **kwargs)
|
||||
elif self.merged:
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
else:
|
||||
@ -798,6 +829,8 @@ class Linear(nn.Module, LoraLayer):
|
||||
active_adapter=active_adapter,
|
||||
x=x,
|
||||
result=result,
|
||||
**variant_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
result = result.to(torch_result_dtype)
|
||||
@ -854,7 +887,7 @@ class Embedding(nn.Module, LoraLayer):
|
||||
return DoraEmbeddingVariant()
|
||||
|
||||
def update_layer(
|
||||
self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora, lora_bias
|
||||
self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora, lora_bias, **kwargs
|
||||
):
|
||||
# collect the kwargs
|
||||
kwargs = locals().copy()
|
||||
@ -1044,7 +1077,7 @@ class Embedding(nn.Module, LoraLayer):
|
||||
# TODO: no dtype conversion here, unlike in Linear, is that correct?
|
||||
self._check_forward_args(x, *args, **kwargs)
|
||||
adapter_names = kwargs.pop("adapter_names", None)
|
||||
|
||||
variant_kwargs = {k: kwargs.pop(k, None) for k in VARIANT_KWARG_KEYS} # don't pass these to base_layer
|
||||
if self.disable_adapters:
|
||||
if self.merged:
|
||||
self.unmerge()
|
||||
@ -1072,6 +1105,8 @@ class Embedding(nn.Module, LoraLayer):
|
||||
active_adapter=active_adapter,
|
||||
x=x,
|
||||
result=result,
|
||||
**variant_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
result = result.to(torch_result_dtype)
|
||||
|
||||
@ -1099,7 +1134,8 @@ class _ConvNd(nn.Module, LoraLayer):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
LoraLayer.__init__(self, base_layer)
|
||||
|
||||
if kwargs.get("use_alora", False):
|
||||
raise ValueError("aLoRA does not support adapting conv layers.")
|
||||
if base_layer.groups > 1:
|
||||
warnings.warn("LoRA adapter added to ConvNd layer with groups > 1. Merging is not supported.")
|
||||
|
||||
@ -1125,7 +1161,7 @@ class _ConvNd(nn.Module, LoraLayer):
|
||||
)
|
||||
|
||||
def update_layer(
|
||||
self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora, lora_bias
|
||||
self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora, lora_bias, **kwargs
|
||||
):
|
||||
# collect the kwargs
|
||||
kwargs = locals().copy()
|
||||
@ -1332,7 +1368,7 @@ class _ConvNd(nn.Module, LoraLayer):
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
self._check_forward_args(x, *args, **kwargs)
|
||||
adapter_names = kwargs.pop("adapter_names", None)
|
||||
|
||||
variant_kwargs = {k: kwargs.pop(k, None) for k in VARIANT_KWARG_KEYS} # don't pass these to base_layer
|
||||
if self.disable_adapters:
|
||||
if self.merged:
|
||||
self.unmerge()
|
||||
@ -1363,6 +1399,8 @@ class _ConvNd(nn.Module, LoraLayer):
|
||||
active_adapter=active_adapter,
|
||||
x=x,
|
||||
result=result,
|
||||
**variant_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
result = result.to(torch_result_dtype)
|
||||
@ -1462,7 +1500,8 @@ class MultiheadAttention(nn.Module, LoraLayer):
|
||||
if use_dora:
|
||||
# TODO: probably not so hard to implement
|
||||
raise ValueError(f"{self.__class__.__name__} does not support DoRA (yet), please set use_dora to False")
|
||||
|
||||
if kwargs.get("use_alora", False):
|
||||
raise ValueError(f"{self.__class__.__name__} does not support aLoRA (yet), please set use_alora to False")
|
||||
super().__init__()
|
||||
LoraLayer.__init__(self, base_layer, **kwargs)
|
||||
|
||||
|
@ -64,6 +64,11 @@ def _adapter_names_pre_forward_hook(target, args, kwargs, adapter_names):
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def _alora_offsets_pre_forward_hook(target, args, kwargs, alora_offsets):
|
||||
kwargs["alora_offsets"] = alora_offsets
|
||||
return args, kwargs
|
||||
|
||||
|
||||
class LoraModel(BaseTuner):
|
||||
"""
|
||||
Creates Low Rank Adapter (LoRA) model from a pretrained transformers model.
|
||||
@ -211,6 +216,7 @@ class LoraModel(BaseTuner):
|
||||
"init_lora_weights": lora_config.init_lora_weights,
|
||||
"use_rslora": lora_config.use_rslora,
|
||||
"use_dora": lora_config.use_dora,
|
||||
"use_alora": lora_config.alora_invocation_tokens is not None,
|
||||
"use_qalora": lora_config.use_qalora,
|
||||
"qalora_group_size": lora_config.qalora_group_size,
|
||||
"ephemeral_gpu_offload": lora_config.runtime_config.ephemeral_gpu_offload,
|
||||
@ -449,57 +455,69 @@ class LoraModel(BaseTuner):
|
||||
def _enable_peft_forward_hooks(self, *args, **kwargs):
|
||||
# If adapter_names is passed as an argument, we inject it into the forward arguments.
|
||||
adapter_names = kwargs.pop("adapter_names", None)
|
||||
if adapter_names is None:
|
||||
alora_offsets = kwargs.pop("alora_offsets", None)
|
||||
if adapter_names is None and alora_offsets is None:
|
||||
# nothing to do
|
||||
yield
|
||||
return
|
||||
|
||||
if self.training:
|
||||
raise ValueError("Cannot pass `adapter_names` when the model is in training mode.")
|
||||
|
||||
# Check that users only passed actually existing adapters.
|
||||
# Note: We cannot do this on the layer level, as each individual layer may not have each adapter. Still, we want
|
||||
# to check that there is at least one layer with the given name, or else something like typos can easily slip.
|
||||
expected_adapters = set()
|
||||
for layer in self.modules():
|
||||
if isinstance(layer, LoraLayer):
|
||||
expected_adapters |= layer.lora_A.keys()
|
||||
expected_adapters |= layer.lora_embedding_A.keys()
|
||||
unique_adapters = {name for name in adapter_names if name != "__base__"}
|
||||
unexpected_adapters = unique_adapters - expected_adapters
|
||||
if unexpected_adapters:
|
||||
raise ValueError(f"Trying to infer with non-existing adapter(s): {', '.join(sorted(unexpected_adapters))}")
|
||||
|
||||
# deal with beam search
|
||||
hook_handles = []
|
||||
if alora_offsets is not None:
|
||||
for layer in self.modules():
|
||||
if isinstance(layer, LoraLayer):
|
||||
pre_forward = partial(_alora_offsets_pre_forward_hook, alora_offsets=alora_offsets)
|
||||
handle = layer.register_forward_pre_hook(pre_forward, with_kwargs=True)
|
||||
hook_handles.append(handle)
|
||||
num_beams = kwargs.get("num_beams", None)
|
||||
uses_beam_search = isinstance(num_beams, int) and (num_beams > 1)
|
||||
original_adapter_names = adapter_names[:]
|
||||
if uses_beam_search:
|
||||
if not isinstance(adapter_names, (list, tuple)):
|
||||
raise TypeError(f"Got adapter names of type {type(adapter_names)}, expected a list of str.")
|
||||
# When there is beam search, the inputs are repeated n times, thus we repeat each adapter name n times and
|
||||
# then flatten the nested list. For encoder-decoder models, this extended list should not be applied to the
|
||||
# encoder part. Further below, the original argument is thus restored for the encoder.
|
||||
adapter_names = sum(([n] * kwargs["num_beams"] for n in adapter_names), [])
|
||||
if alora_offsets is not None:
|
||||
raise ValueError("Beam search not yet supported for aLoRA.")
|
||||
if adapter_names is not None:
|
||||
if self.training:
|
||||
raise ValueError("Cannot pass `adapter_names` when the model is in training mode.")
|
||||
|
||||
hook_handles = []
|
||||
for module in self.modules():
|
||||
if isinstance(module, LoraLayer) or isinstance(module, AuxiliaryTrainingWrapper):
|
||||
pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=adapter_names)
|
||||
handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
|
||||
hook_handles.append(handle)
|
||||
# Check that users only passed actually existing adapters.
|
||||
# Note: We cannot do this on the layer level, as each individual layer may not have each adapter. Still, we want
|
||||
# to check that there is at least one layer with the given name, or else something like typos can easily slip.
|
||||
expected_adapters = set()
|
||||
for layer in self.modules():
|
||||
if isinstance(layer, LoraLayer):
|
||||
expected_adapters |= layer.lora_A.keys()
|
||||
expected_adapters |= layer.lora_embedding_A.keys()
|
||||
unique_adapters = {name for name in adapter_names if name != "__base__"}
|
||||
unexpected_adapters = unique_adapters - expected_adapters
|
||||
if unexpected_adapters:
|
||||
raise ValueError(
|
||||
f"Trying to infer with non-existing adapter(s): {', '.join(sorted(unexpected_adapters))}"
|
||||
)
|
||||
|
||||
if uses_beam_search and hasattr(self.model, "get_encoder"):
|
||||
# For encoder-decoder models, even when applying beam search, the encoder part of the model should not use
|
||||
# the extended adapter_names. This is because the encoder still uses the original, non-extended samples.
|
||||
for module in self.model.get_encoder().modules():
|
||||
# deal with beam search
|
||||
original_adapter_names = adapter_names[:]
|
||||
if uses_beam_search:
|
||||
if not isinstance(adapter_names, (list, tuple)):
|
||||
raise TypeError(f"Got adapter names of type {type(adapter_names)}, expected a list of str.")
|
||||
# When there is beam search, the inputs are repeated n times, thus we repeat each adapter name n times and
|
||||
# then flatten the nested list. For encoder-decoder models, this extended list should not be applied to the
|
||||
# encoder part. Further below, the original argument is thus restored for the encoder.
|
||||
adapter_names = sum(([n] * kwargs["num_beams"] for n in adapter_names), [])
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, LoraLayer) or isinstance(module, AuxiliaryTrainingWrapper):
|
||||
# Add another hook to overwrite the kwargs with the original adapter names -- this is easier than
|
||||
# trying to exclude the encoder.
|
||||
pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=original_adapter_names)
|
||||
pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=adapter_names)
|
||||
handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
|
||||
hook_handles.append(handle)
|
||||
|
||||
if uses_beam_search and hasattr(self.model, "get_encoder"):
|
||||
# For encoder-decoder models, even when applying beam search, the encoder part of the model should not use
|
||||
# the extended adapter_names. This is because the encoder still uses the original, non-extended samples.
|
||||
for module in self.model.get_encoder().modules():
|
||||
if isinstance(module, LoraLayer) or isinstance(module, AuxiliaryTrainingWrapper):
|
||||
# Add another hook to overwrite the kwargs with the original adapter names -- this is easier than
|
||||
# trying to exclude the encoder.
|
||||
pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=original_adapter_names)
|
||||
handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
|
||||
hook_handles.append(handle)
|
||||
|
||||
yield
|
||||
|
||||
for handle in hook_handles:
|
||||
|
@ -13,7 +13,9 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
import collections
|
||||
import warnings
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from accelerate.utils.imports import is_xpu_available
|
||||
@ -21,6 +23,7 @@ from torch import nn
|
||||
|
||||
from peft.utils.other import transpose
|
||||
|
||||
from .config import PeftConfig
|
||||
from .dora import DoraConv1dLayer, DoraConv2dLayer, DoraConv3dLayer, DoraEmbeddingLayer, DoraLinearLayer
|
||||
from .layer import Conv1d, Conv2d, Conv3d, Embedding, Linear, LoraVariant, _ConvNd
|
||||
|
||||
@ -107,7 +110,13 @@ class DoraLinearVariant(LoraVariant):
|
||||
return new_weight
|
||||
|
||||
@staticmethod
|
||||
def forward(module: Linear, active_adapter: str, x: torch.Tensor, result: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
module: Linear,
|
||||
active_adapter: str,
|
||||
x: torch.Tensor,
|
||||
result: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
lora_A = module.lora_A[active_adapter]
|
||||
lora_B = module.lora_B[active_adapter]
|
||||
dropout = module.lora_dropout[active_adapter]
|
||||
@ -197,7 +206,13 @@ class DoraEmbeddingVariant(DoraLinearVariant):
|
||||
return new_weight
|
||||
|
||||
@staticmethod
|
||||
def forward(module: Embedding, active_adapter: str, x: torch.Tensor, result: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
module: Embedding,
|
||||
active_adapter: str,
|
||||
x: torch.Tensor,
|
||||
result: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
embedding_A = module.lora_embedding_A[active_adapter].T
|
||||
embedding_B = module.lora_embedding_B[active_adapter].T
|
||||
scaling = module.scaling[active_adapter]
|
||||
@ -273,7 +288,13 @@ class _DoraConvNdVariant(LoraVariant):
|
||||
return new_weight
|
||||
|
||||
@staticmethod
|
||||
def forward(module: _ConvNd, active_adapter: str, x: torch.Tensor, result: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
module: _ConvNd,
|
||||
active_adapter: str,
|
||||
x: torch.Tensor,
|
||||
result: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
lora_A = module.lora_A[active_adapter]
|
||||
lora_B = module.lora_B[active_adapter]
|
||||
dropout = module.lora_dropout[active_adapter]
|
||||
@ -380,7 +401,13 @@ class QALoraLinearVariant(LoraVariant):
|
||||
raise NotImplementedError("QALoRA for GPTQ layers does not support 'unmerge'.")
|
||||
|
||||
@staticmethod
|
||||
def forward(module: Linear, active_adapter: str, x: torch.Tensor, result: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
module: Linear,
|
||||
active_adapter: str,
|
||||
x: torch.Tensor,
|
||||
result: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
lora_A_weight = module.lora_A[active_adapter].weight
|
||||
lora_B_weight = module.lora_B[active_adapter].weight
|
||||
dropout = module.lora_dropout[active_adapter]
|
||||
@ -411,3 +438,227 @@ class QALoraLinearVariant(LoraVariant):
|
||||
delta = delta.view(orig_shape[:-1] + (delta.size(-1),))
|
||||
|
||||
return result + delta
|
||||
|
||||
|
||||
class ALoraLinearVariant(LoraVariant):
|
||||
@staticmethod
|
||||
def init(module: Linear, adapter_name: str, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def merge_safe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError("aLoRA does not support safe merging.")
|
||||
|
||||
@staticmethod
|
||||
def merge_unsafe(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> None:
|
||||
raise NotImplementedError("aLoRA does not support merging.")
|
||||
|
||||
@staticmethod
|
||||
def unmerge(module: Linear, active_adapter: str, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError("aLoRA does not support unmerging.")
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
module: Linear,
|
||||
active_adapter: str,
|
||||
x: torch.Tensor,
|
||||
result: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
alora_offsets = kwargs.get("alora_offsets", None)
|
||||
lora_A = module.lora_A[active_adapter]
|
||||
lora_B = module.lora_B[active_adapter]
|
||||
dropout = module.lora_dropout[active_adapter]
|
||||
scaling = module.scaling[active_adapter]
|
||||
x = x.to(lora_A.weight.dtype)
|
||||
result_shape = result.shape
|
||||
B = result_shape[0] # batch
|
||||
if len(result_shape) == 3:
|
||||
T = result_shape[1] # tokens
|
||||
else:
|
||||
T = 1
|
||||
D = result_shape[-1] # dimensions
|
||||
Dx = x.shape[-1]
|
||||
device = result.device
|
||||
if alora_offsets is None: # use base model only, but ensure 0 gradient
|
||||
mask = torch.zeros((B, T), dtype=torch.bool)
|
||||
else:
|
||||
# If alora_offsets[i] is None, this means that the invocation sequence was not found in the
|
||||
# input. As a result, the weights should not be activated anywhere (equivalent to base model).
|
||||
# Convert None -> 0 and clip to T
|
||||
offsets = torch.tensor(
|
||||
[0 if o is None else min(int(o), T) for o in alora_offsets],
|
||||
device=device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
# Mask True on the last `offsets[i]` positions for each row i
|
||||
pos = torch.arange(T, device=device).unsqueeze(0) # [1, T]
|
||||
mask = pos >= (T - offsets).unsqueeze(1)
|
||||
|
||||
# Flatten for vectorization
|
||||
x_flat = x.view(-1, Dx)
|
||||
res_flat = result.view(-1, D)
|
||||
mask_flat = mask.view(-1)
|
||||
|
||||
# Compute adapter on the selected tokens only
|
||||
res_flat[mask_flat] += lora_B(lora_A(dropout(x_flat[mask_flat]))) * scaling
|
||||
return result
|
||||
|
||||
|
||||
def calculate_alora_offsets(
|
||||
peft_config: PeftConfig, active_adapter: str, input_ids: torch.Tensor, adapter_names: Optional[list[str]] = None
|
||||
) -> list[int]:
|
||||
"""
|
||||
This is a helper function for Activated LoRA (aLoRA) that searches each input token sequence for the last occurence
|
||||
of the appropriate "alora_invocation_tokens" invocation sequence. The calculated alora_offset is the location of
|
||||
the *start* of the invocation tokens, counting backward from the end (will therefore always be >=
|
||||
len(alora_invocation_tokens). If adapter_names is passed, then each input uses the appropriate invocation sequence
|
||||
for the specified adapter for that row. Logic is provided to handle mixed collections of adapters for which not all
|
||||
are aLoRAs (e.g. some base model, some LoRA).
|
||||
"""
|
||||
if input_ids is None:
|
||||
return []
|
||||
|
||||
batch_size = input_ids.shape[0]
|
||||
alora_offsets = [None] * batch_size
|
||||
|
||||
cached_invocation_tensors = {}
|
||||
adapters_to_process_indices = collections.defaultdict(list)
|
||||
|
||||
for i in range(batch_size):
|
||||
current_adapter_name = adapter_names[i] if adapter_names and i < len(adapter_names) else active_adapter
|
||||
|
||||
if current_adapter_name == "__base__":
|
||||
alora_offsets[i] = None
|
||||
continue
|
||||
|
||||
if current_adapter_name not in peft_config:
|
||||
warnings.warn(f"Adapter '{current_adapter_name}' not found in peft_config. Using base model for row {i}.")
|
||||
alora_offsets[i] = None
|
||||
continue
|
||||
|
||||
current_peft_config = peft_config[current_adapter_name]
|
||||
|
||||
invocation_tokens = getattr(current_peft_config, "alora_invocation_tokens", None)
|
||||
if invocation_tokens is None:
|
||||
alora_offsets[i] = None # Not an aLoRA adapter or wrong type
|
||||
continue
|
||||
|
||||
if current_adapter_name not in cached_invocation_tensors:
|
||||
cached_invocation_tensors[current_adapter_name] = torch.tensor(
|
||||
invocation_tokens, dtype=torch.long, device=input_ids.device
|
||||
)
|
||||
|
||||
adapters_to_process_indices[current_adapter_name].append(i)
|
||||
|
||||
for adapter_name_to_process, indices in adapters_to_process_indices.items():
|
||||
current_invocation_ids_tensor = cached_invocation_tensors[adapter_name_to_process]
|
||||
invocation_len = len(current_invocation_ids_tensor)
|
||||
|
||||
for i in indices:
|
||||
sequence = input_ids[i]
|
||||
seq_len = len(sequence)
|
||||
best_match_start_idx = -1
|
||||
|
||||
possible_starts = (sequence == current_invocation_ids_tensor[0]).nonzero(as_tuple=True)[0]
|
||||
|
||||
for start_idx_tensor in possible_starts:
|
||||
idx = start_idx_tensor.item()
|
||||
if idx + invocation_len <= seq_len:
|
||||
if torch.equal(sequence[idx : idx + invocation_len], current_invocation_ids_tensor):
|
||||
if idx > best_match_start_idx:
|
||||
best_match_start_idx = idx
|
||||
|
||||
if best_match_start_idx != -1:
|
||||
offset_val = seq_len - best_match_start_idx
|
||||
alora_offsets[i] = offset_val if offset_val > 0 else None
|
||||
else: # Invocation sequence not found in input
|
||||
alora_offsets[i] = None
|
||||
return alora_offsets
|
||||
|
||||
|
||||
def is_alora_relevant_in_batch(model: nn.Module, adapter_names: Optional[list[str]] = None):
|
||||
"""
|
||||
Helper function to determine if the current batch has any aLoRA adapters.
|
||||
"""
|
||||
is_alora_relevant = False
|
||||
if getattr(model.active_peft_config, "alora_invocation_tokens", None):
|
||||
is_alora_relevant = True
|
||||
elif adapter_names:
|
||||
for name in adapter_names:
|
||||
if name == "__base__":
|
||||
continue
|
||||
config_ = model.peft_config.get(name)
|
||||
if config_ and getattr(config_, "alora_invocation_tokens", None):
|
||||
is_alora_relevant = True
|
||||
break
|
||||
|
||||
return is_alora_relevant
|
||||
|
||||
|
||||
def get_alora_offsets_for_forward(
|
||||
model: nn.Module, input_ids: torch.Tensor = None, inputs_embeds: torch.Tensor = None, **kwargs
|
||||
):
|
||||
"""
|
||||
Wrapper around calculate_alora_offsets, for the .forward of the model. It only calculates alora_offsets if the
|
||||
batch contains aLoRA adapters.
|
||||
"""
|
||||
adapter_names_for_offset_calc = kwargs.get("adapter_names", None)
|
||||
if not is_alora_relevant_in_batch(model, adapter_names_for_offset_calc):
|
||||
# Nothing to compute
|
||||
return kwargs
|
||||
alora_offsets = kwargs.get("alora_offsets")
|
||||
if alora_offsets is None:
|
||||
if input_ids is None and inputs_embeds is not None:
|
||||
warnings.warn(
|
||||
"Cannot calculate aLoRA offsets when only inputs_embeds are provided. Disabling aLoRA for this forward pass."
|
||||
)
|
||||
kwargs["alora_offsets"] = None
|
||||
elif input_ids is not None:
|
||||
kwargs["alora_offsets"] = calculate_alora_offsets(
|
||||
model.peft_config,
|
||||
model.active_adapter,
|
||||
input_ids,
|
||||
adapter_names=adapter_names_for_offset_calc,
|
||||
)
|
||||
else:
|
||||
kwargs["alora_offsets"] = None
|
||||
return kwargs
|
||||
|
||||
|
||||
def get_alora_offsets_for_generate(model: nn.module, *args, **kwargs):
|
||||
"""
|
||||
Wrapper around calculate_alora_offsets, for the .generate of the model. It only calculates alora_offsets if the
|
||||
batch contains aLoRA adapters.
|
||||
"""
|
||||
adapter_names_for_offset_calc = kwargs.get("adapter_names")
|
||||
if not is_alora_relevant_in_batch(model, adapter_names_for_offset_calc):
|
||||
# Nothing to compute
|
||||
return kwargs
|
||||
alora_offsets_from_kwargs = kwargs.get("alora_offsets")
|
||||
if alora_offsets_from_kwargs is None:
|
||||
current_input_ids = kwargs.get("input_ids")
|
||||
if current_input_ids is None: # args[0] is usually input_ids
|
||||
if args and isinstance(args[0], torch.Tensor):
|
||||
current_input_ids = args[0]
|
||||
else:
|
||||
current_input_ids = None
|
||||
|
||||
if current_input_ids is not None:
|
||||
if current_input_ids.ndim == 1:
|
||||
current_input_ids = current_input_ids.unsqueeze(0)
|
||||
calculated_offsets = calculate_alora_offsets(
|
||||
model.peft_config,
|
||||
model.active_adapter,
|
||||
current_input_ids,
|
||||
adapter_names=adapter_names_for_offset_calc,
|
||||
)
|
||||
kwargs["alora_offsets"] = calculated_offsets
|
||||
|
||||
else:
|
||||
warnings.warn(
|
||||
"Cannot calculate aLoRA offsets during generate as input_ids are not available. Disabling aLoRA."
|
||||
)
|
||||
|
||||
kwargs["alora_offsets"] = None
|
||||
return kwargs
|
||||
|
@ -2099,7 +2099,6 @@ class TestPeftCustomModel(PeftCommonTester):
|
||||
if config_cls != LoraConfig or config_cls != BOFTConfig:
|
||||
# skip this test for other configs as bias is specific to Lora
|
||||
pytest.skip("Testing bias warnings only for LoraConfig or BOFTConfig")
|
||||
|
||||
if not issubclass(config_cls, (LoraConfig, BOFTConfig)):
|
||||
pytest.skip("Bias argument is only supported for LoRA or BOFT models")
|
||||
|
||||
|
@ -152,6 +152,32 @@ ALL_CONFIGS = [
|
||||
"bias": "none",
|
||||
},
|
||||
),
|
||||
# Activated LoRA (aLoRA)
|
||||
(
|
||||
LoraConfig,
|
||||
{
|
||||
"task_type": "CAUSAL_LM",
|
||||
"r": 8,
|
||||
"lora_alpha": 32,
|
||||
"target_modules": None,
|
||||
"lora_dropout": 0.05,
|
||||
"bias": "none",
|
||||
"alora_invocation_tokens": [1],
|
||||
},
|
||||
),
|
||||
(
|
||||
LoraConfig,
|
||||
{
|
||||
"task_type": "CAUSAL_LM",
|
||||
"r": 8,
|
||||
"lora_alpha": 32,
|
||||
"target_modules": None,
|
||||
"lora_dropout": 0.05,
|
||||
"bias": "none",
|
||||
# not one test input sequence will ever have this token, this should do nothing at all
|
||||
"alora_invocation_tokens": [1000],
|
||||
},
|
||||
),
|
||||
# LoRA + trainable tokens
|
||||
(
|
||||
LoraConfig,
|
||||
@ -273,6 +299,11 @@ def _skip_adalora_oft_hra_bone_for_gpt2(model_id, config_cls):
|
||||
pytest.skip("Skipping AdaLora/BOFT/HRA/OFT/Bone/MiSS for GPT2LMHeadModel")
|
||||
|
||||
|
||||
def _skip_alora_no_activation(config_cls, config_kwargs):
|
||||
if config_cls is LoraConfig and config_kwargs.get("alora_invocation_tokens") == [1000]:
|
||||
pytest.skip("Skipping aLoRA no-activation-case because the test expects changed output which there won't be.")
|
||||
|
||||
|
||||
class TestDecoderModels(PeftCommonTester):
|
||||
transformers_class = AutoModelForCausalLM
|
||||
|
||||
@ -411,6 +442,7 @@ class TestDecoderModels(PeftCommonTester):
|
||||
def test_mixed_adapter_batches(self, model_id, config_cls, config_kwargs):
|
||||
if config_cls != LoraConfig:
|
||||
pytest.skip("Mixed adapter batches not supported for this config.")
|
||||
_skip_alora_no_activation(config_cls, config_kwargs)
|
||||
config_kwargs = set_init_weights_false(config_cls, config_kwargs)
|
||||
self._test_mixed_adapter_batches(model_id, config_cls, config_kwargs.copy())
|
||||
|
||||
@ -500,6 +532,7 @@ class TestDecoderModels(PeftCommonTester):
|
||||
def test_unload_adapter(self, model_id, config_cls, config_kwargs):
|
||||
_skip_adalora_oft_hra_bone_for_gpt2(model_id, config_cls)
|
||||
_skip_if_not_conv1d_supported(model_id, config_cls)
|
||||
_skip_alora_no_activation(config_cls, config_kwargs)
|
||||
config_kwargs = set_init_weights_false(config_cls, config_kwargs)
|
||||
self._test_unload_adapter(model_id, config_cls, config_kwargs.copy())
|
||||
|
||||
@ -518,6 +551,7 @@ class TestDecoderModels(PeftCommonTester):
|
||||
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
|
||||
def test_disable_adapter(self, model_id, config_cls, config_kwargs):
|
||||
_skip_if_not_conv1d_supported(model_id, config_cls)
|
||||
_skip_alora_no_activation(config_cls, config_kwargs)
|
||||
config_kwargs = set_init_weights_false(config_cls, config_kwargs)
|
||||
self._test_disable_adapter(model_id, config_cls, config_kwargs.copy())
|
||||
|
||||
|
@ -4853,6 +4853,79 @@ class TestEvaInitializationGPU:
|
||||
)
|
||||
|
||||
|
||||
class TestALoRAInferenceGPU:
|
||||
"""GPU inference for Activated LoRA."""
|
||||
|
||||
# Constants for test configuration
|
||||
NUM_SEEDS = 3
|
||||
LORA_DIM = 8
|
||||
LORA_ALPHA = 1
|
||||
DEVICE = infer_device()
|
||||
|
||||
@pytest.fixture
|
||||
def tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
return tokenizer
|
||||
|
||||
@pytest.fixture
|
||||
def model(self):
|
||||
model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
|
||||
model.model.decoder.layers = model.model.decoder.layers[:2] # truncate to 2 layers
|
||||
return model.to(self.DEVICE)
|
||||
|
||||
@pytest.fixture
|
||||
def model_bnb(self):
|
||||
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"facebook/opt-125m",
|
||||
quantization_config=bnb_config,
|
||||
)
|
||||
model.model.decoder.layers = model.model.decoder.layers[:2] # truncate to 2 layers
|
||||
model = prepare_model_for_kbit_training(model)
|
||||
return model
|
||||
|
||||
@pytest.fixture
|
||||
def peft_config(self):
|
||||
return LoraConfig(
|
||||
r=self.LORA_DIM,
|
||||
task_type="CAUSAL_LM",
|
||||
lora_alpha=self.LORA_ALPHA,
|
||||
target_modules=["q_proj"],
|
||||
alora_invocation_tokens=[2], # id for </s>
|
||||
init_lora_weights=False,
|
||||
)
|
||||
|
||||
@require_non_cpu
|
||||
@require_bitsandbytes
|
||||
@pytest.mark.single_gpu_tests
|
||||
def test_alora_forward_consistency(self, model, model_bnb, peft_config):
|
||||
"""Test that the forwards of the model with adapter are similar across quantizations."""
|
||||
for seed in range(self.NUM_SEEDS):
|
||||
torch.manual_seed(seed)
|
||||
# random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
peft_model = get_peft_model(deepcopy(model), peft_config)
|
||||
torch.manual_seed(seed)
|
||||
# random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
peft_model_bnb = get_peft_model(deepcopy(model_bnb), peft_config)
|
||||
peft_model.eval()
|
||||
peft_model_bnb.eval()
|
||||
input_ids = torch.tensor([[0, 1, 2, 3]]).to(self.DEVICE)
|
||||
with torch.no_grad():
|
||||
peft_out = peft_model(input_ids=input_ids, return_dict=True, output_hidden_states=True)
|
||||
peft_out_bnb = peft_model_bnb(input_ids=input_ids, return_dict=True, output_hidden_states=True)
|
||||
h_fp = peft_out.hidden_states[-1]
|
||||
h_4b = peft_out_bnb.hidden_states[-1]
|
||||
a = h_fp.detach().to(torch.float32).cpu()
|
||||
b = h_4b.detach().to(torch.float32).cpu()
|
||||
import torch.nn.functional as F
|
||||
|
||||
cos = F.cosine_similarity(a.flatten(), b.flatten(), dim=0).item()
|
||||
assert cos > 0.9
|
||||
|
||||
|
||||
@pytest.mark.multi_gpu_tests
|
||||
class TestPrefixTuning:
|
||||
device = infer_device()
|
||||
|
@ -22,13 +22,18 @@ from peft.tuners.lora.layer import Conv2d as LoraConv2d
|
||||
from peft.tuners.lora.layer import Embedding as LoraEmbedding
|
||||
from peft.tuners.lora.layer import Linear as LoraLinear
|
||||
from peft.tuners.lora.variants import (
|
||||
ALoraLinearVariant,
|
||||
DoraConv1dVariant,
|
||||
DoraConv2dVariant,
|
||||
DoraEmbeddingVariant,
|
||||
DoraLinearVariant,
|
||||
calculate_alora_offsets,
|
||||
get_alora_offsets_for_forward,
|
||||
get_alora_offsets_for_generate,
|
||||
)
|
||||
|
||||
|
||||
# Custom model featuring embeddings and a 'visual stack'
|
||||
class CustomModel(nn.Module):
|
||||
"""pytorch module that contains common targetable layers (linear, embedding, conv, ...)"""
|
||||
|
||||
@ -61,13 +66,46 @@ class CustomModel(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
# Used for testing alora_offsets for aLoRA
|
||||
class DummyLM(nn.Module):
|
||||
def __init__(self, vocab_size: int = 10, hidden_dim: int = 8):
|
||||
super().__init__()
|
||||
self.embed = nn.Embedding(vocab_size, hidden_dim)
|
||||
self.linear = nn.Linear(hidden_dim, vocab_size)
|
||||
|
||||
def forward(self, X=None, embeds=None, num_beams=None, alora_offsets=None):
|
||||
if X is not None:
|
||||
embeds = self.embed(X)
|
||||
return self.linear(embeds)
|
||||
|
||||
|
||||
class MockTransformerWrapper:
|
||||
"""Mock class to behave like a transformers model.
|
||||
|
||||
This is needed because the tests initialize the model by calling transformers_class.from_pretrained.
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls):
|
||||
# set the seed so that from_pretrained always returns the same model
|
||||
torch.manual_seed(0)
|
||||
|
||||
torch_dtype = torch.float32
|
||||
|
||||
return DummyLM().to(torch_dtype)
|
||||
|
||||
|
||||
VARIANT_MAP = {
|
||||
"dora": {
|
||||
LoraLinear: DoraLinearVariant,
|
||||
LoraEmbedding: DoraEmbeddingVariant,
|
||||
LoraConv1d: DoraConv1dVariant,
|
||||
LoraConv2d: DoraConv2dVariant,
|
||||
}
|
||||
},
|
||||
"alora": {
|
||||
LoraLinear: ALoraLinearVariant,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@ -77,6 +115,11 @@ TEST_CASES = [
|
||||
LoraConfig,
|
||||
{"target_modules": ["linear1", "linear2", "conv1d", "conv2d", "embedding"], "use_dora": True},
|
||||
),
|
||||
(
|
||||
"alora",
|
||||
LoraConfig,
|
||||
{"target_modules": ["linear1", "linear2"], "alora_invocation_tokens": [1]},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@ -124,3 +167,101 @@ class TestLoraVariants:
|
||||
|
||||
for layer in layer_names:
|
||||
assert getattr(peft_model.base_model.model, layer).lora_magnitude_vector["default"].weight.grad is not None
|
||||
|
||||
|
||||
class TestActivatedLora:
|
||||
@pytest.mark.parametrize(
|
||||
"input_ids, alora_invocation_tokens, expected_offsets",
|
||||
[
|
||||
([[0, 1, 2, 3], [0, 4, 5, 6]], [1, 2], [3, None]),
|
||||
([[1, 2, 1, 2], [0, 4, 1, 2]], [1, 2], [2, 2]),
|
||||
([[1, 2, 3, 4], [0, 4, 1, 4]], [1, 2], [4, None]),
|
||||
([[1, 2, 3, 4]], None, [None]),
|
||||
],
|
||||
)
|
||||
# Verify alora_offsets are calculated correctly
|
||||
def test_calculate_alora_offsets(self, input_ids, alora_invocation_tokens, expected_offsets):
|
||||
config = LoraConfig(alora_invocation_tokens=alora_invocation_tokens)
|
||||
peft_config = {"default": config}
|
||||
|
||||
# compute offsets
|
||||
offsets = calculate_alora_offsets(peft_config, "default", torch.tensor(input_ids))
|
||||
|
||||
assert offsets == expected_offsets
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_ids, alora_invocations, expected_offsets",
|
||||
[
|
||||
([[0, 1, 1], [0, 2, 2]], {"a1": [1], "a2": [2]}, [1, 1]),
|
||||
([[0, 1, 1], [0, 2, 2]], {"a1": [1], "a2": None}, [1, None]),
|
||||
],
|
||||
)
|
||||
# Verify alora_offsets are correct with adapter names
|
||||
def test_calculate_alora_offsets_with_adapter_names(self, input_ids, alora_invocations, expected_offsets):
|
||||
peft_config = {}
|
||||
for alora_name in alora_invocations.keys():
|
||||
peft_config[alora_name] = LoraConfig(alora_invocation_tokens=alora_invocations[alora_name])
|
||||
|
||||
adapter_names = list(alora_invocations.keys())
|
||||
offsets = calculate_alora_offsets(
|
||||
peft_config, adapter_names[0], torch.tensor(input_ids), adapter_names=adapter_names
|
||||
)
|
||||
|
||||
assert offsets == expected_offsets
|
||||
|
||||
# Verify that the adapter does not modify outputs prior to invocation point
|
||||
def test_alora_activation_matches_base_until_invocation(self):
|
||||
transformers_class = MockTransformerWrapper
|
||||
base_model = transformers_class.from_pretrained()
|
||||
cfg = LoraConfig(target_modules=["linear"], alora_invocation_tokens=[2], init_lora_weights=False)
|
||||
lora_model = get_peft_model(base_model, cfg)
|
||||
lora_model.eval()
|
||||
|
||||
input_ids = torch.tensor([[0, 1, 2, 3]])
|
||||
start = 2
|
||||
with lora_model.disable_adapter():
|
||||
with torch.no_grad():
|
||||
base_out = lora_model(X=input_ids)
|
||||
|
||||
kwargs = get_alora_offsets_for_forward(lora_model, input_ids)
|
||||
with torch.no_grad():
|
||||
lora_out = lora_model(X=input_ids, **kwargs)
|
||||
assert torch.allclose(lora_out[:, :start], base_out[:, :start])
|
||||
assert not torch.allclose(lora_out[:, start:], base_out[:, start:])
|
||||
|
||||
# Verify that warning is given for alora when providing embeddings only
|
||||
def test_input_embeds_warning(self):
|
||||
transformers_class = MockTransformerWrapper
|
||||
base_model = transformers_class.from_pretrained()
|
||||
cfg = LoraConfig(target_modules=["linear"], alora_invocation_tokens=[2], init_lora_weights=False)
|
||||
lora_model = get_peft_model(base_model, cfg)
|
||||
lora_model.eval()
|
||||
|
||||
input_ids = torch.tensor([[0, 1, 2, 3]])
|
||||
input_embeds = base_model.embed(input_ids)
|
||||
with pytest.warns(
|
||||
UserWarning,
|
||||
match="Cannot calculate aLoRA offsets when only inputs_embeds are provided. Disabling aLoRA for this forward pass.",
|
||||
):
|
||||
kwargs = get_alora_offsets_for_forward(lora_model, inputs_embeds=input_embeds)
|
||||
assert kwargs.get("alora_offsets") is None
|
||||
with pytest.warns(
|
||||
UserWarning,
|
||||
match="Cannot calculate aLoRA offsets during generate as input_ids are not available. Disabling aLoRA.",
|
||||
):
|
||||
kwargs = get_alora_offsets_for_generate(lora_model, inputs_embeds=input_embeds)
|
||||
assert kwargs.get("alora_offsets") is None
|
||||
|
||||
# Verify that error is raised when requesting num_beams > 1 for alora
|
||||
def test_num_beams_error(self):
|
||||
transformers_class = MockTransformerWrapper
|
||||
base_model = transformers_class.from_pretrained()
|
||||
cfg = LoraConfig(target_modules=["linear"], alora_invocation_tokens=[2], init_lora_weights=False)
|
||||
lora_model = get_peft_model(base_model, cfg)
|
||||
lora_model.eval()
|
||||
|
||||
input_ids = torch.tensor([[0, 1, 2, 3]])
|
||||
with pytest.raises(ValueError) as e:
|
||||
with torch.no_grad():
|
||||
lora_out = lora_model(X=input_ids, num_beams=2, alora_offsets=[3])
|
||||
assert "Beam search not yet supported for aLoRA." in str(e.value)
|
||||
|
@ -586,10 +586,15 @@ class PeftCommonTester:
|
||||
assert load_result2.missing_keys == []
|
||||
|
||||
def _test_merge_layers_fp16(self, model_id, config_cls, config_kwargs):
|
||||
if config_cls not in (LoraConfig, IA3Config, AdaLoraConfig, LoHaConfig, LoKrConfig, VBLoRAConfig):
|
||||
# Merge layers only supported for LoRA and IA³
|
||||
return pytest.skip(f"Test not applicable for {config_cls}")
|
||||
|
||||
if (
|
||||
config_cls not in (LoraConfig, IA3Config, AdaLoraConfig, LoHaConfig, LoKrConfig, VBLoRAConfig)
|
||||
or config_kwargs.get("alora_invocation_tokens") is not None
|
||||
):
|
||||
# Merge layers only supported for LoRA and IA³, and not for Activated LoRA (aLoRA)
|
||||
if config_kwargs.get("alora_invocation_tokens") is None:
|
||||
return pytest.skip(f"Test not applicable for {config_cls}")
|
||||
else:
|
||||
return pytest.skip("Test not applicable for Activated LoRA")
|
||||
if ("gpt2" in model_id.lower()) and (config_cls != LoraConfig):
|
||||
self.skipTest("Merging GPT2 adapters not supported for IA³ (yet)")
|
||||
|
||||
@ -611,16 +616,20 @@ class PeftCommonTester:
|
||||
_ = model.merge_and_unload()
|
||||
|
||||
def _test_merge_layers_nan(self, model_id, config_cls, config_kwargs):
|
||||
if config_cls not in (
|
||||
LoraConfig,
|
||||
IA3Config,
|
||||
AdaLoraConfig,
|
||||
LoHaConfig,
|
||||
LoKrConfig,
|
||||
VeraConfig,
|
||||
FourierFTConfig,
|
||||
if (
|
||||
config_cls
|
||||
not in (
|
||||
LoraConfig,
|
||||
IA3Config,
|
||||
AdaLoraConfig,
|
||||
LoHaConfig,
|
||||
LoKrConfig,
|
||||
VeraConfig,
|
||||
FourierFTConfig,
|
||||
)
|
||||
or config_kwargs.get("alora_invocation_tokens") is not None
|
||||
):
|
||||
# Merge layers only supported for LoRA and IA³
|
||||
# Merge layers only supported for LoRA and IA³, and not for Activated LoRA (aLoRA)
|
||||
return
|
||||
if ("gpt2" in model_id.lower()) and (config_cls != LoraConfig):
|
||||
self.skipTest("Merging GPT2 adapters not supported for IA³ (yet)")
|
||||
@ -700,6 +709,9 @@ class PeftCommonTester:
|
||||
if issubclass(config_cls, (OFTConfig, BOFTConfig)):
|
||||
return pytest.skip(f"Test not applicable for {config_cls}")
|
||||
|
||||
if config_kwargs.get("alora_invocation_tokens") is not None:
|
||||
return pytest.skip("Merging not applicable to aLoRA")
|
||||
|
||||
if ("gpt2" in model_id.lower()) and (config_cls != LoraConfig):
|
||||
self.skipTest("Merging GPT2 adapters not supported for IA³ (yet)")
|
||||
|
||||
@ -805,7 +817,7 @@ class PeftCommonTester:
|
||||
**config_kwargs,
|
||||
)
|
||||
|
||||
if config.peft_type not in supported_peft_types:
|
||||
if config.peft_type not in supported_peft_types or config_kwargs.get("alora_invocation_tokens") is not None:
|
||||
return
|
||||
|
||||
with hub_online_once(model_id):
|
||||
@ -867,6 +879,9 @@ class PeftCommonTester:
|
||||
assert torch.allclose(logits_merged_adapter_default, logits_adapter_1, atol=1e-3, rtol=1e-3)
|
||||
|
||||
def _test_merge_layers_is_idempotent(self, model_id, config_cls, config_kwargs):
|
||||
if config_kwargs.get("alora_invocation_tokens") is not None:
|
||||
# Merging not supported for Activated LoRA (aLoRA)
|
||||
return pytest.skip("Test not applicable for Activated LoRA (aLoRA)")
|
||||
with hub_online_once(model_id):
|
||||
model = self.transformers_class.from_pretrained(model_id)
|
||||
config = config_cls(
|
||||
@ -889,6 +904,9 @@ class PeftCommonTester:
|
||||
assert torch.allclose(logits_0, logits_1, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def _test_safe_merge(self, model_id, config_cls, config_kwargs):
|
||||
if config_kwargs.get("alora_invocation_tokens") is not None:
|
||||
# Merging not supported for Activated LoRA (aLoRA)
|
||||
return pytest.skip("Test not applicable for Activated LoRA (aLoRA)")
|
||||
torch.manual_seed(0)
|
||||
with hub_online_once(model_id):
|
||||
model = self.transformers_class.from_pretrained(model_id)
|
||||
@ -954,7 +972,6 @@ class PeftCommonTester:
|
||||
dummy_input = self.prepare_inputs_for_testing()
|
||||
# ensure that we have at least 3 samples for this test
|
||||
dummy_input = {k: torch.cat([v for _ in range(3)]) for k, v in dummy_input.items()}
|
||||
|
||||
with torch.inference_mode():
|
||||
with model.disable_adapter():
|
||||
output_base = model(**dummy_input)[0]
|
||||
@ -984,7 +1001,6 @@ class PeftCommonTester:
|
||||
# alternate between base model, adapter0, and adapter1
|
||||
adapters = ["__base__", "adapter0", "adapter1"]
|
||||
dummy_input["adapter_names"] = [adapters[i % 3] for i in (range(len(dummy_input["input_ids"])))]
|
||||
|
||||
with torch.inference_mode():
|
||||
output_mixed = model(**dummy_input)[0]
|
||||
logits_mixed = model.generate(**dummy_input, return_dict_in_generate=True, output_scores=True).scores[0]
|
||||
@ -1001,7 +1017,8 @@ class PeftCommonTester:
|
||||
# adapter_names argument. See #2283.
|
||||
if config_cls not in (LoraConfig,):
|
||||
return pytest.skip(f"Mixed adapter batches not supported for {config_cls}")
|
||||
|
||||
if config_kwargs.get("alora_invocation_tokens") is not None:
|
||||
return pytest.skip("Beam search not yet supported for aLoRA") # beam search not yet fully supported
|
||||
if config_kwargs.get("trainable_token_indices", None) is not None:
|
||||
# for some configurations this test will fail since the adapter values don't differ.
|
||||
# this is probably a problem with the test setup and not with the implementation.
|
||||
@ -1030,7 +1047,6 @@ class PeftCommonTester:
|
||||
dummy_input = self.prepare_inputs_for_testing()
|
||||
# ensure that we have at least 3 samples for this test
|
||||
dummy_input = {k: torch.cat([v for _ in range(3)]) for k, v in dummy_input.items()}
|
||||
|
||||
gen_kwargs = {**dummy_input, "max_length": 20, "num_beams": 10, "early_stopping": True}
|
||||
with torch.inference_mode():
|
||||
with model.disable_adapter():
|
||||
|
Reference in New Issue
Block a user