ENH Allow disabling input dtype casting for LoRA (#2353)

Provides the disable_input_dtype_casting to prevent the input dtype to
be cast during the forward call of a PEFT layer.

Normally, the dtype of the weight and input need to match, which is why
the dtype is cast. However, in certain circumustances, this is handled
by forward hooks, e.g. when using layerwise casting in diffusers. In
that case, PEFT casting the dtype interferes with the layerwise casting,
which is why the option to disable it is given.

Right now, this only supports LoRA. LoKr and LoHa don't cast the input
dtype anyway. Therefore, the PEFT methods most relevant for diffusers
are covered.
This commit is contained in:
Benjamin Bossan
2025-02-04 17:32:29 +01:00
committed by GitHub
parent 2825774d2d
commit db9dd3f4db
14 changed files with 181 additions and 30 deletions

View File

@ -14,4 +14,9 @@ A collection of helper functions for PEFT.
## Temporarily Rescaling Adapter Scale in LoraLayer Modules
[[autodoc]] helpers.rescale_adapter_scale
- all
- all
## Context manager to disable input dtype casting in the `forward` method of LoRA layers
[[autodoc]] helpers.disable_input_dtype_casting
- all

View File

@ -18,8 +18,10 @@ from copy import deepcopy
from functools import update_wrapper
from types import MethodType
from torch import nn
from .peft_model import PeftConfig, PeftModel
from .tuners.lora.layer import LoraLayer
from .tuners.lora import LoraLayer
def update_forward_signature(model: PeftModel) -> None:
@ -209,3 +211,42 @@ def rescale_adapter_scale(model, multiplier):
# restore original scaling values after exiting the context
for module, scaling in original_scaling.items():
module.scaling = scaling
@contextmanager
def disable_input_dtype_casting(model: nn.Module, active: bool = True):
"""
Context manager disables input dtype casting to the dtype of the weight.
Currently specifically works for LoRA.
Parameters:
model (nn.Module):
The model containing PEFT modules whose input dtype casting is to be adjusted.
active (bool):
Whether the context manager is active (default) or inactive.
"""
# Additional info: Normally, the dtype of the weight and input need to match, which is why the dtype is cast.
# However, in certain circumustances, this is handled by forward hooks, e.g. when using layerwise casting in
# diffusers. In that case, PEFT casting the dtype interferes with the layerwise casting, which is why the option to
# disable it is given.
if not active:
yield
return
original_values = {}
for name, module in model.named_modules():
if not isinstance(module, LoraLayer):
continue
original_values[name] = module.cast_input_dtype_enabled
module.cast_input_dtype_enabled = False
try:
yield
finally:
for name, module in model.named_modules():
if not isinstance(module, LoraLayer):
continue
if name in original_values:
module.cast_input_dtype_enabled = original_values[name]

View File

@ -129,9 +129,7 @@ if is_bnb_4bit_available():
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
compute_dtype = lora_A.dtype
if x.dtype != compute_dtype:
x = x.to(compute_dtype)
x = self._cast_input_dtype(x, lora_A.dtype)
output = dropout(x) @ (lora_A * lora_E).T @ lora_B.T
if requires_conversion:

View File

@ -55,8 +55,7 @@ class SVDQuantLinear(torch.nn.Module, AdaLoraLayer):
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
if x.dtype != torch.float32:
x = x.float()
x = self._cast_input_dtype(x, torch.float32)
output = (dropout(x) @ (lora_A * lora_E).T @ lora_B.T) * scaling / ranknum
# TODO: here, the dtype conversion is applied on the *whole expression*,

View File

@ -180,7 +180,7 @@ class SVDLinear(nn.Module, AdaLoraLayer):
scaling = self.scaling[active_adapter]
ranknum = self.ranknum[active_adapter] + 1e-5
x = x.to(lora_A.dtype)
x = self._cast_input_dtype(x, lora_A.dtype)
result += (dropout(x) @ (lora_A * lora_E).T @ lora_B.T) * scaling / ranknum
return result

View File

@ -75,7 +75,7 @@ class AqlmLoraLinear(torch.nn.Module, LoraLayer):
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)
output = lora_B(lora_A(dropout(x)))
if requires_conversion:

View File

@ -75,7 +75,7 @@ class AwqLoraLinear(torch.nn.Module, LoraLayer):
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)
output = lora_B(lora_A(dropout(x)))
if requires_conversion:

View File

@ -204,9 +204,7 @@ if is_bnb_available():
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
compute_dtype = lora_A.weight.dtype
if x.dtype != compute_dtype:
x = x.to(compute_dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)
# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
# layer output
@ -243,9 +241,7 @@ if is_bnb_available():
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
compute_dtype = lora_A.weight.dtype
if x.dtype != compute_dtype:
x = x.to(compute_dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)
if not self.use_dora[active_adapter]:
output = lora_B(lora_A(dropout(x))) * scaling
@ -470,7 +466,7 @@ if is_bnb_4bit_available():
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)
# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
# layer output
@ -514,7 +510,7 @@ if is_bnb_4bit_available():
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)
if not self.use_dora[active_adapter]:
output = lora_B(lora_A(dropout(x))) * scaling

View File

@ -76,7 +76,7 @@ if is_eetq_available():
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)
output = lora_B(lora_A(dropout(x)))
if requires_conversion:

View File

@ -75,7 +75,7 @@ class QuantLinear(torch.nn.Module, LoraLayer):
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)
output = lora_B(lora_A(dropout(x)))
if requires_conversion:

View File

@ -178,9 +178,7 @@ if is_hqq_available():
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
compute_dtype = lora_A.weight.dtype
if x.dtype != compute_dtype:
x = x.to(compute_dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)
# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
# layer output
@ -218,9 +216,7 @@ if is_hqq_available():
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
compute_dtype = lora_A.weight.dtype
if x.dtype != compute_dtype:
x = x.to(compute_dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)
if not self.use_dora[active_adapter]:
result = result + lora_B(lora_A(dropout(x))) * scaling

View File

@ -62,6 +62,8 @@ class LoraLayer(BaseTunerLayer):
self.lora_magnitude_vector = torch.nn.ModuleDict() # for DoRA
self._caches: dict[str, Any] = {}
self.ephemeral_gpu_offload: bool = ephemeral_gpu_offload
# flag to enable/disable casting of input to weight dtype during forward call
self.cast_input_dtype_enabled: bool = True
self.kwargs = kwargs
base_layer = self.get_base_layer()
@ -492,6 +494,19 @@ class LoraLayer(BaseTunerLayer):
return result
def _cast_input_dtype(self, x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
"""
Whether to cast the dtype of the input to the forward method.
Usually, we want to enable this to align the input dtype with the dtype of the weight, but by setting
layer.cast_input_dtype=False, this can be disabled if necessary.
Enabling or disabling can be managed via the peft.helpers.disable_lora_input_dtype_casting context manager.
"""
if (not self.cast_input_dtype_enabled) or (x.dtype == dtype):
return x
return x.to(dtype=dtype)
# Below code is based on https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
# and modified to work with PyTorch FSDP
@ -703,7 +718,7 @@ class Linear(nn.Module, LoraLayer):
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)
if not self.use_dora[active_adapter]:
result = result + lora_B(lora_A(dropout(x))) * scaling
@ -1268,7 +1283,7 @@ class _ConvNd(nn.Module, LoraLayer):
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)
if not self.use_dora[active_adapter]:
result = result + lora_B(lora_A(dropout(x))) * scaling

View File

@ -205,7 +205,7 @@ class LoraParallelLinear(nn.Module, LoraLayer):
lora_B = self.lora_B[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
x = x.to(lora_A.weight.dtype)
x = self._cast_input_dtype(x, lora_A.weight.dtype)
if not self.use_dora[active_adapter]:
result = result + lora_B(lora_A(dropout(x))) * scaling

View File

@ -16,11 +16,13 @@
import pytest
import torch
from diffusers import StableDiffusionPipeline
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model
from peft.helpers import check_if_peft_model, rescale_adapter_scale
from peft.helpers import check_if_peft_model, disable_input_dtype_casting, rescale_adapter_scale
from peft.tuners.lora.layer import LoraLayer
from peft.utils import infer_device
class TestCheckIsPeftModel:
@ -369,3 +371,102 @@ class TestScalingAdapters:
logits_merged_scaling = model(**inputs).logits
assert torch.allclose(logits_merged_scaling, logits_unmerged_scaling, atol=1e-4, rtol=1e-4)
class TestDisableInputDtypeCasting:
"""Test the context manager `disable_input_dtype_casting` that temporarily disables input dtype casting
in the model.
The test works as follows:
We create a simple MLP and convert it to a PeftModel. The model dtype is set to float16. Then a pre-foward hook is
added that casts the model parameters to float32. Moreover, a post-forward hook is added that casts the weights
back to float16. The input dtype is float32.
Without the disable_input_dtype_casting context, what would happen is that PEFT detects that the input dtype is
float32 but the weight dtype is float16, so it casts the input to float16. Then the pre-forward hook casts the
weight to float32, which results in a RuntimeError.
With the disable_input_dtype_casting context, the input dtype is left as float32 and there is no error. We also add
a hook to record the dtype of the result from the LoraLayer to ensure that it is indeed float32.
"""
device = infer_device()
dtype_record = []
@torch.no_grad()
def cast_params_to_fp32_pre_hook(self, module, input):
for param in module.parameters(recurse=False):
param.data = param.data.float()
return input
@torch.no_grad()
def cast_params_to_fp16_hook(self, module, input, output):
for param in module.parameters(recurse=False):
param.data = param.data.half()
return output
def record_dtype_hook(self, module, input, output):
self.dtype_record.append(output[0].dtype)
@pytest.fixture
def inputs(self):
return torch.randn(4, 10, device=self.device, dtype=torch.float32)
@pytest.fixture
def base_model(self):
class MLP(nn.Module):
def __init__(self, bias=True):
super().__init__()
self.lin0 = nn.Linear(10, 20, bias=bias)
self.lin1 = nn.Linear(20, 2, bias=bias)
self.sm = nn.LogSoftmax(dim=-1)
def forward(self, X):
X = self.lin0(X)
X = self.lin1(X)
X = self.sm(X)
return X
return MLP()
@pytest.fixture
def model(self, base_model):
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
model = get_peft_model(base_model, config).to(device=self.device, dtype=torch.float16)
# Register hooks on the submodule that holds parameters
for module in model.modules():
if sum(p.numel() for p in module.parameters()) > 0:
module.register_forward_pre_hook(self.cast_params_to_fp32_pre_hook)
module.register_forward_hook(self.cast_params_to_fp16_hook)
if isinstance(module, LoraLayer):
module.register_forward_hook(self.record_dtype_hook)
return model
def test_disable_input_dtype_casting_active(self, model, inputs):
self.dtype_record.clear()
with disable_input_dtype_casting(model, active=True):
model(inputs)
assert self.dtype_record == [torch.float32]
def test_no_disable_input_dtype_casting(self, model, inputs):
msg = r"expected m.*1 and m.*2 to have the same dtype"
with pytest.raises(RuntimeError, match=msg):
model(inputs)
def test_disable_input_dtype_casting_inactive(self, model, inputs):
msg = r"expected m.*1 and m.*2 to have the same dtype"
with pytest.raises(RuntimeError, match=msg):
with disable_input_dtype_casting(model, active=False):
model(inputs)
def test_disable_input_dtype_casting_inactive_after_existing_context(self, model, inputs):
# this is to ensure that when the context is left, we return to the previous behavior
with disable_input_dtype_casting(model, active=True):
model(inputs)
# after the context exited, we're back to the error
msg = r"expected m.*1 and m.*2 to have the same dtype"
with pytest.raises(RuntimeError, match=msg):
model(inputs)