Benjamin Bossan 13fa0aea7e FIX: Wrong coupling between requires_grad and the active adapter (#2765)
Description

At the moment, we strongly couple the active adapter with
requires_grad=True. Concretely, when we call model.set_adapter(name), we
automatically assume that this adapter should not only be made active,
its requires_grad should also be set to True.

For the purpose of training PEFT models, this is fair. However, when
loading PEFT models for inference, this is not desired. Generally, for
inference, we don't need requires_grad=True, but as is, it is enabled.

Generally, this is not a severe bug, since in the inference code, we
don't perform any updates, thus we don't inadvertently update a weight
because it wrongly has requires_grad=True -- this is probably why it
went unnoticed so far. However, it could lead to worse runtime
performance and memory overhead when PyTorch records grads for those
parameters (which it shouldn't if called with torch.inference_mode, but
some users may forget to use this). Therefore, this bug is still worth
fixing.

Example

Example

With `modules_to_save`

A very basic example where the current PEFT fails:

import os
from transformers import AutoModelForCausalLM
from peft import LoraConfig, PeftModel, get_peft_model

model_id = "facebook/opt-125m"
path = "/tmp/peft/2759"
if not os.path.exists(path + "/adapter_model.safetensors"):
    model = AutoModelForCausalLM.from_pretrained(model_id)
    config = LoraConfig(target_modules=["q_proj", "v_proj"], modules_to_save=["lm_head"], r=8)
    model = get_peft_model(model, config)
    model.save_pretrained(path)
    del model

model = AutoModelForCausalLM.from_pretrained(model_id)
model = PeftModel.from_pretrained(model, path)

`modules_to_save` should not have grads enabled, but currently it does.

### With multiple adapters

There is also an issue when loading more than one adapter:

model = PeftModel.from_pretrained(...)
assert not any(p.requires_grad for p in model.parameters())  # works

So far, so good, the first adapter does not have `requires_grad`.

model.load_adapter(...)
assert not any(p.requires_grad for p in model.parameters())  # fails

The load_adapter call inadvertently sets requires_grad=True for the
weights of the _first_ adapter. The reason why this happens is because
when the second adapter is loaded, we call set_adapter with the first
adapter to ensure that it remains the activate adapter. However, due to
the coupling of active adapter and requires_grad, this would result in
setting requires_grad=True for the first adapter.

The PR relaxes this coupling by allowing to call set_adapter with an
additional argument, inference_mode. If set to True, the requires_grad
will not be enabled, even if the adapter is activated.

The example above would also fail for modules_to_save and trainable
tokens, not only for the LoRA/LoHa/... weights.

Still open bugs

The proposed solution is unfortunately not perfect. Right now, we do
pass inference_mode based on the PEFT config of the adapter being added,
which helps with the original issue described above. However, even this
is not absolutely correct, because inference_mode of the second adapter
does not necessarily have the same value as inference_mode of the first
adapter. To illustrate how this can go wrong, I added an xfailing test:

test_loading_model_requires_grad_set_correctly_switch_inference_mode

I believe that this use case is rarer than the ones described at the
beginning, so IMO it is okay to have this bug because we fix more common
bugs. However, LMK if you disagree.

Related to this, I noticed that many tests in
test_custom_models.TestRequiresGrad had code like this:

config0 = FooConfig(...)
peft_model = get_peft_model(MLP(), config0)
config1 = FooConfig(..., inference_mode=True)  # <==
peft_model.add_adapter("adapter1", config1)

This now fails because of the reason just given. I removed
inference_mode=True here and the tests pass again.

Note that the only reason why inference_mode=True was passed here is
because AdaLoRA cannot load 2 adapters in training mode and thus
requires this. Later PEFT methods without this restriction blindly
copied the AdaLoRA test. For those PEFT methods, I removed
inference_mode=True.

However, this also means that the AdaLoRA tests now fail. I thus marked
them as xfail.

To properly fix this bug, I think we would have to refactor the code to
isolate set_adapter (i.e. determining the active adapter) and setting
requires_grad into separate code paths, as they're orthogonal. Moreover,
these attributes are being set all over the place, which makes it hard
to reason about where these attributes are being changed. This should be
streamlined.

Making these changes while not breaking any existing code is not
trivial (or maybe impossible even). Therefore, I went the easier way for
the time being with this PR. Maybe a bigger refactor could be envisioned
for a version 1.0 release of PEFT.

Related changes

While working on this, I noticed that LNTuning was completely buggy when
calling set_adapter. This is now fixed.

Moreover, since I had to touch update_layer everywhere, I ensured that
they all take kwargs for consistency.
2025-09-08 19:49:29 +02:00
2022-11-25 09:21:10 +05:30

🤗 PEFT

State-of-the-art Parameter-Efficient Fine-Tuning (PEFT) methods

Fine-tuning large pretrained models is often prohibitively costly due to their scale. Parameter-Efficient Fine-Tuning (PEFT) methods enable efficient adaptation of large pretrained models to various downstream applications by only fine-tuning a small number of (extra) model parameters instead of all the model's parameters. This significantly decreases the computational and storage costs. Recent state-of-the-art PEFT techniques achieve performance comparable to fully fine-tuned models.

PEFT is integrated with Transformers for easy model training and inference, Diffusers for conveniently managing different adapters, and Accelerate for distributed training and inference for really big models.

Tip

Visit the PEFT organization to read about the PEFT methods implemented in the library and to see notebooks demonstrating how to apply these methods to a variety of downstream tasks. Click the "Watch repos" button on the organization page to be notified of newly implemented methods and notebooks!

Check the PEFT Adapters API Reference section for a list of supported PEFT methods, and read the Adapters, Soft prompts, and IA3 conceptual guides to learn more about how these methods work.

Quickstart

Install PEFT from pip:

pip install peft

Prepare a model for training with a PEFT method such as LoRA by wrapping the base model and PEFT configuration with get_peft_model. For the bigscience/mt0-large model, you're only training 0.19% of the parameters!

from transformers import AutoModelForCausalLM
from peft import LoraConfig, TaskType, get_peft_model

device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
model_id = "Qwen/Qwen2.5-3B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device)
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    task_type=TaskType.CAUSAL_LM,
    # target_modules=["q_proj", "v_proj", ...]  # optionally indicate target modules
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
# prints: trainable params: 3,686,400 || all params: 3,089,625,088 || trainable%: 0.1193

# now perform training on your dataset, e.g. using transformers Trainer, then save the model
model.save_pretrained("qwen2.5-3b-lora")

To load a PEFT model for inference:

from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
model_id = "Qwen/Qwen2.5-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device)
model = PeftModel.from_pretrained(model, "qwen2.5-3b-lora")

inputs = tokenizer("Preheat the oven to 350 degrees and place the cookie dough", return_tensors="pt")
outputs = model.generate(**inputs.to(device), max_new_tokens=50)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

# prints something like: Preheat the oven to 350 degrees and place the cookie dough in a baking dish [...]

Why you should use PEFT

There are many benefits of using PEFT but the main one is the huge savings in compute and storage, making PEFT applicable to many different use cases.

High performance on consumer hardware

Consider the memory requirements for training the following models on the ought/raft/twitter_complaints dataset with an A100 80GB GPU with more than 64GB of CPU RAM.

Model Full Finetuning PEFT-LoRA PyTorch PEFT-LoRA DeepSpeed with CPU Offloading
bigscience/T0_3B (3B params) 47.14GB GPU / 2.96GB CPU 14.4GB GPU / 2.96GB CPU 9.8GB GPU / 17.8GB CPU
bigscience/mt0-xxl (12B params) OOM GPU 56GB GPU / 3GB CPU 22GB GPU / 52GB CPU
bigscience/bloomz-7b1 (7B params) OOM GPU 32GB GPU / 3.8GB CPU 18.1GB GPU / 35GB CPU

With LoRA you can fully finetune a 12B parameter model that would've otherwise run out of memory on the 80GB GPU, and comfortably fit and train a 3B parameter model. When you look at the 3B parameter model's performance, it is comparable to a fully finetuned model at a fraction of the GPU memory.

Submission Name Accuracy
Human baseline (crowdsourced) 0.897
Flan-T5 0.892
lora-t0-3b 0.863

Tip

The bigscience/T0_3B model performance isn't optimized in the table above. You can squeeze even more performance out of it by playing around with the input instruction templates, LoRA hyperparameters, and other training related hyperparameters. The final checkpoint size of this model is just 19MB compared to 11GB of the full bigscience/T0_3B model. Learn more about the advantages of finetuning with PEFT in this blog post.

Quantization

Quantization is another method for reducing the memory requirements of a model by representing the data in a lower precision. It can be combined with PEFT methods to make it even easier to train and load LLMs for inference.

Save compute and storage

PEFT can help you save storage by avoiding full finetuning of models on each of downstream task or dataset. In many cases, you're only finetuning a very small fraction of a model's parameters and each checkpoint is only a few MBs in size (instead of GBs). These smaller PEFT adapters demonstrate performance comparable to a fully finetuned model. If you have many datasets, you can save a lot of storage with a PEFT model and not have to worry about catastrophic forgetting or overfitting the backbone or base model.

PEFT integrations

PEFT is widely supported across the Hugging Face ecosystem because of the massive efficiency it brings to training and inference.

Diffusers

The iterative diffusion process consumes a lot of memory which can make it difficult to train. PEFT can help reduce the memory requirements and reduce the storage size of the final model checkpoint. For example, consider the memory required for training a Stable Diffusion model with LoRA on an A100 80GB GPU with more than 64GB of CPU RAM. The final model checkpoint size is only 8.8MB!

Model Full Finetuning PEFT-LoRA PEFT-LoRA with Gradient Checkpointing
CompVis/stable-diffusion-v1-4 27.5GB GPU / 3.97GB CPU 15.5GB GPU / 3.84GB CPU 8.12GB GPU / 3.77GB CPU

Tip

Take a look at the examples/lora_dreambooth/train_dreambooth.py training script to try training your own Stable Diffusion model with LoRA, and play around with the smangrul/peft-lora-sd-dreambooth Space which is running on a T4 instance. Learn more about the PEFT integration in Diffusers in this tutorial.

Transformers

PEFT is directly integrated with Transformers. After loading a model, call add_adapter to add a new PEFT adapter to the model:

from peft import LoraConfig
model = ...  # transformers model
peft_config = LoraConfig(...)
model.add_adapter(lora_config, adapter_name="lora_1")

To load a trained PEFT adapter, call load_adapter:

model = ...  # transformers model
model.load_adapter(<path-to-adapter>, adapter_name="lora_1")

And to switch between different adapters, call set_adapter:

model.set_adapter("lora_2")

The Transformers integration doesn't include all the functionalities offered in PEFT, such as methods for merging the adapter into the base model.

Accelerate

Accelerate is a library for distributed training and inference on various training setups and hardware (GPUs, TPUs, Apple Silicon, etc.). PEFT models work with Accelerate out of the box, making it really convenient to train really large models or use them for inference on consumer hardware with limited resources.

TRL

PEFT can also be applied to training LLMs with RLHF components such as the ranker and policy. Get started by reading:

Model support

Use this Space or check out the docs to find which models officially support a PEFT method out of the box. Even if you don't see a model listed below, you can manually configure the model config to enable PEFT for a model. Read the New transformers architecture guide to learn how.

Contribute

If you would like to contribute to PEFT, please check out our contribution guide.

Citing 🤗 PEFT

To use 🤗 PEFT in your publication, please cite it by using the following BibTeX entry.

@Misc{peft,
  title =        {{PEFT}: State-of-the-art Parameter-Efficient Fine-Tuning methods},
  author =       {Sourab Mangrulkar and Sylvain Gugger and Lysandre Debut and Younes Belkada and Sayak Paul and Benjamin Bossan},
  howpublished = {\url{https://github.com/huggingface/peft}},
  year =         {2022}
}
Description
🤗 PEFT: State-of-the-art Parameter-Efficient Fine-Tuning.
Readme Apache-2.0 214 MiB
Languages
Python 99.5%
Dockerfile 0.2%
Makefile 0.2%