mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 15:33:48 +08:00
FIX Account for rsLoRA scaling in set_scale (#2775)
This commit is contained in:
@ -113,6 +113,7 @@ class LoraLayer(BaseTunerLayer):
|
||||
self._disable_adapters = False
|
||||
self.merged_adapters = []
|
||||
self.use_dora: dict[str, bool] = {} # not actively used anymore after #2443, keep it for BC
|
||||
self.use_rslora: dict[str, bool] = {}
|
||||
self.lora_bias: dict[str, bool] = {}
|
||||
self.lora_magnitude_vector = torch.nn.ModuleDict() # for DoRA
|
||||
self._caches: dict[str, Any] = {}
|
||||
@ -255,6 +256,8 @@ class LoraLayer(BaseTunerLayer):
|
||||
else:
|
||||
self.scaling[adapter_name] = lora_alpha / r
|
||||
|
||||
self.use_rslora[adapter_name] = use_rslora
|
||||
|
||||
self.use_dora[adapter_name] = use_dora
|
||||
|
||||
# for inits that require access to the base weight, use gather_param_ctx so that the weight is gathered when using DeepSpeed
|
||||
@ -528,7 +531,10 @@ class LoraLayer(BaseTunerLayer):
|
||||
if adapter not in self.scaling:
|
||||
# Ignore the case where the adapter is not in the layer
|
||||
return
|
||||
self.scaling[adapter] = scale * self.lora_alpha[adapter] / self.r[adapter]
|
||||
if self.use_rslora.get(adapter, False):
|
||||
self.scaling[adapter] = scale * self.lora_alpha[adapter] / math.sqrt(self.r[adapter])
|
||||
else:
|
||||
self.scaling[adapter] = scale * self.lora_alpha[adapter] / self.r[adapter]
|
||||
|
||||
def scale_layer(self, scale: float | int) -> None:
|
||||
"""Multiply the current scale of all active adapters by the provided factor"""
|
||||
@ -553,9 +559,12 @@ class LoraLayer(BaseTunerLayer):
|
||||
continue
|
||||
|
||||
if scale is None:
|
||||
self.scaling[active_adapter] = self.lora_alpha[active_adapter] / self.r[active_adapter]
|
||||
if self.use_rslora.get(active_adapter, False):
|
||||
self.scaling[active_adapter] = self.lora_alpha[active_adapter] / math.sqrt(self.r[active_adapter])
|
||||
else:
|
||||
self.scaling[active_adapter] = self.lora_alpha[active_adapter] / self.r[active_adapter]
|
||||
else:
|
||||
self.scaling[active_adapter] /= scale
|
||||
self.scaling[active_adapter] = self.scaling[active_adapter] / scale
|
||||
|
||||
def _check_forward_args(self, x, *args, **kwargs):
|
||||
"""Check if the arguments are compatible with the configs and state of the model"""
|
||||
@ -960,6 +969,8 @@ class Embedding(nn.Module, LoraLayer):
|
||||
else:
|
||||
self.scaling[adapter_name] = lora_alpha / r
|
||||
|
||||
self.use_rslora[adapter_name] = use_rslora
|
||||
|
||||
self.use_dora[adapter_name] = use_dora
|
||||
|
||||
if init_lora_weights == "loftq":
|
||||
@ -1260,6 +1271,8 @@ class _ConvNd(nn.Module, LoraLayer):
|
||||
else:
|
||||
self.scaling[adapter_name] = lora_alpha / r
|
||||
|
||||
self.use_rslora[adapter_name] = use_rslora
|
||||
|
||||
self.use_dora[adapter_name] = use_dora
|
||||
|
||||
if init_lora_weights == "loftq":
|
||||
@ -2033,6 +2046,8 @@ class ParamWrapper(nn.Module, LoraLayer):
|
||||
else:
|
||||
self.scaling[adapter_name] = lora_alpha / r
|
||||
|
||||
self.use_rslora[adapter_name] = use_rslora
|
||||
|
||||
self.use_dora[adapter_name] = use_dora
|
||||
|
||||
# for inits that require access to the base weight, use gather_param_ctx so that the weight is gathered when using DeepSpeed
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
import math
|
||||
import platform
|
||||
import re
|
||||
import warnings
|
||||
@ -3891,6 +3892,44 @@ class TestScaling:
|
||||
expected = [2.0] * n_layers
|
||||
assert scalings == expected
|
||||
|
||||
def test_scaling_with_rslora(self, model):
|
||||
n_layers = 5
|
||||
rank, lora_alpha = 8, 16
|
||||
config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
use_rslora=True,
|
||||
target_modules=["k_proj"],
|
||||
)
|
||||
model = get_peft_model(model, config)
|
||||
scalings = self.get_scalings(model)
|
||||
expected = [lora_alpha / math.sqrt(rank)] * n_layers
|
||||
assert scalings == expected
|
||||
|
||||
# double
|
||||
self.scale_layer(model, 2)
|
||||
scalings = self.get_scalings(model)
|
||||
expected = [2 * lora_alpha / math.sqrt(rank)] * n_layers
|
||||
assert scalings == expected
|
||||
|
||||
# back to original
|
||||
self.unscale_layer(model, None)
|
||||
scalings = self.get_scalings(model)
|
||||
expected = [lora_alpha / math.sqrt(rank)] * n_layers
|
||||
assert scalings == expected
|
||||
|
||||
# triple
|
||||
self.set_scale(model, "default", 3)
|
||||
scalings = self.get_scalings(model)
|
||||
expected = [3 * lora_alpha / math.sqrt(rank)] * n_layers
|
||||
assert scalings == expected
|
||||
|
||||
# back to original
|
||||
self.unscale_layer(model, 3)
|
||||
scalings = self.get_scalings(model)
|
||||
expected = [lora_alpha / math.sqrt(rank)] * n_layers
|
||||
assert scalings == expected
|
||||
|
||||
def test_scaling_rank_pattern_alpha_pattern(self, model):
|
||||
# layer 0: 8 / 8
|
||||
# layer 1: 8 / 16
|
||||
|
Reference in New Issue
Block a user