Compare commits

...

4 Commits

Author SHA1 Message Date
6392935921 Add prompt tuning experiment with sample vocab (#2824)
A new initialization method was added to prompt tuning in #2815. This PR
adds an experiment config for this method to the MetaMathQA benchmark.

Testing locally, this got a test accuracy of 36%, compared to 25% with
random initialization.
2025-10-13 16:54:45 +02:00
25f97e663a ENH: Add set_requires_grad method (#2807)
This PR adds the set_requires_grad method to PEFT models (both PeftModel
and BaseTuner). As the name suggests, this is a method to set the
requires_grad attribute of the specified PEFT adapters.

For more general context, this is mostly relevant when dealing with
multiple adapters. As is, users can already set the active adapter(s)
with set_adapter, which automatically adjust the requires_grad attribute
too, so that only the active adapters will have grads enabled. However,
there can be situations where activity status and requires grad may
differ. Right now, users would need to manually set requires_grad to
deal with that, which is error prone (e.g. forgetting modules_to_save).
This PR closes this gap in the API.

As this functionality is quite general purpose, I added a
set_requires_grad function to functional.py for easier integration.

Note: The set_requires_grad method will raise an error when called with
prompt learning methods like prompt tuning. This is because these
methods don't have a universal base class (BaseTuner and BaseTunerLayer)
that would allow to add this API. Moreover, they only support a single
adapter at a time, hence there is not much need to have this method in
the first place.

A side effect of not supporting prompt learning is that on the
PeftModel, we are free to allow set_requires_grad to accept more than
one adapter, which would normally be difficult, because prompt learning
only allows one adapter.
2025-10-13 16:54:16 +02:00
61a11f9180 CI Testing transformers deprecations (#2817)
Check if PEFT triggers transformers FutureWarning or DeprecationWarning
by converting these warnings into failures.
2025-10-13 16:53:35 +02:00
2f9f759587 Add num_trainable_params column to gradio app (#2819)
While memory usage correlates with the number of trainable params, having this number directly
makes it easier to see that methods are using similar numbers of trainable params and outliers
can be inspected easily.
2025-10-13 14:36:58 +02:00
14 changed files with 279 additions and 2 deletions

View File

@ -28,6 +28,10 @@ The functions provided here can be considered "public API" of PEFT and hence are
[[autodoc]] functional.set_adapter
- all
## Set the `requires_grad` attribute of the specified adapters
[[autodoc]] functional.set_requires_grad
- all
## Load the weights of the PEFT state dict into the model
[[autodoc]] functional.set_peft_model_state_dict
- all

View File

@ -0,0 +1,17 @@
{
"auto_mapping": null,
"base_model_name_or_path": null,
"inference_mode": false,
"num_attention_heads": 24,
"num_layers": 28,
"num_transformer_submodules": 1,
"num_virtual_tokens": 200,
"peft_type": "PROMPT_TUNING",
"prompt_tuning_init": "SAMPLE_VOCAB",
"prompt_tuning_init_text": null,
"revision": null,
"task_type": "CAUSAL_LM",
"token_dim": 3072,
"tokenizer_kwargs": null,
"tokenizer_name_or_path": null
}

View File

@ -0,0 +1,6 @@
{
"optimizer_kwargs": {
"lr": 1e-3
}
}

View File

@ -33,6 +33,7 @@ metric_preferences = {
"file_size": "lower",
"test_accuracy": "higher",
"train_loss": "lower",
"num_trainable_params": "lower",
}

View File

@ -51,6 +51,7 @@ def preprocess(rows, task_name: str, print_fn=print):
"total_time": run_info["total_time"],
"train_time": train_info["train_time"],
"file_size": train_info["file_size"],
"num_trainable_params": train_info["num_trainable_params"],
"test_accuracy": train_metrics["test accuracy"],
"train_loss": train_metrics["train loss"],
"train_samples": train_metrics["train samples"],
@ -103,6 +104,7 @@ def load_df(path, task_name, print_fn=print):
"train_loss": float,
"train_samples": int,
"train_total_tokens": int,
"num_trainable_params": int,
"peft_version": "string",
"peft_branch": "string",
"transformers_version": "string",
@ -131,6 +133,7 @@ def load_df(path, task_name, print_fn=print):
"accelerator_memory_max",
"accelerator_memory_reserved_99th",
"accelerator_memory_reserved_avg",
"num_trainable_params",
"file_size",
"created_at",
"task_name",
@ -138,7 +141,6 @@ def load_df(path, task_name, print_fn=print):
other_columns = [col for col in df if col not in important_columns]
df = df[important_columns + other_columns]
size_before_drop_dups = len(df)
columns = ["experiment_name", "model_id", "peft_type", "created_at"]
# we want to keep only the most recent run for each experiment
df = df.sort_values("created_at").drop_duplicates(columns, keep="last")

View File

@ -49,3 +49,8 @@ markers = [
"regression: whether to run regression suite test",
"bitsandbytes: select bitsandbytes integration tests"
]
filterwarnings = [
"error::DeprecationWarning:transformers",
"error::FutureWarning:transformers",
]

View File

@ -19,7 +19,7 @@ provide PEFT integrations.
"""
from peft.mapping import inject_adapter_in_model
from peft.tuners.tuners_utils import cast_adapter_dtype, delete_adapter, set_adapter
from peft.tuners.tuners_utils import cast_adapter_dtype, delete_adapter, set_adapter, set_requires_grad
from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict
@ -30,4 +30,5 @@ __all__ = [
"inject_adapter_in_model",
"set_adapter",
"set_peft_model_state_dict",
"set_requires_grad",
]

View File

@ -19,6 +19,7 @@ import copy
import inspect
import os
import warnings
from collections.abc import Sequence
from contextlib import contextmanager, nullcontext
from copy import deepcopy
from dataclasses import dataclass
@ -1458,6 +1459,26 @@ class PeftModel(PushToHubMixin, torch.nn.Module):
# handle auxiliary modules
_set_adapter(self, adapter_name)
def set_requires_grad(self, adapter_names: str | Sequence[str], requires_grad: bool = True) -> None:
"""
Enable or disable gradients on the given adapter(s).
Note: Not supported for prompt learning methods like prompt tuning.
Args:
adapter_name (`str` or `Sequence[str]`):
The name of the adapter(s) whose gradients should be enabled/disabled.
requires_grad (`bool`, *optional*)
Whether to enable (`True`, default) or disable (`False`).
"""
if self.active_peft_config.is_prompt_learning:
raise TypeError(
"Setting `requires_grad` is not supported for prompt learning methods like "
f"{self.active_peft_config.peft_type.value}."
)
self.base_model.set_requires_grad(adapter_names=adapter_names, requires_grad=requires_grad)
@property
def base_model_torch_dtype(self):
return getattr(self.base_model, "dtype", None)

View File

@ -20,6 +20,7 @@ import re
import textwrap
import warnings
from abc import ABC, abstractmethod
from collections.abc import Sequence
from contextlib import contextmanager, nullcontext
from typing import Any, Optional, Union, overload
@ -483,6 +484,18 @@ class BaseTuner(nn.Module, ABC):
)
self.active_adapter = new_adapter or []
def set_requires_grad(self, adapter_names: str | Sequence[str], requires_grad: bool = True) -> None:
"""
Enable or disable gradients on the given adapter(s).
Args:
adapter_name (`str` or `Sequence[str]`):
The name of the adapter(s) whose gradients should be enabled/disabled.
requires_grad (`bool`, *optional*)
Whether to enable (`True`, default) or disable (`False`).
"""
set_requires_grad(self.model, adapter_names=adapter_names, requires_grad=requires_grad)
def _check_new_adapter_config(self, config: PeftConfig) -> None:
"""
A helper method to check the config of a new adapter being added.
@ -1353,6 +1366,27 @@ class BaseTunerLayer(ABC):
)
self.set_adapter(remaining_adapters[0])
def set_requires_grad(self, adapter_names: str | Sequence[str], requires_grad: bool = True) -> None:
"""
Enable or disable gradients on the given adapter(s).
Args:
adapter_name (`str` or `Sequence[str]`):
The name of the adapter(s) whose gradients should be enabled/disabled.
requires_grad (`bool`, *optional*)
Whether to enable (`True`, default) or disable (`False`).
"""
if isinstance(adapter_names, str):
adapter_names_set = {adapter_names}
else:
adapter_names_set = set(adapter_names)
for layer_name in self.adapter_layer_names:
module_dict = getattr(self, layer_name)
for key, layer in module_dict.items():
if key in adapter_names_set:
layer.requires_grad_(requires_grad)
def _move_adapter_to_device_of_base_layer(self, adapter_name: str, device: Optional[torch.device] = None) -> None:
"""
Move the adapter of the given name to the device of the base layer.
@ -1877,3 +1911,20 @@ def cast_adapter_dtype(model: nn.Module, adapter_name: str, autocast_adapter_dty
for param in submodule[adapter_name].parameters():
if param.dtype in dtypes_to_convert_to_fp32:
param.data = param.data.to(torch.float32)
def set_requires_grad(model, adapter_names: str | Sequence[str], requires_grad: bool = True) -> None:
"""
Enable or disable gradients on the given adapter(s).
Args:
model (`nn.Module`):
The model from which the adapter should be deleted.
adapter_name (`str` or `Sequence[str]`):
The name of the adapter(s) whose gradients should be enabled/disabled.
requires_grad (`bool`, *optional*)
Whether to enable (`True`, default) or disable (`False`).
"""
for module in model.modules():
if isinstance(module, (BaseTunerLayer, AuxiliaryTrainingWrapper)):
module.set_requires_grad(adapter_names=adapter_names, requires_grad=requires_grad)

View File

@ -21,6 +21,7 @@ import re
import warnings
from collections.abc import Sequence
from contextlib import nullcontext
from operator import attrgetter
from typing import Any, Optional, Union
import accelerate
@ -480,6 +481,28 @@ class AuxiliaryTrainingWrapper(torch.nn.Module):
"""Delete an adapter from the layer, set a new active adapter if necessary"""
raise NotImplementedError
def set_requires_grad(self, adapter_names: str | Sequence[str], requires_grad: bool = True) -> None:
"""
Enable or disable gradients on the given adapter(s).
Args:
adapter_name (`str` or `Sequence[str]`):
The name of the adapter(s) whose gradients should be enabled/disabled.
requires_grad (`bool`, *optional*)
Whether to enable (`True`, default) or disable (`False`).
"""
if isinstance(adapter_names, str):
adapter_names_set = {adapter_names}
else:
adapter_names_set = set(adapter_names)
for layer_name in self.adapter_layer_names:
# use attrgetter, as it resolves `.` in the attribute name
module_dict = attrgetter(layer_name)(self)
for key, layer in module_dict.items():
if key in adapter_names_set:
layer.requires_grad_(requires_grad)
def adapter_state_dict(self, adapter_name):
"""Return the state dict of this module for a given adapter."""
raise NotImplementedError

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import platform
import re
@ -25,6 +26,22 @@ def pytest_addoption(parser):
def pytest_configure(config):
config.addinivalue_line("markers", "regression: mark regression tests")
# Errors from transformers deprecations
logger = logging.getLogger("transformers")
class ErrorOnDeprecation(logging.Handler):
def emit(self, record):
msg = record.getMessage().lower()
if "deprecat" in msg or "future" in msg:
if "torch_dtype" not in msg:
# let's ignore the torch_dtype => dtype deprecation for now
raise AssertionError(f"**Transformers Deprecation**: {msg}")
# Add our handler
handler = ErrorOnDeprecation()
logger.addHandler(handler)
logger.setLevel(logging.WARNING)
def pytest_collection_modifyitems(config, items):
if config.getoption("--regression"):

View File

@ -2756,6 +2756,101 @@ class TestPeftCustomModel(PeftCommonTester):
def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs):
self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs)
@staticmethod
def _check_requires_grad(module, adapter_name, requires_grad):
# a bit of a clumsy way to test requires_grad on the PEFT parameters
for name in module.adapter_layer_names:
module_dict = getattr(module, name)
if adapter_name not in module_dict:
continue
attr = module_dict[adapter_name]
if isinstance(attr, nn.Module):
for param in attr.parameters():
assert param.requires_grad == requires_grad
else: # it's an nn.Parameter
assert attr.requires_grad == requires_grad
@pytest.mark.parametrize("config_cls", ALL_PEFT_CONFIG_CLASSES)
def test_set_requires_grad(self, config_cls):
# checks that the model.set_requires_grad method works as expected
if config_cls == TrainableTokensConfig:
pytest.skip(
"TrainableTokensConfig has a separate test for set_requires_grad, as it needs a different model."
)
config_kwargs = {"target_modules": ["layers.0.lin0"]}
if config_cls == IA3Config:
config_kwargs["feedforward_modules"] = []
config0 = config_cls(**config_kwargs)
model = DeepMLP(size=256) # a size that works with all adapters
model = get_peft_model(model, config0, adapter_name="adapter0").eval()
# check that it works with a single adapter
self._check_requires_grad(model.base_model.model.layers[0].lin0, adapter_name="adapter0", requires_grad=True)
# add another adapter with two target modules and with modules_to_save
config_kwargs["target_modules"] = ["layers.0.lin0", "layers.1.lin0"]
config_kwargs["modules_to_save"] = ["layers.2.lin0"]
config1 = config_cls(**config_kwargs)
model.add_adapter("adapter1", config1)
# adapter0 still has requires_grad=True, adapter1 has requires_grad=False
self._check_requires_grad(model.base_model.model.layers[0].lin0, adapter_name="adapter0", requires_grad=True)
self._check_requires_grad(model.base_model.model.layers[0].lin0, adapter_name="adapter1", requires_grad=False)
self._check_requires_grad(model.base_model.model.layers[1].lin0, adapter_name="adapter1", requires_grad=False)
self._check_requires_grad(model.base_model.model.layers[2].lin0, adapter_name="adapter1", requires_grad=False)
# enable grad for adapter1; adapter0 is unaffected
model.set_requires_grad(adapter_names="adapter1")
self._check_requires_grad(model.base_model.model.layers[0].lin0, adapter_name="adapter0", requires_grad=True)
self._check_requires_grad(model.base_model.model.layers[0].lin0, adapter_name="adapter1", requires_grad=True)
self._check_requires_grad(model.base_model.model.layers[1].lin0, adapter_name="adapter1", requires_grad=True)
self._check_requires_grad(model.base_model.model.layers[2].lin0, adapter_name="adapter1", requires_grad=True)
# disable adapter for both
model.set_requires_grad(adapter_names=["adapter0", "adapter1"], requires_grad=False)
self._check_requires_grad(model.base_model.model.layers[0].lin0, adapter_name="adapter0", requires_grad=False)
self._check_requires_grad(model.base_model.model.layers[0].lin0, adapter_name="adapter1", requires_grad=False)
self._check_requires_grad(model.base_model.model.layers[1].lin0, adapter_name="adapter1", requires_grad=False)
def test_set_requires_grad_trainable_tokens(self):
# same as test_set_requires_grad for trainable tokens
class EmbModel(nn.Module):
def __init__(self):
super().__init__()
self.emb0 = nn.Embedding(10, 10)
self.emb1 = nn.Embedding(10, 10)
config_kwargs = {"target_modules": ["emb0"], "token_indices": [0, 2, 4]}
config0 = TrainableTokensConfig(**config_kwargs)
model = EmbModel()
model = get_peft_model(model, config0, adapter_name="adapter0").eval()
# check that it works with a single adapter
self._check_requires_grad(model.base_model.model.emb0, adapter_name="adapter0", requires_grad=True)
# add another adapter which targets 2 embedding layers
config_kwargs["target_modules"] = ["emb0", "emb1"]
config1 = TrainableTokensConfig(**config_kwargs)
model.add_adapter("adapter1", config1)
# adapter0 still has requires_grad=True, adapter1 has requires_grad=False
self._check_requires_grad(model.base_model.model.emb0, adapter_name="adapter0", requires_grad=True)
self._check_requires_grad(model.base_model.model.emb0, adapter_name="adapter1", requires_grad=False)
self._check_requires_grad(model.base_model.model.emb1, adapter_name="adapter1", requires_grad=False)
# enable grad for adapter1; adapter0 is unaffected
model.set_requires_grad(adapter_names="adapter1")
self._check_requires_grad(model.base_model.model.emb0, adapter_name="adapter0", requires_grad=True)
self._check_requires_grad(model.base_model.model.emb0, adapter_name="adapter1", requires_grad=True)
self._check_requires_grad(model.base_model.model.emb1, adapter_name="adapter1", requires_grad=True)
# disable adapter for both
model.set_requires_grad(adapter_names=["adapter0", "adapter1"], requires_grad=False)
self._check_requires_grad(model.base_model.model.emb0, adapter_name="adapter0", requires_grad=False)
self._check_requires_grad(model.base_model.model.emb0, adapter_name="adapter1", requires_grad=False)
self._check_requires_grad(model.base_model.model.emb1, adapter_name="adapter1", requires_grad=False)
def test_weight_bias_attributes(self):
model = MLP()
config = LoraConfig(target_modules=["lin0"])

View File

@ -795,3 +795,23 @@ class TestDecoderModels(PeftCommonTester):
)
else:
assert not contains_embedding
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
def test_set_requires_grad_prompt_learning_raises(self, config_cls, config_kwargs):
# Test that for prompt learning, calling set_requires_grad raises an error with an appropriate error message.
# Note that for non-prompt learning methods, set_requires_grad is being tested for custom models, so there is no
# specific test here.
model_id = PEFT_DECODER_MODELS_TO_TEST[0] # it's enough to test this with one model
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
if not config.is_prompt_learning:
pytest.skip("This test is only for prompt learning methods.")
with hub_online_once(model_id + config_kwargs.get("tokenizer_name_or_path", "")):
model = self.transformers_class.from_pretrained(model_id).to(self.torch_device)
model = get_peft_model(model, config)
msg = "Setting `requires_grad` is not supported for prompt learning methods like"
with pytest.raises(TypeError, match=msg):
model.set_requires_grad(adapter_names="adpater0")

View File

@ -771,6 +771,14 @@ class TestModelAndLayerStatus:
expected = [{"default": False, "other": True}, {"default": False}, {"other": True}, {"default": False}]
assert result == expected
# change requires grad, is now inconsistent with active/inactive adapter
large_model.set_requires_grad("default", requires_grad=True)
large_model.set_requires_grad("other", requires_grad=False)
layer_status = large_model.get_layer_status()
result = [status.requires_grad for status in layer_status]
expected = [{"default": True, "other": False}, {"default": True}, {"other": False}, {"default": True}]
assert result == expected
def test_requires_grad_irregular(self, large_model):
# inject an embedding layer with requires_grad=False
# this is an invalid state, but we should still test it
@ -1114,6 +1122,12 @@ class TestModelAndLayerStatus:
model_status = large_model.get_model_status()
assert model_status.requires_grad == {"default": False, "other": True}
# change requires grad, is now inconsistent with active/inactive adapter
large_model.set_requires_grad("default", requires_grad=True)
large_model.set_requires_grad("other", requires_grad=False)
model_status = large_model.get_model_status()
assert model_status.requires_grad == {"default": True, "other": False}
def test_model_requires_grad_model_irregular(self, large_model):
# inject an embedding layer with requires_grad=False
# this is an invalid state, but we should still test it