DOC: Explain how to use multiple adapters at the same time (#2763)

Explain how to use multiple adapters (e.g. 2 LoRA adapters) at the same
time, as the API is not quite intuitive and there are some footguns
around trainable parameters.

This question has come up multiple times in the past (for recent
examples, check #2749 and #2756). Thus it's a good idea to properly
document this.

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
Benjamin Bossan
2025-09-25 17:58:57 +02:00
committed by GitHub
parent 530d7bbf1e
commit 7b2a5b1f02

View File

@ -401,3 +401,67 @@ If it is not possible for you to upgrade PEFT, there is a workaround you can try
Assume the error message says that the unknown keyword argument is named `foobar`. Search inside the `adapter_config.json` of this PEFT adapter for the `foobar` entry and delete it from the file. Then save the file and try loading the model again.
This solution works most of the time. As long as it is the default value for `foobar`, it can be ignored. However, when it is set to some other value, you will get incorrect results. Upgrading PEFT is the recommended solution.
## Adapter handling
### Using multiple adapters at the same time
PEFT allows you to create more than one adapter on the same model. This can be useful in many situations. For example, for inference, you may want to serve two fine-tuned models from the same base model instead of loading the base model once for each fine-tuned model, which would cost more memory. However, multiple adapters can be activated at the same time. This way, the model may leverage the learnings from all those adapters at the same time. As an example, if you have a diffusion model, you may want to use one LoRA adapter to change the style and a different one to change the subject.
Activating multiple adapters at the same time is generally possible on all PEFT methods (LoRA, LoHa, IA³, etc.) except for prompt learning methods (p-tuning, prefix tuning, etc.). The following example illustrates how to achieve this:
```python
from transformers import AutoModelForCausalLM
from peft import PeftModel
model_id = ...
base_model = AutoModelForCausalLM.from_pretrained(model_id)
model = PeftModel.from_pretrained(base_model, lora_path_0) # default adapter_name is 'default'
model.load_adapter(lora_path_1, adapter_name="other")
# the 'other' adapter was loaded but it's not active yet, so to activate both adapters:
model.base_model.set_adapter(["default", "other"])
```
> [!TIP]
> In the example above, you can see that we need to call `model.base_model.set_adapter(["default", "other"])`. Why can we not call `model.set_adapter(["default", "other"])`? This is unfortunately not possible because, as explained earlier, some PEFT methods don't support activating more than one adapter at a time.
It is also possible to train two adapters at the same time, but you should be careful to ensure that the weights of both adapters are known to the optimizer. Otherwise, only one adapter will receive updates.
```python
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
model_id = ...
base_model = AutoModelForCausalLM.from_pretrained(model_id)
lora_config_0 = LoraConfig(...)
lora_config_1 = LoraConfig(...)
model = get_peft_model(base_model, lora_config_0)
model.add_adapter(adapter_name="other", peft_config=lora_config_1)
```
If we would now call:
```python
from transformers import Trainer
trainer = Trainer(model=model, ...)
trainer.train()
```
or
```python
optimizer = torch.optim.AdamW([param for param in model.parameters() if param.requires_grad], ...)
```
then the second LoRA adapter (`"other"`) would not be trained. This is because it is inactive at this moment, which means the `requires_grad` attribute on its parameters is set to `False` and the optimizer will ignore it. Therefore, make sure to activate all adapters that should be trained _before_ initializing the optimizer:
```python
# activate all adapters
model.base_model.set_adapter(["default", "other"])
trainer = Trainer(model=model, ...)
trainer.train()
```
> [!TIP]
> This section deals with using multiple adapters _of the same type_ on the same model, for example, using multiple LoRA adapters at the same time. It does not apply to using _different types_ of adapters on the same model, for example one LoRA adapter and one LoHa adapter. For this, please check [`PeftMixedModel`](https://huggingface.co/docs/peft/developer_guides/mixed_models).