mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 15:33:48 +08:00
ENH Allow disabling input dtype casting for LoRA (#2353)
Provides the disable_input_dtype_casting to prevent the input dtype to be cast during the forward call of a PEFT layer. Normally, the dtype of the weight and input need to match, which is why the dtype is cast. However, in certain circumustances, this is handled by forward hooks, e.g. when using layerwise casting in diffusers. In that case, PEFT casting the dtype interferes with the layerwise casting, which is why the option to disable it is given. Right now, this only supports LoRA. LoKr and LoHa don't cast the input dtype anyway. Therefore, the PEFT methods most relevant for diffusers are covered.
This commit is contained in:
@ -14,4 +14,9 @@ A collection of helper functions for PEFT.
|
||||
## Temporarily Rescaling Adapter Scale in LoraLayer Modules
|
||||
|
||||
[[autodoc]] helpers.rescale_adapter_scale
|
||||
- all
|
||||
- all
|
||||
|
||||
## Context manager to disable input dtype casting in the `forward` method of LoRA layers
|
||||
|
||||
[[autodoc]] helpers.disable_input_dtype_casting
|
||||
- all
|
||||
|
@ -18,8 +18,10 @@ from copy import deepcopy
|
||||
from functools import update_wrapper
|
||||
from types import MethodType
|
||||
|
||||
from torch import nn
|
||||
|
||||
from .peft_model import PeftConfig, PeftModel
|
||||
from .tuners.lora.layer import LoraLayer
|
||||
from .tuners.lora import LoraLayer
|
||||
|
||||
|
||||
def update_forward_signature(model: PeftModel) -> None:
|
||||
@ -209,3 +211,42 @@ def rescale_adapter_scale(model, multiplier):
|
||||
# restore original scaling values after exiting the context
|
||||
for module, scaling in original_scaling.items():
|
||||
module.scaling = scaling
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_input_dtype_casting(model: nn.Module, active: bool = True):
|
||||
"""
|
||||
Context manager disables input dtype casting to the dtype of the weight.
|
||||
|
||||
Currently specifically works for LoRA.
|
||||
|
||||
Parameters:
|
||||
model (nn.Module):
|
||||
The model containing PEFT modules whose input dtype casting is to be adjusted.
|
||||
active (bool):
|
||||
Whether the context manager is active (default) or inactive.
|
||||
|
||||
"""
|
||||
# Additional info: Normally, the dtype of the weight and input need to match, which is why the dtype is cast.
|
||||
# However, in certain circumustances, this is handled by forward hooks, e.g. when using layerwise casting in
|
||||
# diffusers. In that case, PEFT casting the dtype interferes with the layerwise casting, which is why the option to
|
||||
# disable it is given.
|
||||
if not active:
|
||||
yield
|
||||
return
|
||||
|
||||
original_values = {}
|
||||
for name, module in model.named_modules():
|
||||
if not isinstance(module, LoraLayer):
|
||||
continue
|
||||
original_values[name] = module.cast_input_dtype_enabled
|
||||
module.cast_input_dtype_enabled = False
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for name, module in model.named_modules():
|
||||
if not isinstance(module, LoraLayer):
|
||||
continue
|
||||
if name in original_values:
|
||||
module.cast_input_dtype_enabled = original_values[name]
|
||||
|
@ -129,9 +129,7 @@ if is_bnb_4bit_available():
|
||||
requires_conversion = not torch.is_autocast_enabled()
|
||||
if requires_conversion:
|
||||
expected_dtype = result.dtype
|
||||
compute_dtype = lora_A.dtype
|
||||
if x.dtype != compute_dtype:
|
||||
x = x.to(compute_dtype)
|
||||
x = self._cast_input_dtype(x, lora_A.dtype)
|
||||
|
||||
output = dropout(x) @ (lora_A * lora_E).T @ lora_B.T
|
||||
if requires_conversion:
|
||||
|
@ -55,8 +55,7 @@ class SVDQuantLinear(torch.nn.Module, AdaLoraLayer):
|
||||
requires_conversion = not torch.is_autocast_enabled()
|
||||
if requires_conversion:
|
||||
expected_dtype = result.dtype
|
||||
if x.dtype != torch.float32:
|
||||
x = x.float()
|
||||
x = self._cast_input_dtype(x, torch.float32)
|
||||
|
||||
output = (dropout(x) @ (lora_A * lora_E).T @ lora_B.T) * scaling / ranknum
|
||||
# TODO: here, the dtype conversion is applied on the *whole expression*,
|
||||
|
@ -180,7 +180,7 @@ class SVDLinear(nn.Module, AdaLoraLayer):
|
||||
scaling = self.scaling[active_adapter]
|
||||
ranknum = self.ranknum[active_adapter] + 1e-5
|
||||
|
||||
x = x.to(lora_A.dtype)
|
||||
x = self._cast_input_dtype(x, lora_A.dtype)
|
||||
result += (dropout(x) @ (lora_A * lora_E).T @ lora_B.T) * scaling / ranknum
|
||||
|
||||
return result
|
||||
|
@ -75,7 +75,7 @@ class AqlmLoraLinear(torch.nn.Module, LoraLayer):
|
||||
requires_conversion = not torch.is_autocast_enabled()
|
||||
if requires_conversion:
|
||||
expected_dtype = result.dtype
|
||||
x = x.to(lora_A.weight.dtype)
|
||||
x = self._cast_input_dtype(x, lora_A.weight.dtype)
|
||||
|
||||
output = lora_B(lora_A(dropout(x)))
|
||||
if requires_conversion:
|
||||
|
@ -75,7 +75,7 @@ class AwqLoraLinear(torch.nn.Module, LoraLayer):
|
||||
requires_conversion = not torch.is_autocast_enabled()
|
||||
if requires_conversion:
|
||||
expected_dtype = result.dtype
|
||||
x = x.to(lora_A.weight.dtype)
|
||||
x = self._cast_input_dtype(x, lora_A.weight.dtype)
|
||||
|
||||
output = lora_B(lora_A(dropout(x)))
|
||||
if requires_conversion:
|
||||
|
@ -204,9 +204,7 @@ if is_bnb_available():
|
||||
requires_conversion = not torch.is_autocast_enabled()
|
||||
if requires_conversion:
|
||||
expected_dtype = result.dtype
|
||||
compute_dtype = lora_A.weight.dtype
|
||||
if x.dtype != compute_dtype:
|
||||
x = x.to(compute_dtype)
|
||||
x = self._cast_input_dtype(x, lora_A.weight.dtype)
|
||||
|
||||
# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
|
||||
# layer output
|
||||
@ -243,9 +241,7 @@ if is_bnb_available():
|
||||
requires_conversion = not torch.is_autocast_enabled()
|
||||
if requires_conversion:
|
||||
expected_dtype = result.dtype
|
||||
compute_dtype = lora_A.weight.dtype
|
||||
if x.dtype != compute_dtype:
|
||||
x = x.to(compute_dtype)
|
||||
x = self._cast_input_dtype(x, lora_A.weight.dtype)
|
||||
|
||||
if not self.use_dora[active_adapter]:
|
||||
output = lora_B(lora_A(dropout(x))) * scaling
|
||||
@ -470,7 +466,7 @@ if is_bnb_4bit_available():
|
||||
requires_conversion = not torch.is_autocast_enabled()
|
||||
if requires_conversion:
|
||||
expected_dtype = result.dtype
|
||||
x = x.to(lora_A.weight.dtype)
|
||||
x = self._cast_input_dtype(x, lora_A.weight.dtype)
|
||||
|
||||
# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
|
||||
# layer output
|
||||
@ -514,7 +510,7 @@ if is_bnb_4bit_available():
|
||||
requires_conversion = not torch.is_autocast_enabled()
|
||||
if requires_conversion:
|
||||
expected_dtype = result.dtype
|
||||
x = x.to(lora_A.weight.dtype)
|
||||
x = self._cast_input_dtype(x, lora_A.weight.dtype)
|
||||
|
||||
if not self.use_dora[active_adapter]:
|
||||
output = lora_B(lora_A(dropout(x))) * scaling
|
||||
|
@ -76,7 +76,7 @@ if is_eetq_available():
|
||||
requires_conversion = not torch.is_autocast_enabled()
|
||||
if requires_conversion:
|
||||
expected_dtype = result.dtype
|
||||
x = x.to(lora_A.weight.dtype)
|
||||
x = self._cast_input_dtype(x, lora_A.weight.dtype)
|
||||
|
||||
output = lora_B(lora_A(dropout(x)))
|
||||
if requires_conversion:
|
||||
|
@ -75,7 +75,7 @@ class QuantLinear(torch.nn.Module, LoraLayer):
|
||||
requires_conversion = not torch.is_autocast_enabled()
|
||||
if requires_conversion:
|
||||
expected_dtype = result.dtype
|
||||
x = x.to(lora_A.weight.dtype)
|
||||
x = self._cast_input_dtype(x, lora_A.weight.dtype)
|
||||
|
||||
output = lora_B(lora_A(dropout(x)))
|
||||
if requires_conversion:
|
||||
|
@ -178,9 +178,7 @@ if is_hqq_available():
|
||||
requires_conversion = not torch.is_autocast_enabled()
|
||||
if requires_conversion:
|
||||
expected_dtype = result.dtype
|
||||
compute_dtype = lora_A.weight.dtype
|
||||
if x.dtype != compute_dtype:
|
||||
x = x.to(compute_dtype)
|
||||
x = self._cast_input_dtype(x, lora_A.weight.dtype)
|
||||
|
||||
# getting the sub-batch, passing it to LoRA layers and updating the corresponding indices of the linear
|
||||
# layer output
|
||||
@ -218,9 +216,7 @@ if is_hqq_available():
|
||||
requires_conversion = not torch.is_autocast_enabled()
|
||||
if requires_conversion:
|
||||
expected_dtype = result.dtype
|
||||
compute_dtype = lora_A.weight.dtype
|
||||
if x.dtype != compute_dtype:
|
||||
x = x.to(compute_dtype)
|
||||
x = self._cast_input_dtype(x, lora_A.weight.dtype)
|
||||
|
||||
if not self.use_dora[active_adapter]:
|
||||
result = result + lora_B(lora_A(dropout(x))) * scaling
|
||||
|
@ -62,6 +62,8 @@ class LoraLayer(BaseTunerLayer):
|
||||
self.lora_magnitude_vector = torch.nn.ModuleDict() # for DoRA
|
||||
self._caches: dict[str, Any] = {}
|
||||
self.ephemeral_gpu_offload: bool = ephemeral_gpu_offload
|
||||
# flag to enable/disable casting of input to weight dtype during forward call
|
||||
self.cast_input_dtype_enabled: bool = True
|
||||
self.kwargs = kwargs
|
||||
|
||||
base_layer = self.get_base_layer()
|
||||
@ -492,6 +494,19 @@ class LoraLayer(BaseTunerLayer):
|
||||
|
||||
return result
|
||||
|
||||
def _cast_input_dtype(self, x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||
"""
|
||||
Whether to cast the dtype of the input to the forward method.
|
||||
|
||||
Usually, we want to enable this to align the input dtype with the dtype of the weight, but by setting
|
||||
layer.cast_input_dtype=False, this can be disabled if necessary.
|
||||
|
||||
Enabling or disabling can be managed via the peft.helpers.disable_lora_input_dtype_casting context manager.
|
||||
"""
|
||||
if (not self.cast_input_dtype_enabled) or (x.dtype == dtype):
|
||||
return x
|
||||
return x.to(dtype=dtype)
|
||||
|
||||
|
||||
# Below code is based on https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
|
||||
# and modified to work with PyTorch FSDP
|
||||
@ -703,7 +718,7 @@ class Linear(nn.Module, LoraLayer):
|
||||
lora_B = self.lora_B[active_adapter]
|
||||
dropout = self.lora_dropout[active_adapter]
|
||||
scaling = self.scaling[active_adapter]
|
||||
x = x.to(lora_A.weight.dtype)
|
||||
x = self._cast_input_dtype(x, lora_A.weight.dtype)
|
||||
|
||||
if not self.use_dora[active_adapter]:
|
||||
result = result + lora_B(lora_A(dropout(x))) * scaling
|
||||
@ -1268,7 +1283,7 @@ class _ConvNd(nn.Module, LoraLayer):
|
||||
lora_B = self.lora_B[active_adapter]
|
||||
dropout = self.lora_dropout[active_adapter]
|
||||
scaling = self.scaling[active_adapter]
|
||||
x = x.to(lora_A.weight.dtype)
|
||||
x = self._cast_input_dtype(x, lora_A.weight.dtype)
|
||||
|
||||
if not self.use_dora[active_adapter]:
|
||||
result = result + lora_B(lora_A(dropout(x))) * scaling
|
||||
|
@ -205,7 +205,7 @@ class LoraParallelLinear(nn.Module, LoraLayer):
|
||||
lora_B = self.lora_B[active_adapter]
|
||||
dropout = self.lora_dropout[active_adapter]
|
||||
scaling = self.scaling[active_adapter]
|
||||
x = x.to(lora_A.weight.dtype)
|
||||
x = self._cast_input_dtype(x, lora_A.weight.dtype)
|
||||
|
||||
if not self.use_dora[active_adapter]:
|
||||
result = result + lora_B(lora_A(dropout(x))) * scaling
|
||||
|
@ -16,11 +16,13 @@
|
||||
import pytest
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from torch import nn
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from peft.helpers import check_if_peft_model, rescale_adapter_scale
|
||||
from peft.helpers import check_if_peft_model, disable_input_dtype_casting, rescale_adapter_scale
|
||||
from peft.tuners.lora.layer import LoraLayer
|
||||
from peft.utils import infer_device
|
||||
|
||||
|
||||
class TestCheckIsPeftModel:
|
||||
@ -369,3 +371,102 @@ class TestScalingAdapters:
|
||||
logits_merged_scaling = model(**inputs).logits
|
||||
|
||||
assert torch.allclose(logits_merged_scaling, logits_unmerged_scaling, atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
||||
class TestDisableInputDtypeCasting:
|
||||
"""Test the context manager `disable_input_dtype_casting` that temporarily disables input dtype casting
|
||||
in the model.
|
||||
|
||||
The test works as follows:
|
||||
|
||||
We create a simple MLP and convert it to a PeftModel. The model dtype is set to float16. Then a pre-foward hook is
|
||||
added that casts the model parameters to float32. Moreover, a post-forward hook is added that casts the weights
|
||||
back to float16. The input dtype is float32.
|
||||
|
||||
Without the disable_input_dtype_casting context, what would happen is that PEFT detects that the input dtype is
|
||||
float32 but the weight dtype is float16, so it casts the input to float16. Then the pre-forward hook casts the
|
||||
weight to float32, which results in a RuntimeError.
|
||||
|
||||
With the disable_input_dtype_casting context, the input dtype is left as float32 and there is no error. We also add
|
||||
a hook to record the dtype of the result from the LoraLayer to ensure that it is indeed float32.
|
||||
|
||||
"""
|
||||
|
||||
device = infer_device()
|
||||
dtype_record = []
|
||||
|
||||
@torch.no_grad()
|
||||
def cast_params_to_fp32_pre_hook(self, module, input):
|
||||
for param in module.parameters(recurse=False):
|
||||
param.data = param.data.float()
|
||||
return input
|
||||
|
||||
@torch.no_grad()
|
||||
def cast_params_to_fp16_hook(self, module, input, output):
|
||||
for param in module.parameters(recurse=False):
|
||||
param.data = param.data.half()
|
||||
return output
|
||||
|
||||
def record_dtype_hook(self, module, input, output):
|
||||
self.dtype_record.append(output[0].dtype)
|
||||
|
||||
@pytest.fixture
|
||||
def inputs(self):
|
||||
return torch.randn(4, 10, device=self.device, dtype=torch.float32)
|
||||
|
||||
@pytest.fixture
|
||||
def base_model(self):
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, bias=True):
|
||||
super().__init__()
|
||||
self.lin0 = nn.Linear(10, 20, bias=bias)
|
||||
self.lin1 = nn.Linear(20, 2, bias=bias)
|
||||
self.sm = nn.LogSoftmax(dim=-1)
|
||||
|
||||
def forward(self, X):
|
||||
X = self.lin0(X)
|
||||
X = self.lin1(X)
|
||||
X = self.sm(X)
|
||||
return X
|
||||
|
||||
return MLP()
|
||||
|
||||
@pytest.fixture
|
||||
def model(self, base_model):
|
||||
config = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"])
|
||||
model = get_peft_model(base_model, config).to(device=self.device, dtype=torch.float16)
|
||||
# Register hooks on the submodule that holds parameters
|
||||
for module in model.modules():
|
||||
if sum(p.numel() for p in module.parameters()) > 0:
|
||||
module.register_forward_pre_hook(self.cast_params_to_fp32_pre_hook)
|
||||
module.register_forward_hook(self.cast_params_to_fp16_hook)
|
||||
if isinstance(module, LoraLayer):
|
||||
module.register_forward_hook(self.record_dtype_hook)
|
||||
return model
|
||||
|
||||
def test_disable_input_dtype_casting_active(self, model, inputs):
|
||||
self.dtype_record.clear()
|
||||
with disable_input_dtype_casting(model, active=True):
|
||||
model(inputs)
|
||||
assert self.dtype_record == [torch.float32]
|
||||
|
||||
def test_no_disable_input_dtype_casting(self, model, inputs):
|
||||
msg = r"expected m.*1 and m.*2 to have the same dtype"
|
||||
with pytest.raises(RuntimeError, match=msg):
|
||||
model(inputs)
|
||||
|
||||
def test_disable_input_dtype_casting_inactive(self, model, inputs):
|
||||
msg = r"expected m.*1 and m.*2 to have the same dtype"
|
||||
with pytest.raises(RuntimeError, match=msg):
|
||||
with disable_input_dtype_casting(model, active=False):
|
||||
model(inputs)
|
||||
|
||||
def test_disable_input_dtype_casting_inactive_after_existing_context(self, model, inputs):
|
||||
# this is to ensure that when the context is left, we return to the previous behavior
|
||||
with disable_input_dtype_casting(model, active=True):
|
||||
model(inputs)
|
||||
|
||||
# after the context exited, we're back to the error
|
||||
msg = r"expected m.*1 and m.*2 to have the same dtype"
|
||||
with pytest.raises(RuntimeError, match=msg):
|
||||
model(inputs)
|
||||
|
Reference in New Issue
Block a user