diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index ce08865a..5cb1f7e4 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -40,9 +40,10 @@ from transformers.utils import PushToHubMixin from peft.tuners.lora.variants import get_alora_offsets_for_forward, get_alora_offsets_for_generate from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer +from peft.utils import AuxiliaryTrainingWrapper from peft.utils.constants import DUMMY_MODEL_CONFIG from peft.utils.integrations import init_empty_weights -from peft.utils.other import create_attention_mask, set_additional_trainable_modules +from peft.utils.other import TrainableTokensWrapper, create_attention_mask, set_additional_trainable_modules from . import __version__ from .config import PeftConfig @@ -3047,7 +3048,11 @@ def get_layer_status(model: torch.nn.Module) -> list[TunerLayerStatus]: layer_status: list[TunerLayerStatus] = [] for name, module in base_model.named_modules(): - if not isinstance(module, BaseTunerLayer): + if not isinstance(module, (BaseTunerLayer, AuxiliaryTrainingWrapper)): + continue + if isinstance(module, TrainableTokensWrapper): + # Skip TrainableTokensWrapper, since it wraps TrainableTokensLayer, which is the actual PEFT layer we're + # interested in. continue # determine if all submodules/parameters if this module require grad or not diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index 0e9db349..29878bcc 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -238,6 +238,13 @@ class AuxiliaryTrainingWrapper(torch.nn.Module): """ + # All names of layers that may contain adapter (trainable) weights + adapter_layer_names: tuple[str, ...] = () + # All names of other parameters that may contain adapter-related parameters + other_param_names: tuple[str, ...] = () + # List all merged adapters + merged_adapters: list[str] = [] + def __init__(self, module_to_save, adapter_name, **kwargs): """Extra kwargs will be passed to `self.init_modules` and `self.update`.""" super().__init__() @@ -255,6 +262,10 @@ class AuxiliaryTrainingWrapper(torch.nn.Module): """A place to initialize PyTorch modules in `__init__` before the call to `self.update()`.""" raise NotImplementedError + def _get_available_adapters(self) -> set[str]: + """Return all adapter names that can be found on this module.""" + raise NotImplementedError + def _error_message_name(self): """Returns a user friendly identifier for error messages, e.g. for type compatibility error messages from `check_module()` so that the user can backtrack where the error comes from. A generic "training wrapper" is @@ -492,6 +503,9 @@ class AuxiliaryTrainingWrapper(torch.nn.Module): class ModulesToSaveWrapper(AuxiliaryTrainingWrapper): """Wraps a module that is supposed to be trained (i.e. `requires_grad_(True)`) and saved after training.""" + # All names of layers that may contain adapter (trainable) weights + adapter_layer_names: tuple[str, ...] = ("modules_to_save",) + def __init__(self, module_to_save, adapter_name): super().__init__(module_to_save, adapter_name) @@ -700,6 +714,10 @@ class ModulesToSaveWrapper(AuxiliaryTrainingWrapper): return new_module + def _get_available_adapters(self) -> set[str]: + """Return all adapter names that can be found on this module.""" + return set(self.modules_to_save.keys()) + class TrainableTokensWrapper(AuxiliaryTrainingWrapper): """Wraps a module (typically an embedding layer) that is supposed to be re-trained selectively (i.e. @@ -709,6 +727,10 @@ class TrainableTokensWrapper(AuxiliaryTrainingWrapper): `TrainableTokensLayer`. """ + # All names of layers that may contain adapter (trainable) weights + adapter_layer_names: tuple[str, ...] = ("token_adapter.trainable_tokens_delta",) + other_param_names: tuple[str, ...] = ("token_adapter.token_indices", "token_adapter.trainable_tokens_original") + def __init__( self, module_to_save: torch.nn.Module, @@ -871,6 +893,10 @@ class TrainableTokensWrapper(AuxiliaryTrainingWrapper): self.token_adapter.merge(safe_merge=safe_merge, adapter_names=adapter_names) return self.token_adapter.get_base_layer() + def _get_available_adapters(self) -> set[str]: + """Return all adapter names that can be found on this module.""" + return set(self.token_adapter.trainable_tokens_delta.keys()) + def _get_input_embeddings_name(model, default=None): if not hasattr(model, "get_input_embeddings"): diff --git a/tests/test_tuners_utils.py b/tests/test_tuners_utils.py index 441ecff0..de48ee02 100644 --- a/tests/test_tuners_utils.py +++ b/tests/test_tuners_utils.py @@ -606,15 +606,29 @@ class TestModelAndLayerStatus: torch_device = infer_device() @pytest.fixture - def small_model(self): + def small_base_model_cls(self): class SmallModel(nn.Module): def __init__(self): super().__init__() self.lin0 = nn.Linear(10, 10) self.lin1 = nn.Linear(10, 10) + return SmallModel + + @pytest.fixture + def small_base_emb_model_cls(self): + class SmallEmbModel(nn.Module): + def __init__(self): + super().__init__() + self.lin0 = nn.Linear(10, 10) + self.emb = nn.Embedding(10, 10) + + return SmallEmbModel + + @pytest.fixture + def small_model(self, small_base_model_cls): config = LoraConfig(target_modules="lin0") - return get_peft_model(SmallModel(), config) + return get_peft_model(small_base_model_cls(), config) @pytest.fixture def large_model(self): @@ -801,6 +815,44 @@ class TestModelAndLayerStatus: ] assert result == expected + def test_with_modules_to_save(self, small_base_model_cls): + # check that modules_to_save are correctly reported in layer status + model = small_base_model_cls() + config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"]) + model = get_peft_model(model, config) + layer_status = model.get_layer_status() + + assert len(layer_status) == 2 + status = layer_status[1] # for modules_to_save + + assert status.name == "model.lin1" + assert status.module_type == "ModulesToSaveWrapper" + assert status.enabled is True + assert status.active_adapters == ["default"] + assert status.merged_adapters == [] + assert status.available_adapters == ["default"] + assert status.requires_grad == {"default": True} + assert status.devices == {"default": ["cpu"]} + + def test_with_trainable_tokens(self, small_base_emb_model_cls): + # check that trainable_token_indices are correctly reported in layer status + model = small_base_emb_model_cls() + config = LoraConfig(target_modules=["lin0"], trainable_token_indices={"emb": [0, 1, 2]}) + model = get_peft_model(model, config) + layer_status = model.get_layer_status() + + assert len(layer_status) == 2 + status = layer_status[1] # for trainable tokens + + assert status.name == "model.emb.token_adapter" + assert status.module_type == "TrainableTokensLayer" + assert status.enabled is True + assert status.active_adapters == ["default"] + assert status.merged_adapters == [] + assert status.available_adapters == ["default"] + assert status.requires_grad == {"default": True} + assert status.devices == {"default": ["cpu"]} + @require_non_cpu def test_devices_all_gpu_large(self, large_model): large_model.to(self.torch_device) @@ -932,6 +984,32 @@ class TestModelAndLayerStatus: model_status = large_model.get_model_status() assert model_status.enabled == "irregular" + def test_model_enabled_irregular_with_modules_to_save(self, small_base_model_cls): + # check that modules_to_save are correctly reported in layer status + model = small_base_model_cls() + config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"]) + model = get_peft_model(model, config) + + # disable only lin0 + model.lin0.enable_adapters(False) + + model_status = model.get_model_status() + # since lin1 is still enabled, the overall model status is "irregular" + assert model_status.enabled == "irregular" + + def test_model_enabled_irregular_with_trainable_tokens(self, small_base_emb_model_cls): + # check that trainable_token_indices are correctly reported in layer status + model = small_base_emb_model_cls() + config = LoraConfig(target_modules=["lin0"], trainable_token_indices={"emb": [0, 1, 2]}) + model = get_peft_model(model, config) + + # disable only lin0 + model.lin0.enable_adapters(False) + + model_status = model.get_model_status() + # since emb is still enabled, the overall model status is "irregular" + assert model_status.enabled == "irregular" + def test_model_active_adapters_small(self, small_model): model_status = small_model.get_model_status() assert model_status.active_adapters == ["default"] @@ -958,6 +1036,34 @@ class TestModelAndLayerStatus: model_status = large_model.get_model_status() assert model_status.active_adapters == "irregular" + def test_model_active_adapters_with_modules_to_save_irregular(self, small_base_model_cls): + # check that modules_to_save are correctly reported in layer status + model = small_base_model_cls() + config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"]) + model = get_peft_model(model, config) + model.add_adapter("other", config) + + # switch modules_to_save to "other" + model.lin1.set_adapter("other") + + model_status = model.get_model_status() + # since lin0 is still on "default", the overall model status is "irregular" + assert model_status.active_adapters == "irregular" + + def test_model_active_adapters_with_trainable_tokens_irregular(self, small_base_emb_model_cls): + # check that trainable_token_indices are correctly reported in layer status + model = small_base_emb_model_cls() + config = LoraConfig(target_modules=["lin0"], trainable_token_indices={"emb": [0, 1, 2]}) + model = get_peft_model(model, config) + model.add_adapter("other", config) + + # switch trainable tokens to "other" + model.emb.set_adapter("other") + + model_status = model.get_model_status() + # since lin0 is still on "default", the overall model status is "irregular" + assert model_status.active_adapters == "irregular" + def test_model_merged_adapters_small(self, small_model): model_status = small_model.get_model_status() assert model_status.merged_adapters == [] @@ -1021,6 +1127,32 @@ class TestModelAndLayerStatus: model_status = large_model.get_model_status() assert model_status.requires_grad == {"default": "irregular", "other": False} + def test_model_requires_irregular_with_modules_to_save(self, small_base_model_cls): + # check that modules_to_save are correctly reported in layer status + model = small_base_model_cls() + config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"]) + model = get_peft_model(model, config) + + # set modules_to_save to requires_grad=False + model.lin1.modules_to_save.default.weight.requires_grad = False + + model_status = model.get_model_status() + # since lin1 is still requires_grad=True, the overall model status is "irregular" + assert model_status.requires_grad == {"default": "irregular"} + + def test_model_requires_irregular_with_trainable_tokens(self, small_base_emb_model_cls): + # check that trainable_token_indices are correctly reported in layer status + model = small_base_emb_model_cls() + config = LoraConfig(target_modules=["lin0"], trainable_token_indices={"emb": [0, 1, 2]}) + model = get_peft_model(model, config) + + # set trainable tokens to requires_grad=False + model.emb.token_adapter.trainable_tokens_delta.default.requires_grad = False + + model_status = model.get_model_status() + # since emb is still requires_grad=True, the overall model status is "irregular" + assert model_status.requires_grad == {"default": "irregular"} + def test_model_available_adapters_small(self, small_model): model_status = small_model.get_model_status() assert model_status.available_adapters == ["default"] @@ -1075,6 +1207,50 @@ class TestModelAndLayerStatus: assert model_status.num_adapter_layers == 2 assert model_status.trainable_params == 2 * (8 * 10 + 10 * 8) + def test_model_status_with_modules_to_save(self, small_base_model_cls): + # check that modules_to_save are correctly reported in layer status + model = small_base_model_cls() + num_base_params = sum(p.numel() for p in small_base_model_cls().parameters()) + config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"]) + model = get_peft_model(model, config) + model_status = model.get_model_status() + + assert model_status.base_model_type == "SmallModel" + assert model_status.adapter_model_type == "LoraModel" + assert model_status.peft_types == {"default": "LORA"} + # 2 x 80 for LoRA, 100 for modules_to_save.weight, 10 for modules_to_save.bias + assert model_status.trainable_params == 2 * 80 + 100 + 10 + assert model_status.total_params == 2 * 80 + 100 + 10 + num_base_params + assert model_status.num_adapter_layers == 2 # lin0 + lin1 + assert model_status.enabled is True + assert model_status.active_adapters == ["default"] + assert model_status.merged_adapters == [] + assert model_status.requires_grad == {"default": True} + assert model_status.available_adapters == ["default"] + assert model_status.devices == {"default": ["cpu"]} # all on CPU + + def test_model_status_with_trainable_tokens(self, small_base_emb_model_cls): + # check that trainable_token_indices are correctly reported in layer status + model = small_base_emb_model_cls() + num_base_params = sum(p.numel() for p in small_base_emb_model_cls().parameters()) + config = LoraConfig(target_modules=["lin0"], trainable_token_indices={"emb": [0, 1, 2]}) + model = get_peft_model(model, config) + model_status = model.get_model_status() + + assert model_status.base_model_type == "SmallEmbModel" + assert model_status.adapter_model_type == "LoraModel" + assert model_status.peft_types == {"default": "LORA"} + # 2 x 80 for LoRA, 3 x 10 for trainable tokens + assert model_status.trainable_params == 2 * 80 + 3 * 10 + assert model_status.total_params == 2 * 80 + 3 * 10 + num_base_params + assert model_status.num_adapter_layers == 2 + assert model_status.enabled is True + assert model_status.active_adapters == ["default"] + assert model_status.merged_adapters == [] + assert model_status.requires_grad == {"default": True} + assert model_status.available_adapters == ["default"] + assert model_status.devices == {"default": ["cpu"]} # all on CPU + def test_loha_model(self): # ensure that this also works with non-LoRA, it's not necessary to test all tuners class SmallModel(nn.Module): diff --git a/tests/testing_common.py b/tests/testing_common.py index a219dff4..7cdc2c04 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -21,6 +21,7 @@ import shutil import tempfile import warnings from dataclasses import replace +from operator import attrgetter import pytest import torch @@ -1453,7 +1454,7 @@ class PeftCommonTester: target, "other_param_names", [] ) for attr in attributes_to_check: - assert adapter_to_delete not in getattr(target, attr) + assert adapter_to_delete not in attrgetter(attr)(target) # check auxiliary modules for module in model.modules(): @@ -1527,7 +1528,7 @@ class PeftCommonTester: target, "other_param_names", [] ) for attr in attributes_to_check: - assert adapter_to_delete not in getattr(target, attr) + assert adapter_to_delete not in attrgetter(attr)(target) # check auxiliary modules for module in model.modules():