From ae671baec9c58994001009f17811d82bc8718c9e Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 25 Sep 2025 17:59:45 +0200 Subject: [PATCH] FIX PEFT layers expose in_features, out_features (#2784) Resolves #2783. Most PEFT layers (BaseTunerLayers) expose the in_features and out_features attributes. Therefore, other packages like diffusers may expect this attribute to exist. However, there were a few PEFT methods where these attributes were missing: - LoHa - LoKr - LN Tuning - Trainable Tokens The layers of these methods now also expose the attributes. Implementation To avoid code duplication, I factored out the whole code block in LoRA layers that extracts these attributes, since LoRA has the most exhaustive list of checks. The new utility function has the exact same functionality and can now be used by other PEFT methods. I updated the four PEFT methods mentioned above to use this new function, but I did not update PEFT methods that already handled it, as there wasn't really a need (they check one or two layer types at most, so there is little duplication). --- src/peft/tuners/ln_tuning/layer.py | 6 ++- src/peft/tuners/lora/layer.py | 58 +-------------------- src/peft/tuners/lycoris_utils.py | 11 +++- src/peft/tuners/trainable_tokens/layer.py | 6 ++- src/peft/tuners/tuners_utils.py | 63 +++++++++++++++++++++++ tests/test_custom_models.py | 14 +++++ 6 files changed, 99 insertions(+), 59 deletions(-) diff --git a/src/peft/tuners/ln_tuning/layer.py b/src/peft/tuners/ln_tuning/layer.py index f78ca99e..e29149f2 100644 --- a/src/peft/tuners/ln_tuning/layer.py +++ b/src/peft/tuners/ln_tuning/layer.py @@ -19,7 +19,7 @@ from typing import Optional import torch import torch.nn as nn -from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge +from peft.tuners.tuners_utils import BaseTunerLayer, _get_in_out_features, check_adapters_to_merge class LNTuningLayer(nn.Module, BaseTunerLayer): @@ -37,6 +37,10 @@ class LNTuningLayer(nn.Module, BaseTunerLayer): self._active_adapter = adapter_name self.merged_adapters = [] + in_features, out_features = _get_in_out_features(self.get_base_layer()) + self.in_features = in_features + self.out_features = out_features + def update_layer(self, layer: nn.Module, adapter_name: str, inference_mode: bool = False, **kwargs): self.ln_tuning_layers[adapter_name] = deepcopy(layer) self.set_adapter(adapter_name, inference_mode=inference_mode) diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index a7d652ec..b01e87e6 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -21,12 +21,11 @@ from typing import Any, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F -from packaging import version from torch import svd_lowrank from transformers.pytorch_utils import Conv1D from peft.tuners._buffer_dict import BufferDict -from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge +from peft.tuners.tuners_utils import BaseTunerLayer, _get_in_out_features, check_adapters_to_merge from peft.utils.integrations import ( dequantize_module_weight, gather_params_ctx, @@ -124,60 +123,7 @@ class LoraLayer(BaseTunerLayer): self.kwargs = kwargs base_layer = self.get_base_layer() - if isinstance(base_layer, nn.Linear): - torch_supports_dtensor = version.parse(torch.__version__) >= version.parse("2.5.0") - if torch_supports_dtensor and isinstance(self.base_layer.weight, torch.distributed.tensor.DTensor): - # If Tensor Parallel is used, the weight is sharded, so we need to get the local shape - out_features, in_features = self.base_layer.weight.to_local().shape - else: - in_features, out_features = base_layer.in_features, base_layer.out_features - elif isinstance(base_layer, nn.Conv1d): - in_features, out_features = base_layer.in_channels, base_layer.out_channels - elif isinstance(base_layer, nn.Conv2d): - in_features, out_features = base_layer.in_channels, base_layer.out_channels - elif isinstance(base_layer, nn.Conv3d): - in_features, out_features = base_layer.in_channels, base_layer.out_channels - elif isinstance(base_layer, nn.Embedding): - in_features, out_features = base_layer.num_embeddings, base_layer.embedding_dim - elif isinstance(base_layer, Conv1D): - in_features, out_features = ( - base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape - ) - elif isinstance(base_layer, nn.MultiheadAttention): - if not base_layer._qkv_same_embed_dim: - raise ValueError(f"Only same dim for query/key/value is supported as of now for {self.__class__}.") - in_features, out_features = base_layer.embed_dim, 3 * base_layer.embed_dim - elif hasattr(base_layer, "infeatures") and hasattr(base_layer, "outfeatures"): - # QuantLinear - in_features, out_features = base_layer.infeatures, base_layer.outfeatures - elif hasattr(base_layer, "input_size") and hasattr(base_layer, "output_size"): - # Megatron ColumnParallelLinear,RowParallelLinear - in_features, out_features = base_layer.input_size, base_layer.output_size - elif hasattr(base_layer, "codebooks") and base_layer.__class__.__name__ == "QuantizedLinear": - # AQLM QuantLinear - in_features, out_features = base_layer.in_features, base_layer.out_features - elif hasattr(base_layer, "w_bit") and base_layer.__class__.__name__ == "WQLinear_GEMM": - # Awq layers - in_features, out_features = base_layer.in_features, base_layer.out_features - elif base_layer.__class__.__name__ == "EetqLinear": - # Eetq layers - in_features, out_features = base_layer.in_features, base_layer.out_features - elif hasattr(base_layer, "W_q") and base_layer.__class__.__name__ == "HQQLinear": - # HQQ layers - in_features, out_features = base_layer.in_features, base_layer.out_features - elif base_layer.__class__.__name__ == "PatchedLinear": - # INC layers - in_features, out_features = base_layer.in_features, base_layer.out_features - else: - # possibly support user provided custom layer types using dynamic dispatch - if hasattr(base_layer, "in_features") and hasattr(base_layer, "out_features"): - in_features, out_features = base_layer.in_features, base_layer.out_features - else: - in_features, out_features = None, None - warnings.warn( - f"Unsupported layer type '{type(base_layer)}' encountered, proceed at your own risk.", UserWarning - ) - + in_features, out_features = _get_in_out_features(base_layer) self.in_features = in_features self.out_features = out_features diff --git a/src/peft/tuners/lycoris_utils.py b/src/peft/tuners/lycoris_utils.py index 09b1c919..7ea7a260 100644 --- a/src/peft/tuners/lycoris_utils.py +++ b/src/peft/tuners/lycoris_utils.py @@ -23,7 +23,12 @@ import torch.nn as nn from peft.config import PeftConfig -from .tuners_utils import BaseTuner, BaseTunerLayer, check_adapters_to_merge +from .tuners_utils import ( + BaseTuner, + BaseTunerLayer, + _get_in_out_features, + check_adapters_to_merge, +) @dataclass @@ -75,6 +80,10 @@ class LycorisLayer(BaseTunerLayer): # flag to enable/disable casting of input to weight dtype during forward call self.cast_input_dtype_enabled = True + in_features, out_features = _get_in_out_features(self.get_base_layer()) + self.in_features = in_features + self.out_features = out_features + @property @abstractmethod def _available_adapters(self) -> set[str]: ... diff --git a/src/peft/tuners/trainable_tokens/layer.py b/src/peft/tuners/trainable_tokens/layer.py index a6f17056..0f354622 100644 --- a/src/peft/tuners/trainable_tokens/layer.py +++ b/src/peft/tuners/trainable_tokens/layer.py @@ -23,7 +23,7 @@ import torch.nn as nn import torch.nn.functional as F from peft.tuners._buffer_dict import BufferDict -from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge +from peft.tuners.tuners_utils import BaseTunerLayer, _get_in_out_features, check_adapters_to_merge from peft.utils.integrations import check_deepspeed_zero3_enabled, gather_params_ctx @@ -69,6 +69,10 @@ class TrainableTokensLayer(nn.Module, BaseTunerLayer): # Mark the weight as unmerged self.merged_adapters = [] + in_features, out_features = _get_in_out_features(self.get_base_layer()) + self.in_features = in_features + self.out_features = out_features + @property def tied_adapter(self): if self._tied_adapter: diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index 8ecd8421..4fd0d128 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -26,6 +26,7 @@ from typing import Any, Optional, Union, overload import torch from accelerate.hooks import AlignDevicesHook from accelerate.utils import named_module_tensors, offload_state_dict +from packaging import version from torch import nn from tqdm import tqdm from transformers import PreTrainedModel @@ -144,6 +145,68 @@ def _check_lora_target_modules_mamba(peft_config: PeftConfig, model: nn.Module, ) +def _get_in_out_features(module: nn.Module) -> tuple[int, int] | tuple[None, None]: + """ + Get the in_features and out_features of the layer. + + Returns in_features and out_features as a tuple. If they cannot be determined, return a tuple of None and None. + This function covers a broad range of layers, some of which the caller might not support. Therefore, just because + this function returns a valid result does not imply that the layer type is supported. + """ + if isinstance(module, nn.Linear): + torch_supports_dtensor = version.parse(torch.__version__) >= version.parse("2.5.0") + if torch_supports_dtensor and isinstance(module.weight, torch.distributed.tensor.DTensor): + # If Tensor Parallel is used, the weight is sharded, so we need to get the local shape + out_features, in_features = module.weight.to_local().shape + else: + in_features, out_features = module.in_features, module.out_features + elif isinstance(module, nn.Conv1d): + in_features, out_features = module.in_channels, module.out_channels + elif isinstance(module, nn.Conv2d): + in_features, out_features = module.in_channels, module.out_channels + elif isinstance(module, nn.Conv3d): + in_features, out_features = module.in_channels, module.out_channels + elif isinstance(module, nn.Embedding): + in_features, out_features = module.num_embeddings, module.embedding_dim + elif isinstance(module, Conv1D): + in_features, out_features = ( + module.weight.ds_shape if hasattr(module.weight, "ds_shape") else module.weight.shape + ) + elif isinstance(module, nn.MultiheadAttention): + if not module._qkv_same_embed_dim: + raise ValueError("Only same dim for query/key/value is supported as of now for MultiheadAttention.") + in_features, out_features = module.embed_dim, 3 * module.embed_dim + elif hasattr(module, "infeatures") and hasattr(module, "outfeatures"): + # QuantLinear + in_features, out_features = module.infeatures, module.outfeatures + elif hasattr(module, "input_size") and hasattr(module, "output_size"): + # Megatron ColumnParallelLinear,RowParallelLinear + in_features, out_features = module.input_size, module.output_size + elif hasattr(module, "codebooks") and module.__class__.__name__ == "QuantizedLinear": + # AQLM QuantLinear + in_features, out_features = module.in_features, module.out_features + elif hasattr(module, "w_bit") and module.__class__.__name__ == "WQLinear_GEMM": + # Awq layers + in_features, out_features = module.in_features, module.out_features + elif module.__class__.__name__ == "EetqLinear": + # Eetq layers + in_features, out_features = module.in_features, module.out_features + elif hasattr(module, "W_q") and module.__class__.__name__ == "HQQLinear": + # HQQ layers + in_features, out_features = module.in_features, module.out_features + elif module.__class__.__name__ == "PatchedLinear": + # INC layers + in_features, out_features = module.in_features, module.out_features + else: + # possibly support user provided custom layer types using dynamic dispatch + if hasattr(module, "in_features") and hasattr(module, "out_features"): + in_features, out_features = module.in_features, module.out_features + else: + in_features, out_features = None, None + warnings.warn(f"Unsupported layer type '{type(module)}' encountered, proceed at your own risk.", UserWarning) + return in_features, out_features + + class BaseTuner(nn.Module, ABC): r""" A base tuner model that provides the common methods and attributes for all tuners that are injectable into a diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index a2af81c9..33cded41 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -1668,6 +1668,20 @@ class TestPeftCustomModel(PeftCommonTester): def test_peft_model_device_map(self, test_name, model_id, config_cls, config_kwargs): self._test_peft_model_device_map(model_id, config_cls, config_kwargs) + @pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES) + def test_in_features_out_features_exposed(self, test_name, model_id, config_cls, config_kwargs): + # the PEFT layer should expose the .in_features and .out_features attributes + model = self.transformers_class.from_pretrained(model_id).to(self.torch_device) + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + model = get_peft_model(model, config) + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + assert hasattr(module, "in_features") + assert hasattr(module, "out_features") + @pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES) def test_forward_output_finite(self, test_name, model_id, config_cls, config_kwargs): X = self.prepare_inputs_for_testing()