diff --git a/src/peft/mixed_model.py b/src/peft/mixed_model.py index 815b5c2e..ddd49c5a 100644 --- a/src/peft/mixed_model.py +++ b/src/peft/mixed_model.py @@ -251,9 +251,14 @@ class PeftMixedModel(PushToHubMixin, torch.nn.Module): self.modules_to_save = set(modules_to_save) else: self.modules_to_save.update(modules_to_save) - _set_trainable(self, adapter_name, module_names=getattr(peft_config, "modules_to_save", None)) + _set_trainable( + self, + adapter_name, + module_names=getattr(peft_config, "modules_to_save", None), + inference_mode=peft_config.inference_mode, + ) - def set_adapter(self, adapter_name: Union[str, list[str]]) -> None: + def set_adapter(self, adapter_name: Union[str, list[str]], inference_mode: bool = False) -> None: """ Sets the active adapter(s) for the model. @@ -262,18 +267,14 @@ class PeftMixedModel(PushToHubMixin, torch.nn.Module): order in which the adapters were loaded into the model. The active adapters only determine which adapters are active during the forward pass, but not the order in which they are applied. - Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is - not desired, use the following code. - - ```py - >>> for name, param in model_peft.named_parameters(): - ... if ...: # some check on name (ex. if 'lora' in name) - ... param.requires_grad = False - ``` + Additionally, this function will set the specified adapter to trainable (i.e., requires_grad=True) unless + inference_mode is True. Args: - adapter_name (`str` or `List[str]`): - The name of the adapter(s) to be activated. + adapter_name (str, list[str]): + The name(s) of the adapter(s) to set as active + inference_mode (bool, optional): + Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False. """ if isinstance(adapter_name, str): adapter_name = [adapter_name] @@ -284,8 +285,8 @@ class PeftMixedModel(PushToHubMixin, torch.nn.Module): f"Adapter(s) {sorted(mismatched)} not found, available adapters: {sorted(self.peft_config.keys())}" ) - self.base_model.set_adapter(adapter_name) - _set_adapter(self, adapter_name) + self.base_model.set_adapter(adapter_name, inference_mode=inference_mode) + _set_adapter(self, adapter_name, inference_mode=inference_mode) def delete_adapter(self, adapter_name: Union[str, list[str]]) -> None: if isinstance(adapter_name, str): diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 2eb169eb..f2dfe7ff 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -1624,7 +1624,12 @@ class PeftModelForSequenceClassification(PeftModel): break # to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper - _set_trainable(self, adapter_name, module_names=getattr(peft_config, "modules_to_save", None)) + _set_trainable( + self, + adapter_name, + module_names=getattr(peft_config, "modules_to_save", None), + inference_mode=peft_config.inference_mode, + ) def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None: """ @@ -2475,7 +2480,12 @@ class PeftModelForTokenClassification(PeftModel): break # to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper - _set_trainable(self, adapter_name, module_names=getattr(peft_config, "modules_to_save", None)) + _set_trainable( + self, + adapter_name, + module_names=getattr(peft_config, "modules_to_save", None), + inference_mode=peft_config.inference_mode, + ) def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None: """ @@ -2691,7 +2701,12 @@ class PeftModelForQuestionAnswering(PeftModel): break # to make sure classifier layer is trainable; this may add a new ModulesToSaveWrapper - _set_trainable(self, adapter_name, module_names=getattr(peft_config, "modules_to_save", None)) + _set_trainable( + self, + adapter_name, + module_names=getattr(peft_config, "modules_to_save", None), + inference_mode=peft_config.inference_mode, + ) def add_adapter(self, adapter_name: str, peft_config: PeftConfig, low_cpu_mem_usage: bool = False) -> None: """ diff --git a/src/peft/tuners/adalora/layer.py b/src/peft/tuners/adalora/layer.py index 9079b42d..635e5105 100644 --- a/src/peft/tuners/adalora/layer.py +++ b/src/peft/tuners/adalora/layer.py @@ -45,7 +45,9 @@ class AdaLoraLayer(LoraLayer): self.lora_B = nn.ParameterDict({}) self.ranknum = nn.ParameterDict({}) - def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights): + def update_layer( + self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, inference_mode: bool = False, **kwargs + ): if r < 0: # note: r == 0 is allowed for AdaLora, see #1539 raise ValueError(f"`r` should be a positive integer or 0, but the value passed is {r}") @@ -74,7 +76,7 @@ class AdaLoraLayer(LoraLayer): self.reset_lora_parameters(adapter_name) self._move_adapter_to_device_of_base_layer(adapter_name) - self.set_adapter(self.active_adapters) + self.set_adapter(self.active_adapters, inference_mode=inference_mode) def reset_lora_parameters(self, adapter_name): if adapter_name in self.lora_A.keys(): diff --git a/src/peft/tuners/boft/layer.py b/src/peft/tuners/boft/layer.py index 15406b1b..7232f39d 100644 --- a/src/peft/tuners/boft/layer.py +++ b/src/peft/tuners/boft/layer.py @@ -261,7 +261,15 @@ class BOFTLayer(BaseTunerLayer): warnings.warn("Unscaling operation for BOFT not supported! Keeping scale to 1.") def update_layer( - self, adapter_name, boft_block_size, boft_block_num, boft_n_butterfly_factor, boft_dropout, init_weights + self, + adapter_name, + boft_block_size, + boft_block_num, + boft_n_butterfly_factor, + boft_dropout, + init_weights, + inference_mode: bool = False, + **kwargs, ): """ Update the linear layer with trainable BOFT weights. Override for other layer types. @@ -360,7 +368,7 @@ class BOFTLayer(BaseTunerLayer): self.boft_block_num[adapter_name] = boft_block_num self._move_adapter_to_device_of_base_layer(adapter_name) - self.set_adapter(self.active_adapters) + self.set_adapter(self.active_adapters, inference_mode=inference_mode) def reset_boft_parameters(self, adapter_name, init_weights): """ @@ -682,7 +690,15 @@ class Conv2d(nn.Module, BOFTLayer): ) def update_layer( - self, adapter_name, boft_block_size, boft_block_num, boft_n_butterfly_factor, boft_dropout, init_weights + self, + adapter_name, + boft_block_size, + boft_block_num, + boft_n_butterfly_factor, + boft_dropout, + init_weights, + inference_mode: bool = False, + **kwargs, ): """ Update the conv2d layer with trainable BOFT weights. @@ -787,7 +803,7 @@ class Conv2d(nn.Module, BOFTLayer): self.boft_block_num[adapter_name] = boft_block_num self._move_adapter_to_device_of_base_layer(adapter_name) - self.set_adapter(self.active_adapters) + self.set_adapter(self.active_adapters, inference_mode=inference_mode) def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: """ diff --git a/src/peft/tuners/boft/model.py b/src/peft/tuners/boft/model.py index a9a20e59..8e757a9c 100644 --- a/src/peft/tuners/boft/model.py +++ b/src/peft/tuners/boft/model.py @@ -245,14 +245,14 @@ class BOFTModel(BaseTuner): warnings.warn(msg) self._set_adapter_layers(enabled=False) - def set_adapter(self, adapter_name): - self.set_auxiliary_adapters(adapter_name) + def set_adapter(self, adapter_name, inference_mode: bool = False): + self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode) for module in self.model.modules(): if isinstance(module, BOFTLayer): if module.merged: warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") module.unmerge() - module.set_adapter(adapter_name) + module.set_adapter(adapter_name, inference_mode=inference_mode) self.active_adapter = adapter_name @staticmethod diff --git a/src/peft/tuners/bone/layer.py b/src/peft/tuners/bone/layer.py index fc07cc49..2cd04c57 100644 --- a/src/peft/tuners/bone/layer.py +++ b/src/peft/tuners/bone/layer.py @@ -50,6 +50,7 @@ class BoneLayer(BaseTunerLayer): adapter_name: str, r: int, init_weights: bool, + inference_mode: bool = False, **kwargs, ) -> None: """Internal function to create bone adapter @@ -83,7 +84,7 @@ class BoneLayer(BaseTunerLayer): self.reset_bone_parameters_random(adapter_name) # Move new weights to device self._move_adapter_to_device_of_base_layer(adapter_name) - self.set_adapter(self.active_adapters) + self.set_adapter(self.active_adapters, inference_mode=inference_mode) def reset_bone_parameters(self, adapter_name: str, r): self.bone_block[adapter_name] = nn.Parameter(torch.zeros(r, self.out_features), requires_grad=True) diff --git a/src/peft/tuners/bone/model.py b/src/peft/tuners/bone/model.py index c46043f5..8e456f53 100644 --- a/src/peft/tuners/bone/model.py +++ b/src/peft/tuners/bone/model.py @@ -239,14 +239,14 @@ class BoneModel(BaseTuner): warnings.warn(msg) self._set_adapter_layers(enabled=False) - def set_adapter(self, adapter_name): - self.set_auxiliary_adapters(adapter_name) + def set_adapter(self, adapter_name, inference_mode: bool = False): + self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode) for module in self.model.modules(): if isinstance(module, BoneLayer): if module.merged: warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") module.unmerge() - module.set_adapter(adapter_name) + module.set_adapter(adapter_name, inference_mode=inference_mode) self.active_adapter = adapter_name @staticmethod diff --git a/src/peft/tuners/c3a/layer.py b/src/peft/tuners/c3a/layer.py index b6b3643a..0890ba0e 100644 --- a/src/peft/tuners/c3a/layer.py +++ b/src/peft/tuners/c3a/layer.py @@ -56,7 +56,7 @@ class C3ALayer(BaseTunerLayer): delta_weight = get_circulant_fast(c3a_kernel.to(torch.float32)).to(base_layer_weight_dtype) return delta_weight / base_layer_weight.size(-1) - def update_layer(self, adapter_name, block_size, init_weights): + def update_layer(self, adapter_name, block_size, init_weights, inference_mode: bool = False, **kwargs): if block_size <= 0: raise ValueError(f"`block_size` should be a positive integer value but the value passed is {block_size}") if self.in_features % block_size != 0: @@ -85,7 +85,7 @@ class C3ALayer(BaseTunerLayer): self.reset_c3a_parameters(adapter_name, init_weights) self._move_adapter_to_device_of_base_layer(adapter_name) - self.set_adapter(self.active_adapters) + self.set_adapter(self.active_adapters, inference_mode=inference_mode) @torch.no_grad() def reset_c3a_parameters(self, adapter_name, init_weights): diff --git a/src/peft/tuners/c3a/model.py b/src/peft/tuners/c3a/model.py index 45876593..2f16060f 100644 --- a/src/peft/tuners/c3a/model.py +++ b/src/peft/tuners/c3a/model.py @@ -213,19 +213,22 @@ class C3AModel(BaseTuner): warnings.warn(msg) self._set_adapter_layers(enabled=False) - def set_adapter(self, adapter_name: str | list[str]) -> None: + def set_adapter(self, adapter_name: str | list[str], inference_mode: bool = False) -> None: """Set the active adapter(s). Args: - adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated. + adapter_name (`str` or `list[str]`): + Name(s) of the adapter(s) to be activated. + inference_mode (bool, optional): + Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False. """ - self.set_auxiliary_adapters(adapter_name) + self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode) for module in self.model.modules(): if isinstance(module, C3ALayer): if module.merged: warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") module.unmerge() - module.set_adapter(adapter_name) + module.set_adapter(adapter_name, inference_mode=inference_mode) self.active_adapter = adapter_name @staticmethod diff --git a/src/peft/tuners/fourierft/layer.py b/src/peft/tuners/fourierft/layer.py index 548c132e..a03a57f1 100644 --- a/src/peft/tuners/fourierft/layer.py +++ b/src/peft/tuners/fourierft/layer.py @@ -51,7 +51,9 @@ class FourierFTLayer(BaseTunerLayer): else: raise ValueError(f"Unsupported layer type {type(base_layer)}") - def update_layer(self, adapter_name, n_frequency, scaling, init_weights, random_loc_seed): + def update_layer( + self, adapter_name, n_frequency, scaling, init_weights, random_loc_seed, inference_mode: bool = False, **kwargs + ): if n_frequency <= 0: raise ValueError(f"`n_frequency` should be a positive integer value but the value passed is {n_frequency}") if n_frequency > self.in_features * self.out_features: @@ -76,7 +78,7 @@ class FourierFTLayer(BaseTunerLayer): self.reset_fourier_parameters(adapter_name) self._move_adapter_to_device_of_base_layer(adapter_name) - self.set_adapter(self.active_adapters) + self.set_adapter(self.active_adapters, inference_mode=inference_mode) @torch.no_grad() def reset_fourier_parameters(self, adapter_name): diff --git a/src/peft/tuners/fourierft/model.py b/src/peft/tuners/fourierft/model.py index 0bf8ecde..8145474b 100644 --- a/src/peft/tuners/fourierft/model.py +++ b/src/peft/tuners/fourierft/model.py @@ -245,19 +245,22 @@ class FourierFTModel(BaseTuner): warnings.warn(msg) self._set_adapter_layers(enabled=False) - def set_adapter(self, adapter_name: str | list[str]) -> None: + def set_adapter(self, adapter_name: str | list[str], inference_mode: bool = False) -> None: """Set the active adapter(s). Args: - adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated. + adapter_name (`str` or `list[str]`): + Name(s) of the adapter(s) to be activated. + inference_mode (bool, optional): + Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False. """ - self.set_auxiliary_adapters(adapter_name) + self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode) for module in self.model.modules(): if isinstance(module, FourierFTLayer): if module.merged: warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") module.unmerge() - module.set_adapter(adapter_name) + module.set_adapter(adapter_name, inference_mode=inference_mode) self.active_adapter = adapter_name @staticmethod diff --git a/src/peft/tuners/hra/layer.py b/src/peft/tuners/hra/layer.py index 40800604..55ab6db6 100644 --- a/src/peft/tuners/hra/layer.py +++ b/src/peft/tuners/hra/layer.py @@ -55,6 +55,7 @@ class HRALayer(BaseTunerLayer): r: int, apply_GS: bool, init_weights: bool, + inference_mode: bool = False, **kwargs, ) -> None: """Internal function to create hra adapter @@ -91,7 +92,7 @@ class HRALayer(BaseTunerLayer): # Move new weights to device self._move_adapter_to_device_of_base_layer(adapter_name) - self.set_adapter(self.active_adapters) + self.set_adapter(self.active_adapters, inference_mode=inference_mode) def reset_hra_parameters(self, adapter_name: str): if self.hra_r[adapter_name] % 2 != 0: diff --git a/src/peft/tuners/hra/model.py b/src/peft/tuners/hra/model.py index c2079671..2b8846dd 100644 --- a/src/peft/tuners/hra/model.py +++ b/src/peft/tuners/hra/model.py @@ -244,14 +244,14 @@ class HRAModel(BaseTuner): warnings.warn(msg) self._set_adapter_layers(enabled=False) - def set_adapter(self, adapter_name): - self.set_auxiliary_adapters(adapter_name) + def set_adapter(self, adapter_name, inference_mode: bool = False): + self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode) for module in self.model.modules(): if isinstance(module, HRALayer): if module.merged: warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") module.unmerge() - module.set_adapter(adapter_name) + module.set_adapter(adapter_name, inference_mode=inference_mode) self.active_adapter = adapter_name @staticmethod diff --git a/src/peft/tuners/ia3/layer.py b/src/peft/tuners/ia3/layer.py index 45ed6ae6..48cb08ba 100644 --- a/src/peft/tuners/ia3/layer.py +++ b/src/peft/tuners/ia3/layer.py @@ -51,7 +51,7 @@ class IA3Layer(BaseTunerLayer): self.in_features = in_features self.out_features = out_features - def update_layer(self, adapter_name, init_ia3_weights): + def update_layer(self, adapter_name, init_ia3_weights, inference_mode: bool = False, **kwargs): # This code works for linear layers, override for other layer types # Actual trainable parameters if self.is_feedforward: @@ -62,7 +62,7 @@ class IA3Layer(BaseTunerLayer): if init_ia3_weights: self.reset_ia3_parameters(adapter_name) self._move_adapter_to_device_of_base_layer(adapter_name) - self.set_adapter(self.active_adapters) + self.set_adapter(self.active_adapters, inference_mode=inference_mode) def reset_ia3_parameters(self, adapter_name): if adapter_name in self.ia3_l.keys(): @@ -202,7 +202,7 @@ class _ConvNd(nn.Module, IA3Layer): self.update_layer(adapter_name, init_ia3_weights) - def update_layer(self, adapter_name, init_ia3_weights): + def update_layer(self, adapter_name, init_ia3_weights, inference_mode: bool = False, **kwargs): # Actual trainable parameters num_features = self.in_features if self.is_feedforward else self.out_features weights_size = (1, num_features) + (1,) * (self._kernel_dim - 2) @@ -211,7 +211,7 @@ class _ConvNd(nn.Module, IA3Layer): if init_ia3_weights: self.reset_ia3_parameters(adapter_name) self._move_adapter_to_device_of_base_layer(adapter_name) - self.set_adapter(self.active_adapters) + self.set_adapter(self.active_adapters, inference_mode=inference_mode) def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: """ diff --git a/src/peft/tuners/ia3/model.py b/src/peft/tuners/ia3/model.py index 80c94067..cd727bab 100644 --- a/src/peft/tuners/ia3/model.py +++ b/src/peft/tuners/ia3/model.py @@ -260,28 +260,22 @@ class IA3Model(BaseTuner): """ self._set_adapter_layers(enabled=False) - def set_adapter(self, adapter_name: str | list[str]) -> None: + def set_adapter(self, adapter_name: str | list[str], inference_mode: bool = False) -> None: """Set the active adapter(s). - Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is - not desired, use the following code. - - ```py - >>> for name, param in model_peft.named_parameters(): - ... if ...: # some check on name (ex. if 'lora' in name) - ... param.requires_grad = False - ``` - Args: - adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated. + adapter_name (`str` or `list[str]`): + Name(s) of the adapter(s) to be activated. + inference_mode (bool, optional): + Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False. """ - self.set_auxiliary_adapters(adapter_name) + self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode) for module in self.model.modules(): if isinstance(module, IA3Layer): if module.merged: warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") module.unmerge() - module.set_adapter(adapter_name) + module.set_adapter(adapter_name, inference_mode=inference_mode) self.active_adapter = adapter_name @staticmethod diff --git a/src/peft/tuners/ln_tuning/layer.py b/src/peft/tuners/ln_tuning/layer.py index ca6950dc..f78ca99e 100644 --- a/src/peft/tuners/ln_tuning/layer.py +++ b/src/peft/tuners/ln_tuning/layer.py @@ -37,8 +37,9 @@ class LNTuningLayer(nn.Module, BaseTunerLayer): self._active_adapter = adapter_name self.merged_adapters = [] - def update_layer(self, layer: nn.Module, adapter_name: str): + def update_layer(self, layer: nn.Module, adapter_name: str, inference_mode: bool = False, **kwargs): self.ln_tuning_layers[adapter_name] = deepcopy(layer) + self.set_adapter(adapter_name, inference_mode=inference_mode) def enable_adapters(self, enabled: bool) -> None: """Toggle the enabling and disabling of adapters diff --git a/src/peft/tuners/ln_tuning/model.py b/src/peft/tuners/ln_tuning/model.py index fe844287..d95d720c 100644 --- a/src/peft/tuners/ln_tuning/model.py +++ b/src/peft/tuners/ln_tuning/model.py @@ -134,8 +134,6 @@ class LNTuningModel(BaseTuner): for n, p in model.named_parameters(): if self.prefix not in n: p.requires_grad = False - else: - p.requires_grad = True def _check_target_module_exists(self, peft_config: PeftConfig, key: str) -> bool: return check_target_module_exists(peft_config, key) @@ -159,14 +157,14 @@ class LNTuningModel(BaseTuner): """ self._set_adapter_layers(enabled=False) - def set_adapter(self, adapter_name: str) -> None: - self.set_auxiliary_adapters(adapter_name) + def set_adapter(self, adapter_name: str, inference_mode: bool = False) -> None: + self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode) for module in self.model.modules(): if isinstance(module, LNTuningLayer): if module.merged: warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") module.unmerge() - module.set_adapter(adapter_name) + module.set_adapter(adapter_name, inference_mode=inference_mode) self.active_adapter = adapter_name def _unload_and_optionally_merge( diff --git a/src/peft/tuners/loha/layer.py b/src/peft/tuners/loha/layer.py index 19582bec..96f9b1e0 100644 --- a/src/peft/tuners/loha/layer.py +++ b/src/peft/tuners/loha/layer.py @@ -107,6 +107,7 @@ class LoHaLayer(nn.Module, LycorisLayer): module_dropout: float, init_weights: bool, use_effective_conv2d: bool = False, + inference_mode: bool = False, **kwargs, ) -> None: """Internal function to create loha adapter @@ -175,7 +176,7 @@ class LoHaLayer(nn.Module, LycorisLayer): # Move new weights to device self._move_adapter_to_device_of_base_layer(adapter_name) - self.set_adapter(self.active_adapters) + self.set_adapter(self.active_adapters, inference_mode=inference_mode) def get_delta_weight(self, adapter_name: str) -> torch.Tensor: # https://github.com/KohakuBlueleaf/LyCORIS/blob/eb460098187f752a5d66406d3affade6f0a07ece/lycoris/modules/loha.py#L178 diff --git a/src/peft/tuners/lokr/layer.py b/src/peft/tuners/lokr/layer.py index c898065c..295193bf 100644 --- a/src/peft/tuners/lokr/layer.py +++ b/src/peft/tuners/lokr/layer.py @@ -166,6 +166,7 @@ class LoKrLayer(nn.Module, LycorisLayer): use_effective_conv2d: bool, decompose_both: bool, decompose_factor: int, + inference_mode: bool = False, **kwargs, ) -> None: """Internal function to create lokr adapter @@ -251,7 +252,7 @@ class LoKrLayer(nn.Module, LycorisLayer): # Move new weights to device self._move_adapter_to_device_of_base_layer(adapter_name) - self.set_adapter(self.active_adapters) + self.set_adapter(self.active_adapters, inference_mode=inference_mode) def get_delta_weight(self, adapter_name: str) -> torch.Tensor: # https://github.com/KohakuBlueleaf/LyCORIS/blob/e4259b870d3354a9615a96be61cb5d07455c58ea/lycoris/modules/lokr.py#L224 diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 84fd955b..87576f1d 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -208,6 +208,7 @@ class LoraLayer(BaseTunerLayer): lora_bias: bool = False, arrow_config: ArrowConfig = None, qalora_group_size: int = 32, + inference_mode: bool = False, **kwargs, ): # collect the kwargs @@ -282,7 +283,7 @@ class LoraLayer(BaseTunerLayer): if adapter_name in self.lora_variant: self.lora_variant[adapter_name].init(self, **kwargs) - self.set_adapter(self.active_adapters) + self.set_adapter(self.active_adapters, inference_mode=inference_mode) # Check for adapters that were added or removed from the arrow_model. # The arrow model may be modified after creation by adding new experts @@ -925,6 +926,7 @@ class Embedding(nn.Module, LoraLayer): use_dora, lora_bias, arrow_config: ArrowConfig = None, + inference_mode: bool = False, **kwargs, ): # collect the kwargs @@ -971,7 +973,7 @@ class Embedding(nn.Module, LoraLayer): if adapter_name in self.lora_variant: self.lora_variant[adapter_name].init(self, **kwargs) - self.set_adapter(self.active_adapters) + self.set_adapter(self.active_adapters, inference_mode=inference_mode) def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: """ @@ -1211,6 +1213,7 @@ class _ConvNd(nn.Module, LoraLayer): use_dora, lora_bias, arrow_config: ArrowConfig = None, + inference_mode: bool = False, **kwargs, ): # collect the kwargs @@ -1270,7 +1273,7 @@ class _ConvNd(nn.Module, LoraLayer): if adapter_name in self.lora_variant: self.lora_variant[adapter_name].init(self, **kwargs) - self.set_adapter(self.active_adapters) + self.set_adapter(self.active_adapters, inference_mode=inference_mode) def _get_dora_factor_view(self): return (-1,) + (1,) * (self._kernel_dim - 1) @@ -1990,6 +1993,7 @@ class ParamWrapper(nn.Module, LoraLayer): use_qalora: bool = False, lora_bias: bool = False, qalora_group_size: int = 32, + inference_mode: bool = False, **kwargs, ): # same method as in lora.Linear but taking into account that there can be multiple experts (3d parameter) @@ -2057,7 +2061,7 @@ class ParamWrapper(nn.Module, LoraLayer): if adapter_name in self.lora_variant: self.lora_variant[adapter_name].init(self, **kwargs) - self.set_adapter(self.active_adapters) + self.set_adapter(self.active_adapters, inference_mode=inference_mode) def _move_adapter_to_device_of_base_layer(self, adapter_name: str, device: Optional[torch.device] = None) -> None: """ diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 8f75f97c..eeb3da26 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -257,6 +257,7 @@ class LoraModel(BaseTuner): use_dora=lora_config.use_dora, lora_bias=lora_config.lora_bias, arrow_config=lora_config.arrow_config, + inference_mode=lora_config.inference_mode, ) else: if isinstance(target, ParamWrapper) and (parameter_name == target.parameter_name): @@ -430,28 +431,22 @@ class LoraModel(BaseTuner): warnings.warn(msg) self._set_adapter_layers(enabled=False) - def set_adapter(self, adapter_name: str | list[str]) -> None: - """Set the active adapter(s). - - Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is - not desired, use the following code. - - ```py - >>> for name, param in model_peft.named_parameters(): - ... if ...: # some check on name (ex. if 'lora' in name) - ... param.requires_grad = False - ``` + def set_adapter(self, adapter_name: str | list[str], inference_mode: bool = False) -> None: + """Set the active adapter(s) Args: - adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated. + adapter_name (str, list[str]): + The name(s) of the adapter(s) to set as active + inference_mode (bool, optional): + Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False. """ - self.set_auxiliary_adapters(adapter_name) + self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode) for module in self.model.modules(): if isinstance(module, LoraLayer): if module.merged: warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") module.unmerge() - module.set_adapter(adapter_name) + module.set_adapter(adapter_name, inference_mode=inference_mode) self.active_adapter = adapter_name @contextmanager diff --git a/src/peft/tuners/lora/tp_layer.py b/src/peft/tuners/lora/tp_layer.py index 2c2d66cf..7edd4d3a 100644 --- a/src/peft/tuners/lora/tp_layer.py +++ b/src/peft/tuners/lora/tp_layer.py @@ -110,6 +110,7 @@ class LoraParallelLinear(nn.Module, LoraLayer): init_method=init.xavier_normal_, input_is_parallel=True, gather_output=False, + inference_mode: bool = False, **parallel_linear_kwargs, ): # collect the kwargs @@ -182,7 +183,7 @@ class LoraParallelLinear(nn.Module, LoraLayer): if adapter_name in self.lora_variant: self.lora_variant[adapter_name].init(self, **kwargs) - self.set_adapter(self.active_adapters) + self.set_adapter(self.active_adapters, inference_mode=inference_mode) def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any): self._check_forward_args(x, *args, **kwargs) diff --git a/src/peft/tuners/lycoris_utils.py b/src/peft/tuners/lycoris_utils.py index 517aab1f..1513738e 100644 --- a/src/peft/tuners/lycoris_utils.py +++ b/src/peft/tuners/lycoris_utils.py @@ -389,28 +389,22 @@ class LycorisTuner(BaseTuner): """ return self._unload_and_optionally_merge(merge=False) - def set_adapter(self, adapter_name: str | list[str]) -> None: + def set_adapter(self, adapter_name: str | list[str], inference_mode: bool = False) -> None: """Set the active adapter(s). - Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is - not desired, use the following code. - - ```py - >>> for name, param in model_peft.named_parameters(): - ... if ...: # some check on name (ex. if 'lora' in name) - ... param.requires_grad = False - ``` - Args: - adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated. + adapter_name (`str` or `list[str]`): + Name(s) of the adapter(s) to be activated. + inference_mode (bool, optional): + Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False. """ - self.set_auxiliary_adapters(adapter_name) + self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode) for module in self.model.modules(): if isinstance(module, LycorisLayer): if module.merged: warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") module.unmerge() - module.set_adapter(adapter_name) + module.set_adapter(adapter_name, inference_mode=inference_mode) self.active_adapter = adapter_name def delete_adapter(self, adapter_name: str) -> None: diff --git a/src/peft/tuners/miss/layer.py b/src/peft/tuners/miss/layer.py index c9186d06..dc238a03 100644 --- a/src/peft/tuners/miss/layer.py +++ b/src/peft/tuners/miss/layer.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import math import warnings from typing import Any, Optional, Union @@ -53,7 +55,8 @@ class MissLayer(BaseTunerLayer): r: int, mini_r: int, miss_dropout, - init_weights: bool, + init_weights: bool | str, + inference_mode: bool = False, **kwargs, ) -> None: """Internal function to create miss adapter @@ -101,7 +104,7 @@ class MissLayer(BaseTunerLayer): self.reset_miss_parameters_random(adapter_name) # Move new weights to device self._move_adapter_to_device_of_base_layer(adapter_name) - self.set_adapter(self.active_adapters) + self.set_adapter(self.active_adapters, inference_mode=inference_mode) def reset_miss_parameters(self, adapter_name: str, r): self.miss_block[adapter_name] = nn.Parameter(torch.zeros(r, self.out_features), requires_grad=True) diff --git a/src/peft/tuners/miss/model.py b/src/peft/tuners/miss/model.py index 76d8cf97..cfe95a87 100644 --- a/src/peft/tuners/miss/model.py +++ b/src/peft/tuners/miss/model.py @@ -243,14 +243,14 @@ class MissModel(BaseTuner): warnings.warn(msg) self._set_adapter_layers(enabled=False) - def set_adapter(self, adapter_name): - self.set_auxiliary_adapters(adapter_name) + def set_adapter(self, adapter_name, inference_mode: bool = False): + self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode) for module in self.model.modules(): if isinstance(module, MissLayer): if module.merged: warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") module.unmerge() - module.set_adapter(adapter_name) + module.set_adapter(adapter_name, inference_mode=inference_mode) self.active_adapter = adapter_name @staticmethod diff --git a/src/peft/tuners/mixed/model.py b/src/peft/tuners/mixed/model.py index b5123e20..8b5669eb 100644 --- a/src/peft/tuners/mixed/model.py +++ b/src/peft/tuners/mixed/model.py @@ -225,13 +225,14 @@ class MixedModel(BaseTuner): warnings.warn(msg) self._set_adapter_layers(enabled=False) - def set_adapter(self, adapter_name: Union[str, list[str]]) -> None: + def set_adapter(self, adapter_name: Union[str, list[str]], inference_mode: bool = False) -> None: + self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode) for module in self.model.modules(): if isinstance(module, Layers): if module.merged: warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") module.unmerge() - module.set_adapter(adapter_name) + module.set_adapter(adapter_name, inference_mode=inference_mode) self.active_adapter = adapter_name @staticmethod diff --git a/src/peft/tuners/oft/layer.py b/src/peft/tuners/oft/layer.py index 561eb3ae..6b14d015 100644 --- a/src/peft/tuners/oft/layer.py +++ b/src/peft/tuners/oft/layer.py @@ -412,6 +412,8 @@ class OFTLayer(BaseTunerLayer): init_weights, use_cayley_neumann, num_cayley_neumann_terms, + inference_mode: bool = False, + **kwargs, ): """ Update the linear layer with trainable OFT weights. Override for other layer types. @@ -479,7 +481,7 @@ class OFTLayer(BaseTunerLayer): # Move new weights to device self._move_adapter_to_device_of_base_layer(adapter_name) - self.set_adapter(self.active_adapters) + self.set_adapter(self.active_adapters, inference_mode=inference_mode) def reset_oft_parameters(self, adapter_name, init_weights): """ @@ -716,6 +718,8 @@ class Conv2d(nn.Module, OFTLayer): init_weights, use_cayley_neumann, num_cayley_neumann_terms, + inference_mode: bool = False, + **kwargs, ): """ Update the conv2d layer with trainable OFT weights. @@ -777,7 +781,7 @@ class Conv2d(nn.Module, OFTLayer): # Move new weights to device self._move_adapter_to_device_of_base_layer(adapter_name) - self.set_adapter(self.active_adapters) + self.set_adapter(self.active_adapters, inference_mode=inference_mode) def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None: """ diff --git a/src/peft/tuners/oft/model.py b/src/peft/tuners/oft/model.py index 24be5dc4..a21063d3 100644 --- a/src/peft/tuners/oft/model.py +++ b/src/peft/tuners/oft/model.py @@ -306,28 +306,24 @@ class OFTModel(BaseTuner): warnings.warn(msg) self._set_adapter_layers(enabled=False) - def set_adapter(self, adapter_name): + def set_adapter(self, adapter_name, inference_mode: bool = False): """Set the active adapter(s). Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is not desired, use the following code. - ```py - >>> for name, param in model_peft.named_parameters(): - ... if ...: # some check on name (ex. if 'lora' in name) - ... param.requires_grad = False - ``` - - Args: - adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated. + adapter_name (`str` or `list[str]`): + Name(s) of the adapter(s) to be activated. + inference_mode (bool, optional): + Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False. """ - self.set_auxiliary_adapters(adapter_name) + self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode) for module in self.model.modules(): if isinstance(module, OFTLayer): if module.merged: warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") module.unmerge() - module.set_adapter(adapter_name) + module.set_adapter(adapter_name, inference_mode=inference_mode) self.active_adapter = adapter_name def _check_merge_allowed(self): diff --git a/src/peft/tuners/poly/layer.py b/src/peft/tuners/poly/layer.py index 5169d9ba..2f700997 100644 --- a/src/peft/tuners/poly/layer.py +++ b/src/peft/tuners/poly/layer.py @@ -51,7 +51,7 @@ class PolyLayer(BaseTunerLayer): self.in_features = in_features self.out_features = out_features - def update_layer(self, adapter_name, poly_config): + def update_layer(self, adapter_name, poly_config, inference_mode: bool = False, **kwargs): if poly_config.r <= 0: raise ValueError(f"`r` should be a positive integer value but the value passed is {poly_config.r}") @@ -82,7 +82,7 @@ class PolyLayer(BaseTunerLayer): self.reset_poly_parameters(adapter_name, init_weights=poly_config.init_weights) self._move_adapter_to_device_of_base_layer(adapter_name) - self.set_adapter(self.active_adapters) + self.set_adapter(self.active_adapters, inference_mode=inference_mode) def reset_poly_parameters(self, adapter_name, init_weights): if adapter_name in self.poly_lora_A.keys(): diff --git a/src/peft/tuners/poly/model.py b/src/peft/tuners/poly/model.py index 85c73254..ef84dc2e 100644 --- a/src/peft/tuners/poly/model.py +++ b/src/peft/tuners/poly/model.py @@ -135,11 +135,11 @@ class PolyModel(BaseTuner): def disable_adapter_layers(self): self._set_adapter_layers(enabled=False) - def set_adapter(self, adapter_name): - self.set_auxiliary_adapters(adapter_name) + def set_adapter(self, adapter_name, inference_mode: bool = False): + self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode) for module in self.model.modules(): if isinstance(module, PolyLayer): - module.set_adapter(adapter_name) + module.set_adapter(adapter_name, inference_mode=inference_mode) def _prepare_adapter_config(self, peft_config, model_config): if peft_config.target_modules is None: diff --git a/src/peft/tuners/randlora/layer.py b/src/peft/tuners/randlora/layer.py index 9940f19f..77ecbdaf 100644 --- a/src/peft/tuners/randlora/layer.py +++ b/src/peft/tuners/randlora/layer.py @@ -99,6 +99,8 @@ class RandLoraLayer(BaseTunerLayer): randlora_alpha, randlora_dropout, init_weights, + inference_mode: bool = False, + **kwargs, ): if r <= 0: raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") @@ -166,7 +168,7 @@ class RandLoraLayer(BaseTunerLayer): self.reset_randlora_parameters(adapter_name) self._move_adapter_to_device_of_base_layer(adapter_name) - self.set_adapter(self.active_adapters) + self.set_adapter(self.active_adapters, inference_mode=inference_mode) def reset_randlora_parameters(self, adapter_name): if adapter_name in self.randlora_lambda.keys(): diff --git a/src/peft/tuners/randlora/model.py b/src/peft/tuners/randlora/model.py index b494d9a6..67f9177b 100644 --- a/src/peft/tuners/randlora/model.py +++ b/src/peft/tuners/randlora/model.py @@ -456,14 +456,14 @@ class RandLoraModel(BaseTuner): warnings.warn(msg) self._set_adapter_layers(enabled=False) - def set_adapter(self, adapter_name): - self.set_auxiliary_adapters(adapter_name) + def set_adapter(self, adapter_name, inference_mode: bool = False): + self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode) for module in self.model.modules(): if isinstance(module, RandLoraLayer): if module.merged: warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") module.unmerge() - module.set_adapter(adapter_name) + module.set_adapter(adapter_name, inference_mode=inference_mode) self.active_adapter = adapter_name @staticmethod diff --git a/src/peft/tuners/road/layer.py b/src/peft/tuners/road/layer.py index d30aed30..d59dc056 100644 --- a/src/peft/tuners/road/layer.py +++ b/src/peft/tuners/road/layer.py @@ -81,6 +81,7 @@ class RoadLayer(BaseTunerLayer): variant, group_size, init_weights, + inference_mode: bool = False, ): self.variant[adapter_name] = variant self.group_size[adapter_name] = group_size @@ -107,7 +108,7 @@ class RoadLayer(BaseTunerLayer): self.reset_parameters(adapter_name, init_weights) self._move_adapter_to_device_of_base_layer(adapter_name) - self.set_adapter(self.active_adapters) + self.set_adapter(self.active_adapters, inference_mode=inference_mode) def reset_parameters(self, adapter_name, init_weights): if init_weights is False: diff --git a/src/peft/tuners/road/model.py b/src/peft/tuners/road/model.py index a522c70a..96f1f9f9 100644 --- a/src/peft/tuners/road/model.py +++ b/src/peft/tuners/road/model.py @@ -188,25 +188,19 @@ class RoadModel(BaseTuner): def enable_adapter_layers(self) -> None: self._set_adapter_layers(enabled=True) - def set_adapter(self, adapter_name: str | list[str]) -> None: + def set_adapter(self, adapter_name: str | list[str], inference_mode: bool = False) -> None: """Set the active adapter(s). - Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is - not desired, use the following code. - - ```py - >>> for name, param in model_peft.named_parameters(): - ... if ...: # some check on name (ex. if 'lora' in name) - ... param.requires_grad = False - ``` - Args: - adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated. + adapter_name (`str` or `list[str]`): + Name(s) of the adapter(s) to be activated. + inference_mode (bool, optional): + Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False. """ - self.set_auxiliary_adapters(adapter_name) + self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode) for module in self.model.modules(): if isinstance(module, RoadLayer): - module.set_adapter(adapter_name) + module.set_adapter(adapter_name, inference_mode=inference_mode) self.active_adapter = adapter_name def __getattr__(self, name: str): diff --git a/src/peft/tuners/shira/layer.py b/src/peft/tuners/shira/layer.py index 15007640..15fb4f3c 100644 --- a/src/peft/tuners/shira/layer.py +++ b/src/peft/tuners/shira/layer.py @@ -57,6 +57,8 @@ class ShiraLayer(BaseTunerLayer): mask, r, init_weights: bool = True, + inference_mode: bool = False, + **kwargs, ): if r <= 0: raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") @@ -95,7 +97,7 @@ class ShiraLayer(BaseTunerLayer): ) self._move_adapter_to_device_of_base_layer(adapter_name) - self.set_adapter(self.active_adapters) + self.set_adapter(self.active_adapters, inference_mode=inference_mode) def reset_shira_parameters(self, adapter_name): nn.init.zeros_(self.shira_weight[adapter_name]) diff --git a/src/peft/tuners/shira/model.py b/src/peft/tuners/shira/model.py index f06aa6ba..3110d9db 100644 --- a/src/peft/tuners/shira/model.py +++ b/src/peft/tuners/shira/model.py @@ -226,14 +226,14 @@ class ShiraModel(BaseTuner): def disable_adapter_layers(self): self._set_adapter_layers(enabled=False) - def set_adapter(self, adapter_name): - self.set_auxiliary_adapters(adapter_name) + def set_adapter(self, adapter_name, inference_mode: bool = False): + self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode) for module in self.model.modules(): if isinstance(module, ShiraLayer): if module.merged: warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") module.unmerge() - module.set_adapter(adapter_name) + module.set_adapter(adapter_name, inference_mode=inference_mode) self.active_adapter = adapter_name @staticmethod diff --git a/src/peft/tuners/trainable_tokens/model.py b/src/peft/tuners/trainable_tokens/model.py index 582a31f6..ff359370 100644 --- a/src/peft/tuners/trainable_tokens/model.py +++ b/src/peft/tuners/trainable_tokens/model.py @@ -204,27 +204,21 @@ class TrainableTokensModel(BaseTuner): """ self._set_adapter_layers(enabled=False) - def set_adapter(self, adapter_name: str | list[str]) -> None: + def set_adapter(self, adapter_name: str | list[str], inference_mode: bool = False) -> None: """Set the active adapter(s). - Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is - not desired, use the following code. - - ```py - >>> for name, param in model_peft.named_parameters(): - ... if ...: # some check on name (ex. if 'lora' in name) - ... param.requires_grad = False - ``` - Args: - adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated. + adapter_name (`str` or `list[str]`): + Name(s) of the adapter(s) to be activated. + inference_mode (bool, optional): + Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False. """ for module in self.model.modules(): if isinstance(module, TrainableTokensLayer): if module.merged: warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") module.unmerge() - module.set_adapter(adapter_name) + module.set_adapter(adapter_name, inference_mode=inference_mode) self.active_adapter = adapter_name def unload(self) -> torch.nn.Module: diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index 94b27285..7f0b1a77 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -710,7 +710,7 @@ class BaseTuner(nn.Module, ABC): # It's important to set the adapter here (again), because otherwise it can happen that if a 2nd adapter is # added, and it targets different layer(s) than the first adapter (which is active), then those different # layers will be activated, which we don't want. - self.set_adapter(self.active_adapters) + self.set_adapter(self.active_adapters, inference_mode=peft_config.inference_mode) self._mark_only_adapters_as_trainable(model) if self.peft_config[adapter_name].inference_mode: @@ -839,7 +839,7 @@ class BaseTuner(nn.Module, ABC): with onload_layer(module): module.unmerge() - def set_auxiliary_adapters(self, adapter_name: str | list[str]) -> None: + def set_auxiliary_adapters(self, adapter_name: str | list[str], inference_mode: bool) -> None: """ Sets the active adapter(s) on auxiliary modules. @@ -848,9 +848,11 @@ class BaseTuner(nn.Module, ABC): Args: adapter_name (`str` or `list[str]`): - The name(s) of the adapter to be set as active. The adapters must be loaded first. + The name(s) of the adapter(s) to be set as active. The adapters must be loaded first. + inference_mode (bool, optional): + Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False. """ - _set_adapter(self, adapter_name) + _set_adapter(self, adapter_name, inference_mode=inference_mode) def _delete_auxiliary_adapter(self, adapter_name: str, new_active_adapters: Optional[list[str]]) -> None: for module in self.modules(): @@ -1010,29 +1012,26 @@ class BaseTunerLayer(ABC): layer.requires_grad_(False) self._disable_adapters = True - def set_adapter(self, adapter_names: str | list[str]) -> None: + def set_adapter(self, adapter_names: str | list[str], inference_mode: bool = False) -> None: """Set the active adapter(s). - Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is - not desired, use the following code. - - ```py - >>> for name, param in model_peft.named_parameters(): - ... if ...: # some check on name (ex. if 'lora' in name) - ... param.requires_grad = False - ``` + Additionally, this function will set the specified adapter to trainable (i.e., requires_grad=True) unless + inference_mode is True. Args: - adapter_name (`str` or `List[str]`): Name of the adapter(s) to be activated. + adapter_name (`str` or `list[str]`): + The name(s) of the adapter(s) to set as active. + inference_mode (bool, optional): + Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False. """ if isinstance(adapter_names, str): adapter_names = [adapter_names] - # Deactivate grads on the inactive adapter and activate grads on the active adapter + # Deactivate grads on the inactive adapter and activate grads on the active adapter (if not in inference mode) for layer_name in self.adapter_layer_names: module_dict = getattr(self, layer_name) for key, layer in module_dict.items(): - if key in adapter_names: + if (key in adapter_names) and (not inference_mode): # Note: It is possible that not a single layer is called with requires_grad_(True) here. This may # happen if a completely different adapter layer is being activated. layer.requires_grad_(True) diff --git a/src/peft/tuners/vblora/layer.py b/src/peft/tuners/vblora/layer.py index ea34c222..ea2f0cca 100644 --- a/src/peft/tuners/vblora/layer.py +++ b/src/peft/tuners/vblora/layer.py @@ -68,6 +68,8 @@ class VBLoRALayer(BaseTunerLayer): vector_length: float, vblora_dropout: float = 0.0, init_logits_std: float = 0.01, + inference_mode: bool = False, + **kwargs, ): if r <= 0: raise ValueError(f"`r` {r} should be a positive integer value") @@ -97,7 +99,7 @@ class VBLoRALayer(BaseTunerLayer): self.vblora_vector_bank = vblora_vector_bank self.reset_vblora_logits(adapter_name, init_logits_std) self._move_adapter_to_device_of_base_layer(adapter_name) - self.set_adapter(self.active_adapters) + self.set_adapter(self.active_adapters, inference_mode=inference_mode) def reset_vblora_logits(self, adapter_name, init_logits_std): if adapter_name in self.vblora_logits_A.keys(): diff --git a/src/peft/tuners/vblora/model.py b/src/peft/tuners/vblora/model.py index 4dd6ffc4..81636f28 100644 --- a/src/peft/tuners/vblora/model.py +++ b/src/peft/tuners/vblora/model.py @@ -278,28 +278,22 @@ class VBLoRAModel(BaseTuner): warnings.warn(msg) self._set_adapter_layers(enabled=False) - def set_adapter(self, adapter_name: str | list[str]) -> None: + def set_adapter(self, adapter_name: str | list[str], inference_mode: bool = False) -> None: """Set the active adapter(s). - Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is - not desired, use the following code. - - ```py - >>> for name, param in model_peft.named_parameters(): - ... if ...: # some check on name (ex. if 'lora' in name) - ... param.requires_grad = False - ``` - Args: - adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated. + adapter_name (`str` or `list[str]`): + Name(s) of the adapter(s) to be activated. + inference_mode (bool, optional): + Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False. """ - self.set_auxiliary_adapters(adapter_name) + self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode) for module in self.model.modules(): if isinstance(module, VBLoRALayer): if module.merged: warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") module.unmerge() - module.set_adapter(adapter_name) + module.set_adapter(adapter_name, inference_mode=inference_mode) self.active_adapter = adapter_name @staticmethod diff --git a/src/peft/tuners/vera/layer.py b/src/peft/tuners/vera/layer.py index 725058d4..7559eea4 100644 --- a/src/peft/tuners/vera/layer.py +++ b/src/peft/tuners/vera/layer.py @@ -74,6 +74,8 @@ class VeraLayer(BaseTunerLayer): vera_dropout, init_weights, d_initial: float = 0.1, + inference_mode: bool = False, + **kwargs, ): if r <= 0: raise ValueError(f"`r` should be a positive integer value but the value passed is {r}") @@ -129,7 +131,7 @@ class VeraLayer(BaseTunerLayer): self.reset_vera_parameters(adapter_name, d_initial=d_initial) self._move_adapter_to_device_of_base_layer(adapter_name) - self.set_adapter(self.active_adapters) + self.set_adapter(self.active_adapters, inference_mode=inference_mode) def reset_vera_parameters(self, adapter_name, d_initial: float = 0.1): if adapter_name in self.vera_lambda_d.keys(): diff --git a/src/peft/tuners/vera/model.py b/src/peft/tuners/vera/model.py index 9f4ca9db..36ef7df1 100644 --- a/src/peft/tuners/vera/model.py +++ b/src/peft/tuners/vera/model.py @@ -394,14 +394,14 @@ class VeraModel(BaseTuner): warnings.warn(msg) self._set_adapter_layers(enabled=False) - def set_adapter(self, adapter_name): - self.set_auxiliary_adapters(adapter_name) + def set_adapter(self, adapter_name, inference_mode: bool = False): + self.set_auxiliary_adapters(adapter_name, inference_mode=inference_mode) for module in self.model.modules(): if isinstance(module, VeraLayer): if module.merged: warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") module.unmerge() - module.set_adapter(adapter_name) + module.set_adapter(adapter_name, inference_mode=inference_mode) self.active_adapter = adapter_name @staticmethod diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index 72e6ab8f..0d0b24b8 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -430,11 +430,14 @@ class AuxiliaryTrainingWrapper(torch.nn.Module): """ raise NotImplementedError - def set_adapter(self, adapter_names: Union[str, list[str]]): + def set_adapter(self, adapter_names: Union[str, list[str]], inference_mode: bool = False) -> None: """Set the active adapter Args: - adapter_name (str): The name of the adapter to set as active + adapter_names (str or list[str]): + The name(s) of the adapter(s) to set as active + inference_mode (bool, optional): + Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False. """ if isinstance(adapter_names, str): self._active_adapter = adapter_names @@ -572,20 +575,17 @@ class ModulesToSaveWrapper(AuxiliaryTrainingWrapper): return adapter_name_to_set - def set_adapter(self, adapter_names: Union[str, list[str]]): + def set_adapter(self, adapter_names: Union[str, list[str]], inference_mode: bool = False) -> None: """Set the active adapter - Additionally, this function will set the specified adapter to trainable (i.e., requires_grad=True). If this is - not desired, use the following code. - - ```py - >>> for name, param in model_peft.named_parameters(): - ... if ...: # some check on name (ex. if 'lora' in name) - ... param.requires_grad = False - ``` + Additionally, this function will set the specified adapter to trainable (i.e., requires_grad=True) unless + inference_mode is True. Args: - adapter_names (list[str], str): The name of the adapter to set as active + adapter_names (list[str], str): + The name(s) of the adapter(s) to set as active. + inference_mode (bool, optional): + Whether the activated adapter should be frozen (i.e. `requires_grad=False`). Default is False. """ if isinstance(adapter_names, str): adapter_names = [adapter_names] @@ -605,7 +605,7 @@ class ModulesToSaveWrapper(AuxiliaryTrainingWrapper): for currently_active_adapter_name in self.active_adapters: self.modules_to_save[currently_active_adapter_name].requires_grad_(False) - self.modules_to_save[adapter_name].requires_grad_(True) + self.modules_to_save[adapter_name].requires_grad_(not inference_mode) self._active_adapter = adapter_name def delete_adapter(self, adapter_name: str, new_active_adapters: Optional[list[str]]) -> None: @@ -810,9 +810,9 @@ class TrainableTokensWrapper(AuxiliaryTrainingWrapper): return adapter_name_to_set - def set_adapter(self, adapter_names: Union[str, list[str]]): - super().set_adapter(adapter_names) - self.token_adapter.set_adapter(adapter_names) + def set_adapter(self, adapter_names: Union[str, list[str]], inference_mode: bool = False) -> None: + super().set_adapter(adapter_names, inference_mode=inference_mode) + self.token_adapter.set_adapter(adapter_names, inference_mode=inference_mode) def delete_adapter(self, adapter_name: str, new_active_adapters: Optional[list[str]]) -> None: """ @@ -887,6 +887,7 @@ def _set_trainable( model, adapter_name, module_names, + inference_mode: bool, strict_module_check: bool = False, wrapper_cls: Optional[AuxiliaryTrainingWrapper] = None, activate_adapter: bool = True, @@ -926,13 +927,13 @@ def _set_trainable( parent, target, target_name = _get_submodules(model, key) if isinstance(target, wrapper_cls): target.update(adapter_name, **wrapper_kwargs) - target.set_adapter(target.active_adapter) + target.set_adapter(target.active_adapter, inference_mode=inference_mode) else: new_module = wrapper_cls(target, adapter_name, **wrapper_kwargs) if activate_adapter: - new_module.set_adapter(adapter_name) + new_module.set_adapter(adapter_name, inference_mode=inference_mode) else: - new_module.set_adapter([]) + new_module.set_adapter([], inference_mode=inference_mode) setattr(parent, target_name, new_module) trainable_modules.append(new_module) found_modules.add(target_name) @@ -946,7 +947,7 @@ def _set_trainable( return trainable_modules -def _set_adapter(model, adapter_name): +def _set_adapter(model, adapter_name: str | list[str], inference_mode: bool = False): for module in model.modules(): if isinstance(module, AuxiliaryTrainingWrapper): # only check the adapter_name if we actually encounter a AuxiliaryTrainingWrapper, otherwise we don't care @@ -956,10 +957,10 @@ def _set_adapter(model, adapter_name): # module if adapter_name_to_set in module._adapters: module.enable_adapters(True) - module.set_adapter(adapter_name_to_set) + module.set_adapter(adapter_name_to_set, inference_mode=inference_mode) else: module.enable_adapters(False) - module.set_adapter([]) + module.set_adapter([], inference_mode=inference_mode) def _prepare_prompt_learning_config(peft_config, model_config): @@ -1326,6 +1327,7 @@ def set_additional_trainable_modules(model, peft_config, model_config, adapter_n _set_trainable( model, adapter_name, + inference_mode=peft_config.inference_mode, module_names=getattr(peft_config, "modules_to_save", None), activate_adapter=activate_adapter, ) @@ -1351,6 +1353,7 @@ def set_additional_trainable_modules(model, peft_config, model_config, adapter_n _set_trainable( model, adapter_name, + inference_mode=peft_config.inference_mode, module_names=[target_layer], strict_module_check=True, wrapper_cls=TrainableTokensWrapper, @@ -1374,6 +1377,7 @@ def set_additional_trainable_modules(model, peft_config, model_config, adapter_n _set_trainable( model, adapter_name, + inference_mode=peft_config.inference_mode, module_names=module_keys, strict_module_check=True, wrapper_cls=TrainableTokensWrapper, diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 1dd531a0..a2af81c9 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -3685,11 +3685,20 @@ class TestRequiresGrad: "base_model.model.lin0.ia3_l.adapter1", ) + @pytest.mark.xfail(strict=True) def test_requires_grad_adalora_different_targets(self): # test two different AdaLora adapters that target different modules + + # Note: This test is expected to fail because first loading one adapter, then the next adapter with + # inference_mode=True incorrectly leads to the requires_grad of the first adapter being turned to False. This is + # of course not desired but has yet to be fixed. In practice, it's unlikely that a user would pass + # inference_mode=True for add_adapter, this flag is mostly being used when calling PeftModel.from_pretrained, so + # we accept this issue for now. Note that only for AdaLoRA do we even need to pass inference_mode=True here, + # other PEFT methods don't require this. config0 = AdaLoraConfig(target_modules=["lin0"], total_step=1) peft_model = get_peft_model(MLP(), config0) + # note: AdaLoRA cannot have more than 1 trainable active adapter, hence enable inference_mode config1 = AdaLoraConfig(target_modules=["lin1"], total_step=1, inference_mode=True) peft_model.add_adapter("adapter1", config1) @@ -3731,11 +3740,20 @@ class TestRequiresGrad: "base_model.model.lin1.lora_E.adapter1", ) + @pytest.mark.xfail(strict=True) def test_requires_grad_adalora_same_targets(self): # same as previous test, except that AdaLora adapters target the same layer + + # Note: This test is expected to fail because first loading one adapter, then the next adapter with + # inference_mode=True incorrectly leads to the requires_grad of the first adapter being turned to False. This is + # of course not desired but has yet to be fixed. In practice, it's unlikely that a user would pass + # inference_mode=True for add_adapter, this flag is mostly being used when calling PeftModel.from_pretrained, so + # we accept this issue for now. Note that only for AdaLoRA do we even need to pass inference_mode=True here, + # other PEFT methods don't require this. config0 = AdaLoraConfig(target_modules=["lin0"], total_step=1) peft_model = get_peft_model(MLP(), config0) + # note: AdaLoRA cannot have more than 1 trainable active adapter, hence enable inference_mode config1 = AdaLoraConfig(target_modules=["lin0"], total_step=1, inference_mode=True) peft_model.add_adapter("adapter1", config1) @@ -3944,7 +3962,7 @@ class TestRequiresGrad: config0 = LoHaConfig(target_modules=["lin0"]) peft_model = get_peft_model(MLP(), config0) - config1 = LoHaConfig(target_modules=["lin1"], inference_mode=True) + config1 = LoHaConfig(target_modules=["lin1"]) peft_model.add_adapter("adapter1", config1) # active adapter is still "default" @@ -3994,7 +4012,7 @@ class TestRequiresGrad: config0 = LoHaConfig(target_modules=["lin0"]) peft_model = get_peft_model(MLP(), config0) - config1 = LoHaConfig(target_modules=["lin0"], inference_mode=True) + config1 = LoHaConfig(target_modules=["lin0"]) peft_model.add_adapter("adapter1", config1) # active adapter is still "default" @@ -4045,7 +4063,7 @@ class TestRequiresGrad: config0 = LoKrConfig(target_modules=["lin0"]) peft_model = get_peft_model(MLP(), config0) - config1 = LoKrConfig(target_modules=["lin1"], inference_mode=True) + config1 = LoKrConfig(target_modules=["lin1"]) peft_model.add_adapter("adapter1", config1) # active adapter is still "default" @@ -4087,7 +4105,7 @@ class TestRequiresGrad: config0 = LoKrConfig(target_modules=["lin0"]) peft_model = get_peft_model(MLP(), config0) - config1 = LoKrConfig(target_modules=["lin0"], inference_mode=True) + config1 = LoKrConfig(target_modules=["lin0"]) peft_model.add_adapter("adapter1", config1) # active adapter is still "default" @@ -4130,7 +4148,7 @@ class TestRequiresGrad: config0 = OFTConfig(target_modules=["lin0"], r=2, oft_block_size=0) peft_model = get_peft_model(MLP(), config0) - config1 = OFTConfig(target_modules=["lin1"], r=2, oft_block_size=0, inference_mode=True) + config1 = OFTConfig(target_modules=["lin1"], r=2, oft_block_size=0) peft_model.add_adapter("adapter1", config1) # active adapter is still "default" @@ -4168,7 +4186,7 @@ class TestRequiresGrad: config0 = OFTConfig(target_modules=["lin0"], r=2, oft_block_size=0) peft_model = get_peft_model(MLP(), config0) - config1 = OFTConfig(target_modules=["lin0"], r=2, oft_block_size=0, inference_mode=True) + config1 = OFTConfig(target_modules=["lin0"], r=2, oft_block_size=0) peft_model.add_adapter("adapter1", config1) # active adapter is still "default" @@ -4207,7 +4225,7 @@ class TestRequiresGrad: config0 = HRAConfig(target_modules=["lin0"]) peft_model = get_peft_model(MLP(), config0) - config1 = HRAConfig(target_modules=["lin1"], inference_mode=True) + config1 = HRAConfig(target_modules=["lin1"]) peft_model.add_adapter("adapter1", config1) # active adapter is still "default" @@ -4245,7 +4263,7 @@ class TestRequiresGrad: config0 = HRAConfig(target_modules=["lin0"]) peft_model = get_peft_model(MLP(), config0) - config1 = HRAConfig(target_modules=["lin0"], inference_mode=True) + config1 = HRAConfig(target_modules=["lin0"]) peft_model.add_adapter("adapter1", config1) # active adapter is still "default" @@ -4284,7 +4302,7 @@ class TestRequiresGrad: config0 = BoneConfig(target_modules=["lin0"], r=2) peft_model = get_peft_model(MLP(), config0) - config1 = BoneConfig(target_modules=["lin1"], r=2, inference_mode=True) + config1 = BoneConfig(target_modules=["lin1"], r=2) peft_model.add_adapter("adapter1", config1) # active adapter is still "default" @@ -4322,7 +4340,7 @@ class TestRequiresGrad: config0 = BoneConfig(target_modules=["lin0"], r=2) peft_model = get_peft_model(MLP(), config0) - config1 = BoneConfig(target_modules=["lin0"], r=2, inference_mode=True) + config1 = BoneConfig(target_modules=["lin0"], r=2) peft_model.add_adapter("adapter1", config1) # active adapter is still "default" @@ -4361,7 +4379,7 @@ class TestRequiresGrad: config0 = MissConfig(target_modules=["lin0"], r=2) peft_model = get_peft_model(MLP(), config0) - config1 = MissConfig(target_modules=["lin1"], r=2, inference_mode=True) + config1 = MissConfig(target_modules=["lin1"], r=2) peft_model.add_adapter("adapter1", config1) # active adapter is still "default" @@ -4399,7 +4417,7 @@ class TestRequiresGrad: config0 = MissConfig(target_modules=["lin0"], r=2) peft_model = get_peft_model(MLP(), config0) - config1 = MissConfig(target_modules=["lin0"], r=2, inference_mode=True) + config1 = MissConfig(target_modules=["lin0"], r=2) peft_model.add_adapter("adapter1", config1) # active adapter is still "default" @@ -4438,7 +4456,7 @@ class TestRequiresGrad: config0 = BOFTConfig(target_modules=["lin0"], boft_block_size=2) peft_model = get_peft_model(MLP2(), config0) - config1 = BOFTConfig(target_modules=["lin1"], boft_block_size=2, inference_mode=True) + config1 = BOFTConfig(target_modules=["lin1"], boft_block_size=2) peft_model.add_adapter("adapter1", config1) # active pter is still "default" @@ -4480,7 +4498,7 @@ class TestRequiresGrad: config0 = BOFTConfig(target_modules=["lin1"], boft_block_size=2) peft_model = get_peft_model(MLP(), config0) - config1 = BOFTConfig(target_modules=["lin1"], boft_block_size=2, inference_mode=True) + config1 = BOFTConfig(target_modules=["lin1"], boft_block_size=2) peft_model.add_adapter("adapter1", config1) # active adapter is still "default" @@ -4524,10 +4542,7 @@ class TestRequiresGrad: ) peft_model = get_peft_model(MLP_LayerNorm(), config0) - config1 = LNTuningConfig( - target_modules=["layernorm1"], - inference_mode=True, - ) + config1 = LNTuningConfig(target_modules=["layernorm1"]) peft_model.add_adapter("adapter1", config1) # active adapter is still "default" @@ -4571,7 +4586,7 @@ class TestRequiresGrad: ) peft_model = get_peft_model(MLP_LayerNorm(), config0) - config1 = LNTuningConfig(target_modules=["layernorm0"], inference_mode=True) + config1 = LNTuningConfig(target_modules=["layernorm0"]) peft_model.add_adapter("adapter1", config1) # active adapter is still "default" @@ -4990,7 +5005,7 @@ class TestRequiresGrad: config0 = FourierFTConfig(n_frequency=10, target_modules=["lin0"]) peft_model = get_peft_model(MLP(), config0) - config1 = FourierFTConfig(n_frequency=10, target_modules=["lin1"], inference_mode=True) + config1 = FourierFTConfig(n_frequency=10, target_modules=["lin1"]) peft_model.add_adapter("adapter1", config1) # active adapter is still "default" @@ -5028,7 +5043,7 @@ class TestRequiresGrad: config0 = FourierFTConfig(n_frequency=10, target_modules=["lin0"]) peft_model = get_peft_model(MLP(), config0) - config1 = FourierFTConfig(n_frequency=10, target_modules=["lin0"], inference_mode=True) + config1 = FourierFTConfig(n_frequency=10, target_modules=["lin0"]) peft_model.add_adapter("adapter1", config1) # active adapter is still "default" @@ -5062,6 +5077,177 @@ class TestRequiresGrad: "base_model.model.lin0.fourierft_spectrum.adapter1", ) + @pytest.mark.parametrize("config_cls", ALL_PEFT_CONFIG_CLASSES) + @pytest.mark.parametrize("is_trainable", [False, True]) # note: default is False + def test_loading_model_requires_grad_set_correctly(self, config_cls, is_trainable, tmp_path): + # Test that when loading PeftModel and then loading another adapter, the requires_grad is set correctly and + # is_trainable is respected. + # See #2759 + model = DeepMLP(size=256) # a size that works with all adapters + extra_kwargs = {} + if config_cls == IA3Config: + extra_kwargs["feedforward_modules"] = [] + config = config_cls(target_modules=["layers.0.lin0"], **extra_kwargs) + + if config_cls == TrainableTokensConfig: # TrainbleTokens requires a different base model and config + model = ModelEmbConv1D() + config = config_cls(target_modules=["emb"], token_indices=[0, 2, 4]) + + model = get_peft_model(model, config) + model.save_pretrained(tmp_path) + del model + + model = DeepMLP(size=256) + if config_cls == TrainableTokensConfig: # TrainbleTokens requires a different base + model = ModelEmbConv1D() + model = PeftModel.from_pretrained(model, tmp_path, is_trainable=is_trainable) + + if is_trainable: + for name, param in model.named_parameters(): + if ".default" in name: + assert param.requires_grad + else: + assert not param.requires_grad + else: + assert all(not p.requires_grad for p in model.parameters()) + + # load one more adapter; this adapter is not automatically activated + model.load_adapter(tmp_path, adapter_name="other", is_trainable=is_trainable) + if is_trainable: + for name, param in model.named_parameters(): + if ".default" in name: + assert param.requires_grad + else: + assert not param.requires_grad + else: + assert all(not p.requires_grad for p in model.parameters()) + + @pytest.mark.parametrize("config_cls", ALL_PEFT_CONFIG_CLASSES) + @pytest.mark.parametrize("is_trainable", [False, True]) # note: default is False + def test_loading_model_with_modules_to_save_requires_grad_set_correctly(self, config_cls, is_trainable, tmp_path): + # Same test as above, but with modules_to_save + if config_cls == TrainableTokensConfig: + pytest.skip(reason="Trainable tokens does not support modules_to_save") + + model = DeepMLP(size=256) # a size that works with all adapters + extra_kwargs = {} + if config_cls == IA3Config: + extra_kwargs["feedforward_modules"] = [] + # targeting the different modules with modules_to_save: + config = config_cls(target_modules=["layers.0.lin0"], modules_to_save=["layers.0.lin1"], **extra_kwargs) + model = get_peft_model(model, config) + model.save_pretrained(tmp_path) + del model + + model = DeepMLP(size=256) + model = PeftModel.from_pretrained(model, tmp_path, is_trainable=is_trainable) + + if is_trainable: + for name, param in model.named_parameters(): + if ".default" in name: + assert param.requires_grad + else: + assert not param.requires_grad + else: + assert all(not p.requires_grad for p in model.parameters()) + + # load one more adapter + model.load_adapter(tmp_path, adapter_name="other", is_trainable=is_trainable) + if is_trainable: + for name, param in model.named_parameters(): + if ".default" in name: + assert param.requires_grad + else: + assert not param.requires_grad + else: + assert all(not p.requires_grad for p in model.parameters()) + + @pytest.mark.parametrize("is_trainable", [False, True]) # note: default is False + def test_loading_model_with_trainble_tokens_requires_grad_set_correctly(self, is_trainable, tmp_path): + model = ModelEmbConv1D() + # targeting the same modules with modules_to_save: + config = LoraConfig(target_modules=["lin0"], trainable_token_indices={"emb": [0]}) + model = get_peft_model(model, config) + model.save_pretrained(tmp_path) + del model + + model = ModelEmbConv1D() + model = PeftModel.from_pretrained(model, tmp_path, is_trainable=is_trainable) + + if is_trainable: + for name, param in model.named_parameters(): + if ".default" in name: + assert param.requires_grad + else: + assert not param.requires_grad + else: + assert all(not p.requires_grad for p in model.parameters()) + + # load one more adapter + model.load_adapter(tmp_path, adapter_name="other", is_trainable=is_trainable) + if is_trainable: + for name, param in model.named_parameters(): + if ".default" in name: + assert param.requires_grad + else: + assert not param.requires_grad + else: + assert all(not p.requires_grad for p in model.parameters()) + + @pytest.mark.xfail(strict=True) + @pytest.mark.parametrize("config_cls", [LoraConfig]) # no need to check each method, they all fail + def test_loading_model_requires_grad_set_correctly_switch_inference_mode(self, config_cls, tmp_path): + # Same as test_loading_model_requires_grad_set_correctly but this time we first load with is_trainable=False and + # then with is_trainable=True. Loading the second adapter should not affect the requires_grad of the first + # adapter, but it does. The reason is that is_training/inference_mode is taken from the current PEFT config, but + # that config does not necessarily belong to the active adapter, creating a mismatch. + # When/If this is fixed, the check can be integrated into test_loading_model_requires_grad_set_correctly and + # this test can be deleted. + model = DeepMLP(size=256) # a size that works with all adapters + extra_kwargs = {} + config = config_cls(target_modules=["layers.0.lin0"]) + model = get_peft_model(model, config) + model.save_pretrained(tmp_path) + del model + + model = DeepMLP(size=256) + model = PeftModel.from_pretrained(model, tmp_path, is_trainable=False) + assert all(not p.requires_grad for p in model.parameters()) + + # load one more adapter; this adapter is not automatically activated + model.load_adapter(tmp_path, adapter_name="other", is_trainable=True) + params_with_grad = [n for n, p in model.named_parameters() if p.requires_grad] + expected = [ + "base_model.model.layers.0.lin0.lora_A.other.weight", + "base_model.model.layers.0.lin0.lora_B.other.weight", + ] + # this fails, instead with get ...lora_A.default.weight and ...lora_B.default.weight + assert params_with_grad == expected + + @pytest.mark.xfail(strict=True) + @pytest.mark.parametrize("config_cls", [LoraConfig]) # no need to check each method, they all fail + def test_loading_model_requires_grad_load_adapter_then_add_adapter(self, config_cls, tmp_path): + # When adding a new adapter with model.add_adapter, through the set_adapter call in update_layer, we activate + # the gradients of the first adapter, even if it's not desired. Since there is no is_trainable argument on + # add_adapter, there is no way to disable that at the moment. + # When/If this is fixed, the check can be integrated into test_loading_model_requires_grad_set_correctly and + # this test can be deleted. + model = DeepMLP(size=256) # a size that works with all adapters + extra_kwargs = {} + config = config_cls(target_modules=["layers.0.lin0"]) + model = get_peft_model(model, config) + model.save_pretrained(tmp_path) + del model + + model = DeepMLP(size=256) + model = PeftModel.from_pretrained(model, tmp_path, is_trainable=False) + assert all(not p.requires_grad for p in model.parameters()) + + # add a new adapter + model.add_adapter(adapter_name="other", peft_config=config) + params_with_grad = [n for n, p in model.named_parameters() if p.requires_grad] + assert all(not p.requires_grad for p in model.parameters()) + # this is for PEFT methods that support mixed adapter batches. MIXED_ADAPTER_TEST_CASES = [