Fix module target edge cases (#2773)

Resolves #2772

Fixes several edge cases with unusual layer names or target modules.

1. As #2772 stated, if "weight" is part of a layer name, it would be
treated incorrectly when creating the PEFT state_dict.
2. Similarly, when the adapter name itself is part of a layer name.

Some of these errors would pass silently, which is especially bad (e.g.
a weight not being loaded but no error raised).

I also added some tests that were not failing before, but to cover some
yet uncovered cases or to lay out some basic functionality.

While working on this, I also noticed that it was possible to target a
BaseTunerLayer with modules_to_save and trainable_token_indices (e.g.
the lora_A and lora_B nn.Linear would be replaced with
ModulesToSaveWrapper). I don't think this is ever desired, so we now
raise an error if this is detected.
This commit is contained in:
Benjamin Bossan
2025-09-30 11:09:44 +02:00
committed by GitHub
parent 046e32bf16
commit e596112b7b
4 changed files with 348 additions and 4 deletions

View File

@ -917,6 +917,18 @@ def _get_submodules(model, key):
return parent, target, target_name
def _get_submodules_with_grandparent(model, key):
parent = model.get_submodule(".".join(key.split(".")[:-1]))
try:
grandparent = model.get_submodule(".".join(key.split(".")[:-2]))
except AttributeError:
# no grand parent
grandparent = None
target_name = key.split(".")[-1]
target = model.get_submodule(key)
return parent, grandparent, target, target_name
def _freeze_adapter(model, adapter_name):
for n, p in model.named_parameters():
if adapter_name in n:
@ -948,6 +960,8 @@ def _set_trainable(
The `active_adapter` flag indicates if this new adapter should be activated.
"""
from peft.tuners.tuners_utils import BaseTunerLayer
if wrapper_cls is None:
wrapper_cls = ModulesToSaveWrapper
@ -964,7 +978,21 @@ def _set_trainable(
for key in key_list:
target_module_found = any(key.endswith(target_key) for target_key in module_names)
if target_module_found:
parent, target, target_name = _get_submodules(model, key)
parent, grandparent, target, target_name = _get_submodules_with_grandparent(model, key)
if isinstance(grandparent, BaseTunerLayer):
# This is an extreme edge case: Let's assume that there is a PEFT config with
# modules_to_save=["default"], which is the same name as the adapter name. The PEFT method's adapter
# (e.g. LoRA) is applied first. Then, when the modules_to_save matching is performed, the LoRA layer
# would be considered a valid target. Assuming that the name is "foo.bar.lora_A.default", it would
# match, with "default" being an nn.Linear and the parent, "lora_A", being an nn.ModuleDict. This by
# itself is not enough to prove that this is an unintended match. Thererfore, we also need to check the
# grandparent, "bar", that would be a lora.LoraLayer. When we see this, we should raise an error.
raise ValueError(
f"You are trying to target a module with {wrapper_cls} that is a child of {type(grandparent)}. "
"This is almost certainly not the intended behavior. Please ensure that the adapter name, "
f"'{adapter_name}', does not conflict with any of the targeted modules."
)
if isinstance(target, wrapper_cls):
target.update(adapter_name, **wrapper_kwargs)
target.set_adapter(target.active_adapter, inference_mode=inference_mode)

View File

@ -325,7 +325,31 @@ def get_peft_model_state_dict(
warnings.warn("Could not identify embedding layer(s) because the model is not a 🤗 transformers model.")
# REMOVE ADAPTER NAME
to_return = {k.replace(f".{adapter_name}", ""): v for k, v in to_return.items()}
# Ensure not to replace in the middle of the key because a module happens to have the same name as the adapter.
pattern = re.compile(re.escape(f".{adapter_name}") + r"$")
def remove_adapter_name(key):
if "." not in key:
# nothing to do
return key
if key.endswith(f".{adapter_name}"):
# comes from an nn.Parameter, so no .weight suffix, the adapter name is directly at the end
return key.removesuffix(f".{adapter_name}")
# comes from an nn.Module, i.e. the adapter name is the 2nd to last element, e.g. v_proj.lora_A.default.weight
key, _, suffix = key.rpartition(".") # split, e.g. v_proj.lora_A.default + weight
if (config.peft_type == PeftType.VBLORA) and suffix.startswith(f"{adapter_name}_"):
# special case: VBLoRA creates keys that require this replacement:
# base_model.model.lin0.vblora_logits_A.default_topk_indices =>
# base_model.model.lin0.vblora_logits_A_topk_indices
return key + "_" + suffix.removeprefix(f"{adapter_name}_")
key = pattern.sub("", key) # remove adapter name, e.g. v_proj.lora_A
return f"{key}.{suffix}" # stitch the suffix back, e.g, v_proj.lora_A.weight
to_return = {remove_adapter_name(k): v for k, v in to_return.items()}
return to_return
@ -364,10 +388,12 @@ def _insert_adapter_name_into_state_dict(
peft_model_state_dict = {}
for key, val in state_dict.items():
if parameter_prefix in key:
suffix = key.split(parameter_prefix)[1]
_, _, suffix = key.rpartition(parameter_prefix)
if "." in suffix:
suffix_to_replace = ".".join(suffix.split(".")[1:])
key = key.replace(suffix_to_replace, f"{adapter_name}.{suffix_to_replace}")
# only replace the substring if the key ends on the substring to avoid accidental replacement inside of
# the key if a module happens to have a name that contains the substring
key = re.sub(re.escape(suffix_to_replace) + r"$", f"{adapter_name}.{suffix_to_replace}", key)
else:
key = f"{key}.{adapter_name}"
peft_model_state_dict[key] = val

View File

@ -1101,6 +1101,42 @@ class TestLoraInitialization:
assert model.embed.scaling["default"] == expected_scaling["embed"]
assert model.conv2d.scaling["default"] == expected_scaling["conv2d"]
def test_modules_to_save_targets_lora_layer_raises(self):
# There is no good reason to have auxiliary modules to target a LoRA layer. As auxiliary modules are applied
# *after* BaseTunerLayers, a possible way for this to happen accidentally is if the
# modules_to_save/trainable_token_indices coincide with the adapter name, e.g. if the adapter name is "foobar",
# we can have a module named model.base_model.model.self_attn.lora_A.foobar. If
# modules_to_save/trainable_token_indices is also "foobar", there would be a match.
# Note: Theoretically, a lot more PEFT methods support modules_to_save, so would have to be tested, but the code
# path is the same for all of them, so only testing LoRA.
model = self.get_model()
config = LoraConfig(
target_modules=["linear"],
modules_to_save=["foobar"],
)
msg = (
"You are trying to target a module with <class 'peft.utils.other.ModulesToSaveWrapper'> that is a child of "
"<class 'peft.tuners.lora.layer.Linear'>. This is almost certainly not the intended behavior. Please "
"ensure that the adapter name, 'foobar', does not conflict with any of the targeted modules."
)
with pytest.raises(ValueError, match=msg):
get_peft_model(model, config, adapter_name="foobar")
def test_trainable_token_indices_targets_lora_layer_raises(self):
# Same test as test_modules_to_save_targets_lora_layer_raises, but using trainable_token_indices
model = self.get_model()
# check scaling factor use_rslora=True with rank and alpha pattern
config = LoraConfig(target_modules=["embed"], trainable_token_indices={"foobar": [1, 2, 3]})
msg = (
"You are trying to target a module with <class 'peft.utils.other.TrainableTokensWrapper'> that is a child "
"of <class 'peft.tuners.lora.layer.Embedding'>. This is almost certainly not the intended behavior. Please "
"ensure that the adapter name, 'foobar', does not conflict with any of the targeted modules."
)
with pytest.raises(ValueError, match=msg):
get_peft_model(model, config, adapter_name="foobar")
@require_deterministic_for_xpu
def test_lora_use_dora_linear(self, data):
# check that dora is a no-op when initialized
@ -1260,6 +1296,28 @@ class TestLoraInitialization:
assert torch.allclose(merged_mask0, merged_mask1)
assert mask_type0 == mask_type1
@pytest.mark.parametrize("bias", ["none", "all", "lora_only", "invalid"])
def test_lora_with_bias_argument(self, bias):
model = self.get_model()
config = LoraConfig(target_modules=["linear", "conv2d"], bias=bias)
if bias == "invalid":
with pytest.raises(NotImplementedError):
get_peft_model(model, config)
return
model = get_peft_model(model, config) # does not raise
for name, param in model.named_parameters():
if not name.endswith("bias"):
continue
if bias == "none":
assert param.requires_grad is False
elif bias == "all":
assert param.requires_grad is True
elif bias == "lora_only":
# only layers targeted with target_modules
assert param.requires_grad is ("linear" in name) or ("conv2d" in name)
def test_lora_with_bias_extra_params(self):
# lora with lora_bias=True
model = self.get_model()

View File

@ -20,6 +20,7 @@ import re
import pytest
import torch
from diffusers import StableDiffusionPipeline
from torch import nn
from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
from peft import (
@ -28,8 +29,10 @@ from peft import (
LoKrConfig,
LoraConfig,
RandLoraConfig,
get_peft_model,
get_peft_model_state_dict,
inject_adapter_in_model,
set_peft_model_state_dict,
)
from peft.tuners import lora
from peft.utils import ModulesToSaveWrapper
@ -389,3 +392,232 @@ class TestInjectAdapterFromStateDict:
assert sd_before.keys() == sd_after.keys()
for key in sd_before.keys():
assert sd_before[key].shape == sd_after[key].shape
class TestPeftStateDict:
# Test some edge cases around getting and setting the PEFT state_dict. There are potential sources of errors there
# because the adapter_name is removed from/added to the state_dict keys.
def test_get_peft_model_state_dict_removes_adapter_name(self):
# ensure that the adapter name, "default", is removed from the state_dict
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
with hub_online_once(model_id):
model = AutoModelForCausalLM.from_pretrained(model_id)
# note: lora targets q_proj and v_proj; add in an auxiliary module for good measure
model = get_peft_model(model, LoraConfig(modules_to_save=["lm_head"]))
sd = get_peft_model_state_dict(model)
assert len(sd) > 1 # sanity check
assert not any("default" in key for key in sd)
def test_get_peft_model_state_dict_removes_non_defaul_adapter_name(self):
# ensure that the adapter name is removed from the state_dict, even if it's not "default"
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
with hub_online_once(model_id):
model = AutoModelForCausalLM.from_pretrained(model_id)
model = get_peft_model(model, LoraConfig(modules_to_save=["lm_head"]), adapter_name="other")
sd = get_peft_model_state_dict(model, adapter_name="other")
assert len(sd) > 1 # sanity check
assert not any("other" in key for key in sd)
def test_get_peft_model_state_dict_removes_adapter_name_when_same_as_module_name(self):
# here the adapter is named "v_proj", which is the same name as some modules targeted with lora in the model,
# which is nefarious
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
with hub_online_once(model_id):
model = AutoModelForCausalLM.from_pretrained(model_id)
config = LoraConfig(modules_to_save=["lm_head"], target_modules=["v_proj"])
model = get_peft_model(model, config, adapter_name="v_proj")
sd = get_peft_model_state_dict(model, adapter_name="v_proj")
assert len(sd) > 1 # sanity check
for key in sd:
# assert that the adapter_name was indeed removed
assert not key.endswith("lora_A.v_proj.weight")
assert not key.endswith("lora_B.v_proj.weight")
assert not key.endswith("modules_to_save.v_proj.weight")
# assert that the module name was not stripped completely from the key
assert ("v_proj" in key) or ("q_proj" in key) or ("lm_head") in key
def check_peft_model_weights_loaded_correctly(self, inner_model_cls, config, nested, adapter_name="default"):
# Runs checks that a roundtrip of get_peft_model_state_dict and set_peft_model_state_dict results in the same
# model (same outputs and same weights).
class Outer(nn.Module):
def __init__(self):
super().__init__()
self.inner = inner_model_cls()
def forward(self, x):
return self.inner(x)
if nested:
# add another layer of nesting
model_cls = Outer
else:
model_cls = inner_model_cls
x = torch.randn(1, 5)
torch.manual_seed(0)
base_model = model_cls()
with torch.inference_mode():
base_out = base_model(x)
torch.manual_seed(42)
model = get_peft_model(base_model, config, adapter_name=adapter_name)
with torch.inference_mode():
peft_out = model(x)
# sanity check: peft adapter has an effect
assert not torch.allclose(base_out, peft_out, atol=1e-6)
sd = get_peft_model_state_dict(model, adapter_name=adapter_name)
torch.manual_seed(0)
base_model = model_cls()
torch.manual_seed(42 + 1) # ensure we start with a different, randomly initialized PEFT model
model_new = get_peft_model(base_model, config, adapter_name=adapter_name)
with torch.inference_mode():
peft_new = model_new(x)
assert not torch.allclose(peft_out, peft_new, atol=1e-6)
set_peft_model_state_dict(model_new, sd, adapter_name=adapter_name)
with torch.inference_mode():
peft_out_loaded = model_new(x)
assert torch.allclose(peft_out, peft_out_loaded, atol=1e-6)
sd_new = get_peft_model_state_dict(model, adapter_name=adapter_name)
assert sd.keys() == sd_new.keys()
for key, val in sd.items():
val_new = sd_new[key]
torch.allclose(val, val_new)
@pytest.mark.parametrize("nested", [False, True])
def test_get_and_set_peft_model_state_dict_normal_names(self, nested):
# In this test, there is no edge case. Therefore, this test is basically the "control group" for the subsequent
# tests (if this test were to fail, it means the testing code itself is wrong).
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.foo_linear = nn.Linear(5, 5)
self.foo_baz = nn.Linear(5, 5)
self.baz_foo = nn.Linear(5, 5)
self.foo_baz_foo = nn.Linear(5, 5)
self.baz_foo_baz = nn.Linear(5, 5)
def forward(self, x):
x = self.foo_linear(x)
x = self.foo_baz(x)
x = self.baz_foo(x)
x = self.foo_baz_foo(x)
x = self.baz_foo_baz(x)
return x
config = LoraConfig(
target_modules=["foo_linear", "foo_baz", "baz_foo", "foo_baz_foo", "baz_foo_baz"], init_lora_weights=False
)
self.check_peft_model_weights_loaded_correctly(MyModel, config, nested=nested)
@pytest.mark.parametrize("nested", [False, True])
def test_get_and_set_peft_model_state_dict_peft_prefix_in_module_name(self, nested):
# Here we have a model with some modules containing "lora" in their name.
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.foo_linear = nn.Linear(5, 5)
self.foo_lora = nn.Linear(5, 5)
self.lora_foo = nn.Linear(5, 5)
self.foo_lora_foo = nn.Linear(5, 5)
self.lora_foo_lora = nn.Linear(5, 5)
def forward(self, x):
x = self.foo_linear(x)
x = self.foo_lora(x)
x = self.lora_foo(x)
x = self.foo_lora_foo(x)
x = self.lora_foo_lora(x)
return x
config = LoraConfig(
target_modules=["foo_linear", "foo_lora", "lora_foo", "foo_lora_foo", "lora_foo_lora"],
init_lora_weights=False,
)
self.check_peft_model_weights_loaded_correctly(MyModel, config, nested=nested)
@pytest.mark.parametrize("nested", [False, True])
def test_get_and_set_peft_model_state_dict_weight_in_module_name(self, nested):
# Here we have a model with some modules containing "weight" in their name.
# See #2772
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.foo_linear = nn.Linear(5, 5)
self.foo_weight = nn.Linear(5, 5)
self.weight_foo = nn.Linear(5, 5)
self.foo_weight_foo = nn.Linear(5, 5)
self.weight_foo_weight = nn.Linear(5, 5)
def forward(self, x):
x = self.foo_linear(x)
x = self.foo_weight(x)
x = self.weight_foo(x)
x = self.foo_weight_foo(x)
x = self.weight_foo_weight(x)
return x
config = LoraConfig(
target_modules=["foo_linear", "foo_weight", "weight_foo", "foo_weight_foo", "weight_foo_weight"],
init_lora_weights=False,
)
self.check_peft_model_weights_loaded_correctly(MyModel, config, nested=nested)
@pytest.mark.parametrize("nested", [False, True])
def test_get_and_set_peft_model_state_dict_bias_in_module_name(self, nested):
# Here we have a model with some modules containing "bias" in their name.
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.foo_linear = nn.Linear(5, 5)
self.foo_bias = nn.Linear(5, 5)
self.bias_foo = nn.Linear(5, 5)
self.foo_bias_foo = nn.Linear(5, 5)
self.bias_foo_bias = nn.Linear(5, 5)
def forward(self, x):
x = self.foo_linear(x)
x = self.foo_bias(x)
x = self.bias_foo(x)
x = self.foo_bias_foo(x)
x = self.bias_foo_bias(x)
return x
config = LoraConfig(
target_modules=["foo_linear", "foo_bias", "bias_foo", "foo_bias_foo", "bias_foo_bias"],
init_lora_weights=False,
bias="lora_only",
)
self.check_peft_model_weights_loaded_correctly(MyModel, config, nested=nested)
@pytest.mark.parametrize("nested", [False, True])
def test_get_and_set_peft_model_state_dict_adapter_name_same_as_module_name(self, nested):
# Here we choose a module name that is identical to the name of one of the adapters.
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.foo = nn.Linear(5, 5)
self.foo_baz = nn.Linear(5, 5)
self.baz_foo = nn.Linear(5, 5)
self.foo_baz_foo = nn.Linear(5, 5)
self.baz_foo_baz = nn.Linear(5, 5)
def forward(self, x):
x = self.foo(x)
x = self.foo_baz(x)
x = self.baz_foo(x)
x = self.foo_baz_foo(x)
x = self.baz_foo_baz(x)
return x
config = LoraConfig(
target_modules=["foo", "foo_baz", "baz_foo", "foo_baz_foo", "baz_foo_baz"], init_lora_weights=False
)
self.check_peft_model_weights_loaded_correctly(MyModel, config, nested=nested, adapter_name="foo")