CHORE: Clean up config kwargs in custom model tests (#2736)

Resolves #2695

For some PEFT methods, there was a bit of a mess when it comes to how
the init_weights argument was set in test_custom_models.py. The default
kwargs for the tests should be that the PEFT method is initialized as an
identity transform, and for specific tests we want to disable that. Note
that most PEFT methods are initialized by default to be identity
transforms, which is why the argument does not need to be set
explicitly, but it's not true for all PEFT methods.

With this PR, SHiRA, C3A, and FourierFT are now initialized to be
consistent with this. This made it possible to remove some extra
handling of those methods which was intermingled with certain tests.

Moreover, test_custom_models.py now uses the set_init_weights_false
helper function where appropriate.

While working on this, I also cleaned up a bit the docs for the
init_weights arguments of these PEFT methods where appropriate.

I added some clarifying comments.

For test_unload_adapter, I simplified a config type check and
rewrote it to load the base model only once.

---------

Co-authored-by: githubnemo <githubnemo@users.noreply.github.com>
This commit is contained in:
Benjamin Bossan
2025-08-19 11:55:25 +02:00
committed by GitHub
parent 480929537f
commit b5ace6a8c4
7 changed files with 85 additions and 115 deletions

View File

@ -46,8 +46,8 @@ class C3AConfig(PeftConfig):
The mapping from layer names or regexp expression to block_size which are different from the default
specified. For example, `{"model.decoder.layers.0.encoder_attn.k_proj": 1280`}
init_weights (`Union[bool, Literal["gaussian", "kaiming_uniform", "xavier_uniform"]]`):
The initialization of the C3A weights. Set this to False if the weights should be initialized to a commonly
used distribution. Set this to True if the weights should be initialized to zeros.
Defaults to 'xavier_uniform'. Setting this to `False` also uses 'xavier_uniform'. To set the weights to
zeros (thus making C3A a no-op), set the value to `True`.
"""
block_size: int = field(
@ -116,10 +116,8 @@ class C3AConfig(PeftConfig):
default="xavier_uniform",
metadata={
"help": (
"The initialization of the C3A weights. Leave it as default or"
" set it to False if the weights should be initialized with Xavier uniform,"
" which is experimentally suitable for C3A."
" Set this to True if the weights should be initialized to zeros."
"Defaults to 'xavier_uniform'. Setting this to `False` also uses 'xavier_uniform'. To set the weights "
"to zeros (thus making C3A a no-op), set the value to `True`."
)
},
)

View File

@ -76,7 +76,9 @@ class C3ALayer(BaseTunerLayer):
self.out_features // block_size,
self.in_features // block_size,
block_size,
dtype=torch.float32, # Currently, only fp32 is widely supported for FFT (fp16 is only supported on GPU with shapes of powers of 2, bf16 lacks FFT support)
# Currently, only fp32 is widely supported for FFT (fp16 is only supported on GPU with shapes of powers
# of 2, bf16 lacks FFT support)
dtype=torch.float32,
device=weight.device,
)
)
@ -93,7 +95,7 @@ class C3ALayer(BaseTunerLayer):
if adapter_name in self.c3a_kernel.keys():
if init_weights == "gaussian":
nn.init.normal_(self.c3a_kernel[adapter_name])
elif init_weights in ["xavier_uniform", False]: # Support test cases where False presents
elif init_weights in ["xavier_uniform", False]:
fan_in, fan_out = self.in_features, self.out_features
std = 1.0 * math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std

View File

@ -76,8 +76,8 @@ class FourierFTConfig(PeftConfig):
The mapping from layer names or regexp expression to n_frequency which are different from the default
specified. For example, `{model.decoder.layers.0.encoder_attn.k_proj: 1000`}.
init_weights (`bool`):
The initialization of the Fourier weights. Set this to False if the spectrum are initialized to a standard
normal distribution. Set this to True if the spectrum are initialized to zeros.
The initialization of the Fourier weights. Set this to False (the default) if the spectrum are initialized
to a standard normal distribution. Set this to True if the spectrum are initialized to zeros.
"""
n_frequency: int = field(
@ -178,8 +178,9 @@ class FourierFTConfig(PeftConfig):
default=False,
metadata={
"help": (
"The initialization of the Fourier weights. Set this to False if the spectrum should be initialized to a standard normal distribution."
"Set this to True if the spectrum should be initialized to zeros."
"The initialization of the Fourier weights. Set this to False (the default) if the spectrum should be "
"initialized to a standard normal distribution. Set this to True if the spectrum should be initialized "
"to zeros."
)
},
)

View File

@ -59,7 +59,7 @@ from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import AuxiliaryTrainingWrapper, infer_device
from .testing_common import PeftCommonTester
from .testing_utils import get_state_dict, require_non_cpu
from .testing_utils import get_state_dict, require_non_cpu, set_init_weights_false
# MLP is a vanilla FF network with only linear layers
@ -552,20 +552,20 @@ TEST_CASES = [
#########
# SHiRA #
#########
("Vanilla MLP 1 SHiRA", "MLP", ShiraConfig, {"r": 1, "target_modules": "lin0", "init_weights": False}),
("Vanilla MLP 2 SHiRA", "MLP", ShiraConfig, {"r": 1, "target_modules": ["lin0"], "init_weights": False}),
("Vanilla MLP 3 SHiRA", "MLP", ShiraConfig, {"r": 1, "target_modules": ["lin1"], "init_weights": False}),
("Vanilla MLP 1 SHiRA", "MLP", ShiraConfig, {"r": 1, "target_modules": "lin0"}),
("Vanilla MLP 2 SHiRA", "MLP", ShiraConfig, {"r": 1, "target_modules": ["lin0"]}),
("Vanilla MLP 3 SHiRA", "MLP", ShiraConfig, {"r": 1, "target_modules": ["lin1"]}),
(
"Vanilla MLP 4 SHiRA",
"MLP",
ShiraConfig,
{"r": 1, "target_modules": ["lin0", "lin1"], "random_seed": 56, "init_weights": False},
{"r": 1, "target_modules": ["lin0", "lin1"], "random_seed": 56},
),
(
"Vanilla MLP 5 SHiRA",
"MLP",
ShiraConfig,
{"r": 1, "target_modules": ["lin0"], "init_weights": False},
{"r": 1, "target_modules": ["lin0"]},
),
########
# VeRA #
@ -586,23 +586,39 @@ TEST_CASES = [
VeraConfig,
{"target_modules": ["conv1d"]},
),
########
#############
# FourierFT #
########
("Vanilla MLP 1 FourierFT", "MLP", FourierFTConfig, {"n_frequency": 10, "target_modules": "lin0"}),
("Vanilla MLP 2 FourierFT", "MLP", FourierFTConfig, {"n_frequency": 10, "target_modules": ["lin0"]}),
("Vanilla MLP 3 FourierFT", "MLP", FourierFTConfig, {"n_frequency": 10, "target_modules": ["lin1"]}),
#############
# FourierFT is not initialized as an identity transform by default, hence set init_weights=True
(
"Vanilla MLP 1 FourierFT",
"MLP",
FourierFTConfig,
{"n_frequency": 10, "target_modules": "lin0", "init_weights": True},
),
(
"Vanilla MLP 2 FourierFT",
"MLP",
FourierFTConfig,
{"n_frequency": 10, "target_modules": ["lin0"], "init_weights": True},
),
(
"Vanilla MLP 3 FourierFT",
"MLP",
FourierFTConfig,
{"n_frequency": 10, "target_modules": ["lin1"], "init_weights": True},
),
(
"Vanilla MLP 5 FourierFT",
"MLP",
FourierFTConfig,
{"n_frequency": 10, "target_modules": ["lin0"], "modules_to_save": ["lin1"]},
{"n_frequency": 10, "target_modules": ["lin0"], "modules_to_save": ["lin1"], "init_weights": True},
),
(
"Vanilla MLP 6 FourierFT",
"MLP",
FourierFTConfig,
{"n_frequency": 10, "target_modules": ["lin0", "lin1"], "modules_to_save": ["lin1"]},
{"n_frequency": 10, "target_modules": ["lin0", "lin1"], "modules_to_save": ["lin1"], "init_weights": True},
),
(
"Vanilla MLP 7 FourierFT",
@ -612,6 +628,7 @@ TEST_CASES = [
"n_frequency_pattern": {"lin0": 5, "lin1": 10},
"target_modules": ["lin0", "lin1"],
"modules_to_save": ["lin1"],
"init_weights": True,
},
),
##########
@ -676,20 +693,21 @@ TEST_CASES = [
#######
# C3A #
#######
("Vanilla MLP 1 C3A", "MLP", C3AConfig, {"block_size": 2, "target_modules": "lin0"}),
("Vanilla MLP 2 C3A", "MLP", C3AConfig, {"block_size": 2, "target_modules": ["lin0"]}),
("Vanilla MLP 3 C3A", "MLP", C3AConfig, {"block_size": 2, "target_modules": ["lin1"]}),
# note: C3A is not initialized as an identity transform by default, hence set init_weights=True
("Vanilla MLP 1 C3A", "MLP", C3AConfig, {"block_size": 2, "target_modules": "lin0", "init_weights": True}),
("Vanilla MLP 2 C3A", "MLP", C3AConfig, {"block_size": 2, "target_modules": ["lin0"], "init_weights": True}),
("Vanilla MLP 3 C3A", "MLP", C3AConfig, {"block_size": 2, "target_modules": ["lin1"], "init_weights": True}),
(
"Vanilla MLP 5 C3A",
"MLP",
C3AConfig,
{"block_size": 10, "target_modules": ["lin0"], "modules_to_save": ["lin1"]},
{"block_size": 10, "target_modules": ["lin0"], "modules_to_save": ["lin1"], "init_weights": True},
),
(
"Vanilla MLP 6 C3A",
"MLP",
C3AConfig,
{"block_size": 10, "target_modules": ["lin0", "lin1"], "modules_to_save": ["lin1"]},
{"block_size": 10, "target_modules": ["lin0", "lin1"], "modules_to_save": ["lin1"], "init_weights": True},
),
(
"Vanilla MLP 7 C3A",
@ -699,6 +717,7 @@ TEST_CASES = [
"block_size_pattern": {"lin0": 5, "lin1": 10},
"target_modules": ["lin0", "lin1"],
"modules_to_save": ["lin1"],
"init_weights": True,
},
),
]
@ -1415,19 +1434,7 @@ class TestPeftCustomModel(PeftCommonTester):
f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)"
)
config_kwargs = config_kwargs.copy()
if issubclass(config_cls, LoraConfig):
config_kwargs["init_lora_weights"] = False
elif issubclass(config_cls, IA3Config):
config_kwargs["init_ia3_weights"] = False
elif issubclass(config_cls, LNTuningConfig):
pass
elif issubclass(config_cls, VBLoRAConfig):
pass
elif issubclass(config_cls, TrainableTokensConfig):
pass
else:
config_kwargs["init_weights"] = False
config_kwargs = set_init_weights_false(config_cls, config_kwargs)
self._test_merge_layers(model_id, config_cls, config_kwargs)
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
@ -1438,27 +1445,20 @@ class TestPeftCustomModel(PeftCommonTester):
f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)"
)
config_kwargs = config_kwargs.copy()
if issubclass(config_cls, LoraConfig):
config_kwargs["init_lora_weights"] = False
elif issubclass(config_cls, IA3Config):
config_kwargs["init_ia3_weights"] = False
config_kwargs = set_init_weights_false(config_cls, config_kwargs)
self._test_merge_layers_fp16(model_id, config_cls, config_kwargs)
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
def test_merge_layers_is_idempotent(self, test_name, model_id, config_cls, config_kwargs):
# calling merge twice with the same arguments should not change the output
# https://github.com/huggingface/peft/pull/2403
if model_id in ["Conv2dGroups", "Conv2dGroups2"]:
pytest.skip(
f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)"
)
# calling merge twice with the same arguments should not change the output
config_kwargs = config_kwargs.copy()
if issubclass(config_cls, LoraConfig):
config_kwargs["init_lora_weights"] = False
elif issubclass(config_cls, IA3Config):
config_kwargs["init_ia3_weights"] = False
config_kwargs = set_init_weights_false(config_cls, config_kwargs)
self._test_merge_layers_is_idempotent(model_id, config_cls, config_kwargs)
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
@ -1469,20 +1469,7 @@ class TestPeftCustomModel(PeftCommonTester):
f"Skipping test for {model_id} as merging is not supported. (See https://github.com/huggingface/peft/pull/2403 for details)"
)
# calling merge twice with the same arguments should not change the output
config_kwargs = config_kwargs.copy()
if issubclass(config_cls, LoraConfig):
config_kwargs["init_lora_weights"] = False
elif issubclass(config_cls, IA3Config):
config_kwargs["init_ia3_weights"] = False
elif issubclass(config_cls, LNTuningConfig):
# LNTuning do not take init_weights
pass
elif issubclass(config_cls, VBLoRAConfig):
# VBLoRA do not take init_weights
pass
else:
config_kwargs["init_weights"] = False
config_kwargs = set_init_weights_false(config_cls, config_kwargs)
self._test_safe_merge(model_id, config_cls, config_kwargs)
@pytest.mark.parametrize("safe_merge", [False, True])
@ -1814,19 +1801,16 @@ class TestPeftCustomModel(PeftCommonTester):
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs):
# Test that it's possible to disable the adapter, in which case the model output should be identical to that of
# the base model.
X = self.prepare_inputs_for_testing()
model = self.transformers_class.from_pretrained(model_id).to(self.torch_device).eval()
outputs_base = model(**X)
if issubclass(config_cls, (FourierFTConfig, TrainableTokensConfig, C3AConfig)):
if issubclass(config_cls, (TrainableTokensConfig,)):
config_kwargs = config_kwargs.copy()
# override the default value and make PEFT operation a no-op
config_kwargs["init_weights"] = True
if issubclass(config_cls, (ShiraConfig,)):
# for SHiRA, setting this to default value of True will turn the PEFT operation into a no-op
# because SHiRA is always initialized to zeros. Configs declared in the test file had set init_weights
# to False (to make sure all other tests have a randn SHiRA initialization). Setting it back to True here
# as required by this test.
config_kwargs["init_weights"] = True
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
@ -1880,6 +1864,8 @@ class TestPeftCustomModel(PeftCommonTester):
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
def test_disable_adapters_with_merging(self, test_name, model_id, config_cls, config_kwargs):
# Same test as test_disable_adapters, but additionally merge the trained adapter.
# https://github.com/huggingface/peft/pull/2403
if model_id in ["Conv2dGroups", "Conv2dGroups2"]:
pytest.skip(
@ -1889,9 +1875,6 @@ class TestPeftCustomModel(PeftCommonTester):
# same as test_disable_adapters, but with merging
X = self.prepare_inputs_for_testing()
model = self.transformers_class.from_pretrained(model_id).to(self.torch_device)
if issubclass(config_cls, (FourierFTConfig, C3AConfig)):
config_kwargs = config_kwargs.copy()
config_kwargs["init_weights"] = True
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,

View File

@ -74,6 +74,8 @@ SMALL_GRID_MODELS = [
# TODO Missing from this list are LoKr, LoHa, LN Tuning, add them
# Note: If the PEFT method offers an initialization option to make it an identity transform (typically via the
# init_weights argument), then this option should be set here, if it's not already the default.
ALL_CONFIGS = [
(
AdaLoraConfig,

View File

@ -1557,8 +1557,11 @@ class PeftCommonTester:
def _test_unload_adapter(self, model_id, config_cls, config_kwargs):
with hub_online_once(model_id):
model = self.transformers_class.from_pretrained(model_id)
model = self.transformers_class.from_pretrained(model_id).to(self.torch_device)
num_params_base = len(model.state_dict())
dummy_input = self.prepare_inputs_for_testing()
with torch.inference_mode():
logits_transformers = model(**dummy_input)[0]
config = config_cls(
base_model_name_or_path=model_id,
@ -1567,44 +1570,26 @@ class PeftCommonTester:
model = get_peft_model(model, config)
model = model.to(self.torch_device)
if config.peft_type not in (
"LORA",
"ADALORA",
"IA3",
"BOFT",
"OFT",
"VERA",
"FOURIERFT",
"HRA",
"VBLORA",
"RANDLORA",
"SHIRA",
"BONE",
"C3A",
"MISS",
):
if isinstance(config, PromptLearningConfig):
# prompt learning does not support unloading
with pytest.raises(AttributeError):
model = model.unload()
else:
self.perturb_trainable_token_weights_if_used(model, config_kwargs)
with torch.inference_mode():
logits_with_adapter = model(**dummy_input)[0]
dummy_input = self.prepare_inputs_for_testing()
logits_with_adapter = model(**dummy_input)[0]
with hub_online_once(model_id):
transformers_model = self.transformers_class.from_pretrained(model_id).to(self.torch_device)
logits_transformers = transformers_model(**dummy_input)[0]
model.eval()
model = model.unload()
model.eval()
model = model.unload()
num_params_unloaded = len(model.state_dict())
with torch.inference_mode():
logits_unload = model(**dummy_input)[0]
num_params_unloaded = len(model.state_dict())
# check that PEFT layers are completely removed
assert not any(isinstance(module, BaseTunerLayer) for module in model.modules())
assert not torch.allclose(logits_with_adapter, logits_unload, atol=1e-10, rtol=1e-10)
assert torch.allclose(logits_transformers, logits_unload, atol=1e-4, rtol=1e-4)
assert num_params_base == num_params_unloaded
# check that PEFT layers are completely removed
assert not any(isinstance(module, BaseTunerLayer) for module in model.modules())
assert not torch.allclose(logits_with_adapter, logits_unload, atol=1e-10, rtol=1e-10)
assert torch.allclose(logits_transformers, logits_unload, atol=1e-4, rtol=1e-4)
assert num_params_base == num_params_unloaded
def _test_weighted_combination_of_adapters_lora(self, model, config, adapter_list, weight_list):
model.add_adapter(adapter_list[1], config)

View File

@ -26,9 +26,9 @@ from datasets import load_dataset
from peft import (
AdaLoraConfig,
IA3Config,
LNTuningConfig,
LoraConfig,
PromptLearningConfig,
ShiraConfig,
VBLoRAConfig,
)
from peft.import_utils import (
@ -231,16 +231,15 @@ def load_cat_image():
def set_init_weights_false(config_cls, kwargs):
# helper function that sets the config kwargs such that the model is *not* initialized as an identity transform
kwargs = kwargs.copy()
if issubclass(config_cls, PromptLearningConfig):
return kwargs
if issubclass(config_cls, ShiraConfig):
return kwargs
if config_cls == VBLoRAConfig:
if config_cls in (LNTuningConfig, VBLoRAConfig):
return kwargs
if (config_cls == LoraConfig) or (config_cls == AdaLoraConfig):
if config_cls in (LoraConfig, AdaLoraConfig):
kwargs["init_lora_weights"] = False
elif config_cls == IA3Config:
kwargs["init_ia3_weights"] = False