mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 15:33:48 +08:00
Fix module target edge cases (#2773)
Resolves #2772 Fixes several edge cases with unusual layer names or target modules. 1. As #2772 stated, if "weight" is part of a layer name, it would be treated incorrectly when creating the PEFT state_dict. 2. Similarly, when the adapter name itself is part of a layer name. Some of these errors would pass silently, which is especially bad (e.g. a weight not being loaded but no error raised). I also added some tests that were not failing before, but to cover some yet uncovered cases or to lay out some basic functionality. While working on this, I also noticed that it was possible to target a BaseTunerLayer with modules_to_save and trainable_token_indices (e.g. the lora_A and lora_B nn.Linear would be replaced with ModulesToSaveWrapper). I don't think this is ever desired, so we now raise an error if this is detected.
This commit is contained in:
@ -917,6 +917,18 @@ def _get_submodules(model, key):
|
||||
return parent, target, target_name
|
||||
|
||||
|
||||
def _get_submodules_with_grandparent(model, key):
|
||||
parent = model.get_submodule(".".join(key.split(".")[:-1]))
|
||||
try:
|
||||
grandparent = model.get_submodule(".".join(key.split(".")[:-2]))
|
||||
except AttributeError:
|
||||
# no grand parent
|
||||
grandparent = None
|
||||
target_name = key.split(".")[-1]
|
||||
target = model.get_submodule(key)
|
||||
return parent, grandparent, target, target_name
|
||||
|
||||
|
||||
def _freeze_adapter(model, adapter_name):
|
||||
for n, p in model.named_parameters():
|
||||
if adapter_name in n:
|
||||
@ -948,6 +960,8 @@ def _set_trainable(
|
||||
|
||||
The `active_adapter` flag indicates if this new adapter should be activated.
|
||||
"""
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
if wrapper_cls is None:
|
||||
wrapper_cls = ModulesToSaveWrapper
|
||||
|
||||
@ -964,7 +978,21 @@ def _set_trainable(
|
||||
for key in key_list:
|
||||
target_module_found = any(key.endswith(target_key) for target_key in module_names)
|
||||
if target_module_found:
|
||||
parent, target, target_name = _get_submodules(model, key)
|
||||
parent, grandparent, target, target_name = _get_submodules_with_grandparent(model, key)
|
||||
if isinstance(grandparent, BaseTunerLayer):
|
||||
# This is an extreme edge case: Let's assume that there is a PEFT config with
|
||||
# modules_to_save=["default"], which is the same name as the adapter name. The PEFT method's adapter
|
||||
# (e.g. LoRA) is applied first. Then, when the modules_to_save matching is performed, the LoRA layer
|
||||
# would be considered a valid target. Assuming that the name is "foo.bar.lora_A.default", it would
|
||||
# match, with "default" being an nn.Linear and the parent, "lora_A", being an nn.ModuleDict. This by
|
||||
# itself is not enough to prove that this is an unintended match. Thererfore, we also need to check the
|
||||
# grandparent, "bar", that would be a lora.LoraLayer. When we see this, we should raise an error.
|
||||
raise ValueError(
|
||||
f"You are trying to target a module with {wrapper_cls} that is a child of {type(grandparent)}. "
|
||||
"This is almost certainly not the intended behavior. Please ensure that the adapter name, "
|
||||
f"'{adapter_name}', does not conflict with any of the targeted modules."
|
||||
)
|
||||
|
||||
if isinstance(target, wrapper_cls):
|
||||
target.update(adapter_name, **wrapper_kwargs)
|
||||
target.set_adapter(target.active_adapter, inference_mode=inference_mode)
|
||||
|
@ -325,7 +325,31 @@ def get_peft_model_state_dict(
|
||||
warnings.warn("Could not identify embedding layer(s) because the model is not a 🤗 transformers model.")
|
||||
|
||||
# REMOVE ADAPTER NAME
|
||||
to_return = {k.replace(f".{adapter_name}", ""): v for k, v in to_return.items()}
|
||||
# Ensure not to replace in the middle of the key because a module happens to have the same name as the adapter.
|
||||
pattern = re.compile(re.escape(f".{adapter_name}") + r"$")
|
||||
|
||||
def remove_adapter_name(key):
|
||||
if "." not in key:
|
||||
# nothing to do
|
||||
return key
|
||||
|
||||
if key.endswith(f".{adapter_name}"):
|
||||
# comes from an nn.Parameter, so no .weight suffix, the adapter name is directly at the end
|
||||
return key.removesuffix(f".{adapter_name}")
|
||||
|
||||
# comes from an nn.Module, i.e. the adapter name is the 2nd to last element, e.g. v_proj.lora_A.default.weight
|
||||
key, _, suffix = key.rpartition(".") # split, e.g. v_proj.lora_A.default + weight
|
||||
|
||||
if (config.peft_type == PeftType.VBLORA) and suffix.startswith(f"{adapter_name}_"):
|
||||
# special case: VBLoRA creates keys that require this replacement:
|
||||
# base_model.model.lin0.vblora_logits_A.default_topk_indices =>
|
||||
# base_model.model.lin0.vblora_logits_A_topk_indices
|
||||
return key + "_" + suffix.removeprefix(f"{adapter_name}_")
|
||||
|
||||
key = pattern.sub("", key) # remove adapter name, e.g. v_proj.lora_A
|
||||
return f"{key}.{suffix}" # stitch the suffix back, e.g, v_proj.lora_A.weight
|
||||
|
||||
to_return = {remove_adapter_name(k): v for k, v in to_return.items()}
|
||||
return to_return
|
||||
|
||||
|
||||
@ -364,10 +388,12 @@ def _insert_adapter_name_into_state_dict(
|
||||
peft_model_state_dict = {}
|
||||
for key, val in state_dict.items():
|
||||
if parameter_prefix in key:
|
||||
suffix = key.split(parameter_prefix)[1]
|
||||
_, _, suffix = key.rpartition(parameter_prefix)
|
||||
if "." in suffix:
|
||||
suffix_to_replace = ".".join(suffix.split(".")[1:])
|
||||
key = key.replace(suffix_to_replace, f"{adapter_name}.{suffix_to_replace}")
|
||||
# only replace the substring if the key ends on the substring to avoid accidental replacement inside of
|
||||
# the key if a module happens to have a name that contains the substring
|
||||
key = re.sub(re.escape(suffix_to_replace) + r"$", f"{adapter_name}.{suffix_to_replace}", key)
|
||||
else:
|
||||
key = f"{key}.{adapter_name}"
|
||||
peft_model_state_dict[key] = val
|
||||
|
@ -1101,6 +1101,42 @@ class TestLoraInitialization:
|
||||
assert model.embed.scaling["default"] == expected_scaling["embed"]
|
||||
assert model.conv2d.scaling["default"] == expected_scaling["conv2d"]
|
||||
|
||||
def test_modules_to_save_targets_lora_layer_raises(self):
|
||||
# There is no good reason to have auxiliary modules to target a LoRA layer. As auxiliary modules are applied
|
||||
# *after* BaseTunerLayers, a possible way for this to happen accidentally is if the
|
||||
# modules_to_save/trainable_token_indices coincide with the adapter name, e.g. if the adapter name is "foobar",
|
||||
# we can have a module named model.base_model.model.self_attn.lora_A.foobar. If
|
||||
# modules_to_save/trainable_token_indices is also "foobar", there would be a match.
|
||||
# Note: Theoretically, a lot more PEFT methods support modules_to_save, so would have to be tested, but the code
|
||||
# path is the same for all of them, so only testing LoRA.
|
||||
model = self.get_model()
|
||||
|
||||
config = LoraConfig(
|
||||
target_modules=["linear"],
|
||||
modules_to_save=["foobar"],
|
||||
)
|
||||
msg = (
|
||||
"You are trying to target a module with <class 'peft.utils.other.ModulesToSaveWrapper'> that is a child of "
|
||||
"<class 'peft.tuners.lora.layer.Linear'>. This is almost certainly not the intended behavior. Please "
|
||||
"ensure that the adapter name, 'foobar', does not conflict with any of the targeted modules."
|
||||
)
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
get_peft_model(model, config, adapter_name="foobar")
|
||||
|
||||
def test_trainable_token_indices_targets_lora_layer_raises(self):
|
||||
# Same test as test_modules_to_save_targets_lora_layer_raises, but using trainable_token_indices
|
||||
model = self.get_model()
|
||||
|
||||
# check scaling factor use_rslora=True with rank and alpha pattern
|
||||
config = LoraConfig(target_modules=["embed"], trainable_token_indices={"foobar": [1, 2, 3]})
|
||||
msg = (
|
||||
"You are trying to target a module with <class 'peft.utils.other.TrainableTokensWrapper'> that is a child "
|
||||
"of <class 'peft.tuners.lora.layer.Embedding'>. This is almost certainly not the intended behavior. Please "
|
||||
"ensure that the adapter name, 'foobar', does not conflict with any of the targeted modules."
|
||||
)
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
get_peft_model(model, config, adapter_name="foobar")
|
||||
|
||||
@require_deterministic_for_xpu
|
||||
def test_lora_use_dora_linear(self, data):
|
||||
# check that dora is a no-op when initialized
|
||||
@ -1260,6 +1296,28 @@ class TestLoraInitialization:
|
||||
assert torch.allclose(merged_mask0, merged_mask1)
|
||||
assert mask_type0 == mask_type1
|
||||
|
||||
@pytest.mark.parametrize("bias", ["none", "all", "lora_only", "invalid"])
|
||||
def test_lora_with_bias_argument(self, bias):
|
||||
model = self.get_model()
|
||||
config = LoraConfig(target_modules=["linear", "conv2d"], bias=bias)
|
||||
|
||||
if bias == "invalid":
|
||||
with pytest.raises(NotImplementedError):
|
||||
get_peft_model(model, config)
|
||||
return
|
||||
|
||||
model = get_peft_model(model, config) # does not raise
|
||||
for name, param in model.named_parameters():
|
||||
if not name.endswith("bias"):
|
||||
continue
|
||||
if bias == "none":
|
||||
assert param.requires_grad is False
|
||||
elif bias == "all":
|
||||
assert param.requires_grad is True
|
||||
elif bias == "lora_only":
|
||||
# only layers targeted with target_modules
|
||||
assert param.requires_grad is ("linear" in name) or ("conv2d" in name)
|
||||
|
||||
def test_lora_with_bias_extra_params(self):
|
||||
# lora with lora_bias=True
|
||||
model = self.get_model()
|
||||
|
@ -20,6 +20,7 @@ import re
|
||||
import pytest
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from torch import nn
|
||||
from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
|
||||
|
||||
from peft import (
|
||||
@ -28,8 +29,10 @@ from peft import (
|
||||
LoKrConfig,
|
||||
LoraConfig,
|
||||
RandLoraConfig,
|
||||
get_peft_model,
|
||||
get_peft_model_state_dict,
|
||||
inject_adapter_in_model,
|
||||
set_peft_model_state_dict,
|
||||
)
|
||||
from peft.tuners import lora
|
||||
from peft.utils import ModulesToSaveWrapper
|
||||
@ -389,3 +392,232 @@ class TestInjectAdapterFromStateDict:
|
||||
assert sd_before.keys() == sd_after.keys()
|
||||
for key in sd_before.keys():
|
||||
assert sd_before[key].shape == sd_after[key].shape
|
||||
|
||||
|
||||
class TestPeftStateDict:
|
||||
# Test some edge cases around getting and setting the PEFT state_dict. There are potential sources of errors there
|
||||
# because the adapter_name is removed from/added to the state_dict keys.
|
||||
def test_get_peft_model_state_dict_removes_adapter_name(self):
|
||||
# ensure that the adapter name, "default", is removed from the state_dict
|
||||
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
|
||||
with hub_online_once(model_id):
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
|
||||
# note: lora targets q_proj and v_proj; add in an auxiliary module for good measure
|
||||
model = get_peft_model(model, LoraConfig(modules_to_save=["lm_head"]))
|
||||
sd = get_peft_model_state_dict(model)
|
||||
assert len(sd) > 1 # sanity check
|
||||
assert not any("default" in key for key in sd)
|
||||
|
||||
def test_get_peft_model_state_dict_removes_non_defaul_adapter_name(self):
|
||||
# ensure that the adapter name is removed from the state_dict, even if it's not "default"
|
||||
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
|
||||
with hub_online_once(model_id):
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
|
||||
model = get_peft_model(model, LoraConfig(modules_to_save=["lm_head"]), adapter_name="other")
|
||||
sd = get_peft_model_state_dict(model, adapter_name="other")
|
||||
assert len(sd) > 1 # sanity check
|
||||
assert not any("other" in key for key in sd)
|
||||
|
||||
def test_get_peft_model_state_dict_removes_adapter_name_when_same_as_module_name(self):
|
||||
# here the adapter is named "v_proj", which is the same name as some modules targeted with lora in the model,
|
||||
# which is nefarious
|
||||
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
|
||||
with hub_online_once(model_id):
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
|
||||
config = LoraConfig(modules_to_save=["lm_head"], target_modules=["v_proj"])
|
||||
model = get_peft_model(model, config, adapter_name="v_proj")
|
||||
sd = get_peft_model_state_dict(model, adapter_name="v_proj")
|
||||
assert len(sd) > 1 # sanity check
|
||||
for key in sd:
|
||||
# assert that the adapter_name was indeed removed
|
||||
assert not key.endswith("lora_A.v_proj.weight")
|
||||
assert not key.endswith("lora_B.v_proj.weight")
|
||||
assert not key.endswith("modules_to_save.v_proj.weight")
|
||||
# assert that the module name was not stripped completely from the key
|
||||
assert ("v_proj" in key) or ("q_proj" in key) or ("lm_head") in key
|
||||
|
||||
def check_peft_model_weights_loaded_correctly(self, inner_model_cls, config, nested, adapter_name="default"):
|
||||
# Runs checks that a roundtrip of get_peft_model_state_dict and set_peft_model_state_dict results in the same
|
||||
# model (same outputs and same weights).
|
||||
class Outer(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.inner = inner_model_cls()
|
||||
|
||||
def forward(self, x):
|
||||
return self.inner(x)
|
||||
|
||||
if nested:
|
||||
# add another layer of nesting
|
||||
model_cls = Outer
|
||||
else:
|
||||
model_cls = inner_model_cls
|
||||
|
||||
x = torch.randn(1, 5)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_model = model_cls()
|
||||
with torch.inference_mode():
|
||||
base_out = base_model(x)
|
||||
|
||||
torch.manual_seed(42)
|
||||
model = get_peft_model(base_model, config, adapter_name=adapter_name)
|
||||
with torch.inference_mode():
|
||||
peft_out = model(x)
|
||||
# sanity check: peft adapter has an effect
|
||||
assert not torch.allclose(base_out, peft_out, atol=1e-6)
|
||||
|
||||
sd = get_peft_model_state_dict(model, adapter_name=adapter_name)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_model = model_cls()
|
||||
torch.manual_seed(42 + 1) # ensure we start with a different, randomly initialized PEFT model
|
||||
model_new = get_peft_model(base_model, config, adapter_name=adapter_name)
|
||||
with torch.inference_mode():
|
||||
peft_new = model_new(x)
|
||||
assert not torch.allclose(peft_out, peft_new, atol=1e-6)
|
||||
|
||||
set_peft_model_state_dict(model_new, sd, adapter_name=adapter_name)
|
||||
with torch.inference_mode():
|
||||
peft_out_loaded = model_new(x)
|
||||
assert torch.allclose(peft_out, peft_out_loaded, atol=1e-6)
|
||||
|
||||
sd_new = get_peft_model_state_dict(model, adapter_name=adapter_name)
|
||||
assert sd.keys() == sd_new.keys()
|
||||
for key, val in sd.items():
|
||||
val_new = sd_new[key]
|
||||
torch.allclose(val, val_new)
|
||||
|
||||
@pytest.mark.parametrize("nested", [False, True])
|
||||
def test_get_and_set_peft_model_state_dict_normal_names(self, nested):
|
||||
# In this test, there is no edge case. Therefore, this test is basically the "control group" for the subsequent
|
||||
# tests (if this test were to fail, it means the testing code itself is wrong).
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.foo_linear = nn.Linear(5, 5)
|
||||
self.foo_baz = nn.Linear(5, 5)
|
||||
self.baz_foo = nn.Linear(5, 5)
|
||||
self.foo_baz_foo = nn.Linear(5, 5)
|
||||
self.baz_foo_baz = nn.Linear(5, 5)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.foo_linear(x)
|
||||
x = self.foo_baz(x)
|
||||
x = self.baz_foo(x)
|
||||
x = self.foo_baz_foo(x)
|
||||
x = self.baz_foo_baz(x)
|
||||
return x
|
||||
|
||||
config = LoraConfig(
|
||||
target_modules=["foo_linear", "foo_baz", "baz_foo", "foo_baz_foo", "baz_foo_baz"], init_lora_weights=False
|
||||
)
|
||||
self.check_peft_model_weights_loaded_correctly(MyModel, config, nested=nested)
|
||||
|
||||
@pytest.mark.parametrize("nested", [False, True])
|
||||
def test_get_and_set_peft_model_state_dict_peft_prefix_in_module_name(self, nested):
|
||||
# Here we have a model with some modules containing "lora" in their name.
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.foo_linear = nn.Linear(5, 5)
|
||||
self.foo_lora = nn.Linear(5, 5)
|
||||
self.lora_foo = nn.Linear(5, 5)
|
||||
self.foo_lora_foo = nn.Linear(5, 5)
|
||||
self.lora_foo_lora = nn.Linear(5, 5)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.foo_linear(x)
|
||||
x = self.foo_lora(x)
|
||||
x = self.lora_foo(x)
|
||||
x = self.foo_lora_foo(x)
|
||||
x = self.lora_foo_lora(x)
|
||||
return x
|
||||
|
||||
config = LoraConfig(
|
||||
target_modules=["foo_linear", "foo_lora", "lora_foo", "foo_lora_foo", "lora_foo_lora"],
|
||||
init_lora_weights=False,
|
||||
)
|
||||
self.check_peft_model_weights_loaded_correctly(MyModel, config, nested=nested)
|
||||
|
||||
@pytest.mark.parametrize("nested", [False, True])
|
||||
def test_get_and_set_peft_model_state_dict_weight_in_module_name(self, nested):
|
||||
# Here we have a model with some modules containing "weight" in their name.
|
||||
# See #2772
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.foo_linear = nn.Linear(5, 5)
|
||||
self.foo_weight = nn.Linear(5, 5)
|
||||
self.weight_foo = nn.Linear(5, 5)
|
||||
self.foo_weight_foo = nn.Linear(5, 5)
|
||||
self.weight_foo_weight = nn.Linear(5, 5)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.foo_linear(x)
|
||||
x = self.foo_weight(x)
|
||||
x = self.weight_foo(x)
|
||||
x = self.foo_weight_foo(x)
|
||||
x = self.weight_foo_weight(x)
|
||||
return x
|
||||
|
||||
config = LoraConfig(
|
||||
target_modules=["foo_linear", "foo_weight", "weight_foo", "foo_weight_foo", "weight_foo_weight"],
|
||||
init_lora_weights=False,
|
||||
)
|
||||
self.check_peft_model_weights_loaded_correctly(MyModel, config, nested=nested)
|
||||
|
||||
@pytest.mark.parametrize("nested", [False, True])
|
||||
def test_get_and_set_peft_model_state_dict_bias_in_module_name(self, nested):
|
||||
# Here we have a model with some modules containing "bias" in their name.
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.foo_linear = nn.Linear(5, 5)
|
||||
self.foo_bias = nn.Linear(5, 5)
|
||||
self.bias_foo = nn.Linear(5, 5)
|
||||
self.foo_bias_foo = nn.Linear(5, 5)
|
||||
self.bias_foo_bias = nn.Linear(5, 5)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.foo_linear(x)
|
||||
x = self.foo_bias(x)
|
||||
x = self.bias_foo(x)
|
||||
x = self.foo_bias_foo(x)
|
||||
x = self.bias_foo_bias(x)
|
||||
return x
|
||||
|
||||
config = LoraConfig(
|
||||
target_modules=["foo_linear", "foo_bias", "bias_foo", "foo_bias_foo", "bias_foo_bias"],
|
||||
init_lora_weights=False,
|
||||
bias="lora_only",
|
||||
)
|
||||
self.check_peft_model_weights_loaded_correctly(MyModel, config, nested=nested)
|
||||
|
||||
@pytest.mark.parametrize("nested", [False, True])
|
||||
def test_get_and_set_peft_model_state_dict_adapter_name_same_as_module_name(self, nested):
|
||||
# Here we choose a module name that is identical to the name of one of the adapters.
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.foo = nn.Linear(5, 5)
|
||||
self.foo_baz = nn.Linear(5, 5)
|
||||
self.baz_foo = nn.Linear(5, 5)
|
||||
self.foo_baz_foo = nn.Linear(5, 5)
|
||||
self.baz_foo_baz = nn.Linear(5, 5)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.foo(x)
|
||||
x = self.foo_baz(x)
|
||||
x = self.baz_foo(x)
|
||||
x = self.foo_baz_foo(x)
|
||||
x = self.baz_foo_baz(x)
|
||||
return x
|
||||
|
||||
config = LoraConfig(
|
||||
target_modules=["foo", "foo_baz", "baz_foo", "foo_baz_foo", "baz_foo_baz"], init_lora_weights=False
|
||||
)
|
||||
self.check_peft_model_weights_loaded_correctly(MyModel, config, nested=nested, adapter_name="foo")
|
||||
|
Reference in New Issue
Block a user