FIX #2295: Warn when user reloads modified model (#2306)

When modifying a model with `get_peft_model` that was already modified
in the same way, even specifying a different config may not change
the trainable parameter count, e.g. when specifying target modules that
are only a subset of the previous target modules.

With this patch a warning will be issued with a hint to `.unload()`
when calling `get_peft_model` on an already modified model.
This commit is contained in:
githubnemo
2025-01-07 18:10:07 +01:00
committed by GitHub
parent d967f6394c
commit 3d2bf9a8b2
2 changed files with 65 additions and 1 deletions

View File

@ -70,7 +70,7 @@ from .tuners import (
VeraModel,
XLoraConfig,
)
from .tuners.tuners_utils import BaseTuner
from .tuners.tuners_utils import BaseTuner, BaseTunerLayer
from .utils import _prepare_prompt_learning_config
from .utils.constants import PEFT_TYPE_TO_PREFIX_MAPPING
@ -182,6 +182,15 @@ def get_peft_model(
new_name = model.__dict__.get("name_or_path", None)
peft_config.base_model_name_or_path = new_name
# Especially in notebook environments there could be a case that a user wants to experiment with different
# configuration values. However, it is likely that there won't be any changes for new configs on an already
# initialized PEFT model. The best we can do is warn the user about it.
if any(isinstance(module, BaseTunerLayer) for module in model.modules()):
warnings.warn(
"You are trying to modify a model with PEFT for a second time. If you want to reload the model with a "
"different config, make sure to call `.unload()` before."
)
if (old_name is not None) and (old_name != new_name):
warnings.warn(
f"The PEFT config's `base_model_name_or_path` was renamed from '{old_name}' to '{new_name}'. "

55
tests/test_mapping.py Normal file
View File

@ -0,0 +1,55 @@
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from peft import LoraConfig, get_peft_model
class TestGetPeftModel:
RELOAD_WARNING_EXPECTED_MATCH = r"You are trying to modify a model .*"
@pytest.fixture
def lora_config_0(self):
return LoraConfig(target_modules="0")
@pytest.fixture
def base_model(self):
return torch.nn.Sequential(torch.nn.Linear(10, 2), torch.nn.Linear(2, 10))
def test_get_peft_model_warns_when_reloading_model(self, lora_config_0, base_model):
get_peft_model(base_model, lora_config_0)
with pytest.warns(UserWarning, match=self.RELOAD_WARNING_EXPECTED_MATCH):
get_peft_model(base_model, lora_config_0)
def test_get_peft_model_proposed_fix_in_warning_helps(self, lora_config_0, base_model, recwarn):
peft_model = get_peft_model(base_model, lora_config_0)
peft_model.unload()
get_peft_model(base_model, lora_config_0)
warning_checker = pytest.warns(UserWarning, match=self.RELOAD_WARNING_EXPECTED_MATCH)
for warning in recwarn:
if warning_checker.matches(warning):
pytest.fail("Warning raised even though model was unloaded.")
def test_get_peft_model_repeated_invocation(self, lora_config_0, base_model):
peft_model = get_peft_model(base_model, lora_config_0)
# use direct-addressing of the other layer to accomodate for the nested model
lora_config_1 = LoraConfig(target_modules="base_model.model.1")
with pytest.warns(UserWarning, match=self.RELOAD_WARNING_EXPECTED_MATCH):
get_peft_model(peft_model, lora_config_1)