FIX Multiple issues with target_parameters (#2710)

There are a few issues with target_parameters that are fixed in this PR.

Existing parametrizations

When using target_parameters with LoRA, after the forward call finishes,
the LoRA parametrization is removed. However, this also used to remove
all other parametrizations on the same parameter, which is bad. With
this PR, only the LoRA parametrization is removed.

Module repr

This PR also extends the __repr__ of lora.ParamWrapper to contain the
parameter name, which makes it more useful.

Extend testing

Added a tiny gpt-oss model to the target_parameters test suite.

Multiple LoRA adapters with target_parameters

There is an issue when adding a second LoRA adapter with
target_paramters, where this second adapter would not actually be
applied correctly. The corresponding unit test was too lax to notice the
bug. This is not easy to fix, so for now we forbid adding a second
adapter with target_parameters. This is very strict but it's better than
having silent errors.

Although it was possible to fix that specific issue, the solution
resulted in ever deeply nested adapters (i.e. with multiple
.base_layer). This in turn results in those infixes to be part of the
state_dict. But then we cannot load the individual adapters correctly,
except if the model is restored in the exact same order as it was
previously created. This is not normally a requirement in PEFT (e.g. I
can create a model with two adapters and later decide to load only one
of them).

In the long run, we need to think about solutions that would allow this.
It may require some form of normalization of the layers to prevent ever
deeper nesting. Also, what is ugly right now is that, given that the
LoRA lives on a module but actually targets one of possibly multiple
parameter, the LoRA weights don't actually reference said parameter in
any name. That means, purely from the state_dict, it is unclear which
parameter a LoRA weight belongs to. Ideally, this should be encoded in
the LoRA weight key.
This commit is contained in:
Benjamin Bossan
2025-08-12 13:59:29 +02:00
committed by GitHub
parent 95df499d87
commit a2c6612b12
8 changed files with 346 additions and 108 deletions

View File

@ -276,7 +276,10 @@ The same logic applies to `alpha_pattern`. If you're in doubt, don't try to get
Generally, you should use `target_modules` to target the module (e.g. `nn.Linear`). However, in some circumstances, this is not possible. E.g., in many mixture of expert (MoE) layers in HF Transformers, instead of using `nn.Linear`, an `nn.Parameter` is used. PEFT normally overwrites the `forward` method for LoRA, but for `nn.Parameter`, there is none. Therefore, to apply LoRA to that parameter, it needs to be targeted with `target_parameters`. As an example, for [Llama4](https://huggingface.co/collections/meta-llama/llama-4-67f0c30d9fe03840bc9d0164), you can pass: `target_parameters=['feed_forward.experts.gate_up_proj', 'feed_forward.experts.down_proj]`.
At the moment, this argument allows to target 2-dim or 3-dim `nn.Parameter`s. It is assumed that in the case of a 3-dim parameter, the 0th dimension is the expert dimension.
#### Caveats
- At the moment, this argument allows to target 2-dim or 3-dim `nn.Parameter`s. It is assumed that in the case of a 3-dim parameter, the 0th dimension is the expert dimension.
- It is currently not possible to add multiple LoRA adapters (via `model.add_adapter` or `model.load_adapter`) that use `target_parameters` at the same time.
## Optimizers

View File

@ -2046,14 +2046,28 @@ class ParamWrapper(nn.Module, LoraLayer):
"Something went wrong, please report this issue on PEFT: https://github.com/huggingface/peft/issues"
)
if len(base_layer.parametrizations[parameter_name]) == 1:
param_list = base_layer.parametrizations[parameter_name]
if len(param_list) == 1:
# last parametrization, we can safely remove it completely
nn.utils.parametrize.remove_parametrizations(base_layer, parameter_name, leave_parametrized=False)
else:
# TODO: If there are multiple parametrizations for the same parameter_name, we currently remove all of them,
# which is not desired. Unfortunately, PyTorch does not support this directly, so we need to take care.
# For now, remove all parametrizations.
nn.utils.parametrize.remove_parametrizations(base_layer, parameter_name, leave_parametrized=False)
return
# If there are multiple parametrizations for the same parameter_name, we only want to remove the LoRA proxy.
# Unfortunately, PyTorch does not support this directly, so we need to take care of it manually. To achieve
# this, we check the ParameterList from the back until we find the _LoraParameterProxy instance and then remove
# it.
reversed_indices = reversed(range(len(param_list)))
for i in reversed_indices:
module = param_list[i]
if isinstance(module, _LoraParameterProxy):
del param_list[i]
break
else: # no break encountered
# this should not happen, but raising an error is probably not necessary
warnings.warn(
f"Could not find any LoRA parametrization on {self}, please open an issue on "
"https://github.com/huggingface/peft/issues and report this warning."
)
def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
# same as lora.Linear.merge but not hard-coding base_layer.weight and without special cases like variants removed
@ -2137,6 +2151,10 @@ class ParamWrapper(nn.Module, LoraLayer):
def __repr__(self) -> str:
rep = super().__repr__()
idx = rep.find("(") + 1
# insert the name of the parameter to allow the repr to be disambiguous when multiple parameters on the same
# module are being targeted
rep = f"{rep[:idx]}\n parameter_name='{self.parameter_name}',{rep[idx:]}"
return "lora." + rep

View File

@ -185,6 +185,18 @@ class LoraModel(BaseTuner):
if current_key is None:
raise ValueError("Current Key shouldn't be `None`")
if lora_config.target_parameters:
# Right now, unfortunately, we don't support multiple adapters with target_parameters on the same model.
other_configs_use_target_params = any(
conf.target_parameters for key, conf in self.peft_config.items() if key != adapter_name
)
if other_configs_use_target_params:
raise ValueError(
f"Adding a LoRA config with `target_parameters={lora_config.target_parameters}` but there are "
"already other LoRA adapters on this model that use `target_parameters`. At the moment, only "
"one LoRA adapter per model with `target_parameters` is allowed."
)
# Regexp matching - Find key which matches current target_name in patterns provided
r_key = get_pattern_key(lora_config.rank_pattern.keys(), current_key)
alpha_key = get_pattern_key(lora_config.alpha_pattern.keys(), current_key)

View File

@ -722,43 +722,77 @@ class BaseTuner(nn.Module, ABC):
def _inject_parameters(
self, peft_config: PeftConfig, model: nn.Module, adapter_name: str, low_cpu_mem_usage: bool
) -> None:
# TODO very simple matching, might not cover all use cases
target_names = set(peft_config.target_parameters)
for module_name, module in model.named_modules():
for param_name, param in module.named_parameters(recurse=False):
# It is possible that the layer is already a PEFT layer and needs updating with a new adapter. In this
# case, the name of parameter would be something like `model.layers.0.experts.base_layer.weight`, i.e.
# there is a "base_layer" inserted in the name. We need to remove that, otherwise we won't be able to
# match correctly (in this case, "experts.weight" would not match).
prefix, _, suffix = module_name.rpartition(".base_layer")
"""Inject layers based on peft_config.target_modules"""
def strip_base_layer_from_name(module_name):
# It is possible that the layer is already a PEFT layer and needs updating with a new adapter. In this case,
# the name of parameter would be something like `model.layers.0.experts.base_layer.weight`, i.e. there is a
# "base_layer" inserted in the name. We need to remove that, otherwise we won't be able to match correctly
# (in this case, "experts.weight" would not match).
name = ".base_layer"
while name in module_name:
prefix, _, suffix = module_name.rpartition(name)
module_name = prefix + suffix
key = f"{module_name}.{param_name}"
# we're interested in finding the "lowest" module that contains the parameter, hence recurse=False
if (key in target_names) or any(key.endswith(f".{target_key}") for target_key in target_names):
return module_name
def create_and_replace_param(module_name, key, param_name):
# helper function to avoid duplication
parent, target, target_name = _get_submodules(model, module_name)
unwrapped_module_name = strip_base_layer_from_name(module_name)
unwrapped_module = model.get_submodule(unwrapped_module_name)
# use the class name for checking to avoid circular import
if isinstance(unwrapped_module, BaseTunerLayer) and unwrapped_module.__class__.__name__ != "ParamWrapper":
raise ValueError(
f"Trying to wrap an `nn.Parameter` of layer '{unwrapped_module_name}' of type "
f"{type(target).__name__}, which is not a valid target. Make sure that this layer is not "
"also targeted with `target_modules`. For some models, PEFT will do this automatically, "
"try setting `target_modules=[]` to prevent it."
)
self._check_target_module_compatiblity(peft_config, model, target_name)
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
self._create_and_replace(
peft_config,
adapter_name,
target,
target_name,
parent,
current_key=key,
parameter_name=param_name.rpartition(".")[-1],
)
# TODO very simple matching, might not cover all use cases
unsorted_target_names = set(peft_config.target_parameters)
# As the order of matching can influence the nesting of multiple params on the same module, ensure determinism
# by sorting.
target_names = sorted(unsorted_target_names)
for module_name, module in model.named_modules():
if hasattr(module, "parametrizations"):
# Deal with the case that the parameter is already parametrized. The issue is that we would not be able
# to match `f"{module_name}.{param_name}"`, as the parameter is now something like
# `module.parametrization.weight`.
for key in target_names:
target_module_name, _, param_name = key.rpartition(".")
if target_module_name != module_name:
continue
if getattr(module, param_name, None) is None:
continue
create_and_replace_param(module_name, key, param_name)
self.targeted_parameter_names.append(key)
parent, target, target_name = _get_submodules(model, module_name)
# use the class name for checking to avoid circular import
if isinstance(target, BaseTunerLayer) and target.__class__.__name__ != "ParamWrapper":
raise ValueError(
f"Trying to wrap an `nn.Parameter` of layer '{target_name}' of type "
f"{type(target).__name__}, which is not a valid target. Make sure that this layer is not "
"also targeted with `target_modules`. For some models, PEFT will do this automatically, "
"try setting `target_modules=[]` to prevent it."
)
self._check_target_module_compatiblity(peft_config, model, target_name)
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
self._create_and_replace(
peft_config,
adapter_name,
target,
target_name,
parent,
current_key=key,
parameter_name=param_name.rpartition(".")[-1],
)
else:
# Standard case: the parameter is not already parametrized. Note, however, that the model could already
# be nested with lora.ParamWrapper, as this is how we allow targeting multiple Parameters on the same
# module.
unwrapped_module_name = strip_base_layer_from_name(module_name)
# we're interested in finding the "lowest" module that contains the parameter, hence recurse=False
for param_name, param in module.named_parameters(recurse=False):
key = f"{unwrapped_module_name}.{param_name}"
if (key in target_names) or any(key.endswith(f".{target_key}") for target_key in target_names):
# Note: We use the unwrapped_module_name to check if the key matches, but we use the module_name for
# replacement, since we want to replace the wrapped module.
create_and_replace_param(module_name, key, param_name)
self.targeted_parameter_names.append(key)
def merge_adapter(self, adapter_names: Optional[list[str]] = None, safe_merge: bool = False) -> None:
"""

View File

@ -937,6 +937,11 @@ PREFIXES = {
}
def _skip_tests_with_multiple_adapters_with_target_parameters(config_cls, config_kwargs):
if (config_cls == LoraConfig) and config_kwargs.get("target_parameters"):
pytest.skip("LoRA with multiple adapters with target_parameters is not supported")
class MLP(nn.Module):
def __init__(self, bias=True):
super().__init__()
@ -1390,6 +1395,7 @@ class TestPeftCustomModel(PeftCommonTester):
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
def test_load_model_low_cpu_mem_usage(self, test_name, model_id, config_cls, config_kwargs):
_skip_tests_with_multiple_adapters_with_target_parameters(config_cls, config_kwargs)
self._test_load_model_low_cpu_mem_usage(model_id, config_cls, config_kwargs)
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
@ -1398,6 +1404,7 @@ class TestPeftCustomModel(PeftCommonTester):
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
def test_load_multiple_adapters(self, test_name, model_id, config_cls, config_kwargs):
_skip_tests_with_multiple_adapters_with_target_parameters(config_cls, config_kwargs)
self._test_load_multiple_adapters(model_id, config_cls, config_kwargs)
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
@ -2018,6 +2025,8 @@ class TestPeftCustomModel(PeftCommonTester):
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
def test_active_adapter(self, test_name, model_id, config_cls, config_kwargs):
_skip_tests_with_multiple_adapters_with_target_parameters(config_cls, config_kwargs)
model = self.transformers_class.from_pretrained(model_id).to(self.torch_device)
config = config_cls(
base_model_name_or_path=model_id,
@ -2108,10 +2117,12 @@ class TestPeftCustomModel(PeftCommonTester):
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs):
_skip_tests_with_multiple_adapters_with_target_parameters(config_cls, config_kwargs)
self._test_delete_adapter(model_id, config_cls, config_kwargs)
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
def test_delete_inactive_adapter(self, test_name, model_id, config_cls, config_kwargs):
_skip_tests_with_multiple_adapters_with_target_parameters(config_cls, config_kwargs)
self._test_delete_inactive_adapter(model_id, config_cls, config_kwargs)
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
@ -2809,6 +2820,19 @@ class TestLayerRepr:
assert "lora_B" in print_output
assert "default" in print_output
def test_repr_lora_paramwrapper(self):
config = LoraConfig(target_parameters=["lin0.weight"])
model = get_peft_model(MLP(), config)
print_output = repr(model.model.lin0)
assert print_output.startswith("lora.ParamWrapper")
# important: targeted parameter should be contained:
assert "parameter_name='weight'" in print_output
assert "in_features=10" in print_output
assert "out_features=20" in print_output
assert "lora_A" in print_output
assert "lora_B" in print_output
assert "default" in print_output
class TestMultipleActiveAdapters:
"""
@ -2843,6 +2867,8 @@ class TestMultipleActiveAdapters:
def test_multiple_active_adapters_forward(
self, test_name, tuner_method, config_cls, config_kwargs_1, config_kwargs_2
):
_skip_tests_with_multiple_adapters_with_target_parameters(config_cls, config_kwargs_2)
torch.manual_seed(0)
model = self.resolve_model_cls(tuner_method)
@ -2901,6 +2927,8 @@ class TestMultipleActiveAdapters:
def test_multiple_active_adapters_merge_and_unmerge(
self, test_name, tuner_method, config_cls, config_kwargs_1, config_kwargs_2
):
_skip_tests_with_multiple_adapters_with_target_parameters(config_cls, config_kwargs_2)
torch.manual_seed(0)
model = self.resolve_model_cls(tuner_method)
@ -2934,6 +2962,8 @@ class TestMultipleActiveAdapters:
"test_name, tuner_method, config_cls, config_kwargs_1, config_kwargs_2", MULTIPLE_ACTIVE_ADAPTERS_TEST_CASES
)
def test_merge_layers_multi(self, test_name, tuner_method, config_cls, config_kwargs_1, config_kwargs_2):
_skip_tests_with_multiple_adapters_with_target_parameters(config_cls, config_kwargs_2)
torch.manual_seed(0)
model = self.resolve_model_cls(tuner_method)

View File

@ -1421,7 +1421,7 @@ class TestLoraInitialization:
self.linear = nn.Linear(10, 10)
base_model = MyModule()
config = LoraConfig(target_modules=["linear"], target_parameters=["weight"])
config = LoraConfig(target_modules=["linear"], target_parameters=["linear.weight"])
msg = "Trying to wrap an `nn.Parameter` of layer 'linear' of type Linear, which is not a valid target."
with pytest.raises(ValueError, match=msg):
get_peft_model(base_model, config)
@ -1460,6 +1460,26 @@ class TestLoraInitialization:
with pytest.warns(RuntimeWarning, match=msg):
get_peft_model(model, config)
def test_adding_multiple_adapters_with_target_parameters_raises(self):
model = self.get_model()
config = LoraConfig(target_modules=[], target_parameters=["linear.weight"])
model = get_peft_model(model, config)
msg = re.escape("only one LoRA adapter per model with `target_parameters` is allowed")
with pytest.raises(ValueError, match=msg):
model.add_adapter(adapter_name="other", peft_config=config)
def test_loading_loading_adapters_with_target_parameters_raises(self, tmp_path):
model = self.get_model()
config = LoraConfig(target_modules=[], target_parameters=["linear.weight"])
model = get_peft_model(model, config)
model.save_pretrained(tmp_path)
model = self.get_model()
model = PeftModel.from_pretrained(model, tmp_path)
msg = re.escape("only one LoRA adapter per model with `target_parameters` is allowed")
with pytest.raises(ValueError, match=msg):
model.load_adapter(tmp_path, adapter_name="other")
class TestLokrInitialization:
torch_device = infer_device()

View File

@ -14,6 +14,7 @@
import pytest
import torch
from torch import nn
from transformers import AutoModelForCausalLM
from peft import LoraConfig, TaskType, get_peft_model
@ -22,13 +23,13 @@ from .testing_common import PeftCommonTester
from .testing_utils import hub_online_once, set_init_weights_false
PEFT_DECODER_MODELS_TO_TEST = [
"trl-internal-testing/tiny-Llama4ForCausalLM",
]
ALL_CONFIGS = [
##########
# Llama4 #
##########
# target down_proj
(
"trl-internal-testing/tiny-Llama4ForCausalLM",
LoraConfig,
{
"task_type": TaskType.CAUSAL_LM,
@ -39,8 +40,9 @@ ALL_CONFIGS = [
],
},
),
# target gate_up_proj and down_proj (but not on the same module!)
# target gate_up_proj and down_proj, but not on the same module
(
"trl-internal-testing/tiny-Llama4ForCausalLM",
LoraConfig,
{
"task_type": TaskType.CAUSAL_LM,
@ -54,6 +56,7 @@ ALL_CONFIGS = [
),
# target down_proj and gate_up_proj on the same module
(
"trl-internal-testing/tiny-Llama4ForCausalLM",
LoraConfig,
{
"task_type": "CAUSAL_LM",
@ -70,6 +73,7 @@ ALL_CONFIGS = [
),
# target q_proj, v_proj as modules, and down_proj as parameter
(
"trl-internal-testing/tiny-Llama4ForCausalLM",
LoraConfig,
{
"task_type": TaskType.CAUSAL_LM,
@ -80,6 +84,66 @@ ALL_CONFIGS = [
],
},
),
###########
# gpt-oss #
###########
# target down_proj
(
"trl-internal-testing/tiny-GptOssForCausalLM",
LoraConfig,
{
"task_type": TaskType.CAUSAL_LM,
"target_modules": [],
"lora_dropout": 0.0,
"target_parameters": [
"mlp.experts.down_proj",
],
},
),
# target gate_up_proj and down_proj, but not on the same module
(
"trl-internal-testing/tiny-GptOssForCausalLM",
LoraConfig,
{
"task_type": TaskType.CAUSAL_LM,
"target_modules": [],
"lora_dropout": 0.0,
"target_parameters": [
"0.mlp.experts.gate_up_proj",
"1.mlp.experts.down_proj",
],
},
),
# target down_proj and gate_up_proj on the same module
(
"trl-internal-testing/tiny-GptOssForCausalLM",
LoraConfig,
{
"task_type": "CAUSAL_LM",
"r": 8,
"lora_alpha": 32,
"target_modules": None,
"lora_dropout": 0.0,
"bias": "none",
"target_parameters": [
"mlp.experts.down_proj",
"mlp.experts.gate_up_proj",
],
},
),
# target q_proj, v_proj as modules, and down_proj as parameter
(
"trl-internal-testing/tiny-GptOssForCausalLM",
LoraConfig,
{
"task_type": TaskType.CAUSAL_LM,
"target_modules": ["q_proj", "v_proj"],
"lora_dropout": 0.0,
"target_parameters": [
"mlp.experts.down_proj",
],
},
),
]
@ -114,170 +178,151 @@ class TestDecoderModelsTargetParameters(PeftCommonTester):
attention_mask = torch.tensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device)
return {"input_ids": input_ids, "attention_mask": attention_mask}
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
@pytest.mark.parametrize("model_id,config_cls,config_kwargs", ALL_CONFIGS)
def test_attributes_parametrized(self, model_id, config_cls, config_kwargs):
self._test_model_attr(model_id, config_cls, config_kwargs.copy())
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
@pytest.mark.parametrize("model_id,config_cls,config_kwargs", ALL_CONFIGS)
def test_adapter_name(self, model_id, config_cls, config_kwargs):
self._test_adapter_name(model_id, config_cls, config_kwargs.copy())
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
@pytest.mark.parametrize("model_id,config_cls,config_kwargs", ALL_CONFIGS)
def test_prepare_for_training_parametrized(self, model_id, config_cls, config_kwargs):
self._test_prepare_for_training(model_id, config_cls, config_kwargs.copy())
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
@pytest.mark.parametrize("model_id,config_cls,config_kwargs", ALL_CONFIGS)
def test_save_pretrained(self, model_id, config_cls, config_kwargs):
self._test_save_pretrained(model_id, config_cls, config_kwargs.copy())
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
@pytest.mark.parametrize("model_id,config_cls,config_kwargs", ALL_CONFIGS)
def test_save_pretrained_pickle(self, model_id, config_cls, config_kwargs):
self._test_save_pretrained(model_id, config_cls, config_kwargs.copy(), safe_serialization=False)
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
@pytest.mark.skip(reason="Multiple adapters with target_parameters are not supported yet.")
@pytest.mark.parametrize("model_id,config_cls,config_kwargs", ALL_CONFIGS)
def test_save_pretrained_selected_adapters(self, model_id, config_cls, config_kwargs):
self._test_save_pretrained_selected_adapters(model_id, config_cls, config_kwargs.copy())
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
@pytest.mark.skip(reason="Multiple adapters with target_parameters are not supported yet.")
@pytest.mark.parametrize("model_id,config_cls,config_kwargs", ALL_CONFIGS)
def test_save_pretrained_selected_adapters_pickle(self, model_id, config_cls, config_kwargs):
self._test_save_pretrained_selected_adapters(
model_id, config_cls, config_kwargs.copy(), safe_serialization=False
)
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
@pytest.mark.parametrize("model_id,config_cls,config_kwargs", ALL_CONFIGS)
def test_from_pretrained_config_construction(self, model_id, config_cls, config_kwargs):
self._test_from_pretrained_config_construction(model_id, config_cls, config_kwargs.copy())
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
@pytest.mark.parametrize("model_id,config_cls,config_kwargs", ALL_CONFIGS)
def test_merge_layers(self, model_id, config_cls, config_kwargs):
config_kwargs = set_init_weights_false(config_cls, config_kwargs)
self._test_merge_layers(model_id, config_cls, config_kwargs.copy())
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
@pytest.mark.skip(reason="Multiple adapters with target_parameters are not supported yet.")
@pytest.mark.parametrize("model_id,config_cls,config_kwargs", ALL_CONFIGS)
def test_merge_layers_multi(self, model_id, config_cls, config_kwargs):
config_kwargs = set_init_weights_false(config_cls, config_kwargs)
self._test_merge_layers_multi(model_id, config_cls, config_kwargs.copy())
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
@pytest.mark.parametrize("model_id,config_cls,config_kwargs", ALL_CONFIGS)
def test_merge_layers_nan(self, model_id, config_cls, config_kwargs):
config_kwargs = set_init_weights_false(config_cls, config_kwargs)
self._test_merge_layers_nan(model_id, config_cls, config_kwargs.copy())
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
@pytest.mark.skip(reason="Multiple adapters with target_parameters are not supported yet.")
@pytest.mark.parametrize("model_id,config_cls,config_kwargs", ALL_CONFIGS)
def test_mixed_adapter_batches(self, model_id, config_cls, config_kwargs):
config_kwargs = set_init_weights_false(config_cls, config_kwargs)
msg = "lora.ParamWrapper does not support mixed adapter batches yet."
with pytest.raises(ValueError, match=msg):
self._test_mixed_adapter_batches(model_id, config_cls, config_kwargs.copy())
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
@pytest.mark.skip(reason="Multiple adapters with target_parameters are not supported yet.")
@pytest.mark.parametrize("model_id,config_cls,config_kwargs", ALL_CONFIGS)
def test_generate_with_mixed_adapter_batches(self, model_id, config_cls, config_kwargs):
config_kwargs = set_init_weights_false(config_cls, config_kwargs)
msg = "lora.ParamWrapper does not support mixed adapter batches yet."
with pytest.raises(ValueError, match=msg):
self._test_generate_with_mixed_adapter_batches_and_beam_search(model_id, config_cls, config_kwargs.copy())
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
@pytest.mark.parametrize("model_id,config_cls,config_kwargs", ALL_CONFIGS)
def test_generate(self, model_id, config_cls, config_kwargs):
self._test_generate(model_id, config_cls, config_kwargs.copy())
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
@pytest.mark.parametrize("model_id,config_cls,config_kwargs", ALL_CONFIGS)
def test_generate_pos_args(self, model_id, config_cls, config_kwargs):
self._test_generate_pos_args(model_id, config_cls, config_kwargs.copy(), raises_err=False)
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
@pytest.mark.parametrize("model_id,config_cls,config_kwargs", ALL_CONFIGS)
def test_merge_layers_fp16(self, model_id, config_cls, config_kwargs):
self._test_merge_layers_fp16(model_id, config_cls, config_kwargs.copy())
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
@pytest.mark.parametrize("model_id,config_cls,config_kwargs", ALL_CONFIGS)
def test_generate_half_prec(self, model_id, config_cls, config_kwargs):
self._test_generate_half_prec(model_id, config_cls, config_kwargs.copy())
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
@pytest.mark.parametrize("model_id,config_cls,config_kwargs", ALL_CONFIGS)
def test_training_decoders(self, model_id, config_cls, config_kwargs):
self._test_training(model_id, config_cls, config_kwargs.copy())
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
@pytest.mark.parametrize("model_id,config_cls,config_kwargs", ALL_CONFIGS)
def test_training_decoders_gradient_checkpointing(self, model_id, config_cls, config_kwargs):
self._test_training_gradient_checkpointing(model_id, config_cls, config_kwargs.copy())
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
@pytest.mark.parametrize("model_id,config_cls,config_kwargs", ALL_CONFIGS)
def test_inference_safetensors(self, model_id, config_cls, config_kwargs):
self._test_inference_safetensors(model_id, config_cls, config_kwargs.copy())
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
@pytest.mark.parametrize("model_id,config_cls,config_kwargs", ALL_CONFIGS)
def test_peft_model_device_map(self, model_id, config_cls, config_kwargs):
self._test_peft_model_device_map(model_id, config_cls, config_kwargs.copy())
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
@pytest.mark.skip(reason="Multiple adapters with target_parameters are not supported yet.")
@pytest.mark.parametrize("model_id,config_cls,config_kwargs", ALL_CONFIGS)
def test_delete_adapter(self, model_id, config_cls, config_kwargs):
self._test_delete_adapter(model_id, config_cls, config_kwargs.copy())
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
@pytest.mark.skip(reason="Multiple adapters with target_parameters are not supported yet.")
@pytest.mark.parametrize("model_id,config_cls,config_kwargs", ALL_CONFIGS)
def test_delete_inactive_adapter(self, model_id, config_cls, config_kwargs):
self._test_delete_inactive_adapter(model_id, config_cls, config_kwargs.copy())
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
@pytest.mark.parametrize("model_id,config_cls,config_kwargs", ALL_CONFIGS)
def test_adding_multiple_adapters_with_bias_raises(self, model_id, config_cls, config_kwargs):
self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs.copy())
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
@pytest.mark.parametrize("model_id,config_cls,config_kwargs", ALL_CONFIGS)
def test_unload_adapter(self, model_id, config_cls, config_kwargs):
config_kwargs = set_init_weights_false(config_cls, config_kwargs)
self._test_unload_adapter(model_id, config_cls, config_kwargs.copy())
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
@pytest.mark.skip(reason="Multiple adapters with target_parameters are not supported yet.")
@pytest.mark.parametrize("model_id,config_cls,config_kwargs", ALL_CONFIGS)
def test_weighted_combination_of_adapters(self, model_id, config_cls, config_kwargs):
config_kwargs = set_init_weights_false(config_cls, config_kwargs)
msg = "add_weighted_adapter does not support targeting nn.Parameter"
with pytest.raises(ValueError, match=msg):
self._test_weighted_combination_of_adapters(model_id, config_cls, config_kwargs.copy())
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
@pytest.mark.parametrize("model_id,config_cls,config_kwargs", ALL_CONFIGS)
def test_training_prompt_learning_tasks(self, model_id, config_cls, config_kwargs):
self._test_training_prompt_learning_tasks(model_id, config_cls, config_kwargs.copy())
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
@pytest.mark.parametrize("model_id,config_cls,config_kwargs", ALL_CONFIGS)
def test_disable_adapter(self, model_id, config_cls, config_kwargs):
config_kwargs = set_init_weights_false(config_cls, config_kwargs)
self._test_disable_adapter(model_id, config_cls, config_kwargs.copy())
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
@pytest.mark.parametrize("model_id,config_cls,config_kwargs", ALL_CONFIGS)
def test_passing_input_embeds_works(self, model_id, config_cls, config_kwargs):
self._test_passing_input_embeds_works("", model_id, config_cls, config_kwargs.copy())
class TestTargetParameter:
class TestTargetParameters:
# Tests specifically designed for target_parameters
def test_targeting_module_and_targeting_param_equivalent(self):
# Test that using LoRA with target_modules vs target_parameters yields identical results.
# note: we purposely target the gate_proj because its weight is not square (unlike q_proj, ...), this makes it
# easier to catch shape errors
torch.manual_seed(0)
@ -313,7 +358,6 @@ class TestTargetParameter:
out_lora_1 = model1(x, output_hidden_states=True).hidden_states[-1]
# sanity check: basemodel outputs should be different
atol, rtol = 1e-6, 1e-6
assert not torch.allclose(out_base, out_lora_0, atol=atol, rtol=rtol)
@ -392,3 +436,72 @@ class TestTargetParameter:
atol, rtol = 0.1, 0.1
for key in lora_weights_before.keys():
assert not torch.allclose(lora_weights_before[key], lora_weights_after[key], atol=atol, rtol=rtol)
def test_target_parameters_works_with_existing_parametrization(self):
# When a parameter is already parametrized, we want the LoRA parametrization to work with it correctly.
class MyLinear(nn.Linear):
# For testing purposes, define a linear layer with 2 parameters: weight and other_weight.
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
nn.init.ones_(self.weight)
self.other_weight = nn.Parameter(torch.ones(self.weight.shape))
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.lin = MyLinear(2, 2, bias=False)
def forward(self, x):
return self.lin(x)
class MyParametrization(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x + 1
# base model
model = MyModule()
x = torch.ones((2, 2))
# sanity check: result should be 1*1 + 1*1 == 2
output_base = model(x)
assert torch.all(output_base == 2)
# add parametrization to the weight
nn.utils.parametrize.register_parametrization(model.lin, "weight", MyParametrization())
# result should be (1+1)*1 + (1+1)*1 == 4
output_parametrized = model(x)
assert torch.all(output_parametrized == 4)
# add LoRA parametrization to the weight
config = LoraConfig(r=2, lora_alpha=6, target_parameters=["lin.weight"], init_lora_weights=False)
model = get_peft_model(model, config)
# manually set LoRA weights to ones
nn.init.ones_(model.base_model.model.lin.lora_A["default"].weight)
nn.init.ones_(model.base_model.model.lin.lora_B["default"].weight)
output_lora = model(x)
# delta_weight should be: (1+1) * lora_scale = (1+1) * (alpha / rank) = 2 * (6 / 2) = 6
# result should be: (1+1+6)*1 + (1+1+6)*1 == 8 + 8 == 16
assert torch.all(output_lora == 16)
# calling twice should yield the same result
output_lora2 = model(x)
assert torch.allclose(output_lora, output_lora2)
# add another LoRA parametrization to other_weight, should have no effect on the output
config = LoraConfig(r=2, lora_alpha=6, target_parameters=["lin.other_weight"], init_lora_weights=False)
model.add_adapter("other", config)
output_other_lora = model(x)
# delta_weight should be: (1+1) * lora_scale = (1+1) * (alpha / rank) = 2 * (6 / 2) = 6
# result should be: (1+1+6)*1 + (1+1+6)*1 == 8 + 8 == 16
assert torch.all(output_other_lora == output_lora)
# after unloading, the output should be the same as before LoRA was applied
unloaded = model.unload()
output_unloaded = unloaded(x)
assert torch.all(output_unloaded == output_parametrized)

View File

@ -57,6 +57,7 @@ from peft import (
inject_adapter_in_model,
prepare_model_for_kbit_training,
)
from peft.tuners._buffer_dict import BufferDict
from peft.tuners.lora import LoraLayer
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import (
@ -822,6 +823,13 @@ class PeftCommonTester:
model.set_adapter("adapter-2")
model.eval()
# sanity check: each adapter layer with a 'default' adapter should also have 'adapter-2'
containers = (torch.nn.ModuleDict, torch.nn.ParameterDict, BufferDict)
num_default = len([m for m in model.modules() if isinstance(m, containers) and "default" in m])
num_adapter2 = len([m for m in model.modules() if isinstance(m, containers) and "adapter-2" in m])
assert num_default > 0
assert num_default == num_adapter2
with torch.inference_mode():
logits_adapter_2 = model(**dummy_input)[0]