FIX PEFT layers expose in_features, out_features (#2784)

Resolves #2783.

Most PEFT layers (BaseTunerLayers) expose the in_features and
out_features attributes. Therefore, other packages like diffusers may
expect this attribute to exist. However, there were a few PEFT methods
where these attributes were missing:

- LoHa
- LoKr
- LN Tuning
- Trainable Tokens

The layers of these methods now also expose the attributes.

Implementation

To avoid code duplication, I factored out the whole code block in LoRA
layers that extracts these attributes, since LoRA has the most
exhaustive list of checks. The new utility function has the exact same
functionality and can now be used by other PEFT methods.

I updated the four PEFT methods mentioned above to use this new
function, but I did not update PEFT methods that already handled it, as
there wasn't really a need (they check one or two layer types at most,
so there is little duplication).
This commit is contained in:
Benjamin Bossan
2025-09-25 17:59:45 +02:00
committed by GitHub
parent 7b2a5b1f02
commit ae671baec9
6 changed files with 99 additions and 59 deletions

View File

@ -19,7 +19,7 @@ from typing import Optional
import torch
import torch.nn as nn
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
from peft.tuners.tuners_utils import BaseTunerLayer, _get_in_out_features, check_adapters_to_merge
class LNTuningLayer(nn.Module, BaseTunerLayer):
@ -37,6 +37,10 @@ class LNTuningLayer(nn.Module, BaseTunerLayer):
self._active_adapter = adapter_name
self.merged_adapters = []
in_features, out_features = _get_in_out_features(self.get_base_layer())
self.in_features = in_features
self.out_features = out_features
def update_layer(self, layer: nn.Module, adapter_name: str, inference_mode: bool = False, **kwargs):
self.ln_tuning_layers[adapter_name] = deepcopy(layer)
self.set_adapter(adapter_name, inference_mode=inference_mode)

View File

@ -21,12 +21,11 @@ from typing import Any, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from packaging import version
from torch import svd_lowrank
from transformers.pytorch_utils import Conv1D
from peft.tuners._buffer_dict import BufferDict
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
from peft.tuners.tuners_utils import BaseTunerLayer, _get_in_out_features, check_adapters_to_merge
from peft.utils.integrations import (
dequantize_module_weight,
gather_params_ctx,
@ -124,60 +123,7 @@ class LoraLayer(BaseTunerLayer):
self.kwargs = kwargs
base_layer = self.get_base_layer()
if isinstance(base_layer, nn.Linear):
torch_supports_dtensor = version.parse(torch.__version__) >= version.parse("2.5.0")
if torch_supports_dtensor and isinstance(self.base_layer.weight, torch.distributed.tensor.DTensor):
# If Tensor Parallel is used, the weight is sharded, so we need to get the local shape
out_features, in_features = self.base_layer.weight.to_local().shape
else:
in_features, out_features = base_layer.in_features, base_layer.out_features
elif isinstance(base_layer, nn.Conv1d):
in_features, out_features = base_layer.in_channels, base_layer.out_channels
elif isinstance(base_layer, nn.Conv2d):
in_features, out_features = base_layer.in_channels, base_layer.out_channels
elif isinstance(base_layer, nn.Conv3d):
in_features, out_features = base_layer.in_channels, base_layer.out_channels
elif isinstance(base_layer, nn.Embedding):
in_features, out_features = base_layer.num_embeddings, base_layer.embedding_dim
elif isinstance(base_layer, Conv1D):
in_features, out_features = (
base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape
)
elif isinstance(base_layer, nn.MultiheadAttention):
if not base_layer._qkv_same_embed_dim:
raise ValueError(f"Only same dim for query/key/value is supported as of now for {self.__class__}.")
in_features, out_features = base_layer.embed_dim, 3 * base_layer.embed_dim
elif hasattr(base_layer, "infeatures") and hasattr(base_layer, "outfeatures"):
# QuantLinear
in_features, out_features = base_layer.infeatures, base_layer.outfeatures
elif hasattr(base_layer, "input_size") and hasattr(base_layer, "output_size"):
# Megatron ColumnParallelLinear,RowParallelLinear
in_features, out_features = base_layer.input_size, base_layer.output_size
elif hasattr(base_layer, "codebooks") and base_layer.__class__.__name__ == "QuantizedLinear":
# AQLM QuantLinear
in_features, out_features = base_layer.in_features, base_layer.out_features
elif hasattr(base_layer, "w_bit") and base_layer.__class__.__name__ == "WQLinear_GEMM":
# Awq layers
in_features, out_features = base_layer.in_features, base_layer.out_features
elif base_layer.__class__.__name__ == "EetqLinear":
# Eetq layers
in_features, out_features = base_layer.in_features, base_layer.out_features
elif hasattr(base_layer, "W_q") and base_layer.__class__.__name__ == "HQQLinear":
# HQQ layers
in_features, out_features = base_layer.in_features, base_layer.out_features
elif base_layer.__class__.__name__ == "PatchedLinear":
# INC layers
in_features, out_features = base_layer.in_features, base_layer.out_features
else:
# possibly support user provided custom layer types using dynamic dispatch
if hasattr(base_layer, "in_features") and hasattr(base_layer, "out_features"):
in_features, out_features = base_layer.in_features, base_layer.out_features
else:
in_features, out_features = None, None
warnings.warn(
f"Unsupported layer type '{type(base_layer)}' encountered, proceed at your own risk.", UserWarning
)
in_features, out_features = _get_in_out_features(base_layer)
self.in_features = in_features
self.out_features = out_features

View File

@ -23,7 +23,12 @@ import torch.nn as nn
from peft.config import PeftConfig
from .tuners_utils import BaseTuner, BaseTunerLayer, check_adapters_to_merge
from .tuners_utils import (
BaseTuner,
BaseTunerLayer,
_get_in_out_features,
check_adapters_to_merge,
)
@dataclass
@ -75,6 +80,10 @@ class LycorisLayer(BaseTunerLayer):
# flag to enable/disable casting of input to weight dtype during forward call
self.cast_input_dtype_enabled = True
in_features, out_features = _get_in_out_features(self.get_base_layer())
self.in_features = in_features
self.out_features = out_features
@property
@abstractmethod
def _available_adapters(self) -> set[str]: ...

View File

@ -23,7 +23,7 @@ import torch.nn as nn
import torch.nn.functional as F
from peft.tuners._buffer_dict import BufferDict
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
from peft.tuners.tuners_utils import BaseTunerLayer, _get_in_out_features, check_adapters_to_merge
from peft.utils.integrations import check_deepspeed_zero3_enabled, gather_params_ctx
@ -69,6 +69,10 @@ class TrainableTokensLayer(nn.Module, BaseTunerLayer):
# Mark the weight as unmerged
self.merged_adapters = []
in_features, out_features = _get_in_out_features(self.get_base_layer())
self.in_features = in_features
self.out_features = out_features
@property
def tied_adapter(self):
if self._tied_adapter:

View File

@ -26,6 +26,7 @@ from typing import Any, Optional, Union, overload
import torch
from accelerate.hooks import AlignDevicesHook
from accelerate.utils import named_module_tensors, offload_state_dict
from packaging import version
from torch import nn
from tqdm import tqdm
from transformers import PreTrainedModel
@ -144,6 +145,68 @@ def _check_lora_target_modules_mamba(peft_config: PeftConfig, model: nn.Module,
)
def _get_in_out_features(module: nn.Module) -> tuple[int, int] | tuple[None, None]:
"""
Get the in_features and out_features of the layer.
Returns in_features and out_features as a tuple. If they cannot be determined, return a tuple of None and None.
This function covers a broad range of layers, some of which the caller might not support. Therefore, just because
this function returns a valid result does not imply that the layer type is supported.
"""
if isinstance(module, nn.Linear):
torch_supports_dtensor = version.parse(torch.__version__) >= version.parse("2.5.0")
if torch_supports_dtensor and isinstance(module.weight, torch.distributed.tensor.DTensor):
# If Tensor Parallel is used, the weight is sharded, so we need to get the local shape
out_features, in_features = module.weight.to_local().shape
else:
in_features, out_features = module.in_features, module.out_features
elif isinstance(module, nn.Conv1d):
in_features, out_features = module.in_channels, module.out_channels
elif isinstance(module, nn.Conv2d):
in_features, out_features = module.in_channels, module.out_channels
elif isinstance(module, nn.Conv3d):
in_features, out_features = module.in_channels, module.out_channels
elif isinstance(module, nn.Embedding):
in_features, out_features = module.num_embeddings, module.embedding_dim
elif isinstance(module, Conv1D):
in_features, out_features = (
module.weight.ds_shape if hasattr(module.weight, "ds_shape") else module.weight.shape
)
elif isinstance(module, nn.MultiheadAttention):
if not module._qkv_same_embed_dim:
raise ValueError("Only same dim for query/key/value is supported as of now for MultiheadAttention.")
in_features, out_features = module.embed_dim, 3 * module.embed_dim
elif hasattr(module, "infeatures") and hasattr(module, "outfeatures"):
# QuantLinear
in_features, out_features = module.infeatures, module.outfeatures
elif hasattr(module, "input_size") and hasattr(module, "output_size"):
# Megatron ColumnParallelLinear,RowParallelLinear
in_features, out_features = module.input_size, module.output_size
elif hasattr(module, "codebooks") and module.__class__.__name__ == "QuantizedLinear":
# AQLM QuantLinear
in_features, out_features = module.in_features, module.out_features
elif hasattr(module, "w_bit") and module.__class__.__name__ == "WQLinear_GEMM":
# Awq layers
in_features, out_features = module.in_features, module.out_features
elif module.__class__.__name__ == "EetqLinear":
# Eetq layers
in_features, out_features = module.in_features, module.out_features
elif hasattr(module, "W_q") and module.__class__.__name__ == "HQQLinear":
# HQQ layers
in_features, out_features = module.in_features, module.out_features
elif module.__class__.__name__ == "PatchedLinear":
# INC layers
in_features, out_features = module.in_features, module.out_features
else:
# possibly support user provided custom layer types using dynamic dispatch
if hasattr(module, "in_features") and hasattr(module, "out_features"):
in_features, out_features = module.in_features, module.out_features
else:
in_features, out_features = None, None
warnings.warn(f"Unsupported layer type '{type(module)}' encountered, proceed at your own risk.", UserWarning)
return in_features, out_features
class BaseTuner(nn.Module, ABC):
r"""
A base tuner model that provides the common methods and attributes for all tuners that are injectable into a

View File

@ -1668,6 +1668,20 @@ class TestPeftCustomModel(PeftCommonTester):
def test_peft_model_device_map(self, test_name, model_id, config_cls, config_kwargs):
self._test_peft_model_device_map(model_id, config_cls, config_kwargs)
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
def test_in_features_out_features_exposed(self, test_name, model_id, config_cls, config_kwargs):
# the PEFT layer should expose the .in_features and .out_features attributes
model = self.transformers_class.from_pretrained(model_id).to(self.torch_device)
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
model = get_peft_model(model, config)
for module in model.modules():
if isinstance(module, BaseTunerLayer):
assert hasattr(module, "in_features")
assert hasattr(module, "out_features")
@pytest.mark.parametrize("test_name, model_id, config_cls, config_kwargs", TEST_CASES)
def test_forward_output_finite(self, test_name, model_id, config_cls, config_kwargs):
X = self.prepare_inputs_for_testing()