mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 15:33:48 +08:00
ENH Model and layer status for auxiliary modules (#2762)
Right now, get_model_status() and get_layer_status() only report on BaseTunerLayers, but it would be helpful if they could also report auxiliary modules. This PR now includes those. To facilitate this, a few attributes and methods were added to AuxiliaryTrainingWrapper and subclasses to make them more similar to BaseTunerLayer (e.g. the adapter_layer_names attribute). These attributes and methods were assumed to be present in the code that determines the model and layer status.
This commit is contained in:
@ -40,9 +40,10 @@ from transformers.utils import PushToHubMixin
|
||||
|
||||
from peft.tuners.lora.variants import get_alora_offsets_for_forward, get_alora_offsets_for_generate
|
||||
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer
|
||||
from peft.utils import AuxiliaryTrainingWrapper
|
||||
from peft.utils.constants import DUMMY_MODEL_CONFIG
|
||||
from peft.utils.integrations import init_empty_weights
|
||||
from peft.utils.other import create_attention_mask, set_additional_trainable_modules
|
||||
from peft.utils.other import TrainableTokensWrapper, create_attention_mask, set_additional_trainable_modules
|
||||
|
||||
from . import __version__
|
||||
from .config import PeftConfig
|
||||
@ -3047,7 +3048,11 @@ def get_layer_status(model: torch.nn.Module) -> list[TunerLayerStatus]:
|
||||
|
||||
layer_status: list[TunerLayerStatus] = []
|
||||
for name, module in base_model.named_modules():
|
||||
if not isinstance(module, BaseTunerLayer):
|
||||
if not isinstance(module, (BaseTunerLayer, AuxiliaryTrainingWrapper)):
|
||||
continue
|
||||
if isinstance(module, TrainableTokensWrapper):
|
||||
# Skip TrainableTokensWrapper, since it wraps TrainableTokensLayer, which is the actual PEFT layer we're
|
||||
# interested in.
|
||||
continue
|
||||
|
||||
# determine if all submodules/parameters if this module require grad or not
|
||||
|
@ -238,6 +238,13 @@ class AuxiliaryTrainingWrapper(torch.nn.Module):
|
||||
|
||||
"""
|
||||
|
||||
# All names of layers that may contain adapter (trainable) weights
|
||||
adapter_layer_names: tuple[str, ...] = ()
|
||||
# All names of other parameters that may contain adapter-related parameters
|
||||
other_param_names: tuple[str, ...] = ()
|
||||
# List all merged adapters
|
||||
merged_adapters: list[str] = []
|
||||
|
||||
def __init__(self, module_to_save, adapter_name, **kwargs):
|
||||
"""Extra kwargs will be passed to `self.init_modules` and `self.update`."""
|
||||
super().__init__()
|
||||
@ -255,6 +262,10 @@ class AuxiliaryTrainingWrapper(torch.nn.Module):
|
||||
"""A place to initialize PyTorch modules in `__init__` before the call to `self.update()`."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_available_adapters(self) -> set[str]:
|
||||
"""Return all adapter names that can be found on this module."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _error_message_name(self):
|
||||
"""Returns a user friendly identifier for error messages, e.g. for type compatibility error messages from
|
||||
`check_module()` so that the user can backtrack where the error comes from. A generic "training wrapper" is
|
||||
@ -492,6 +503,9 @@ class AuxiliaryTrainingWrapper(torch.nn.Module):
|
||||
class ModulesToSaveWrapper(AuxiliaryTrainingWrapper):
|
||||
"""Wraps a module that is supposed to be trained (i.e. `requires_grad_(True)`) and saved after training."""
|
||||
|
||||
# All names of layers that may contain adapter (trainable) weights
|
||||
adapter_layer_names: tuple[str, ...] = ("modules_to_save",)
|
||||
|
||||
def __init__(self, module_to_save, adapter_name):
|
||||
super().__init__(module_to_save, adapter_name)
|
||||
|
||||
@ -700,6 +714,10 @@ class ModulesToSaveWrapper(AuxiliaryTrainingWrapper):
|
||||
|
||||
return new_module
|
||||
|
||||
def _get_available_adapters(self) -> set[str]:
|
||||
"""Return all adapter names that can be found on this module."""
|
||||
return set(self.modules_to_save.keys())
|
||||
|
||||
|
||||
class TrainableTokensWrapper(AuxiliaryTrainingWrapper):
|
||||
"""Wraps a module (typically an embedding layer) that is supposed to be re-trained selectively (i.e.
|
||||
@ -709,6 +727,10 @@ class TrainableTokensWrapper(AuxiliaryTrainingWrapper):
|
||||
`TrainableTokensLayer`.
|
||||
"""
|
||||
|
||||
# All names of layers that may contain adapter (trainable) weights
|
||||
adapter_layer_names: tuple[str, ...] = ("token_adapter.trainable_tokens_delta",)
|
||||
other_param_names: tuple[str, ...] = ("token_adapter.token_indices", "token_adapter.trainable_tokens_original")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module_to_save: torch.nn.Module,
|
||||
@ -871,6 +893,10 @@ class TrainableTokensWrapper(AuxiliaryTrainingWrapper):
|
||||
self.token_adapter.merge(safe_merge=safe_merge, adapter_names=adapter_names)
|
||||
return self.token_adapter.get_base_layer()
|
||||
|
||||
def _get_available_adapters(self) -> set[str]:
|
||||
"""Return all adapter names that can be found on this module."""
|
||||
return set(self.token_adapter.trainable_tokens_delta.keys())
|
||||
|
||||
|
||||
def _get_input_embeddings_name(model, default=None):
|
||||
if not hasattr(model, "get_input_embeddings"):
|
||||
|
@ -606,15 +606,29 @@ class TestModelAndLayerStatus:
|
||||
torch_device = infer_device()
|
||||
|
||||
@pytest.fixture
|
||||
def small_model(self):
|
||||
def small_base_model_cls(self):
|
||||
class SmallModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.lin0 = nn.Linear(10, 10)
|
||||
self.lin1 = nn.Linear(10, 10)
|
||||
|
||||
return SmallModel
|
||||
|
||||
@pytest.fixture
|
||||
def small_base_emb_model_cls(self):
|
||||
class SmallEmbModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.lin0 = nn.Linear(10, 10)
|
||||
self.emb = nn.Embedding(10, 10)
|
||||
|
||||
return SmallEmbModel
|
||||
|
||||
@pytest.fixture
|
||||
def small_model(self, small_base_model_cls):
|
||||
config = LoraConfig(target_modules="lin0")
|
||||
return get_peft_model(SmallModel(), config)
|
||||
return get_peft_model(small_base_model_cls(), config)
|
||||
|
||||
@pytest.fixture
|
||||
def large_model(self):
|
||||
@ -801,6 +815,44 @@ class TestModelAndLayerStatus:
|
||||
]
|
||||
assert result == expected
|
||||
|
||||
def test_with_modules_to_save(self, small_base_model_cls):
|
||||
# check that modules_to_save are correctly reported in layer status
|
||||
model = small_base_model_cls()
|
||||
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
|
||||
model = get_peft_model(model, config)
|
||||
layer_status = model.get_layer_status()
|
||||
|
||||
assert len(layer_status) == 2
|
||||
status = layer_status[1] # for modules_to_save
|
||||
|
||||
assert status.name == "model.lin1"
|
||||
assert status.module_type == "ModulesToSaveWrapper"
|
||||
assert status.enabled is True
|
||||
assert status.active_adapters == ["default"]
|
||||
assert status.merged_adapters == []
|
||||
assert status.available_adapters == ["default"]
|
||||
assert status.requires_grad == {"default": True}
|
||||
assert status.devices == {"default": ["cpu"]}
|
||||
|
||||
def test_with_trainable_tokens(self, small_base_emb_model_cls):
|
||||
# check that trainable_token_indices are correctly reported in layer status
|
||||
model = small_base_emb_model_cls()
|
||||
config = LoraConfig(target_modules=["lin0"], trainable_token_indices={"emb": [0, 1, 2]})
|
||||
model = get_peft_model(model, config)
|
||||
layer_status = model.get_layer_status()
|
||||
|
||||
assert len(layer_status) == 2
|
||||
status = layer_status[1] # for trainable tokens
|
||||
|
||||
assert status.name == "model.emb.token_adapter"
|
||||
assert status.module_type == "TrainableTokensLayer"
|
||||
assert status.enabled is True
|
||||
assert status.active_adapters == ["default"]
|
||||
assert status.merged_adapters == []
|
||||
assert status.available_adapters == ["default"]
|
||||
assert status.requires_grad == {"default": True}
|
||||
assert status.devices == {"default": ["cpu"]}
|
||||
|
||||
@require_non_cpu
|
||||
def test_devices_all_gpu_large(self, large_model):
|
||||
large_model.to(self.torch_device)
|
||||
@ -932,6 +984,32 @@ class TestModelAndLayerStatus:
|
||||
model_status = large_model.get_model_status()
|
||||
assert model_status.enabled == "irregular"
|
||||
|
||||
def test_model_enabled_irregular_with_modules_to_save(self, small_base_model_cls):
|
||||
# check that modules_to_save are correctly reported in layer status
|
||||
model = small_base_model_cls()
|
||||
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
|
||||
model = get_peft_model(model, config)
|
||||
|
||||
# disable only lin0
|
||||
model.lin0.enable_adapters(False)
|
||||
|
||||
model_status = model.get_model_status()
|
||||
# since lin1 is still enabled, the overall model status is "irregular"
|
||||
assert model_status.enabled == "irregular"
|
||||
|
||||
def test_model_enabled_irregular_with_trainable_tokens(self, small_base_emb_model_cls):
|
||||
# check that trainable_token_indices are correctly reported in layer status
|
||||
model = small_base_emb_model_cls()
|
||||
config = LoraConfig(target_modules=["lin0"], trainable_token_indices={"emb": [0, 1, 2]})
|
||||
model = get_peft_model(model, config)
|
||||
|
||||
# disable only lin0
|
||||
model.lin0.enable_adapters(False)
|
||||
|
||||
model_status = model.get_model_status()
|
||||
# since emb is still enabled, the overall model status is "irregular"
|
||||
assert model_status.enabled == "irregular"
|
||||
|
||||
def test_model_active_adapters_small(self, small_model):
|
||||
model_status = small_model.get_model_status()
|
||||
assert model_status.active_adapters == ["default"]
|
||||
@ -958,6 +1036,34 @@ class TestModelAndLayerStatus:
|
||||
model_status = large_model.get_model_status()
|
||||
assert model_status.active_adapters == "irregular"
|
||||
|
||||
def test_model_active_adapters_with_modules_to_save_irregular(self, small_base_model_cls):
|
||||
# check that modules_to_save are correctly reported in layer status
|
||||
model = small_base_model_cls()
|
||||
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
|
||||
model = get_peft_model(model, config)
|
||||
model.add_adapter("other", config)
|
||||
|
||||
# switch modules_to_save to "other"
|
||||
model.lin1.set_adapter("other")
|
||||
|
||||
model_status = model.get_model_status()
|
||||
# since lin0 is still on "default", the overall model status is "irregular"
|
||||
assert model_status.active_adapters == "irregular"
|
||||
|
||||
def test_model_active_adapters_with_trainable_tokens_irregular(self, small_base_emb_model_cls):
|
||||
# check that trainable_token_indices are correctly reported in layer status
|
||||
model = small_base_emb_model_cls()
|
||||
config = LoraConfig(target_modules=["lin0"], trainable_token_indices={"emb": [0, 1, 2]})
|
||||
model = get_peft_model(model, config)
|
||||
model.add_adapter("other", config)
|
||||
|
||||
# switch trainable tokens to "other"
|
||||
model.emb.set_adapter("other")
|
||||
|
||||
model_status = model.get_model_status()
|
||||
# since lin0 is still on "default", the overall model status is "irregular"
|
||||
assert model_status.active_adapters == "irregular"
|
||||
|
||||
def test_model_merged_adapters_small(self, small_model):
|
||||
model_status = small_model.get_model_status()
|
||||
assert model_status.merged_adapters == []
|
||||
@ -1021,6 +1127,32 @@ class TestModelAndLayerStatus:
|
||||
model_status = large_model.get_model_status()
|
||||
assert model_status.requires_grad == {"default": "irregular", "other": False}
|
||||
|
||||
def test_model_requires_irregular_with_modules_to_save(self, small_base_model_cls):
|
||||
# check that modules_to_save are correctly reported in layer status
|
||||
model = small_base_model_cls()
|
||||
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
|
||||
model = get_peft_model(model, config)
|
||||
|
||||
# set modules_to_save to requires_grad=False
|
||||
model.lin1.modules_to_save.default.weight.requires_grad = False
|
||||
|
||||
model_status = model.get_model_status()
|
||||
# since lin1 is still requires_grad=True, the overall model status is "irregular"
|
||||
assert model_status.requires_grad == {"default": "irregular"}
|
||||
|
||||
def test_model_requires_irregular_with_trainable_tokens(self, small_base_emb_model_cls):
|
||||
# check that trainable_token_indices are correctly reported in layer status
|
||||
model = small_base_emb_model_cls()
|
||||
config = LoraConfig(target_modules=["lin0"], trainable_token_indices={"emb": [0, 1, 2]})
|
||||
model = get_peft_model(model, config)
|
||||
|
||||
# set trainable tokens to requires_grad=False
|
||||
model.emb.token_adapter.trainable_tokens_delta.default.requires_grad = False
|
||||
|
||||
model_status = model.get_model_status()
|
||||
# since emb is still requires_grad=True, the overall model status is "irregular"
|
||||
assert model_status.requires_grad == {"default": "irregular"}
|
||||
|
||||
def test_model_available_adapters_small(self, small_model):
|
||||
model_status = small_model.get_model_status()
|
||||
assert model_status.available_adapters == ["default"]
|
||||
@ -1075,6 +1207,50 @@ class TestModelAndLayerStatus:
|
||||
assert model_status.num_adapter_layers == 2
|
||||
assert model_status.trainable_params == 2 * (8 * 10 + 10 * 8)
|
||||
|
||||
def test_model_status_with_modules_to_save(self, small_base_model_cls):
|
||||
# check that modules_to_save are correctly reported in layer status
|
||||
model = small_base_model_cls()
|
||||
num_base_params = sum(p.numel() for p in small_base_model_cls().parameters())
|
||||
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
|
||||
model = get_peft_model(model, config)
|
||||
model_status = model.get_model_status()
|
||||
|
||||
assert model_status.base_model_type == "SmallModel"
|
||||
assert model_status.adapter_model_type == "LoraModel"
|
||||
assert model_status.peft_types == {"default": "LORA"}
|
||||
# 2 x 80 for LoRA, 100 for modules_to_save.weight, 10 for modules_to_save.bias
|
||||
assert model_status.trainable_params == 2 * 80 + 100 + 10
|
||||
assert model_status.total_params == 2 * 80 + 100 + 10 + num_base_params
|
||||
assert model_status.num_adapter_layers == 2 # lin0 + lin1
|
||||
assert model_status.enabled is True
|
||||
assert model_status.active_adapters == ["default"]
|
||||
assert model_status.merged_adapters == []
|
||||
assert model_status.requires_grad == {"default": True}
|
||||
assert model_status.available_adapters == ["default"]
|
||||
assert model_status.devices == {"default": ["cpu"]} # all on CPU
|
||||
|
||||
def test_model_status_with_trainable_tokens(self, small_base_emb_model_cls):
|
||||
# check that trainable_token_indices are correctly reported in layer status
|
||||
model = small_base_emb_model_cls()
|
||||
num_base_params = sum(p.numel() for p in small_base_emb_model_cls().parameters())
|
||||
config = LoraConfig(target_modules=["lin0"], trainable_token_indices={"emb": [0, 1, 2]})
|
||||
model = get_peft_model(model, config)
|
||||
model_status = model.get_model_status()
|
||||
|
||||
assert model_status.base_model_type == "SmallEmbModel"
|
||||
assert model_status.adapter_model_type == "LoraModel"
|
||||
assert model_status.peft_types == {"default": "LORA"}
|
||||
# 2 x 80 for LoRA, 3 x 10 for trainable tokens
|
||||
assert model_status.trainable_params == 2 * 80 + 3 * 10
|
||||
assert model_status.total_params == 2 * 80 + 3 * 10 + num_base_params
|
||||
assert model_status.num_adapter_layers == 2
|
||||
assert model_status.enabled is True
|
||||
assert model_status.active_adapters == ["default"]
|
||||
assert model_status.merged_adapters == []
|
||||
assert model_status.requires_grad == {"default": True}
|
||||
assert model_status.available_adapters == ["default"]
|
||||
assert model_status.devices == {"default": ["cpu"]} # all on CPU
|
||||
|
||||
def test_loha_model(self):
|
||||
# ensure that this also works with non-LoRA, it's not necessary to test all tuners
|
||||
class SmallModel(nn.Module):
|
||||
|
@ -21,6 +21,7 @@ import shutil
|
||||
import tempfile
|
||||
import warnings
|
||||
from dataclasses import replace
|
||||
from operator import attrgetter
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -1453,7 +1454,7 @@ class PeftCommonTester:
|
||||
target, "other_param_names", []
|
||||
)
|
||||
for attr in attributes_to_check:
|
||||
assert adapter_to_delete not in getattr(target, attr)
|
||||
assert adapter_to_delete not in attrgetter(attr)(target)
|
||||
|
||||
# check auxiliary modules
|
||||
for module in model.modules():
|
||||
@ -1527,7 +1528,7 @@ class PeftCommonTester:
|
||||
target, "other_param_names", []
|
||||
)
|
||||
for attr in attributes_to_check:
|
||||
assert adapter_to_delete not in getattr(target, attr)
|
||||
assert adapter_to_delete not in attrgetter(attr)(target)
|
||||
|
||||
# check auxiliary modules
|
||||
for module in model.modules():
|
||||
|
Reference in New Issue
Block a user