FIX: Wrong coupling between requires_grad and the active adapter (#2765)

Description

At the moment, we strongly couple the active adapter with
requires_grad=True. Concretely, when we call model.set_adapter(name), we
automatically assume that this adapter should not only be made active,
its requires_grad should also be set to True.

For the purpose of training PEFT models, this is fair. However, when
loading PEFT models for inference, this is not desired. Generally, for
inference, we don't need requires_grad=True, but as is, it is enabled.

Generally, this is not a severe bug, since in the inference code, we
don't perform any updates, thus we don't inadvertently update a weight
because it wrongly has requires_grad=True -- this is probably why it
went unnoticed so far. However, it could lead to worse runtime
performance and memory overhead when PyTorch records grads for those
parameters (which it shouldn't if called with torch.inference_mode, but
some users may forget to use this). Therefore, this bug is still worth
fixing.

Example

Example

With `modules_to_save`

A very basic example where the current PEFT fails:

import os
from transformers import AutoModelForCausalLM
from peft import LoraConfig, PeftModel, get_peft_model

model_id = "facebook/opt-125m"
path = "/tmp/peft/2759"
if not os.path.exists(path + "/adapter_model.safetensors"):
    model = AutoModelForCausalLM.from_pretrained(model_id)
    config = LoraConfig(target_modules=["q_proj", "v_proj"], modules_to_save=["lm_head"], r=8)
    model = get_peft_model(model, config)
    model.save_pretrained(path)
    del model

model = AutoModelForCausalLM.from_pretrained(model_id)
model = PeftModel.from_pretrained(model, path)

`modules_to_save` should not have grads enabled, but currently it does.

### With multiple adapters

There is also an issue when loading more than one adapter:

model = PeftModel.from_pretrained(...)
assert not any(p.requires_grad for p in model.parameters())  # works

So far, so good, the first adapter does not have `requires_grad`.

model.load_adapter(...)
assert not any(p.requires_grad for p in model.parameters())  # fails

The load_adapter call inadvertently sets requires_grad=True for the
weights of the _first_ adapter. The reason why this happens is because
when the second adapter is loaded, we call set_adapter with the first
adapter to ensure that it remains the activate adapter. However, due to
the coupling of active adapter and requires_grad, this would result in
setting requires_grad=True for the first adapter.

The PR relaxes this coupling by allowing to call set_adapter with an
additional argument, inference_mode. If set to True, the requires_grad
will not be enabled, even if the adapter is activated.

The example above would also fail for modules_to_save and trainable
tokens, not only for the LoRA/LoHa/... weights.

Still open bugs

The proposed solution is unfortunately not perfect. Right now, we do
pass inference_mode based on the PEFT config of the adapter being added,
which helps with the original issue described above. However, even this
is not absolutely correct, because inference_mode of the second adapter
does not necessarily have the same value as inference_mode of the first
adapter. To illustrate how this can go wrong, I added an xfailing test:

test_loading_model_requires_grad_set_correctly_switch_inference_mode

I believe that this use case is rarer than the ones described at the
beginning, so IMO it is okay to have this bug because we fix more common
bugs. However, LMK if you disagree.

Related to this, I noticed that many tests in
test_custom_models.TestRequiresGrad had code like this:

config0 = FooConfig(...)
peft_model = get_peft_model(MLP(), config0)
config1 = FooConfig(..., inference_mode=True)  # <==
peft_model.add_adapter("adapter1", config1)

This now fails because of the reason just given. I removed
inference_mode=True here and the tests pass again.

Note that the only reason why inference_mode=True was passed here is
because AdaLoRA cannot load 2 adapters in training mode and thus
requires this. Later PEFT methods without this restriction blindly
copied the AdaLoRA test. For those PEFT methods, I removed
inference_mode=True.

However, this also means that the AdaLoRA tests now fail. I thus marked
them as xfail.

To properly fix this bug, I think we would have to refactor the code to
isolate set_adapter (i.e. determining the active adapter) and setting
requires_grad into separate code paths, as they're orthogonal. Moreover,
these attributes are being set all over the place, which makes it hard
to reason about where these attributes are being changed. This should be
streamlined.

Making these changes while not breaking any existing code is not
trivial (or maybe impossible even). Therefore, I went the easier way for
the time being with this PR. Maybe a bigger refactor could be envisioned
for a version 1.0 release of PEFT.

Related changes

While working on this, I noticed that LNTuning was completely buggy when
calling set_adapter. This is now fixed.

Moreover, since I had to touch update_layer everywhere, I ensured that
they all take kwargs for consistency.
This commit is contained in:
Benjamin Bossan
2025-09-08 19:49:29 +02:00
committed by GitHub
parent 42db980676
commit 13fa0aea7e
44 changed files with 456 additions and 239 deletions

View File

@ -251,9 +251,14 @@ class PeftMixedModel(PushToHubMixin, torch.nn.Module):
self.modules_to_save = set(modules_to_save) self.modules_to_save = set(modules_to_save)
else: else:
self.modules_to_save.update(modules_to_save) self.modules_to_save.update(modules_to_save)
_set_trainable(self, adapter_name, module_names=getattr(peft_config, "modules_to_save", None)) _set_trainable(
self,
adapter_name,
module_names=getattr(peft_config, "modules_to_save", None),
inference_mode=peft_config.inference_mode,
)
def set_adapter(self, adapter_name: Union[str, list[str]]) -> None: def set_adapter(self, adapter_name: Union[str, list[str]], inference_mode: bool = False) -> None:
""" """
Sets the active adapter(s) for the model. Sets the active adapter(s) for the model.
@ -262,18 +267,14 @@ class PeftMixedModel(PushToHubMixin, torch.nn.Module):
order in which the adapters were loaded into the model. The active adapters only determine which adapters are order in which the adapters were loaded into the model. The active adapters only determine which adapters are
active during the forward pass, but not the order in which they are applied. active during the forward pass, but not the order in which they are applied.
Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is Additionally, this function will set the specified adapter to trainable (i.e., requires_grad=True) unless
not desired, use the following code. inference_mode is True.
```py
>>> for name, param in model_peft.named_parameters():
... if ...: # some check on name (ex. if 'lora' in name)
... param.requires_grad = False
```
Args: Args:
adapter_name (`str` or `List[str]`): adapter_name (str, list[str]):
The name of the adapter(s) to be activated. The name(s) of the adapter(s) to set as active
inference_mode (bool, optional):
Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False.
""" """
if isinstance(adapter_name, str): if isinstance(adapter_name, str):
adapter_name = [adapter_name] adapter_name = [adapter_name]
@ -284,8 +285,8 @@ class PeftMixedModel(PushToHubMixin, torch.nn.Module):
f"Adapter(s) {sorted(mismatched)} not found, available adapters: {sorted(self.peft_config.keys())}" f"Adapter(s) {sorted(mismatched)} not found, available adapters: {sorted(self.peft_config.keys())}"
) )
self.base_model.set_adapter(adapter_name) self.base_model.set_adapter(adapter_name, inference_mode=inference_mode)
_set_adapter(self, adapter_name) _set_adapter(self, adapter_name, inference_mode=inference_mode)
def delete_adapter(self, adapter_name: Union[str, list[str]]) -> None: def delete_adapter(self, adapter_name: Union[str, list[str]]) -> None:
if isinstance(adapter_name, str): if isinstance(adapter_name, str):

View File

@ -1624,7 +1624,12 @@ class PeftModelForSequenceClassification(PeftModel):
break break
# to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper # to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper
_set_trainable(self, adapter_name, module_names=getattr(peft_config, "modules_to_save", None)) _set_trainable(
self,
adapter_name,
module_names=getattr(peft_config, "modules_to_save", None),
inference_mode=peft_config.inference_mode,
)
def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None: def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None:
""" """
@ -2475,7 +2480,12 @@ class PeftModelForTokenClassification(PeftModel):
break break
# to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper # to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper
_set_trainable(self, adapter_name, module_names=getattr(peft_config, "modules_to_save", None)) _set_trainable(
self,
adapter_name,
module_names=getattr(peft_config, "modules_to_save", None),
inference_mode=peft_config.inference_mode,
)
def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None: def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None:
""" """
@ -2691,7 +2701,12 @@ class PeftModelForQuestionAnswering(PeftModel):
break break
# to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper # to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper
_set_trainable(self, adapter_name, module_names=getattr(peft_config, "modules_to_save", None)) _set_trainable(
self,
adapter_name,
module_names=getattr(peft_config, "modules_to_save", None),
inference_mode=peft_config.inference_mode,
)
def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None: def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None:
""" """

View File

@ -45,7 +45,9 @@ class AdaLoraLayer(LoraLayer):
self.lora_B = nn.ParameterDict({}) self.lora_B = nn.ParameterDict({})
self.ranknum = nn.ParameterDict({}) self.ranknum = nn.ParameterDict({})
def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): def update_layer(
self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, inference_mode: bool = False, **kwargs
):
if r < 0: if r < 0:
# note: r == 0 is allowed for AdaLora, see #1539 # note: r == 0 is allowed for AdaLora, see #1539
raise ValueError(f"`r` should be a positive integer or 0, but the value passed is {r}") raise ValueError(f"`r` should be a positive integer or 0, but the value passed is {r}")
@ -74,7 +76,7 @@ class AdaLoraLayer(LoraLayer):
self.reset_lora_parameters(adapter_name) self.reset_lora_parameters(adapter_name)
self._move_adapter_to_device_of_base_layer(adapter_name) self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters) self.set_adapter(self.active_adapters, inference_mode=inference_mode)
def reset_lora_parameters(self, adapter_name): def reset_lora_parameters(self, adapter_name):
if adapter_name in self.lora_A.keys(): if adapter_name in self.lora_A.keys():

View File

@ -261,7 +261,15 @@ class BOFTLayer(BaseTunerLayer):
warnings.warn("Unscaling operation for BOFT not supported! Keeping scale to 1.") warnings.warn("Unscaling operation for BOFT not supported! Keeping scale to 1.")
def update_layer( def update_layer(
self, adapter_name, boft_block_size, boft_block_num, boft_n_butterfly_factor, boft_dropout, init_weights self,
adapter_name,
boft_block_size,
boft_block_num,
boft_n_butterfly_factor,
boft_dropout,
init_weights,
inference_mode: bool = False,
**kwargs,
): ):
""" """
Update the linear layer with trainable BOFT weights. Override for other layer types. Update the linear layer with trainable BOFT weights. Override for other layer types.
@ -360,7 +368,7 @@ class BOFTLayer(BaseTunerLayer):
self.boft_block_num[adapter_name] = boft_block_num self.boft_block_num[adapter_name] = boft_block_num
self._move_adapter_to_device_of_base_layer(adapter_name) self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters) self.set_adapter(self.active_adapters, inference_mode=inference_mode)
def reset_boft_parameters(self, adapter_name, init_weights): def reset_boft_parameters(self, adapter_name, init_weights):
""" """
@ -682,7 +690,15 @@ class Conv2d(nn.Module, BOFTLayer):
) )
def update_layer( def update_layer(
self, adapter_name, boft_block_size, boft_block_num, boft_n_butterfly_factor, boft_dropout, init_weights self,
adapter_name,
boft_block_size,
boft_block_num,
boft_n_butterfly_factor,
boft_dropout,
init_weights,
inference_mode: bool = False,
**kwargs,
): ):
""" """
Update the conv2d layer with trainable BOFT weights. Update the conv2d layer with trainable BOFT weights.
@ -787,7 +803,7 @@ class Conv2d(nn.Module, BOFTLayer):
self.boft_block_num[adapter_name] = boft_block_num self.boft_block_num[adapter_name] = boft_block_num
self._move_adapter_to_device_of_base_layer(adapter_name) self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters) self.set_adapter(self.active_adapters, inference_mode=inference_mode)
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
""" """

View File

@ -245,14 +245,14 @@ class BOFTModel(BaseTuner):
warnings.warn(msg) warnings.warn(msg)
self._set_adapter_layers(enabled=False) self._set_adapter_layers(enabled=False)
def set_adapter(self, adapter_name): def set_adapter(self, adapter_name, inference_mode: bool = False):
self.set_auxiliary_adapters(adapter_name) self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode)
for module in self.model.modules(): for module in self.model.modules():
if isinstance(module, BOFTLayer): if isinstance(module, BOFTLayer):
if module.merged: if module.merged:
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
module.unmerge() module.unmerge()
module.set_adapter(adapter_name) module.set_adapter(adapter_name, inference_mode=inference_mode)
self.active_adapter = adapter_name self.active_adapter = adapter_name
@staticmethod @staticmethod

View File

@ -50,6 +50,7 @@ class BoneLayer(BaseTunerLayer):
adapter_name: str, adapter_name: str,
r: int, r: int,
init_weights: bool, init_weights: bool,
inference_mode: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
"""Internal function to create bone adapter """Internal function to create bone adapter
@ -83,7 +84,7 @@ class BoneLayer(BaseTunerLayer):
self.reset_bone_parameters_random(adapter_name) self.reset_bone_parameters_random(adapter_name)
# Move new weights to device # Move new weights to device
self._move_adapter_to_device_of_base_layer(adapter_name) self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters) self.set_adapter(self.active_adapters, inference_mode=inference_mode)
def reset_bone_parameters(self, adapter_name: str, r): def reset_bone_parameters(self, adapter_name: str, r):
self.bone_block[adapter_name] = nn.Parameter(torch.zeros(r, self.out_features), requires_grad=True) self.bone_block[adapter_name] = nn.Parameter(torch.zeros(r, self.out_features), requires_grad=True)

View File

@ -239,14 +239,14 @@ class BoneModel(BaseTuner):
warnings.warn(msg) warnings.warn(msg)
self._set_adapter_layers(enabled=False) self._set_adapter_layers(enabled=False)
def set_adapter(self, adapter_name): def set_adapter(self, adapter_name, inference_mode: bool = False):
self.set_auxiliary_adapters(adapter_name) self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode)
for module in self.model.modules(): for module in self.model.modules():
if isinstance(module, BoneLayer): if isinstance(module, BoneLayer):
if module.merged: if module.merged:
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
module.unmerge() module.unmerge()
module.set_adapter(adapter_name) module.set_adapter(adapter_name, inference_mode=inference_mode)
self.active_adapter = adapter_name self.active_adapter = adapter_name
@staticmethod @staticmethod

View File

@ -56,7 +56,7 @@ class C3ALayer(BaseTunerLayer):
delta_weight = get_circulant_fast(c3a_kernel.to(torch.float32)).to(base_layer_weight_dtype) delta_weight = get_circulant_fast(c3a_kernel.to(torch.float32)).to(base_layer_weight_dtype)
return delta_weight / base_layer_weight.size(-1) return delta_weight / base_layer_weight.size(-1)
def update_layer(self, adapter_name, block_size, init_weights): def update_layer(self, adapter_name, block_size, init_weights, inference_mode: bool = False, **kwargs):
if block_size <= 0: if block_size <= 0:
raise ValueError(f"`block_size` should be a positive integer value but the value passed is {block_size}") raise ValueError(f"`block_size` should be a positive integer value but the value passed is {block_size}")
if self.in_features % block_size != 0: if self.in_features % block_size != 0:
@ -85,7 +85,7 @@ class C3ALayer(BaseTunerLayer):
self.reset_c3a_parameters(adapter_name, init_weights) self.reset_c3a_parameters(adapter_name, init_weights)
self._move_adapter_to_device_of_base_layer(adapter_name) self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters) self.set_adapter(self.active_adapters, inference_mode=inference_mode)
@torch.no_grad() @torch.no_grad()
def reset_c3a_parameters(self, adapter_name, init_weights): def reset_c3a_parameters(self, adapter_name, init_weights):

View File

@ -213,19 +213,22 @@ class C3AModel(BaseTuner):
warnings.warn(msg) warnings.warn(msg)
self._set_adapter_layers(enabled=False) self._set_adapter_layers(enabled=False)
def set_adapter(self, adapter_name: str | list[str]) -> None: def set_adapter(self, adapter_name: str | list[str], inference_mode: bool = False) -> None:
"""Set the active adapter(s). """Set the active adapter(s).
Args: Args:
adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated. adapter_name (`str` or `list[str]`):
Name(s) of the adapter(s) to be activated.
inference_mode (bool, optional):
Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False.
""" """
self.set_auxiliary_adapters(adapter_name) self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode)
for module in self.model.modules(): for module in self.model.modules():
if isinstance(module, C3ALayer): if isinstance(module, C3ALayer):
if module.merged: if module.merged:
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
module.unmerge() module.unmerge()
module.set_adapter(adapter_name) module.set_adapter(adapter_name, inference_mode=inference_mode)
self.active_adapter = adapter_name self.active_adapter = adapter_name
@staticmethod @staticmethod

View File

@ -51,7 +51,9 @@ class FourierFTLayer(BaseTunerLayer):
else: else:
raise ValueError(f"Unsupported layer type {type(base_layer)}") raise ValueError(f"Unsupported layer type {type(base_layer)}")
def update_layer(self, adapter_name, n_frequency, scaling, init_weights, random_loc_seed): def update_layer(
self, adapter_name, n_frequency, scaling, init_weights, random_loc_seed, inference_mode: bool = False, **kwargs
):
if n_frequency <= 0: if n_frequency <= 0:
raise ValueError(f"`n_frequency` should be a positive integer value but the value passed is {n_frequency}") raise ValueError(f"`n_frequency` should be a positive integer value but the value passed is {n_frequency}")
if n_frequency > self.in_features * self.out_features: if n_frequency > self.in_features * self.out_features:
@ -76,7 +78,7 @@ class FourierFTLayer(BaseTunerLayer):
self.reset_fourier_parameters(adapter_name) self.reset_fourier_parameters(adapter_name)
self._move_adapter_to_device_of_base_layer(adapter_name) self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters) self.set_adapter(self.active_adapters, inference_mode=inference_mode)
@torch.no_grad() @torch.no_grad()
def reset_fourier_parameters(self, adapter_name): def reset_fourier_parameters(self, adapter_name):

View File

@ -245,19 +245,22 @@ class FourierFTModel(BaseTuner):
warnings.warn(msg) warnings.warn(msg)
self._set_adapter_layers(enabled=False) self._set_adapter_layers(enabled=False)
def set_adapter(self, adapter_name: str | list[str]) -> None: def set_adapter(self, adapter_name: str | list[str], inference_mode: bool = False) -> None:
"""Set the active adapter(s). """Set the active adapter(s).
Args: Args:
adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated. adapter_name (`str` or `list[str]`):
Name(s) of the adapter(s) to be activated.
inference_mode (bool, optional):
Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False.
""" """
self.set_auxiliary_adapters(adapter_name) self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode)
for module in self.model.modules(): for module in self.model.modules():
if isinstance(module, FourierFTLayer): if isinstance(module, FourierFTLayer):
if module.merged: if module.merged:
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
module.unmerge() module.unmerge()
module.set_adapter(adapter_name) module.set_adapter(adapter_name, inference_mode=inference_mode)
self.active_adapter = adapter_name self.active_adapter = adapter_name
@staticmethod @staticmethod

View File

@ -55,6 +55,7 @@ class HRALayer(BaseTunerLayer):
r: int, r: int,
apply_GS: bool, apply_GS: bool,
init_weights: bool, init_weights: bool,
inference_mode: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
"""Internal function to create hra adapter """Internal function to create hra adapter
@ -91,7 +92,7 @@ class HRALayer(BaseTunerLayer):
# Move new weights to device # Move new weights to device
self._move_adapter_to_device_of_base_layer(adapter_name) self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters) self.set_adapter(self.active_adapters, inference_mode=inference_mode)
def reset_hra_parameters(self, adapter_name: str): def reset_hra_parameters(self, adapter_name: str):
if self.hra_r[adapter_name] % 2 != 0: if self.hra_r[adapter_name] % 2 != 0:

View File

@ -244,14 +244,14 @@ class HRAModel(BaseTuner):
warnings.warn(msg) warnings.warn(msg)
self._set_adapter_layers(enabled=False) self._set_adapter_layers(enabled=False)
def set_adapter(self, adapter_name): def set_adapter(self, adapter_name, inference_mode: bool = False):
self.set_auxiliary_adapters(adapter_name) self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode)
for module in self.model.modules(): for module in self.model.modules():
if isinstance(module, HRALayer): if isinstance(module, HRALayer):
if module.merged: if module.merged:
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
module.unmerge() module.unmerge()
module.set_adapter(adapter_name) module.set_adapter(adapter_name, inference_mode=inference_mode)
self.active_adapter = adapter_name self.active_adapter = adapter_name
@staticmethod @staticmethod

View File

@ -51,7 +51,7 @@ class IA3Layer(BaseTunerLayer):
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
def update_layer(self, adapter_name, init_ia3_weights): def update_layer(self, adapter_name, init_ia3_weights, inference_mode: bool = False, **kwargs):
# This code works for linear layers, override for other layer types # This code works for linear layers, override for other layer types
# Actual trainable parameters # Actual trainable parameters
if self.is_feedforward: if self.is_feedforward:
@ -62,7 +62,7 @@ class IA3Layer(BaseTunerLayer):
if init_ia3_weights: if init_ia3_weights:
self.reset_ia3_parameters(adapter_name) self.reset_ia3_parameters(adapter_name)
self._move_adapter_to_device_of_base_layer(adapter_name) self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters) self.set_adapter(self.active_adapters, inference_mode=inference_mode)
def reset_ia3_parameters(self, adapter_name): def reset_ia3_parameters(self, adapter_name):
if adapter_name in self.ia3_l.keys(): if adapter_name in self.ia3_l.keys():
@ -202,7 +202,7 @@ class _ConvNd(nn.Module, IA3Layer):
self.update_layer(adapter_name, init_ia3_weights) self.update_layer(adapter_name, init_ia3_weights)
def update_layer(self, adapter_name, init_ia3_weights): def update_layer(self, adapter_name, init_ia3_weights, inference_mode: bool = False, **kwargs):
# Actual trainable parameters # Actual trainable parameters
num_features = self.in_features if self.is_feedforward else self.out_features num_features = self.in_features if self.is_feedforward else self.out_features
weights_size = (1, num_features) + (1,) * (self._kernel_dim - 2) weights_size = (1, num_features) + (1,) * (self._kernel_dim - 2)
@ -211,7 +211,7 @@ class _ConvNd(nn.Module, IA3Layer):
if init_ia3_weights: if init_ia3_weights:
self.reset_ia3_parameters(adapter_name) self.reset_ia3_parameters(adapter_name)
self._move_adapter_to_device_of_base_layer(adapter_name) self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters) self.set_adapter(self.active_adapters, inference_mode=inference_mode)
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
""" """

View File

@ -260,28 +260,22 @@ class IA3Model(BaseTuner):
""" """
self._set_adapter_layers(enabled=False) self._set_adapter_layers(enabled=False)
def set_adapter(self, adapter_name: str | list[str]) -> None: def set_adapter(self, adapter_name: str | list[str], inference_mode: bool = False) -> None:
"""Set the active adapter(s). """Set the active adapter(s).
Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is
not desired, use the following code.
```py
>>> for name, param in model_peft.named_parameters():
... if ...: # some check on name (ex. if 'lora' in name)
... param.requires_grad = False
```
Args: Args:
adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated. adapter_name (`str` or `list[str]`):
Name(s) of the adapter(s) to be activated.
inference_mode (bool, optional):
Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False.
""" """
self.set_auxiliary_adapters(adapter_name) self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode)
for module in self.model.modules(): for module in self.model.modules():
if isinstance(module, IA3Layer): if isinstance(module, IA3Layer):
if module.merged: if module.merged:
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
module.unmerge() module.unmerge()
module.set_adapter(adapter_name) module.set_adapter(adapter_name, inference_mode=inference_mode)
self.active_adapter = adapter_name self.active_adapter = adapter_name
@staticmethod @staticmethod

View File

@ -37,8 +37,9 @@ class LNTuningLayer(nn.Module, BaseTunerLayer):
self._active_adapter = adapter_name self._active_adapter = adapter_name
self.merged_adapters = [] self.merged_adapters = []
def update_layer(self, layer: nn.Module, adapter_name: str): def update_layer(self, layer: nn.Module, adapter_name: str, inference_mode: bool = False, **kwargs):
self.ln_tuning_layers[adapter_name] = deepcopy(layer) self.ln_tuning_layers[adapter_name] = deepcopy(layer)
self.set_adapter(adapter_name, inference_mode=inference_mode)
def enable_adapters(self, enabled: bool) -> None: def enable_adapters(self, enabled: bool) -> None:
"""Toggle the enabling and disabling of adapters """Toggle the enabling and disabling of adapters

View File

@ -134,8 +134,6 @@ class LNTuningModel(BaseTuner):
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if self.prefix not in n: if self.prefix not in n:
p.requires_grad = False p.requires_grad = False
else:
p.requires_grad = True
def _check_target_module_exists(self, peft_config: PeftConfig, key: str) -> bool: def _check_target_module_exists(self, peft_config: PeftConfig, key: str) -> bool:
return check_target_module_exists(peft_config, key) return check_target_module_exists(peft_config, key)
@ -159,14 +157,14 @@ class LNTuningModel(BaseTuner):
""" """
self._set_adapter_layers(enabled=False) self._set_adapter_layers(enabled=False)
def set_adapter(self, adapter_name: str) -> None: def set_adapter(self, adapter_name: str, inference_mode: bool = False) -> None:
self.set_auxiliary_adapters(adapter_name) self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode)
for module in self.model.modules(): for module in self.model.modules():
if isinstance(module, LNTuningLayer): if isinstance(module, LNTuningLayer):
if module.merged: if module.merged:
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
module.unmerge() module.unmerge()
module.set_adapter(adapter_name) module.set_adapter(adapter_name, inference_mode=inference_mode)
self.active_adapter = adapter_name self.active_adapter = adapter_name
def _unload_and_optionally_merge( def _unload_and_optionally_merge(

View File

@ -107,6 +107,7 @@ class LoHaLayer(nn.Module, LycorisLayer):
module_dropout: float, module_dropout: float,
init_weights: bool, init_weights: bool,
use_effective_conv2d: bool = False, use_effective_conv2d: bool = False,
inference_mode: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
"""Internal function to create loha adapter """Internal function to create loha adapter
@ -175,7 +176,7 @@ class LoHaLayer(nn.Module, LycorisLayer):
# Move new weights to device # Move new weights to device
self._move_adapter_to_device_of_base_layer(adapter_name) self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters) self.set_adapter(self.active_adapters, inference_mode=inference_mode)
def get_delta_weight(self, adapter_name: str) -> torch.Tensor: def get_delta_weight(self, adapter_name: str) -> torch.Tensor:
# https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/loha.py#L178 # https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/loha.py#L178

View File

@ -166,6 +166,7 @@ class LoKrLayer(nn.Module, LycorisLayer):
use_effective_conv2d: bool, use_effective_conv2d: bool,
decompose_both: bool, decompose_both: bool,
decompose_factor: int, decompose_factor: int,
inference_mode: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
"""Internal function to create lokr adapter """Internal function to create lokr adapter
@ -251,7 +252,7 @@ class LoKrLayer(nn.Module, LycorisLayer):
# Move new weights to device # Move new weights to device
self._move_adapter_to_device_of_base_layer(adapter_name) self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters) self.set_adapter(self.active_adapters, inference_mode=inference_mode)
def get_delta_weight(self, adapter_name: str) -> torch.Tensor: def get_delta_weight(self, adapter_name: str) -> torch.Tensor:
# https://github.com/KohakuBlueleaf/LyCORIS/blob/e4259b870d3354a9615a96be61cb5d07455c58ea/lycoris/modules/lokr.py#L224 # https://github.com/KohakuBlueleaf/LyCORIS/blob/e4259b870d3354a9615a96be61cb5d07455c58ea/lycoris/modules/lokr.py#L224

View File

@ -208,6 +208,7 @@ class LoraLayer(BaseTunerLayer):
lora_bias: bool = False, lora_bias: bool = False,
arrow_config: ArrowConfig = None, arrow_config: ArrowConfig = None,
qalora_group_size: int = 32, qalora_group_size: int = 32,
inference_mode: bool = False,
**kwargs, **kwargs,
): ):
# collect the kwargs # collect the kwargs
@ -282,7 +283,7 @@ class LoraLayer(BaseTunerLayer):
if adapter_name in self.lora_variant: if adapter_name in self.lora_variant:
self.lora_variant[adapter_name].init(self, **kwargs) self.lora_variant[adapter_name].init(self, **kwargs)
self.set_adapter(self.active_adapters) self.set_adapter(self.active_adapters, inference_mode=inference_mode)
# Check for adapters that were added or removed from the arrow_model. # Check for adapters that were added or removed from the arrow_model.
# The arrow model may be modified after creation by adding new experts # The arrow model may be modified after creation by adding new experts
@ -925,6 +926,7 @@ class Embedding(nn.Module, LoraLayer):
use_dora, use_dora,
lora_bias, lora_bias,
arrow_config: ArrowConfig = None, arrow_config: ArrowConfig = None,
inference_mode: bool = False,
**kwargs, **kwargs,
): ):
# collect the kwargs # collect the kwargs
@ -971,7 +973,7 @@ class Embedding(nn.Module, LoraLayer):
if adapter_name in self.lora_variant: if adapter_name in self.lora_variant:
self.lora_variant[adapter_name].init(self, **kwargs) self.lora_variant[adapter_name].init(self, **kwargs)
self.set_adapter(self.active_adapters) self.set_adapter(self.active_adapters, inference_mode=inference_mode)
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
""" """
@ -1211,6 +1213,7 @@ class _ConvNd(nn.Module, LoraLayer):
use_dora, use_dora,
lora_bias, lora_bias,
arrow_config: ArrowConfig = None, arrow_config: ArrowConfig = None,
inference_mode: bool = False,
**kwargs, **kwargs,
): ):
# collect the kwargs # collect the kwargs
@ -1270,7 +1273,7 @@ class _ConvNd(nn.Module, LoraLayer):
if adapter_name in self.lora_variant: if adapter_name in self.lora_variant:
self.lora_variant[adapter_name].init(self, **kwargs) self.lora_variant[adapter_name].init(self, **kwargs)
self.set_adapter(self.active_adapters) self.set_adapter(self.active_adapters, inference_mode=inference_mode)
def _get_dora_factor_view(self): def _get_dora_factor_view(self):
return (-1,) + (1,) * (self._kernel_dim - 1) return (-1,) + (1,) * (self._kernel_dim - 1)
@ -1990,6 +1993,7 @@ class ParamWrapper(nn.Module, LoraLayer):
use_qalora: bool = False, use_qalora: bool = False,
lora_bias: bool = False, lora_bias: bool = False,
qalora_group_size: int = 32, qalora_group_size: int = 32,
inference_mode: bool = False,
**kwargs, **kwargs,
): ):
# same method as in lora.Linear but taking into account that there can be multiple experts (3d parameter) # same method as in lora.Linear but taking into account that there can be multiple experts (3d parameter)
@ -2057,7 +2061,7 @@ class ParamWrapper(nn.Module, LoraLayer):
if adapter_name in self.lora_variant: if adapter_name in self.lora_variant:
self.lora_variant[adapter_name].init(self, **kwargs) self.lora_variant[adapter_name].init(self, **kwargs)
self.set_adapter(self.active_adapters) self.set_adapter(self.active_adapters, inference_mode=inference_mode)
def _move_adapter_to_device_of_base_layer(self, adapter_name: str, device: Optional[torch.device] = None) -> None: def _move_adapter_to_device_of_base_layer(self, adapter_name: str, device: Optional[torch.device] = None) -> None:
""" """

View File

@ -257,6 +257,7 @@ class LoraModel(BaseTuner):
use_dora=lora_config.use_dora, use_dora=lora_config.use_dora,
lora_bias=lora_config.lora_bias, lora_bias=lora_config.lora_bias,
arrow_config=lora_config.arrow_config, arrow_config=lora_config.arrow_config,
inference_mode=lora_config.inference_mode,
) )
else: else:
if isinstance(target, ParamWrapper) and (parameter_name == target.parameter_name): if isinstance(target, ParamWrapper) and (parameter_name == target.parameter_name):
@ -430,28 +431,22 @@ class LoraModel(BaseTuner):
warnings.warn(msg) warnings.warn(msg)
self._set_adapter_layers(enabled=False) self._set_adapter_layers(enabled=False)
def set_adapter(self, adapter_name: str | list[str]) -> None: def set_adapter(self, adapter_name: str | list[str], inference_mode: bool = False) -> None:
"""Set the active adapter(s). """Set the active adapter(s)
Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is
not desired, use the following code.
```py
>>> for name, param in model_peft.named_parameters():
... if ...: # some check on name (ex. if 'lora' in name)
... param.requires_grad = False
```
Args: Args:
adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated. adapter_name (str, list[str]):
The name(s) of the adapter(s) to set as active
inference_mode (bool, optional):
Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False.
""" """
self.set_auxiliary_adapters(adapter_name) self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode)
for module in self.model.modules(): for module in self.model.modules():
if isinstance(module, LoraLayer): if isinstance(module, LoraLayer):
if module.merged: if module.merged:
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
module.unmerge() module.unmerge()
module.set_adapter(adapter_name) module.set_adapter(adapter_name, inference_mode=inference_mode)
self.active_adapter = adapter_name self.active_adapter = adapter_name
@contextmanager @contextmanager

View File

@ -110,6 +110,7 @@ class LoraParallelLinear(nn.Module, LoraLayer):
init_method=init.xavier_normal_, init_method=init.xavier_normal_,
input_is_parallel=True, input_is_parallel=True,
gather_output=False, gather_output=False,
inference_mode: bool = False,
**parallel_linear_kwargs, **parallel_linear_kwargs,
): ):
# collect the kwargs # collect the kwargs
@ -182,7 +183,7 @@ class LoraParallelLinear(nn.Module, LoraLayer):
if adapter_name in self.lora_variant: if adapter_name in self.lora_variant:
self.lora_variant[adapter_name].init(self, **kwargs) self.lora_variant[adapter_name].init(self, **kwargs)
self.set_adapter(self.active_adapters) self.set_adapter(self.active_adapters, inference_mode=inference_mode)
def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any): def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any):
self._check_forward_args(x, *args, **kwargs) self._check_forward_args(x, *args, **kwargs)

View File

@ -389,28 +389,22 @@ class LycorisTuner(BaseTuner):
""" """
return self._unload_and_optionally_merge(merge=False) return self._unload_and_optionally_merge(merge=False)
def set_adapter(self, adapter_name: str | list[str]) -> None: def set_adapter(self, adapter_name: str | list[str], inference_mode: bool = False) -> None:
"""Set the active adapter(s). """Set the active adapter(s).
Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is
not desired, use the following code.
```py
>>> for name, param in model_peft.named_parameters():
... if ...: # some check on name (ex. if 'lora' in name)
... param.requires_grad = False
```
Args: Args:
adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated. adapter_name (`str` or `list[str]`):
Name(s) of the adapter(s) to be activated.
inference_mode (bool, optional):
Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False.
""" """
self.set_auxiliary_adapters(adapter_name) self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode)
for module in self.model.modules(): for module in self.model.modules():
if isinstance(module, LycorisLayer): if isinstance(module, LycorisLayer):
if module.merged: if module.merged:
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
module.unmerge() module.unmerge()
module.set_adapter(adapter_name) module.set_adapter(adapter_name, inference_mode=inference_mode)
self.active_adapter = adapter_name self.active_adapter = adapter_name
def delete_adapter(self, adapter_name: str) -> None: def delete_adapter(self, adapter_name: str) -> None:

View File

@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import annotations
import math import math
import warnings import warnings
from typing import Any, Optional, Union from typing import Any, Optional, Union
@ -53,7 +55,8 @@ class MissLayer(BaseTunerLayer):
r: int, r: int,
mini_r: int, mini_r: int,
miss_dropout, miss_dropout,
init_weights: bool, init_weights: bool | str,
inference_mode: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
"""Internal function to create miss adapter """Internal function to create miss adapter
@ -101,7 +104,7 @@ class MissLayer(BaseTunerLayer):
self.reset_miss_parameters_random(adapter_name) self.reset_miss_parameters_random(adapter_name)
# Move new weights to device # Move new weights to device
self._move_adapter_to_device_of_base_layer(adapter_name) self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters) self.set_adapter(self.active_adapters, inference_mode=inference_mode)
def reset_miss_parameters(self, adapter_name: str, r): def reset_miss_parameters(self, adapter_name: str, r):
self.miss_block[adapter_name] = nn.Parameter(torch.zeros(r, self.out_features), requires_grad=True) self.miss_block[adapter_name] = nn.Parameter(torch.zeros(r, self.out_features), requires_grad=True)

View File

@ -243,14 +243,14 @@ class MissModel(BaseTuner):
warnings.warn(msg) warnings.warn(msg)
self._set_adapter_layers(enabled=False) self._set_adapter_layers(enabled=False)
def set_adapter(self, adapter_name): def set_adapter(self, adapter_name, inference_mode: bool = False):
self.set_auxiliary_adapters(adapter_name) self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode)
for module in self.model.modules(): for module in self.model.modules():
if isinstance(module, MissLayer): if isinstance(module, MissLayer):
if module.merged: if module.merged:
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
module.unmerge() module.unmerge()
module.set_adapter(adapter_name) module.set_adapter(adapter_name, inference_mode=inference_mode)
self.active_adapter = adapter_name self.active_adapter = adapter_name
@staticmethod @staticmethod

View File

@ -225,13 +225,14 @@ class MixedModel(BaseTuner):
warnings.warn(msg) warnings.warn(msg)
self._set_adapter_layers(enabled=False) self._set_adapter_layers(enabled=False)
def set_adapter(self, adapter_name: Union[str, list[str]]) -> None: def set_adapter(self, adapter_name: Union[str, list[str]], inference_mode: bool = False) -> None:
self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode)
for module in self.model.modules(): for module in self.model.modules():
if isinstance(module, Layers): if isinstance(module, Layers):
if module.merged: if module.merged:
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
module.unmerge() module.unmerge()
module.set_adapter(adapter_name) module.set_adapter(adapter_name, inference_mode=inference_mode)
self.active_adapter = adapter_name self.active_adapter = adapter_name
@staticmethod @staticmethod

View File

@ -412,6 +412,8 @@ class OFTLayer(BaseTunerLayer):
init_weights, init_weights,
use_cayley_neumann, use_cayley_neumann,
num_cayley_neumann_terms, num_cayley_neumann_terms,
inference_mode: bool = False,
**kwargs,
): ):
""" """
Update the linear layer with trainable OFT weights. Override for other layer types. Update the linear layer with trainable OFT weights. Override for other layer types.
@ -479,7 +481,7 @@ class OFTLayer(BaseTunerLayer):
# Move new weights to device # Move new weights to device
self._move_adapter_to_device_of_base_layer(adapter_name) self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters) self.set_adapter(self.active_adapters, inference_mode=inference_mode)
def reset_oft_parameters(self, adapter_name, init_weights): def reset_oft_parameters(self, adapter_name, init_weights):
""" """
@ -716,6 +718,8 @@ class Conv2d(nn.Module, OFTLayer):
init_weights, init_weights,
use_cayley_neumann, use_cayley_neumann,
num_cayley_neumann_terms, num_cayley_neumann_terms,
inference_mode: bool = False,
**kwargs,
): ):
""" """
Update the conv2d layer with trainable OFT weights. Update the conv2d layer with trainable OFT weights.
@ -777,7 +781,7 @@ class Conv2d(nn.Module, OFTLayer):
# Move new weights to device # Move new weights to device
self._move_adapter_to_device_of_base_layer(adapter_name) self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters) self.set_adapter(self.active_adapters, inference_mode=inference_mode)
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
""" """

View File

@ -306,28 +306,24 @@ class OFTModel(BaseTuner):
warnings.warn(msg) warnings.warn(msg)
self._set_adapter_layers(enabled=False) self._set_adapter_layers(enabled=False)
def set_adapter(self, adapter_name): def set_adapter(self, adapter_name, inference_mode: bool = False):
"""Set the active adapter(s). """Set the active adapter(s).
Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is
not desired, use the following code. not desired, use the following code.
```py adapter_name (`str` or `list[str]`):
>>> for name, param in model_peft.named_parameters(): Name(s) of the adapter(s) to be activated.
... if ...: # some check on name (ex. if 'lora' in name) inference_mode (bool, optional):
... param.requires_grad = False Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False.
```
Args:
adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated.
""" """
self.set_auxiliary_adapters(adapter_name) self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode)
for module in self.model.modules(): for module in self.model.modules():
if isinstance(module, OFTLayer): if isinstance(module, OFTLayer):
if module.merged: if module.merged:
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
module.unmerge() module.unmerge()
module.set_adapter(adapter_name) module.set_adapter(adapter_name, inference_mode=inference_mode)
self.active_adapter = adapter_name self.active_adapter = adapter_name
def _check_merge_allowed(self): def _check_merge_allowed(self):

View File

@ -51,7 +51,7 @@ class PolyLayer(BaseTunerLayer):
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
def update_layer(self, adapter_name, poly_config): def update_layer(self, adapter_name, poly_config, inference_mode: bool = False, **kwargs):
if poly_config.r <= 0: if poly_config.r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {poly_config.r}") raise ValueError(f"`r` should be a positive integer value but the value passed is {poly_config.r}")
@ -82,7 +82,7 @@ class PolyLayer(BaseTunerLayer):
self.reset_poly_parameters(adapter_name, init_weights=poly_config.init_weights) self.reset_poly_parameters(adapter_name, init_weights=poly_config.init_weights)
self._move_adapter_to_device_of_base_layer(adapter_name) self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters) self.set_adapter(self.active_adapters, inference_mode=inference_mode)
def reset_poly_parameters(self, adapter_name, init_weights): def reset_poly_parameters(self, adapter_name, init_weights):
if adapter_name in self.poly_lora_A.keys(): if adapter_name in self.poly_lora_A.keys():

View File

@ -135,11 +135,11 @@ class PolyModel(BaseTuner):
def disable_adapter_layers(self): def disable_adapter_layers(self):
self._set_adapter_layers(enabled=False) self._set_adapter_layers(enabled=False)
def set_adapter(self, adapter_name): def set_adapter(self, adapter_name, inference_mode: bool = False):
self.set_auxiliary_adapters(adapter_name) self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode)
for module in self.model.modules(): for module in self.model.modules():
if isinstance(module, PolyLayer): if isinstance(module, PolyLayer):
module.set_adapter(adapter_name) module.set_adapter(adapter_name, inference_mode=inference_mode)
def _prepare_adapter_config(self, peft_config, model_config): def _prepare_adapter_config(self, peft_config, model_config):
if peft_config.target_modules is None: if peft_config.target_modules is None:

View File

@ -99,6 +99,8 @@ class RandLoraLayer(BaseTunerLayer):
randlora_alpha, randlora_alpha,
randlora_dropout, randlora_dropout,
init_weights, init_weights,
inference_mode: bool = False,
**kwargs,
): ):
if r <= 0: if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
@ -166,7 +168,7 @@ class RandLoraLayer(BaseTunerLayer):
self.reset_randlora_parameters(adapter_name) self.reset_randlora_parameters(adapter_name)
self._move_adapter_to_device_of_base_layer(adapter_name) self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters) self.set_adapter(self.active_adapters, inference_mode=inference_mode)
def reset_randlora_parameters(self, adapter_name): def reset_randlora_parameters(self, adapter_name):
if adapter_name in self.randlora_lambda.keys(): if adapter_name in self.randlora_lambda.keys():

View File

@ -456,14 +456,14 @@ class RandLoraModel(BaseTuner):
warnings.warn(msg) warnings.warn(msg)
self._set_adapter_layers(enabled=False) self._set_adapter_layers(enabled=False)
def set_adapter(self, adapter_name): def set_adapter(self, adapter_name, inference_mode: bool = False):
self.set_auxiliary_adapters(adapter_name) self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode)
for module in self.model.modules(): for module in self.model.modules():
if isinstance(module, RandLoraLayer): if isinstance(module, RandLoraLayer):
if module.merged: if module.merged:
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
module.unmerge() module.unmerge()
module.set_adapter(adapter_name) module.set_adapter(adapter_name, inference_mode=inference_mode)
self.active_adapter = adapter_name self.active_adapter = adapter_name
@staticmethod @staticmethod

View File

@ -81,6 +81,7 @@ class RoadLayer(BaseTunerLayer):
variant, variant,
group_size, group_size,
init_weights, init_weights,
inference_mode: bool = False,
): ):
self.variant[adapter_name] = variant self.variant[adapter_name] = variant
self.group_size[adapter_name] = group_size self.group_size[adapter_name] = group_size
@ -107,7 +108,7 @@ class RoadLayer(BaseTunerLayer):
self.reset_parameters(adapter_name, init_weights) self.reset_parameters(adapter_name, init_weights)
self._move_adapter_to_device_of_base_layer(adapter_name) self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters) self.set_adapter(self.active_adapters, inference_mode=inference_mode)
def reset_parameters(self, adapter_name, init_weights): def reset_parameters(self, adapter_name, init_weights):
if init_weights is False: if init_weights is False:

View File

@ -188,25 +188,19 @@ class RoadModel(BaseTuner):
def enable_adapter_layers(self) -> None: def enable_adapter_layers(self) -> None:
self._set_adapter_layers(enabled=True) self._set_adapter_layers(enabled=True)
def set_adapter(self, adapter_name: str | list[str]) -> None: def set_adapter(self, adapter_name: str | list[str], inference_mode: bool = False) -> None:
"""Set the active adapter(s). """Set the active adapter(s).
Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is
not desired, use the following code.
```py
>>> for name, param in model_peft.named_parameters():
... if ...: # some check on name (ex. if 'lora' in name)
... param.requires_grad = False
```
Args: Args:
adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated. adapter_name (`str` or `list[str]`):
Name(s) of the adapter(s) to be activated.
inference_mode (bool, optional):
Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False.
""" """
self.set_auxiliary_adapters(adapter_name) self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode)
for module in self.model.modules(): for module in self.model.modules():
if isinstance(module, RoadLayer): if isinstance(module, RoadLayer):
module.set_adapter(adapter_name) module.set_adapter(adapter_name, inference_mode=inference_mode)
self.active_adapter = adapter_name self.active_adapter = adapter_name
def __getattr__(self, name: str): def __getattr__(self, name: str):

View File

@ -57,6 +57,8 @@ class ShiraLayer(BaseTunerLayer):
mask, mask,
r, r,
init_weights: bool = True, init_weights: bool = True,
inference_mode: bool = False,
**kwargs,
): ):
if r <= 0: if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
@ -95,7 +97,7 @@ class ShiraLayer(BaseTunerLayer):
) )
self._move_adapter_to_device_of_base_layer(adapter_name) self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters) self.set_adapter(self.active_adapters, inference_mode=inference_mode)
def reset_shira_parameters(self, adapter_name): def reset_shira_parameters(self, adapter_name):
nn.init.zeros_(self.shira_weight[adapter_name]) nn.init.zeros_(self.shira_weight[adapter_name])

View File

@ -226,14 +226,14 @@ class ShiraModel(BaseTuner):
def disable_adapter_layers(self): def disable_adapter_layers(self):
self._set_adapter_layers(enabled=False) self._set_adapter_layers(enabled=False)
def set_adapter(self, adapter_name): def set_adapter(self, adapter_name, inference_mode: bool = False):
self.set_auxiliary_adapters(adapter_name) self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode)
for module in self.model.modules(): for module in self.model.modules():
if isinstance(module, ShiraLayer): if isinstance(module, ShiraLayer):
if module.merged: if module.merged:
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
module.unmerge() module.unmerge()
module.set_adapter(adapter_name) module.set_adapter(adapter_name, inference_mode=inference_mode)
self.active_adapter = adapter_name self.active_adapter = adapter_name
@staticmethod @staticmethod

View File

@ -204,27 +204,21 @@ class TrainableTokensModel(BaseTuner):
""" """
self._set_adapter_layers(enabled=False) self._set_adapter_layers(enabled=False)
def set_adapter(self, adapter_name: str | list[str]) -> None: def set_adapter(self, adapter_name: str | list[str], inference_mode: bool = False) -> None:
"""Set the active adapter(s). """Set the active adapter(s).
Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is
not desired, use the following code.
```py
>>> for name, param in model_peft.named_parameters():
... if ...: # some check on name (ex. if 'lora' in name)
... param.requires_grad = False
```
Args: Args:
adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated. adapter_name (`str` or `list[str]`):
Name(s) of the adapter(s) to be activated.
inference_mode (bool, optional):
Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False.
""" """
for module in self.model.modules(): for module in self.model.modules():
if isinstance(module, TrainableTokensLayer): if isinstance(module, TrainableTokensLayer):
if module.merged: if module.merged:
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
module.unmerge() module.unmerge()
module.set_adapter(adapter_name) module.set_adapter(adapter_name, inference_mode=inference_mode)
self.active_adapter = adapter_name self.active_adapter = adapter_name
def unload(self) -> torch.nn.Module: def unload(self) -> torch.nn.Module:

View File

@ -710,7 +710,7 @@ class BaseTuner(nn.Module, ABC):
# It's important to set the adapter here (again), because otherwise it can happen that if a 2nd adapter is # It's important to set the adapter here (again), because otherwise it can happen that if a 2nd adapter is
# added, and it targets different layer(s) than the first adapter (which is active), then those different # added, and it targets different layer(s) than the first adapter (which is active), then those different
# layers will be activated, which we don't want. # layers will be activated, which we don't want.
self.set_adapter(self.active_adapters) self.set_adapter(self.active_adapters, inference_mode=peft_config.inference_mode)
self._mark_only_adapters_as_trainable(model) self._mark_only_adapters_as_trainable(model)
if self.peft_config[adapter_name].inference_mode: if self.peft_config[adapter_name].inference_mode:
@ -839,7 +839,7 @@ class BaseTuner(nn.Module, ABC):
with onload_layer(module): with onload_layer(module):
module.unmerge() module.unmerge()
def set_auxiliary_adapters(self, adapter_name: str | list[str]) -> None: def set_auxiliary_adapters(self, adapter_name: str | list[str], inference_mode: bool) -> None:
""" """
Sets the active adapter(s) on auxiliary modules. Sets the active adapter(s) on auxiliary modules.
@ -848,9 +848,11 @@ class BaseTuner(nn.Module, ABC):
Args: Args:
adapter_name (`str` or `list[str]`): adapter_name (`str` or `list[str]`):
The name(s) of the adapter to be set as active. The adapters must be loaded first. The name(s) of the adapter(s) to be set as active. The adapters must be loaded first.
inference_mode (bool, optional):
Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False.
""" """
_set_adapter(self, adapter_name) _set_adapter(self, adapter_name, inference_mode=inference_mode)
def _delete_auxiliary_adapter(self, adapter_name: str, new_active_adapters: Optional[list[str]]) -> None: def _delete_auxiliary_adapter(self, adapter_name: str, new_active_adapters: Optional[list[str]]) -> None:
for module in self.modules(): for module in self.modules():
@ -1010,29 +1012,26 @@ class BaseTunerLayer(ABC):
layer.requires_grad_(False) layer.requires_grad_(False)
self._disable_adapters = True self._disable_adapters = True
def set_adapter(self, adapter_names: str | list[str]) -> None: def set_adapter(self, adapter_names: str | list[str], inference_mode: bool = False) -> None:
"""Set the active adapter(s). """Set the active adapter(s).
Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is Additionally, this function will set the specified adapter to trainable (i.e., requires_grad=True) unless
not desired, use the following code. inference_mode is True.
```py
>>> for name, param in model_peft.named_parameters():
... if ...: # some check on name (ex. if 'lora' in name)
... param.requires_grad = False
```
Args: Args:
adapter_name (`str` or `List[str]`): Name of the adapter(s) to be activated. adapter_name (`str` or `list[str]`):
The name(s) of the adapter(s) to set as active.
inference_mode (bool, optional):
Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False.
""" """
if isinstance(adapter_names, str): if isinstance(adapter_names, str):
adapter_names = [adapter_names] adapter_names = [adapter_names]
# Deactivate grads on the inactive adapter and activate grads on the active adapter # Deactivate grads on the inactive adapter and activate grads on the active adapter (if not in inference mode)
for layer_name in self.adapter_layer_names: for layer_name in self.adapter_layer_names:
module_dict = getattr(self, layer_name) module_dict = getattr(self, layer_name)
for key, layer in module_dict.items(): for key, layer in module_dict.items():
if key in adapter_names: if (key in adapter_names) and (not inference_mode):
# Note: It is possible that not a single layer is called with requires_grad_(True) here. This may # Note: It is possible that not a single layer is called with requires_grad_(True) here. This may
# happen if a completely different adapter layer is being activated. # happen if a completely different adapter layer is being activated.
layer.requires_grad_(True) layer.requires_grad_(True)

View File

@ -68,6 +68,8 @@ class VBLoRALayer(BaseTunerLayer):
vector_length: float, vector_length: float,
vblora_dropout: float = 0.0, vblora_dropout: float = 0.0,
init_logits_std: float = 0.01, init_logits_std: float = 0.01,
inference_mode: bool = False,
**kwargs,
): ):
if r <= 0: if r <= 0:
raise ValueError(f"`r` {r} should be a positive integer value") raise ValueError(f"`r` {r} should be a positive integer value")
@ -97,7 +99,7 @@ class VBLoRALayer(BaseTunerLayer):
self.vblora_vector_bank = vblora_vector_bank self.vblora_vector_bank = vblora_vector_bank
self.reset_vblora_logits(adapter_name, init_logits_std) self.reset_vblora_logits(adapter_name, init_logits_std)
self._move_adapter_to_device_of_base_layer(adapter_name) self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters) self.set_adapter(self.active_adapters, inference_mode=inference_mode)
def reset_vblora_logits(self, adapter_name, init_logits_std): def reset_vblora_logits(self, adapter_name, init_logits_std):
if adapter_name in self.vblora_logits_A.keys(): if adapter_name in self.vblora_logits_A.keys():

View File

@ -278,28 +278,22 @@ class VBLoRAModel(BaseTuner):
warnings.warn(msg) warnings.warn(msg)
self._set_adapter_layers(enabled=False) self._set_adapter_layers(enabled=False)
def set_adapter(self, adapter_name: str | list[str]) -> None: def set_adapter(self, adapter_name: str | list[str], inference_mode: bool = False) -> None:
"""Set the active adapter(s). """Set the active adapter(s).
Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is
not desired, use the following code.
```py
>>> for name, param in model_peft.named_parameters():
... if ...: # some check on name (ex. if 'lora' in name)
... param.requires_grad = False
```
Args: Args:
adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated. adapter_name (`str` or `list[str]`):
Name(s) of the adapter(s) to be activated.
inference_mode (bool, optional):
Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False.
""" """
self.set_auxiliary_adapters(adapter_name) self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode)
for module in self.model.modules(): for module in self.model.modules():
if isinstance(module, VBLoRALayer): if isinstance(module, VBLoRALayer):
if module.merged: if module.merged:
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
module.unmerge() module.unmerge()
module.set_adapter(adapter_name) module.set_adapter(adapter_name, inference_mode=inference_mode)
self.active_adapter = adapter_name self.active_adapter = adapter_name
@staticmethod @staticmethod

View File

@ -74,6 +74,8 @@ class VeraLayer(BaseTunerLayer):
vera_dropout, vera_dropout,
init_weights, init_weights,
d_initial: float = 0.1, d_initial: float = 0.1,
inference_mode: bool = False,
**kwargs,
): ):
if r <= 0: if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
@ -129,7 +131,7 @@ class VeraLayer(BaseTunerLayer):
self.reset_vera_parameters(adapter_name, d_initial=d_initial) self.reset_vera_parameters(adapter_name, d_initial=d_initial)
self._move_adapter_to_device_of_base_layer(adapter_name) self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters) self.set_adapter(self.active_adapters, inference_mode=inference_mode)
def reset_vera_parameters(self, adapter_name, d_initial: float = 0.1): def reset_vera_parameters(self, adapter_name, d_initial: float = 0.1):
if adapter_name in self.vera_lambda_d.keys(): if adapter_name in self.vera_lambda_d.keys():

View File

@ -394,14 +394,14 @@ class VeraModel(BaseTuner):
warnings.warn(msg) warnings.warn(msg)
self._set_adapter_layers(enabled=False) self._set_adapter_layers(enabled=False)
def set_adapter(self, adapter_name): def set_adapter(self, adapter_name, inference_mode: bool = False):
self.set_auxiliary_adapters(adapter_name) self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode)
for module in self.model.modules(): for module in self.model.modules():
if isinstance(module, VeraLayer): if isinstance(module, VeraLayer):
if module.merged: if module.merged:
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
module.unmerge() module.unmerge()
module.set_adapter(adapter_name) module.set_adapter(adapter_name, inference_mode=inference_mode)
self.active_adapter = adapter_name self.active_adapter = adapter_name
@staticmethod @staticmethod

View File

@ -430,11 +430,14 @@ class AuxiliaryTrainingWrapper(torch.nn.Module):
""" """
raise NotImplementedError raise NotImplementedError
def set_adapter(self, adapter_names: Union[str, list[str]]): def set_adapter(self, adapter_names: Union[str, list[str]], inference_mode: bool = False) -> None:
"""Set the active adapter """Set the active adapter
Args: Args:
adapter_name (str): The name of the adapter to set as active adapter_names (str or list[str]):
The name(s) of the adapter(s) to set as active
inference_mode (bool, optional):
Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False.
""" """
if isinstance(adapter_names, str): if isinstance(adapter_names, str):
self._active_adapter = adapter_names self._active_adapter = adapter_names
@ -572,20 +575,17 @@ class ModulesToSaveWrapper(AuxiliaryTrainingWrapper):
return adapter_name_to_set return adapter_name_to_set
def set_adapter(self, adapter_names: Union[str, list[str]]): def set_adapter(self, adapter_names: Union[str, list[str]], inference_mode: bool = False) -> None:
"""Set the active adapter """Set the active adapter
Additionally, this function will set the specified adapter to trainable (i.e., requires_grad=True). If this is Additionally, this function will set the specified adapter to trainable (i.e., requires_grad=True) unless
not desired, use the following code. inference_mode is True.
```py
>>> for name, param in model_peft.named_parameters():
... if ...: # some check on name (ex. if 'lora' in name)
... param.requires_grad = False
```
Args: Args:
adapter_names (list[str], str): The name of the adapter to set as active adapter_names (list[str], str):
The name(s) of the adapter(s) to set as active.
inference_mode (bool, optional):
Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False.
""" """
if isinstance(adapter_names, str): if isinstance(adapter_names, str):
adapter_names = [adapter_names] adapter_names = [adapter_names]
@ -605,7 +605,7 @@ class ModulesToSaveWrapper(AuxiliaryTrainingWrapper):
for currently_active_adapter_name in self.active_adapters: for currently_active_adapter_name in self.active_adapters:
self.modules_to_save[currently_active_adapter_name].requires_grad_(False) self.modules_to_save[currently_active_adapter_name].requires_grad_(False)
self.modules_to_save[adapter_name].requires_grad_(True) self.modules_to_save[adapter_name].requires_grad_(not inference_mode)
self._active_adapter = adapter_name self._active_adapter = adapter_name
def delete_adapter(self, adapter_name: str, new_active_adapters: Optional[list[str]]) -> None: def delete_adapter(self, adapter_name: str, new_active_adapters: Optional[list[str]]) -> None:
@ -810,9 +810,9 @@ class TrainableTokensWrapper(AuxiliaryTrainingWrapper):
return adapter_name_to_set return adapter_name_to_set
def set_adapter(self, adapter_names: Union[str, list[str]]): def set_adapter(self, adapter_names: Union[str, list[str]], inference_mode: bool = False) -> None:
super().set_adapter(adapter_names) super().set_adapter(adapter_names, inference_mode=inference_mode)
self.token_adapter.set_adapter(adapter_names) self.token_adapter.set_adapter(adapter_names, inference_mode=inference_mode)
def delete_adapter(self, adapter_name: str, new_active_adapters: Optional[list[str]]) -> None: def delete_adapter(self, adapter_name: str, new_active_adapters: Optional[list[str]]) -> None:
""" """
@ -887,6 +887,7 @@ def _set_trainable(
model, model,
adapter_name, adapter_name,
module_names, module_names,
inference_mode: bool,
strict_module_check: bool = False, strict_module_check: bool = False,
wrapper_cls: Optional[AuxiliaryTrainingWrapper] = None, wrapper_cls: Optional[AuxiliaryTrainingWrapper] = None,
activate_adapter: bool = True, activate_adapter: bool = True,
@ -926,13 +927,13 @@ def _set_trainable(
parent, target, target_name = _get_submodules(model, key) parent, target, target_name = _get_submodules(model, key)
if isinstance(target, wrapper_cls): if isinstance(target, wrapper_cls):
target.update(adapter_name, **wrapper_kwargs) target.update(adapter_name, **wrapper_kwargs)
target.set_adapter(target.active_adapter) target.set_adapter(target.active_adapter, inference_mode=inference_mode)
else: else:
new_module = wrapper_cls(target, adapter_name, **wrapper_kwargs) new_module = wrapper_cls(target, adapter_name, **wrapper_kwargs)
if activate_adapter: if activate_adapter:
new_module.set_adapter(adapter_name) new_module.set_adapter(adapter_name, inference_mode=inference_mode)
else: else:
new_module.set_adapter([]) new_module.set_adapter([], inference_mode=inference_mode)
setattr(parent, target_name, new_module) setattr(parent, target_name, new_module)
trainable_modules.append(new_module) trainable_modules.append(new_module)
found_modules.add(target_name) found_modules.add(target_name)
@ -946,7 +947,7 @@ def _set_trainable(
return trainable_modules return trainable_modules
def _set_adapter(model, adapter_name): def _set_adapter(model, adapter_name: str | list[str], inference_mode: bool = False):
for module in model.modules(): for module in model.modules():
if isinstance(module, AuxiliaryTrainingWrapper): if isinstance(module, AuxiliaryTrainingWrapper):
# only check the adapter_name if we actually encounter a AuxiliaryTrainingWrapper, otherwise we don't care # only check the adapter_name if we actually encounter a AuxiliaryTrainingWrapper, otherwise we don't care
@ -956,10 +957,10 @@ def _set_adapter(model, adapter_name):
# module # module
if adapter_name_to_set in module._adapters: if adapter_name_to_set in module._adapters:
module.enable_adapters(True) module.enable_adapters(True)
module.set_adapter(adapter_name_to_set) module.set_adapter(adapter_name_to_set, inference_mode=inference_mode)
else: else:
module.enable_adapters(False) module.enable_adapters(False)
module.set_adapter([]) module.set_adapter([], inference_mode=inference_mode)
def _prepare_prompt_learning_config(peft_config, model_config): def _prepare_prompt_learning_config(peft_config, model_config):
@ -1326,6 +1327,7 @@ def set_additional_trainable_modules(model, peft_config, model_config, adapter_n
_set_trainable( _set_trainable(
model, model,
adapter_name, adapter_name,
inference_mode=peft_config.inference_mode,
module_names=getattr(peft_config, "modules_to_save", None), module_names=getattr(peft_config, "modules_to_save", None),
activate_adapter=activate_adapter, activate_adapter=activate_adapter,
) )
@ -1351,6 +1353,7 @@ def set_additional_trainable_modules(model, peft_config, model_config, adapter_n
_set_trainable( _set_trainable(
model, model,
adapter_name, adapter_name,
inference_mode=peft_config.inference_mode,
module_names=[target_layer], module_names=[target_layer],
strict_module_check=True, strict_module_check=True,
wrapper_cls=TrainableTokensWrapper, wrapper_cls=TrainableTokensWrapper,
@ -1374,6 +1377,7 @@ def set_additional_trainable_modules(model, peft_config, model_config, adapter_n
_set_trainable( _set_trainable(
model, model,
adapter_name, adapter_name,
inference_mode=peft_config.inference_mode,
module_names=module_keys, module_names=module_keys,
strict_module_check=True, strict_module_check=True,
wrapper_cls=TrainableTokensWrapper, wrapper_cls=TrainableTokensWrapper,

View File

@ -3685,11 +3685,20 @@ class TestRequiresGrad:
"base_model.model.lin0.ia3_l.adapter1", "base_model.model.lin0.ia3_l.adapter1",
) )
@pytest.mark.xfail(strict=True)
def test_requires_grad_adalora_different_targets(self): def test_requires_grad_adalora_different_targets(self):
# test two different AdaLora adapters that target different modules # test two different AdaLora adapters that target different modules
# Note: This test is expected to fail because first loading one adapter, then the next adapter with
# inference_mode=True incorrectly leads to the requires_grad of the first adapter being turned to False. This is
# of course not desired but has yet to be fixed. In practice, it's unlikely that a user would pass
# inference_mode=True for add_adapter, this flag is mostly being used when calling PeftModel.from_pretrained, so
# we accept this issue for now. Note that only for AdaLoRA do we even need to pass inference_mode=True here,
# other PEFT methods don't require this.
config0 = AdaLoraConfig(target_modules=["lin0"], total_step=1) config0 = AdaLoraConfig(target_modules=["lin0"], total_step=1)
peft_model = get_peft_model(MLP(), config0) peft_model = get_peft_model(MLP(), config0)
# note: AdaLoRA cannot have more than 1 trainable active adapter, hence enable inference_mode
config1 = AdaLoraConfig(target_modules=["lin1"], total_step=1, inference_mode=True) config1 = AdaLoraConfig(target_modules=["lin1"], total_step=1, inference_mode=True)
peft_model.add_adapter("adapter1", config1) peft_model.add_adapter("adapter1", config1)
@ -3731,11 +3740,20 @@ class TestRequiresGrad:
"base_model.model.lin1.lora_E.adapter1", "base_model.model.lin1.lora_E.adapter1",
) )
@pytest.mark.xfail(strict=True)
def test_requires_grad_adalora_same_targets(self): def test_requires_grad_adalora_same_targets(self):
# same as previous test, except that AdaLora adapters target the same layer # same as previous test, except that AdaLora adapters target the same layer
# Note: This test is expected to fail because first loading one adapter, then the next adapter with
# inference_mode=True incorrectly leads to the requires_grad of the first adapter being turned to False. This is
# of course not desired but has yet to be fixed. In practice, it's unlikely that a user would pass
# inference_mode=True for add_adapter, this flag is mostly being used when calling PeftModel.from_pretrained, so
# we accept this issue for now. Note that only for AdaLoRA do we even need to pass inference_mode=True here,
# other PEFT methods don't require this.
config0 = AdaLoraConfig(target_modules=["lin0"], total_step=1) config0 = AdaLoraConfig(target_modules=["lin0"], total_step=1)
peft_model = get_peft_model(MLP(), config0) peft_model = get_peft_model(MLP(), config0)
# note: AdaLoRA cannot have more than 1 trainable active adapter, hence enable inference_mode
config1 = AdaLoraConfig(target_modules=["lin0"], total_step=1, inference_mode=True) config1 = AdaLoraConfig(target_modules=["lin0"], total_step=1, inference_mode=True)
peft_model.add_adapter("adapter1", config1) peft_model.add_adapter("adapter1", config1)
@ -3944,7 +3962,7 @@ class TestRequiresGrad:
config0 = LoHaConfig(target_modules=["lin0"]) config0 = LoHaConfig(target_modules=["lin0"])
peft_model = get_peft_model(MLP(), config0) peft_model = get_peft_model(MLP(), config0)
config1 = LoHaConfig(target_modules=["lin1"], inference_mode=True) config1 = LoHaConfig(target_modules=["lin1"])
peft_model.add_adapter("adapter1", config1) peft_model.add_adapter("adapter1", config1)
# active adapter is still "default" # active adapter is still "default"
@ -3994,7 +4012,7 @@ class TestRequiresGrad:
config0 = LoHaConfig(target_modules=["lin0"]) config0 = LoHaConfig(target_modules=["lin0"])
peft_model = get_peft_model(MLP(), config0) peft_model = get_peft_model(MLP(), config0)
config1 = LoHaConfig(target_modules=["lin0"], inference_mode=True) config1 = LoHaConfig(target_modules=["lin0"])
peft_model.add_adapter("adapter1", config1) peft_model.add_adapter("adapter1", config1)
# active adapter is still "default" # active adapter is still "default"
@ -4045,7 +4063,7 @@ class TestRequiresGrad:
config0 = LoKrConfig(target_modules=["lin0"]) config0 = LoKrConfig(target_modules=["lin0"])
peft_model = get_peft_model(MLP(), config0) peft_model = get_peft_model(MLP(), config0)
config1 = LoKrConfig(target_modules=["lin1"], inference_mode=True) config1 = LoKrConfig(target_modules=["lin1"])
peft_model.add_adapter("adapter1", config1) peft_model.add_adapter("adapter1", config1)
# active adapter is still "default" # active adapter is still "default"
@ -4087,7 +4105,7 @@ class TestRequiresGrad:
config0 = LoKrConfig(target_modules=["lin0"]) config0 = LoKrConfig(target_modules=["lin0"])
peft_model = get_peft_model(MLP(), config0) peft_model = get_peft_model(MLP(), config0)
config1 = LoKrConfig(target_modules=["lin0"], inference_mode=True) config1 = LoKrConfig(target_modules=["lin0"])
peft_model.add_adapter("adapter1", config1) peft_model.add_adapter("adapter1", config1)
# active adapter is still "default" # active adapter is still "default"
@ -4130,7 +4148,7 @@ class TestRequiresGrad:
config0 = OFTConfig(target_modules=["lin0"], r=2, oft_block_size=0) config0 = OFTConfig(target_modules=["lin0"], r=2, oft_block_size=0)
peft_model = get_peft_model(MLP(), config0) peft_model = get_peft_model(MLP(), config0)
config1 = OFTConfig(target_modules=["lin1"], r=2, oft_block_size=0, inference_mode=True) config1 = OFTConfig(target_modules=["lin1"], r=2, oft_block_size=0)
peft_model.add_adapter("adapter1", config1) peft_model.add_adapter("adapter1", config1)
# active adapter is still "default" # active adapter is still "default"
@ -4168,7 +4186,7 @@ class TestRequiresGrad:
config0 = OFTConfig(target_modules=["lin0"], r=2, oft_block_size=0) config0 = OFTConfig(target_modules=["lin0"], r=2, oft_block_size=0)
peft_model = get_peft_model(MLP(), config0) peft_model = get_peft_model(MLP(), config0)
config1 = OFTConfig(target_modules=["lin0"], r=2, oft_block_size=0, inference_mode=True) config1 = OFTConfig(target_modules=["lin0"], r=2, oft_block_size=0)
peft_model.add_adapter("adapter1", config1) peft_model.add_adapter("adapter1", config1)
# active adapter is still "default" # active adapter is still "default"
@ -4207,7 +4225,7 @@ class TestRequiresGrad:
config0 = HRAConfig(target_modules=["lin0"]) config0 = HRAConfig(target_modules=["lin0"])
peft_model = get_peft_model(MLP(), config0) peft_model = get_peft_model(MLP(), config0)
config1 = HRAConfig(target_modules=["lin1"], inference_mode=True) config1 = HRAConfig(target_modules=["lin1"])
peft_model.add_adapter("adapter1", config1) peft_model.add_adapter("adapter1", config1)
# active adapter is still "default" # active adapter is still "default"
@ -4245,7 +4263,7 @@ class TestRequiresGrad:
config0 = HRAConfig(target_modules=["lin0"]) config0 = HRAConfig(target_modules=["lin0"])
peft_model = get_peft_model(MLP(), config0) peft_model = get_peft_model(MLP(), config0)
config1 = HRAConfig(target_modules=["lin0"], inference_mode=True) config1 = HRAConfig(target_modules=["lin0"])
peft_model.add_adapter("adapter1", config1) peft_model.add_adapter("adapter1", config1)
# active adapter is still "default" # active adapter is still "default"
@ -4284,7 +4302,7 @@ class TestRequiresGrad:
config0 = BoneConfig(target_modules=["lin0"], r=2) config0 = BoneConfig(target_modules=["lin0"], r=2)
peft_model = get_peft_model(MLP(), config0) peft_model = get_peft_model(MLP(), config0)
config1 = BoneConfig(target_modules=["lin1"], r=2, inference_mode=True) config1 = BoneConfig(target_modules=["lin1"], r=2)
peft_model.add_adapter("adapter1", config1) peft_model.add_adapter("adapter1", config1)
# active adapter is still "default" # active adapter is still "default"
@ -4322,7 +4340,7 @@ class TestRequiresGrad:
config0 = BoneConfig(target_modules=["lin0"], r=2) config0 = BoneConfig(target_modules=["lin0"], r=2)
peft_model = get_peft_model(MLP(), config0) peft_model = get_peft_model(MLP(), config0)
config1 = BoneConfig(target_modules=["lin0"], r=2, inference_mode=True) config1 = BoneConfig(target_modules=["lin0"], r=2)
peft_model.add_adapter("adapter1", config1) peft_model.add_adapter("adapter1", config1)
# active adapter is still "default" # active adapter is still "default"
@ -4361,7 +4379,7 @@ class TestRequiresGrad:
config0 = MissConfig(target_modules=["lin0"], r=2) config0 = MissConfig(target_modules=["lin0"], r=2)
peft_model = get_peft_model(MLP(), config0) peft_model = get_peft_model(MLP(), config0)
config1 = MissConfig(target_modules=["lin1"], r=2, inference_mode=True) config1 = MissConfig(target_modules=["lin1"], r=2)
peft_model.add_adapter("adapter1", config1) peft_model.add_adapter("adapter1", config1)
# active adapter is still "default" # active adapter is still "default"
@ -4399,7 +4417,7 @@ class TestRequiresGrad:
config0 = MissConfig(target_modules=["lin0"], r=2) config0 = MissConfig(target_modules=["lin0"], r=2)
peft_model = get_peft_model(MLP(), config0) peft_model = get_peft_model(MLP(), config0)
config1 = MissConfig(target_modules=["lin0"], r=2, inference_mode=True) config1 = MissConfig(target_modules=["lin0"], r=2)
peft_model.add_adapter("adapter1", config1) peft_model.add_adapter("adapter1", config1)
# active adapter is still "default" # active adapter is still "default"
@ -4438,7 +4456,7 @@ class TestRequiresGrad:
config0 = BOFTConfig(target_modules=["lin0"], boft_block_size=2) config0 = BOFTConfig(target_modules=["lin0"], boft_block_size=2)
peft_model = get_peft_model(MLP2(), config0) peft_model = get_peft_model(MLP2(), config0)
config1 = BOFTConfig(target_modules=["lin1"], boft_block_size=2, inference_mode=True) config1 = BOFTConfig(target_modules=["lin1"], boft_block_size=2)
peft_model.add_adapter("adapter1", config1) peft_model.add_adapter("adapter1", config1)
# active pter is still "default" # active pter is still "default"
@ -4480,7 +4498,7 @@ class TestRequiresGrad:
config0 = BOFTConfig(target_modules=["lin1"], boft_block_size=2) config0 = BOFTConfig(target_modules=["lin1"], boft_block_size=2)
peft_model = get_peft_model(MLP(), config0) peft_model = get_peft_model(MLP(), config0)
config1 = BOFTConfig(target_modules=["lin1"], boft_block_size=2, inference_mode=True) config1 = BOFTConfig(target_modules=["lin1"], boft_block_size=2)
peft_model.add_adapter("adapter1", config1) peft_model.add_adapter("adapter1", config1)
# active adapter is still "default" # active adapter is still "default"
@ -4524,10 +4542,7 @@ class TestRequiresGrad:
) )
peft_model = get_peft_model(MLP_LayerNorm(), config0) peft_model = get_peft_model(MLP_LayerNorm(), config0)
config1 = LNTuningConfig( config1 = LNTuningConfig(target_modules=["layernorm1"])
target_modules=["layernorm1"],
inference_mode=True,
)
peft_model.add_adapter("adapter1", config1) peft_model.add_adapter("adapter1", config1)
# active adapter is still "default" # active adapter is still "default"
@ -4571,7 +4586,7 @@ class TestRequiresGrad:
) )
peft_model = get_peft_model(MLP_LayerNorm(), config0) peft_model = get_peft_model(MLP_LayerNorm(), config0)
config1 = LNTuningConfig(target_modules=["layernorm0"], inference_mode=True) config1 = LNTuningConfig(target_modules=["layernorm0"])
peft_model.add_adapter("adapter1", config1) peft_model.add_adapter("adapter1", config1)
# active adapter is still "default" # active adapter is still "default"
@ -4990,7 +5005,7 @@ class TestRequiresGrad:
config0 = FourierFTConfig(n_frequency=10, target_modules=["lin0"]) config0 = FourierFTConfig(n_frequency=10, target_modules=["lin0"])
peft_model = get_peft_model(MLP(), config0) peft_model = get_peft_model(MLP(), config0)
config1 = FourierFTConfig(n_frequency=10, target_modules=["lin1"], inference_mode=True) config1 = FourierFTConfig(n_frequency=10, target_modules=["lin1"])
peft_model.add_adapter("adapter1", config1) peft_model.add_adapter("adapter1", config1)
# active adapter is still "default" # active adapter is still "default"
@ -5028,7 +5043,7 @@ class TestRequiresGrad:
config0 = FourierFTConfig(n_frequency=10, target_modules=["lin0"]) config0 = FourierFTConfig(n_frequency=10, target_modules=["lin0"])
peft_model = get_peft_model(MLP(), config0) peft_model = get_peft_model(MLP(), config0)
config1 = FourierFTConfig(n_frequency=10, target_modules=["lin0"], inference_mode=True) config1 = FourierFTConfig(n_frequency=10, target_modules=["lin0"])
peft_model.add_adapter("adapter1", config1) peft_model.add_adapter("adapter1", config1)
# active adapter is still "default" # active adapter is still "default"
@ -5062,6 +5077,177 @@ class TestRequiresGrad:
"base_model.model.lin0.fourierft_spectrum.adapter1", "base_model.model.lin0.fourierft_spectrum.adapter1",
) )
@pytest.mark.parametrize("config_cls", ALL_PEFT_CONFIG_CLASSES)
@pytest.mark.parametrize("is_trainable", [False, True]) # note: default is False
def test_loading_model_requires_grad_set_correctly(self, config_cls, is_trainable, tmp_path):
# Test that when loading PeftModel and then loading another adapter, the requires_grad is set correctly and
# is_trainable is respected.
# See #2759
model = DeepMLP(size=256) # a size that works with all adapters
extra_kwargs = {}
if config_cls == IA3Config:
extra_kwargs["feedforward_modules"] = []
config = config_cls(target_modules=["layers.0.lin0"], **extra_kwargs)
if config_cls == TrainableTokensConfig: # TrainbleTokens requires a different base model and config
model = ModelEmbConv1D()
config = config_cls(target_modules=["emb"], token_indices=[0, 2, 4])
model = get_peft_model(model, config)
model.save_pretrained(tmp_path)
del model
model = DeepMLP(size=256)
if config_cls == TrainableTokensConfig: # TrainbleTokens requires a different base
model = ModelEmbConv1D()
model = PeftModel.from_pretrained(model, tmp_path, is_trainable=is_trainable)
if is_trainable:
for name, param in model.named_parameters():
if ".default" in name:
assert param.requires_grad
else:
assert not param.requires_grad
else:
assert all(not p.requires_grad for p in model.parameters())
# load one more adapter; this adapter is not automatically activated
model.load_adapter(tmp_path, adapter_name="other", is_trainable=is_trainable)
if is_trainable:
for name, param in model.named_parameters():
if ".default" in name:
assert param.requires_grad
else:
assert not param.requires_grad
else:
assert all(not p.requires_grad for p in model.parameters())
@pytest.mark.parametrize("config_cls", ALL_PEFT_CONFIG_CLASSES)
@pytest.mark.parametrize("is_trainable", [False, True]) # note: default is False
def test_loading_model_with_modules_to_save_requires_grad_set_correctly(self, config_cls, is_trainable, tmp_path):
# Same test as above, but with modules_to_save
if config_cls == TrainableTokensConfig:
pytest.skip(reason="Trainable tokens does not support modules_to_save")
model = DeepMLP(size=256) # a size that works with all adapters
extra_kwargs = {}
if config_cls == IA3Config:
extra_kwargs["feedforward_modules"] = []
# targeting the different modules with modules_to_save:
config = config_cls(target_modules=["layers.0.lin0"], modules_to_save=["layers.0.lin1"], **extra_kwargs)
model = get_peft_model(model, config)
model.save_pretrained(tmp_path)
del model
model = DeepMLP(size=256)
model = PeftModel.from_pretrained(model, tmp_path, is_trainable=is_trainable)
if is_trainable:
for name, param in model.named_parameters():
if ".default" in name:
assert param.requires_grad
else:
assert not param.requires_grad
else:
assert all(not p.requires_grad for p in model.parameters())
# load one more adapter
model.load_adapter(tmp_path, adapter_name="other", is_trainable=is_trainable)
if is_trainable:
for name, param in model.named_parameters():
if ".default" in name:
assert param.requires_grad
else:
assert not param.requires_grad
else:
assert all(not p.requires_grad for p in model.parameters())
@pytest.mark.parametrize("is_trainable", [False, True]) # note: default is False
def test_loading_model_with_trainble_tokens_requires_grad_set_correctly(self, is_trainable, tmp_path):
model = ModelEmbConv1D()
# targeting the same modules with modules_to_save:
config = LoraConfig(target_modules=["lin0"], trainable_token_indices={"emb": [0]})
model = get_peft_model(model, config)
model.save_pretrained(tmp_path)
del model
model = ModelEmbConv1D()
model = PeftModel.from_pretrained(model, tmp_path, is_trainable=is_trainable)
if is_trainable:
for name, param in model.named_parameters():
if ".default" in name:
assert param.requires_grad
else:
assert not param.requires_grad
else:
assert all(not p.requires_grad for p in model.parameters())
# load one more adapter
model.load_adapter(tmp_path, adapter_name="other", is_trainable=is_trainable)
if is_trainable:
for name, param in model.named_parameters():
if ".default" in name:
assert param.requires_grad
else:
assert not param.requires_grad
else:
assert all(not p.requires_grad for p in model.parameters())
@pytest.mark.xfail(strict=True)
@pytest.mark.parametrize("config_cls", [LoraConfig]) # no need to check each method, they all fail
def test_loading_model_requires_grad_set_correctly_switch_inference_mode(self, config_cls, tmp_path):
# Same as test_loading_model_requires_grad_set_correctly but this time we first load with is_trainable=False and
# then with is_trainable=True. Loading the second adapter should not affect the requires_grad of the first
# adapter, but it does. The reason is that is_training/inference_mode is taken from the current PEFT config, but
# that config does not necessarily belong to the active adapter, creating a mismatch.
# When/If this is fixed, the check can be integrated into test_loading_model_requires_grad_set_correctly and
# this test can be deleted.
model = DeepMLP(size=256) # a size that works with all adapters
extra_kwargs = {}
config = config_cls(target_modules=["layers.0.lin0"])
model = get_peft_model(model, config)
model.save_pretrained(tmp_path)
del model
model = DeepMLP(size=256)
model = PeftModel.from_pretrained(model, tmp_path, is_trainable=False)
assert all(not p.requires_grad for p in model.parameters())
# load one more adapter; this adapter is not automatically activated
model.load_adapter(tmp_path, adapter_name="other", is_trainable=True)
params_with_grad = [n for n, p in model.named_parameters() if p.requires_grad]
expected = [
"base_model.model.layers.0.lin0.lora_A.other.weight",
"base_model.model.layers.0.lin0.lora_B.other.weight",
]
# this fails, instead with get ...lora_A.default.weight and ...lora_B.default.weight
assert params_with_grad == expected
@pytest.mark.xfail(strict=True)
@pytest.mark.parametrize("config_cls", [LoraConfig]) # no need to check each method, they all fail
def test_loading_model_requires_grad_load_adapter_then_add_adapter(self, config_cls, tmp_path):
# When adding a new adapter with model.add_adapter, through the set_adapter call in update_layer, we activate
# the gradients of the first adapter, even if it's not desired. Since there is no is_trainable argument on
# add_adapter, there is no way to disable that at the moment.
# When/If this is fixed, the check can be integrated into test_loading_model_requires_grad_set_correctly and
# this test can be deleted.
model = DeepMLP(size=256) # a size that works with all adapters
extra_kwargs = {}
config = config_cls(target_modules=["layers.0.lin0"])
model = get_peft_model(model, config)
model.save_pretrained(tmp_path)
del model
model = DeepMLP(size=256)
model = PeftModel.from_pretrained(model, tmp_path, is_trainable=False)
assert all(not p.requires_grad for p in model.parameters())
# add a new adapter
model.add_adapter(adapter_name="other", peft_config=config)
params_with_grad = [n for n, p in model.named_parameters() if p.requires_grad]
assert all(not p.requires_grad for p in model.parameters())
# this is for PEFT methods that support mixed adapter batches. # this is for PEFT methods that support mixed adapter batches.
MIXED_ADAPTER_TEST_CASES = [ MIXED_ADAPTER_TEST_CASES = [