Description At the moment, we strongly couple the active adapter with requires_grad=True. Concretely, when we call model.set_adapter(name), we automatically assume that this adapter should not only be made active, its requires_grad should also be set to True. For the purpose of training PEFT models, this is fair. However, when loading PEFT models for inference, this is not desired. Generally, for inference, we don't need requires_grad=True, but as is, it is enabled. Generally, this is not a severe bug, since in the inference code, we don't perform any updates, thus we don't inadvertently update a weight because it wrongly has requires_grad=True -- this is probably why it went unnoticed so far. However, it could lead to worse runtime performance and memory overhead when PyTorch records grads for those parameters (which it shouldn't if called with torch.inference_mode, but some users may forget to use this). Therefore, this bug is still worth fixing. Example Example With `modules_to_save` A very basic example where the current PEFT fails: import os from transformers import AutoModelForCausalLM from peft import LoraConfig, PeftModel, get_peft_model model_id = "facebook/opt-125m" path = "/tmp/peft/2759" if not os.path.exists(path + "/adapter_model.safetensors"): model = AutoModelForCausalLM.from_pretrained(model_id) config = LoraConfig(target_modules=["q_proj", "v_proj"], modules_to_save=["lm_head"], r=8) model = get_peft_model(model, config) model.save_pretrained(path) del model model = AutoModelForCausalLM.from_pretrained(model_id) model = PeftModel.from_pretrained(model, path) `modules_to_save` should not have grads enabled, but currently it does. ### With multiple adapters There is also an issue when loading more than one adapter: model = PeftModel.from_pretrained(...) assert not any(p.requires_grad for p in model.parameters()) # works So far, so good, the first adapter does not have `requires_grad`. model.load_adapter(...) assert not any(p.requires_grad for p in model.parameters()) # fails The load_adapter call inadvertently sets requires_grad=True for the weights of the _first_ adapter. The reason why this happens is because when the second adapter is loaded, we call set_adapter with the first adapter to ensure that it remains the activate adapter. However, due to the coupling of active adapter and requires_grad, this would result in setting requires_grad=True for the first adapter. The PR relaxes this coupling by allowing to call set_adapter with an additional argument, inference_mode. If set to True, the requires_grad will not be enabled, even if the adapter is activated. The example above would also fail for modules_to_save and trainable tokens, not only for the LoRA/LoHa/... weights. Still open bugs The proposed solution is unfortunately not perfect. Right now, we do pass inference_mode based on the PEFT config of the adapter being added, which helps with the original issue described above. However, even this is not absolutely correct, because inference_mode of the second adapter does not necessarily have the same value as inference_mode of the first adapter. To illustrate how this can go wrong, I added an xfailing test: test_loading_model_requires_grad_set_correctly_switch_inference_mode I believe that this use case is rarer than the ones described at the beginning, so IMO it is okay to have this bug because we fix more common bugs. However, LMK if you disagree. Related to this, I noticed that many tests in test_custom_models.TestRequiresGrad had code like this: config0 = FooConfig(...) peft_model = get_peft_model(MLP(), config0) config1 = FooConfig(..., inference_mode=True) # <== peft_model.add_adapter("adapter1", config1) This now fails because of the reason just given. I removed inference_mode=True here and the tests pass again. Note that the only reason why inference_mode=True was passed here is because AdaLoRA cannot load 2 adapters in training mode and thus requires this. Later PEFT methods without this restriction blindly copied the AdaLoRA test. For those PEFT methods, I removed inference_mode=True. However, this also means that the AdaLoRA tests now fail. I thus marked them as xfail. To properly fix this bug, I think we would have to refactor the code to isolate set_adapter (i.e. determining the active adapter) and setting requires_grad into separate code paths, as they're orthogonal. Moreover, these attributes are being set all over the place, which makes it hard to reason about where these attributes are being changed. This should be streamlined. Making these changes while not breaking any existing code is not trivial (or maybe impossible even). Therefore, I went the easier way for the time being with this PR. Maybe a bigger refactor could be envisioned for a version 1.0 release of PEFT. Related changes While working on this, I noticed that LNTuning was completely buggy when calling set_adapter. This is now fixed. Moreover, since I had to touch update_layer everywhere, I ensured that they all take kwargs for consistency.
🤗 PEFT
State-of-the-art Parameter-Efficient Fine-Tuning (PEFT) methods
Fine-tuning large pretrained models is often prohibitively costly due to their scale. Parameter-Efficient Fine-Tuning (PEFT) methods enable efficient adaptation of large pretrained models to various downstream applications by only fine-tuning a small number of (extra) model parameters instead of all the model's parameters. This significantly decreases the computational and storage costs. Recent state-of-the-art PEFT techniques achieve performance comparable to fully fine-tuned models.
PEFT is integrated with Transformers for easy model training and inference, Diffusers for conveniently managing different adapters, and Accelerate for distributed training and inference for really big models.
Tip
Visit the PEFT organization to read about the PEFT methods implemented in the library and to see notebooks demonstrating how to apply these methods to a variety of downstream tasks. Click the "Watch repos" button on the organization page to be notified of newly implemented methods and notebooks!
Check the PEFT Adapters API Reference section for a list of supported PEFT methods, and read the Adapters, Soft prompts, and IA3 conceptual guides to learn more about how these methods work.
Quickstart
Install PEFT from pip:
pip install peft
Prepare a model for training with a PEFT method such as LoRA by wrapping the base model and PEFT configuration with get_peft_model
. For the bigscience/mt0-large model, you're only training 0.19% of the parameters!
from transformers import AutoModelForCausalLM
from peft import LoraConfig, TaskType, get_peft_model
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
model_id = "Qwen/Qwen2.5-3B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device)
peft_config = LoraConfig(
r=16,
lora_alpha=32,
task_type=TaskType.CAUSAL_LM,
# target_modules=["q_proj", "v_proj", ...] # optionally indicate target modules
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
# prints: trainable params: 3,686,400 || all params: 3,089,625,088 || trainable%: 0.1193
# now perform training on your dataset, e.g. using transformers Trainer, then save the model
model.save_pretrained("qwen2.5-3b-lora")
To load a PEFT model for inference:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
model_id = "Qwen/Qwen2.5-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device)
model = PeftModel.from_pretrained(model, "qwen2.5-3b-lora")
inputs = tokenizer("Preheat the oven to 350 degrees and place the cookie dough", return_tensors="pt")
outputs = model.generate(**inputs.to(device), max_new_tokens=50)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
# prints something like: Preheat the oven to 350 degrees and place the cookie dough in a baking dish [...]
Why you should use PEFT
There are many benefits of using PEFT but the main one is the huge savings in compute and storage, making PEFT applicable to many different use cases.
High performance on consumer hardware
Consider the memory requirements for training the following models on the ought/raft/twitter_complaints dataset with an A100 80GB GPU with more than 64GB of CPU RAM.
Model | Full Finetuning | PEFT-LoRA PyTorch | PEFT-LoRA DeepSpeed with CPU Offloading |
---|---|---|---|
bigscience/T0_3B (3B params) | 47.14GB GPU / 2.96GB CPU | 14.4GB GPU / 2.96GB CPU | 9.8GB GPU / 17.8GB CPU |
bigscience/mt0-xxl (12B params) | OOM GPU | 56GB GPU / 3GB CPU | 22GB GPU / 52GB CPU |
bigscience/bloomz-7b1 (7B params) | OOM GPU | 32GB GPU / 3.8GB CPU | 18.1GB GPU / 35GB CPU |
With LoRA you can fully finetune a 12B parameter model that would've otherwise run out of memory on the 80GB GPU, and comfortably fit and train a 3B parameter model. When you look at the 3B parameter model's performance, it is comparable to a fully finetuned model at a fraction of the GPU memory.
Submission Name | Accuracy |
---|---|
Human baseline (crowdsourced) | 0.897 |
Flan-T5 | 0.892 |
lora-t0-3b | 0.863 |
Tip
The bigscience/T0_3B model performance isn't optimized in the table above. You can squeeze even more performance out of it by playing around with the input instruction templates, LoRA hyperparameters, and other training related hyperparameters. The final checkpoint size of this model is just 19MB compared to 11GB of the full bigscience/T0_3B model. Learn more about the advantages of finetuning with PEFT in this blog post.
Quantization
Quantization is another method for reducing the memory requirements of a model by representing the data in a lower precision. It can be combined with PEFT methods to make it even easier to train and load LLMs for inference.
- Learn how to finetune meta-llama/Llama-2-7b-hf with QLoRA and the TRL library on a 16GB GPU in the Finetune LLMs on your own consumer hardware using tools from PyTorch and Hugging Face ecosystem blog post.
- Learn how to finetune a openai/whisper-large-v2 model for multilingual automatic speech recognition with LoRA and 8-bit quantization in this notebook (see this notebook instead for an example of streaming a dataset).
Save compute and storage
PEFT can help you save storage by avoiding full finetuning of models on each of downstream task or dataset. In many cases, you're only finetuning a very small fraction of a model's parameters and each checkpoint is only a few MBs in size (instead of GBs). These smaller PEFT adapters demonstrate performance comparable to a fully finetuned model. If you have many datasets, you can save a lot of storage with a PEFT model and not have to worry about catastrophic forgetting or overfitting the backbone or base model.
PEFT integrations
PEFT is widely supported across the Hugging Face ecosystem because of the massive efficiency it brings to training and inference.
Diffusers
The iterative diffusion process consumes a lot of memory which can make it difficult to train. PEFT can help reduce the memory requirements and reduce the storage size of the final model checkpoint. For example, consider the memory required for training a Stable Diffusion model with LoRA on an A100 80GB GPU with more than 64GB of CPU RAM. The final model checkpoint size is only 8.8MB!
Model | Full Finetuning | PEFT-LoRA | PEFT-LoRA with Gradient Checkpointing |
---|---|---|---|
CompVis/stable-diffusion-v1-4 | 27.5GB GPU / 3.97GB CPU | 15.5GB GPU / 3.84GB CPU | 8.12GB GPU / 3.77GB CPU |
Tip
Take a look at the examples/lora_dreambooth/train_dreambooth.py training script to try training your own Stable Diffusion model with LoRA, and play around with the smangrul/peft-lora-sd-dreambooth Space which is running on a T4 instance. Learn more about the PEFT integration in Diffusers in this tutorial.
Transformers
PEFT is directly integrated with Transformers. After loading a model, call add_adapter
to add a new PEFT adapter to the model:
from peft import LoraConfig
model = ... # transformers model
peft_config = LoraConfig(...)
model.add_adapter(lora_config, adapter_name="lora_1")
To load a trained PEFT adapter, call load_adapter
:
model = ... # transformers model
model.load_adapter(<path-to-adapter>, adapter_name="lora_1")
And to switch between different adapters, call set_adapter
:
model.set_adapter("lora_2")
The Transformers integration doesn't include all the functionalities offered in PEFT, such as methods for merging the adapter into the base model.
Accelerate
Accelerate is a library for distributed training and inference on various training setups and hardware (GPUs, TPUs, Apple Silicon, etc.). PEFT models work with Accelerate out of the box, making it really convenient to train really large models or use them for inference on consumer hardware with limited resources.
TRL
PEFT can also be applied to training LLMs with RLHF components such as the ranker and policy. Get started by reading:
- Fine-tune a Mistral-7b model with Direct Preference Optimization with PEFT and the TRL library to learn more about the Direct Preference Optimization (DPO) method and how to apply it to a LLM.
- Fine-tuning 20B LLMs with RLHF on a 24GB consumer GPU with PEFT and the TRL library, and then try out the gpt2-sentiment_peft.ipynb notebook to optimize GPT2 to generate positive movie reviews.
- StackLLaMA: A hands-on guide to train LLaMA with RLHF with PEFT, and then try out the stack_llama/scripts for supervised finetuning, reward modeling, and RL finetuning.
Model support
Use this Space or check out the docs to find which models officially support a PEFT method out of the box. Even if you don't see a model listed below, you can manually configure the model config to enable PEFT for a model. Read the New transformers architecture guide to learn how.
Contribute
If you would like to contribute to PEFT, please check out our contribution guide.
Citing 🤗 PEFT
To use 🤗 PEFT in your publication, please cite it by using the following BibTeX entry.
@Misc{peft,
title = {{PEFT}: State-of-the-art Parameter-Efficient Fine-Tuning methods},
author = {Sourab Mangrulkar and Sylvain Gugger and Lysandre Debut and Younes Belkada and Sayak Paul and Benjamin Bossan},
howpublished = {\url{https://github.com/huggingface/peft}},
year = {2022}
}