From e596112b7b0a352d50552f2bcc869e02f233f2bd Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Tue, 30 Sep 2025 11:09:44 +0200 Subject: [PATCH] Fix module target edge cases (#2773) Resolves #2772 Fixes several edge cases with unusual layer names or target modules. 1. As #2772 stated, if "weight" is part of a layer name, it would be treated incorrectly when creating the PEFT state_dict. 2. Similarly, when the adapter name itself is part of a layer name. Some of these errors would pass silently, which is especially bad (e.g. a weight not being loaded but no error raised). I also added some tests that were not failing before, but to cover some yet uncovered cases or to lay out some basic functionality. While working on this, I also noticed that it was possible to target a BaseTunerLayer with modules_to_save and trainable_token_indices (e.g. the lora_A and lora_B nn.Linear would be replaced with ModulesToSaveWrapper). I don't think this is ever desired, so we now raise an error if this is detected. --- src/peft/utils/other.py | 30 ++++- src/peft/utils/save_and_load.py | 32 ++++- tests/test_initialization.py | 58 ++++++++ tests/test_low_level_api.py | 232 ++++++++++++++++++++++++++++++++ 4 files changed, 348 insertions(+), 4 deletions(-) diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index 29878bcc..a0ccb233 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -917,6 +917,18 @@ def _get_submodules(model, key): return parent, target, target_name +def _get_submodules_with_grandparent(model, key): + parent = model.get_submodule(".".join(key.split(".")[:-1])) + try: + grandparent = model.get_submodule(".".join(key.split(".")[:-2])) + except AttributeError: + # no grand parent + grandparent = None + target_name = key.split(".")[-1] + target = model.get_submodule(key) + return parent, grandparent, target, target_name + + def _freeze_adapter(model, adapter_name): for n, p in model.named_parameters(): if adapter_name in n: @@ -948,6 +960,8 @@ def _set_trainable( The `active_adapter` flag indicates if this new adapter should be activated. """ + from peft.tuners.tuners_utils import BaseTunerLayer + if wrapper_cls is None: wrapper_cls = ModulesToSaveWrapper @@ -964,7 +978,21 @@ def _set_trainable( for key in key_list: target_module_found = any(key.endswith(target_key) for target_key in module_names) if target_module_found: - parent, target, target_name = _get_submodules(model, key) + parent, grandparent, target, target_name = _get_submodules_with_grandparent(model, key) + if isinstance(grandparent, BaseTunerLayer): + # This is an extreme edge case: Let's assume that there is a PEFT config with + # modules_to_save=["default"], which is the same name as the adapter name. The PEFT method's adapter + # (e.g. LoRA) is applied first. Then, when the modules_to_save matching is performed, the LoRA layer + # would be considered a valid target. Assuming that the name is "foo.bar.lora_A.default", it would + # match, with "default" being an nn.Linear and the parent, "lora_A", being an nn.ModuleDict. This by + # itself is not enough to prove that this is an unintended match. Thererfore, we also need to check the + # grandparent, "bar", that would be a lora.LoraLayer. When we see this, we should raise an error. + raise ValueError( + f"You are trying to target a module with {wrapper_cls} that is a child of {type(grandparent)}. " + "This is almost certainly not the intended behavior. Please ensure that the adapter name, " + f"'{adapter_name}', does not conflict with any of the targeted modules." + ) + if isinstance(target, wrapper_cls): target.update(adapter_name, **wrapper_kwargs) target.set_adapter(target.active_adapter, inference_mode=inference_mode) diff --git a/src/peft/utils/save_and_load.py b/src/peft/utils/save_and_load.py index 43e8fe2d..778e6964 100644 --- a/src/peft/utils/save_and_load.py +++ b/src/peft/utils/save_and_load.py @@ -325,7 +325,31 @@ def get_peft_model_state_dict( warnings.warn("Could not identify embedding layer(s) because the model is not a 🤗 transformers model.") # REMOVE ADAPTER NAME - to_return = {k.replace(f".{adapter_name}", ""): v for k, v in to_return.items()} + # Ensure not to replace in the middle of the key because a module happens to have the same name as the adapter. + pattern = re.compile(re.escape(f".{adapter_name}") + r"$") + + def remove_adapter_name(key): + if "." not in key: + # nothing to do + return key + + if key.endswith(f".{adapter_name}"): + # comes from an nn.Parameter, so no .weight suffix, the adapter name is directly at the end + return key.removesuffix(f".{adapter_name}") + + # comes from an nn.Module, i.e. the adapter name is the 2nd to last element, e.g. v_proj.lora_A.default.weight + key, _, suffix = key.rpartition(".") # split, e.g. v_proj.lora_A.default + weight + + if (config.peft_type == PeftType.VBLORA) and suffix.startswith(f"{adapter_name}_"): + # special case: VBLoRA creates keys that require this replacement: + # base_model.model.lin0.vblora_logits_A.default_topk_indices => + # base_model.model.lin0.vblora_logits_A_topk_indices + return key + "_" + suffix.removeprefix(f"{adapter_name}_") + + key = pattern.sub("", key) # remove adapter name, e.g. v_proj.lora_A + return f"{key}.{suffix}" # stitch the suffix back, e.g, v_proj.lora_A.weight + + to_return = {remove_adapter_name(k): v for k, v in to_return.items()} return to_return @@ -364,10 +388,12 @@ def _insert_adapter_name_into_state_dict( peft_model_state_dict = {} for key, val in state_dict.items(): if parameter_prefix in key: - suffix = key.split(parameter_prefix)[1] + _, _, suffix = key.rpartition(parameter_prefix) if "." in suffix: suffix_to_replace = ".".join(suffix.split(".")[1:]) - key = key.replace(suffix_to_replace, f"{adapter_name}.{suffix_to_replace}") + # only replace the substring if the key ends on the substring to avoid accidental replacement inside of + # the key if a module happens to have a name that contains the substring + key = re.sub(re.escape(suffix_to_replace) + r"$", f"{adapter_name}.{suffix_to_replace}", key) else: key = f"{key}.{adapter_name}" peft_model_state_dict[key] = val diff --git a/tests/test_initialization.py b/tests/test_initialization.py index f37e0c2c..4cbeab38 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -1101,6 +1101,42 @@ class TestLoraInitialization: assert model.embed.scaling["default"] == expected_scaling["embed"] assert model.conv2d.scaling["default"] == expected_scaling["conv2d"] + def test_modules_to_save_targets_lora_layer_raises(self): + # There is no good reason to have auxiliary modules to target a LoRA layer. As auxiliary modules are applied + # *after* BaseTunerLayers, a possible way for this to happen accidentally is if the + # modules_to_save/trainable_token_indices coincide with the adapter name, e.g. if the adapter name is "foobar", + # we can have a module named model.base_model.model.self_attn.lora_A.foobar. If + # modules_to_save/trainable_token_indices is also "foobar", there would be a match. + # Note: Theoretically, a lot more PEFT methods support modules_to_save, so would have to be tested, but the code + # path is the same for all of them, so only testing LoRA. + model = self.get_model() + + config = LoraConfig( + target_modules=["linear"], + modules_to_save=["foobar"], + ) + msg = ( + "You are trying to target a module with that is a child of " + ". This is almost certainly not the intended behavior. Please " + "ensure that the adapter name, 'foobar', does not conflict with any of the targeted modules." + ) + with pytest.raises(ValueError, match=msg): + get_peft_model(model, config, adapter_name="foobar") + + def test_trainable_token_indices_targets_lora_layer_raises(self): + # Same test as test_modules_to_save_targets_lora_layer_raises, but using trainable_token_indices + model = self.get_model() + + # check scaling factor use_rslora=True with rank and alpha pattern + config = LoraConfig(target_modules=["embed"], trainable_token_indices={"foobar": [1, 2, 3]}) + msg = ( + "You are trying to target a module with that is a child " + "of . This is almost certainly not the intended behavior. Please " + "ensure that the adapter name, 'foobar', does not conflict with any of the targeted modules." + ) + with pytest.raises(ValueError, match=msg): + get_peft_model(model, config, adapter_name="foobar") + @require_deterministic_for_xpu def test_lora_use_dora_linear(self, data): # check that dora is a no-op when initialized @@ -1260,6 +1296,28 @@ class TestLoraInitialization: assert torch.allclose(merged_mask0, merged_mask1) assert mask_type0 == mask_type1 + @pytest.mark.parametrize("bias", ["none", "all", "lora_only", "invalid"]) + def test_lora_with_bias_argument(self, bias): + model = self.get_model() + config = LoraConfig(target_modules=["linear", "conv2d"], bias=bias) + + if bias == "invalid": + with pytest.raises(NotImplementedError): + get_peft_model(model, config) + return + + model = get_peft_model(model, config) # does not raise + for name, param in model.named_parameters(): + if not name.endswith("bias"): + continue + if bias == "none": + assert param.requires_grad is False + elif bias == "all": + assert param.requires_grad is True + elif bias == "lora_only": + # only layers targeted with target_modules + assert param.requires_grad is ("linear" in name) or ("conv2d" in name) + def test_lora_with_bias_extra_params(self): # lora with lora_bias=True model = self.get_model() diff --git a/tests/test_low_level_api.py b/tests/test_low_level_api.py index 7f4347b0..0a097e2d 100644 --- a/tests/test_low_level_api.py +++ b/tests/test_low_level_api.py @@ -20,6 +20,7 @@ import re import pytest import torch from diffusers import StableDiffusionPipeline +from torch import nn from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification from peft import ( @@ -28,8 +29,10 @@ from peft import ( LoKrConfig, LoraConfig, RandLoraConfig, + get_peft_model, get_peft_model_state_dict, inject_adapter_in_model, + set_peft_model_state_dict, ) from peft.tuners import lora from peft.utils import ModulesToSaveWrapper @@ -389,3 +392,232 @@ class TestInjectAdapterFromStateDict: assert sd_before.keys() == sd_after.keys() for key in sd_before.keys(): assert sd_before[key].shape == sd_after[key].shape + + +class TestPeftStateDict: + # Test some edge cases around getting and setting the PEFT state_dict. There are potential sources of errors there + # because the adapter_name is removed from/added to the state_dict keys. + def test_get_peft_model_state_dict_removes_adapter_name(self): + # ensure that the adapter name, "default", is removed from the state_dict + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + with hub_online_once(model_id): + model = AutoModelForCausalLM.from_pretrained(model_id) + + # note: lora targets q_proj and v_proj; add in an auxiliary module for good measure + model = get_peft_model(model, LoraConfig(modules_to_save=["lm_head"])) + sd = get_peft_model_state_dict(model) + assert len(sd) > 1 # sanity check + assert not any("default" in key for key in sd) + + def test_get_peft_model_state_dict_removes_non_defaul_adapter_name(self): + # ensure that the adapter name is removed from the state_dict, even if it's not "default" + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + with hub_online_once(model_id): + model = AutoModelForCausalLM.from_pretrained(model_id) + + model = get_peft_model(model, LoraConfig(modules_to_save=["lm_head"]), adapter_name="other") + sd = get_peft_model_state_dict(model, adapter_name="other") + assert len(sd) > 1 # sanity check + assert not any("other" in key for key in sd) + + def test_get_peft_model_state_dict_removes_adapter_name_when_same_as_module_name(self): + # here the adapter is named "v_proj", which is the same name as some modules targeted with lora in the model, + # which is nefarious + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + with hub_online_once(model_id): + model = AutoModelForCausalLM.from_pretrained(model_id) + + config = LoraConfig(modules_to_save=["lm_head"], target_modules=["v_proj"]) + model = get_peft_model(model, config, adapter_name="v_proj") + sd = get_peft_model_state_dict(model, adapter_name="v_proj") + assert len(sd) > 1 # sanity check + for key in sd: + # assert that the adapter_name was indeed removed + assert not key.endswith("lora_A.v_proj.weight") + assert not key.endswith("lora_B.v_proj.weight") + assert not key.endswith("modules_to_save.v_proj.weight") + # assert that the module name was not stripped completely from the key + assert ("v_proj" in key) or ("q_proj" in key) or ("lm_head") in key + + def check_peft_model_weights_loaded_correctly(self, inner_model_cls, config, nested, adapter_name="default"): + # Runs checks that a roundtrip of get_peft_model_state_dict and set_peft_model_state_dict results in the same + # model (same outputs and same weights). + class Outer(nn.Module): + def __init__(self): + super().__init__() + self.inner = inner_model_cls() + + def forward(self, x): + return self.inner(x) + + if nested: + # add another layer of nesting + model_cls = Outer + else: + model_cls = inner_model_cls + + x = torch.randn(1, 5) + + torch.manual_seed(0) + base_model = model_cls() + with torch.inference_mode(): + base_out = base_model(x) + + torch.manual_seed(42) + model = get_peft_model(base_model, config, adapter_name=adapter_name) + with torch.inference_mode(): + peft_out = model(x) + # sanity check: peft adapter has an effect + assert not torch.allclose(base_out, peft_out, atol=1e-6) + + sd = get_peft_model_state_dict(model, adapter_name=adapter_name) + + torch.manual_seed(0) + base_model = model_cls() + torch.manual_seed(42 + 1) # ensure we start with a different, randomly initialized PEFT model + model_new = get_peft_model(base_model, config, adapter_name=adapter_name) + with torch.inference_mode(): + peft_new = model_new(x) + assert not torch.allclose(peft_out, peft_new, atol=1e-6) + + set_peft_model_state_dict(model_new, sd, adapter_name=adapter_name) + with torch.inference_mode(): + peft_out_loaded = model_new(x) + assert torch.allclose(peft_out, peft_out_loaded, atol=1e-6) + + sd_new = get_peft_model_state_dict(model, adapter_name=adapter_name) + assert sd.keys() == sd_new.keys() + for key, val in sd.items(): + val_new = sd_new[key] + torch.allclose(val, val_new) + + @pytest.mark.parametrize("nested", [False, True]) + def test_get_and_set_peft_model_state_dict_normal_names(self, nested): + # In this test, there is no edge case. Therefore, this test is basically the "control group" for the subsequent + # tests (if this test were to fail, it means the testing code itself is wrong). + class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.foo_linear = nn.Linear(5, 5) + self.foo_baz = nn.Linear(5, 5) + self.baz_foo = nn.Linear(5, 5) + self.foo_baz_foo = nn.Linear(5, 5) + self.baz_foo_baz = nn.Linear(5, 5) + + def forward(self, x): + x = self.foo_linear(x) + x = self.foo_baz(x) + x = self.baz_foo(x) + x = self.foo_baz_foo(x) + x = self.baz_foo_baz(x) + return x + + config = LoraConfig( + target_modules=["foo_linear", "foo_baz", "baz_foo", "foo_baz_foo", "baz_foo_baz"], init_lora_weights=False + ) + self.check_peft_model_weights_loaded_correctly(MyModel, config, nested=nested) + + @pytest.mark.parametrize("nested", [False, True]) + def test_get_and_set_peft_model_state_dict_peft_prefix_in_module_name(self, nested): + # Here we have a model with some modules containing "lora" in their name. + class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.foo_linear = nn.Linear(5, 5) + self.foo_lora = nn.Linear(5, 5) + self.lora_foo = nn.Linear(5, 5) + self.foo_lora_foo = nn.Linear(5, 5) + self.lora_foo_lora = nn.Linear(5, 5) + + def forward(self, x): + x = self.foo_linear(x) + x = self.foo_lora(x) + x = self.lora_foo(x) + x = self.foo_lora_foo(x) + x = self.lora_foo_lora(x) + return x + + config = LoraConfig( + target_modules=["foo_linear", "foo_lora", "lora_foo", "foo_lora_foo", "lora_foo_lora"], + init_lora_weights=False, + ) + self.check_peft_model_weights_loaded_correctly(MyModel, config, nested=nested) + + @pytest.mark.parametrize("nested", [False, True]) + def test_get_and_set_peft_model_state_dict_weight_in_module_name(self, nested): + # Here we have a model with some modules containing "weight" in their name. + # See #2772 + class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.foo_linear = nn.Linear(5, 5) + self.foo_weight = nn.Linear(5, 5) + self.weight_foo = nn.Linear(5, 5) + self.foo_weight_foo = nn.Linear(5, 5) + self.weight_foo_weight = nn.Linear(5, 5) + + def forward(self, x): + x = self.foo_linear(x) + x = self.foo_weight(x) + x = self.weight_foo(x) + x = self.foo_weight_foo(x) + x = self.weight_foo_weight(x) + return x + + config = LoraConfig( + target_modules=["foo_linear", "foo_weight", "weight_foo", "foo_weight_foo", "weight_foo_weight"], + init_lora_weights=False, + ) + self.check_peft_model_weights_loaded_correctly(MyModel, config, nested=nested) + + @pytest.mark.parametrize("nested", [False, True]) + def test_get_and_set_peft_model_state_dict_bias_in_module_name(self, nested): + # Here we have a model with some modules containing "bias" in their name. + class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.foo_linear = nn.Linear(5, 5) + self.foo_bias = nn.Linear(5, 5) + self.bias_foo = nn.Linear(5, 5) + self.foo_bias_foo = nn.Linear(5, 5) + self.bias_foo_bias = nn.Linear(5, 5) + + def forward(self, x): + x = self.foo_linear(x) + x = self.foo_bias(x) + x = self.bias_foo(x) + x = self.foo_bias_foo(x) + x = self.bias_foo_bias(x) + return x + + config = LoraConfig( + target_modules=["foo_linear", "foo_bias", "bias_foo", "foo_bias_foo", "bias_foo_bias"], + init_lora_weights=False, + bias="lora_only", + ) + self.check_peft_model_weights_loaded_correctly(MyModel, config, nested=nested) + + @pytest.mark.parametrize("nested", [False, True]) + def test_get_and_set_peft_model_state_dict_adapter_name_same_as_module_name(self, nested): + # Here we choose a module name that is identical to the name of one of the adapters. + class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.foo = nn.Linear(5, 5) + self.foo_baz = nn.Linear(5, 5) + self.baz_foo = nn.Linear(5, 5) + self.foo_baz_foo = nn.Linear(5, 5) + self.baz_foo_baz = nn.Linear(5, 5) + + def forward(self, x): + x = self.foo(x) + x = self.foo_baz(x) + x = self.baz_foo(x) + x = self.foo_baz_foo(x) + x = self.baz_foo_baz(x) + return x + + config = LoraConfig( + target_modules=["foo", "foo_baz", "baz_foo", "foo_baz_foo", "baz_foo_baz"], init_lora_weights=False + ) + self.check_peft_model_weights_loaded_correctly(MyModel, config, nested=nested, adapter_name="foo")