ENH Model and layer status for auxiliary modules (#2762)

Right now, get_model_status() and get_layer_status() only report on
BaseTunerLayers, but it would be helpful if they could also report
auxiliary modules. This PR now includes those.

To facilitate this, a few attributes and methods were added to
AuxiliaryTrainingWrapper and subclasses to make them more similar to
BaseTunerLayer (e.g. the adapter_layer_names attribute). These
attributes and methods were assumed to be present in the code that
determines the model and layer status.
This commit is contained in:
Benjamin Bossan
2025-09-25 18:00:11 +02:00
committed by GitHub
parent ae671baec9
commit 6030f9160e
4 changed files with 214 additions and 6 deletions

View File

@ -40,9 +40,10 @@ from transformers.utils import PushToHubMixin
from peft.tuners.lora.variants import get_alora_offsets_for_forward, get_alora_offsets_for_generate
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer
from peft.utils import AuxiliaryTrainingWrapper
from peft.utils.constants import DUMMY_MODEL_CONFIG
from peft.utils.integrations import init_empty_weights
from peft.utils.other import create_attention_mask, set_additional_trainable_modules
from peft.utils.other import TrainableTokensWrapper, create_attention_mask, set_additional_trainable_modules
from . import __version__
from .config import PeftConfig
@ -3047,7 +3048,11 @@ def get_layer_status(model: torch.nn.Module) -> list[TunerLayerStatus]:
layer_status: list[TunerLayerStatus] = []
for name, module in base_model.named_modules():
if not isinstance(module, BaseTunerLayer):
if not isinstance(module, (BaseTunerLayer, AuxiliaryTrainingWrapper)):
continue
if isinstance(module, TrainableTokensWrapper):
# Skip TrainableTokensWrapper, since it wraps TrainableTokensLayer, which is the actual PEFT layer we're
# interested in.
continue
# determine if all submodules/parameters if this module require grad or not

View File

@ -238,6 +238,13 @@ class AuxiliaryTrainingWrapper(torch.nn.Module):
"""
# All names of layers that may contain adapter (trainable) weights
adapter_layer_names: tuple[str, ...] = ()
# All names of other parameters that may contain adapter-related parameters
other_param_names: tuple[str, ...] = ()
# List all merged adapters
merged_adapters: list[str] = []
def __init__(self, module_to_save, adapter_name, **kwargs):
"""Extra kwargs will be passed to `self.init_modules` and `self.update`."""
super().__init__()
@ -255,6 +262,10 @@ class AuxiliaryTrainingWrapper(torch.nn.Module):
"""A place to initialize PyTorch modules in `__init__` before the call to `self.update()`."""
raise NotImplementedError
def _get_available_adapters(self) -> set[str]:
"""Return all adapter names that can be found on this module."""
raise NotImplementedError
def _error_message_name(self):
"""Returns a user friendly identifier for error messages, e.g. for type compatibility error messages from
`check_module()` so that the user can backtrack where the error comes from. A generic "training wrapper" is
@ -492,6 +503,9 @@ class AuxiliaryTrainingWrapper(torch.nn.Module):
class ModulesToSaveWrapper(AuxiliaryTrainingWrapper):
"""Wraps a module that is supposed to be trained (i.e. `requires_grad_(True)`) and saved after training."""
# All names of layers that may contain adapter (trainable) weights
adapter_layer_names: tuple[str, ...] = ("modules_to_save",)
def __init__(self, module_to_save, adapter_name):
super().__init__(module_to_save, adapter_name)
@ -700,6 +714,10 @@ class ModulesToSaveWrapper(AuxiliaryTrainingWrapper):
return new_module
def _get_available_adapters(self) -> set[str]:
"""Return all adapter names that can be found on this module."""
return set(self.modules_to_save.keys())
class TrainableTokensWrapper(AuxiliaryTrainingWrapper):
"""Wraps a module (typically an embedding layer) that is supposed to be re-trained selectively (i.e.
@ -709,6 +727,10 @@ class TrainableTokensWrapper(AuxiliaryTrainingWrapper):
`TrainableTokensLayer`.
"""
# All names of layers that may contain adapter (trainable) weights
adapter_layer_names: tuple[str, ...] = ("token_adapter.trainable_tokens_delta",)
other_param_names: tuple[str, ...] = ("token_adapter.token_indices", "token_adapter.trainable_tokens_original")
def __init__(
self,
module_to_save: torch.nn.Module,
@ -871,6 +893,10 @@ class TrainableTokensWrapper(AuxiliaryTrainingWrapper):
self.token_adapter.merge(safe_merge=safe_merge, adapter_names=adapter_names)
return self.token_adapter.get_base_layer()
def _get_available_adapters(self) -> set[str]:
"""Return all adapter names that can be found on this module."""
return set(self.token_adapter.trainable_tokens_delta.keys())
def _get_input_embeddings_name(model, default=None):
if not hasattr(model, "get_input_embeddings"):

View File

@ -606,15 +606,29 @@ class TestModelAndLayerStatus:
torch_device = infer_device()
@pytest.fixture
def small_model(self):
def small_base_model_cls(self):
class SmallModel(nn.Module):
def __init__(self):
super().__init__()
self.lin0 = nn.Linear(10, 10)
self.lin1 = nn.Linear(10, 10)
return SmallModel
@pytest.fixture
def small_base_emb_model_cls(self):
class SmallEmbModel(nn.Module):
def __init__(self):
super().__init__()
self.lin0 = nn.Linear(10, 10)
self.emb = nn.Embedding(10, 10)
return SmallEmbModel
@pytest.fixture
def small_model(self, small_base_model_cls):
config = LoraConfig(target_modules="lin0")
return get_peft_model(SmallModel(), config)
return get_peft_model(small_base_model_cls(), config)
@pytest.fixture
def large_model(self):
@ -801,6 +815,44 @@ class TestModelAndLayerStatus:
]
assert result == expected
def test_with_modules_to_save(self, small_base_model_cls):
# check that modules_to_save are correctly reported in layer status
model = small_base_model_cls()
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
model = get_peft_model(model, config)
layer_status = model.get_layer_status()
assert len(layer_status) == 2
status = layer_status[1] # for modules_to_save
assert status.name == "model.lin1"
assert status.module_type == "ModulesToSaveWrapper"
assert status.enabled is True
assert status.active_adapters == ["default"]
assert status.merged_adapters == []
assert status.available_adapters == ["default"]
assert status.requires_grad == {"default": True}
assert status.devices == {"default": ["cpu"]}
def test_with_trainable_tokens(self, small_base_emb_model_cls):
# check that trainable_token_indices are correctly reported in layer status
model = small_base_emb_model_cls()
config = LoraConfig(target_modules=["lin0"], trainable_token_indices={"emb": [0, 1, 2]})
model = get_peft_model(model, config)
layer_status = model.get_layer_status()
assert len(layer_status) == 2
status = layer_status[1] # for trainable tokens
assert status.name == "model.emb.token_adapter"
assert status.module_type == "TrainableTokensLayer"
assert status.enabled is True
assert status.active_adapters == ["default"]
assert status.merged_adapters == []
assert status.available_adapters == ["default"]
assert status.requires_grad == {"default": True}
assert status.devices == {"default": ["cpu"]}
@require_non_cpu
def test_devices_all_gpu_large(self, large_model):
large_model.to(self.torch_device)
@ -932,6 +984,32 @@ class TestModelAndLayerStatus:
model_status = large_model.get_model_status()
assert model_status.enabled == "irregular"
def test_model_enabled_irregular_with_modules_to_save(self, small_base_model_cls):
# check that modules_to_save are correctly reported in layer status
model = small_base_model_cls()
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
model = get_peft_model(model, config)
# disable only lin0
model.lin0.enable_adapters(False)
model_status = model.get_model_status()
# since lin1 is still enabled, the overall model status is "irregular"
assert model_status.enabled == "irregular"
def test_model_enabled_irregular_with_trainable_tokens(self, small_base_emb_model_cls):
# check that trainable_token_indices are correctly reported in layer status
model = small_base_emb_model_cls()
config = LoraConfig(target_modules=["lin0"], trainable_token_indices={"emb": [0, 1, 2]})
model = get_peft_model(model, config)
# disable only lin0
model.lin0.enable_adapters(False)
model_status = model.get_model_status()
# since emb is still enabled, the overall model status is "irregular"
assert model_status.enabled == "irregular"
def test_model_active_adapters_small(self, small_model):
model_status = small_model.get_model_status()
assert model_status.active_adapters == ["default"]
@ -958,6 +1036,34 @@ class TestModelAndLayerStatus:
model_status = large_model.get_model_status()
assert model_status.active_adapters == "irregular"
def test_model_active_adapters_with_modules_to_save_irregular(self, small_base_model_cls):
# check that modules_to_save are correctly reported in layer status
model = small_base_model_cls()
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
model = get_peft_model(model, config)
model.add_adapter("other", config)
# switch modules_to_save to "other"
model.lin1.set_adapter("other")
model_status = model.get_model_status()
# since lin0 is still on "default", the overall model status is "irregular"
assert model_status.active_adapters == "irregular"
def test_model_active_adapters_with_trainable_tokens_irregular(self, small_base_emb_model_cls):
# check that trainable_token_indices are correctly reported in layer status
model = small_base_emb_model_cls()
config = LoraConfig(target_modules=["lin0"], trainable_token_indices={"emb": [0, 1, 2]})
model = get_peft_model(model, config)
model.add_adapter("other", config)
# switch trainable tokens to "other"
model.emb.set_adapter("other")
model_status = model.get_model_status()
# since lin0 is still on "default", the overall model status is "irregular"
assert model_status.active_adapters == "irregular"
def test_model_merged_adapters_small(self, small_model):
model_status = small_model.get_model_status()
assert model_status.merged_adapters == []
@ -1021,6 +1127,32 @@ class TestModelAndLayerStatus:
model_status = large_model.get_model_status()
assert model_status.requires_grad == {"default": "irregular", "other": False}
def test_model_requires_irregular_with_modules_to_save(self, small_base_model_cls):
# check that modules_to_save are correctly reported in layer status
model = small_base_model_cls()
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
model = get_peft_model(model, config)
# set modules_to_save to requires_grad=False
model.lin1.modules_to_save.default.weight.requires_grad = False
model_status = model.get_model_status()
# since lin1 is still requires_grad=True, the overall model status is "irregular"
assert model_status.requires_grad == {"default": "irregular"}
def test_model_requires_irregular_with_trainable_tokens(self, small_base_emb_model_cls):
# check that trainable_token_indices are correctly reported in layer status
model = small_base_emb_model_cls()
config = LoraConfig(target_modules=["lin0"], trainable_token_indices={"emb": [0, 1, 2]})
model = get_peft_model(model, config)
# set trainable tokens to requires_grad=False
model.emb.token_adapter.trainable_tokens_delta.default.requires_grad = False
model_status = model.get_model_status()
# since emb is still requires_grad=True, the overall model status is "irregular"
assert model_status.requires_grad == {"default": "irregular"}
def test_model_available_adapters_small(self, small_model):
model_status = small_model.get_model_status()
assert model_status.available_adapters == ["default"]
@ -1075,6 +1207,50 @@ class TestModelAndLayerStatus:
assert model_status.num_adapter_layers == 2
assert model_status.trainable_params == 2 * (8 * 10 + 10 * 8)
def test_model_status_with_modules_to_save(self, small_base_model_cls):
# check that modules_to_save are correctly reported in layer status
model = small_base_model_cls()
num_base_params = sum(p.numel() for p in small_base_model_cls().parameters())
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
model = get_peft_model(model, config)
model_status = model.get_model_status()
assert model_status.base_model_type == "SmallModel"
assert model_status.adapter_model_type == "LoraModel"
assert model_status.peft_types == {"default": "LORA"}
# 2 x 80 for LoRA, 100 for modules_to_save.weight, 10 for modules_to_save.bias
assert model_status.trainable_params == 2 * 80 + 100 + 10
assert model_status.total_params == 2 * 80 + 100 + 10 + num_base_params
assert model_status.num_adapter_layers == 2 # lin0 + lin1
assert model_status.enabled is True
assert model_status.active_adapters == ["default"]
assert model_status.merged_adapters == []
assert model_status.requires_grad == {"default": True}
assert model_status.available_adapters == ["default"]
assert model_status.devices == {"default": ["cpu"]} # all on CPU
def test_model_status_with_trainable_tokens(self, small_base_emb_model_cls):
# check that trainable_token_indices are correctly reported in layer status
model = small_base_emb_model_cls()
num_base_params = sum(p.numel() for p in small_base_emb_model_cls().parameters())
config = LoraConfig(target_modules=["lin0"], trainable_token_indices={"emb": [0, 1, 2]})
model = get_peft_model(model, config)
model_status = model.get_model_status()
assert model_status.base_model_type == "SmallEmbModel"
assert model_status.adapter_model_type == "LoraModel"
assert model_status.peft_types == {"default": "LORA"}
# 2 x 80 for LoRA, 3 x 10 for trainable tokens
assert model_status.trainable_params == 2 * 80 + 3 * 10
assert model_status.total_params == 2 * 80 + 3 * 10 + num_base_params
assert model_status.num_adapter_layers == 2
assert model_status.enabled is True
assert model_status.active_adapters == ["default"]
assert model_status.merged_adapters == []
assert model_status.requires_grad == {"default": True}
assert model_status.available_adapters == ["default"]
assert model_status.devices == {"default": ["cpu"]} # all on CPU
def test_loha_model(self):
# ensure that this also works with non-LoRA, it's not necessary to test all tuners
class SmallModel(nn.Module):

View File

@ -21,6 +21,7 @@ import shutil
import tempfile
import warnings
from dataclasses import replace
from operator import attrgetter
import pytest
import torch
@ -1453,7 +1454,7 @@ class PeftCommonTester:
target, "other_param_names", []
)
for attr in attributes_to_check:
assert adapter_to_delete not in getattr(target, attr)
assert adapter_to_delete not in attrgetter(attr)(target)
# check auxiliary modules
for module in model.modules():
@ -1527,7 +1528,7 @@ class PeftCommonTester:
target, "other_param_names", []
)
for attr in attributes_to_check:
assert adapter_to_delete not in getattr(target, attr)
assert adapter_to_delete not in attrgetter(attr)(target)
# check auxiliary modules
for module in model.modules():