FIX Account for rsLoRA scaling in set_scale (#2775)

This commit is contained in:
Tanuj Rai
2025-09-16 15:00:29 +05:30
committed by GitHub
parent 1806c1651a
commit 20a9829f76
2 changed files with 57 additions and 3 deletions

View File

@ -113,6 +113,7 @@ class LoraLayer(BaseTunerLayer):
self._disable_adapters = False
self.merged_adapters = []
self.use_dora: dict[str, bool] = {} # not actively used anymore after #2443, keep it for BC
self.use_rslora: dict[str, bool] = {}
self.lora_bias: dict[str, bool] = {}
self.lora_magnitude_vector = torch.nn.ModuleDict() # for DoRA
self._caches: dict[str, Any] = {}
@ -255,6 +256,8 @@ class LoraLayer(BaseTunerLayer):
else:
self.scaling[adapter_name] = lora_alpha / r
self.use_rslora[adapter_name] = use_rslora
self.use_dora[adapter_name] = use_dora
# for inits that require access to the base weight, use gather_param_ctx so that the weight is gathered when using DeepSpeed
@ -528,7 +531,10 @@ class LoraLayer(BaseTunerLayer):
if adapter not in self.scaling:
# Ignore the case where the adapter is not in the layer
return
self.scaling[adapter] = scale * self.lora_alpha[adapter] / self.r[adapter]
if self.use_rslora.get(adapter, False):
self.scaling[adapter] = scale * self.lora_alpha[adapter] / math.sqrt(self.r[adapter])
else:
self.scaling[adapter] = scale * self.lora_alpha[adapter] / self.r[adapter]
def scale_layer(self, scale: float | int) -> None:
"""Multiply the current scale of all active adapters by the provided factor"""
@ -553,9 +559,12 @@ class LoraLayer(BaseTunerLayer):
continue
if scale is None:
self.scaling[active_adapter] = self.lora_alpha[active_adapter] / self.r[active_adapter]
if self.use_rslora.get(active_adapter, False):
self.scaling[active_adapter] = self.lora_alpha[active_adapter] / math.sqrt(self.r[active_adapter])
else:
self.scaling[active_adapter] = self.lora_alpha[active_adapter] / self.r[active_adapter]
else:
self.scaling[active_adapter] /= scale
self.scaling[active_adapter] = self.scaling[active_adapter] / scale
def _check_forward_args(self, x, *args, **kwargs):
"""Check if the arguments are compatible with the configs and state of the model"""
@ -960,6 +969,8 @@ class Embedding(nn.Module, LoraLayer):
else:
self.scaling[adapter_name] = lora_alpha / r
self.use_rslora[adapter_name] = use_rslora
self.use_dora[adapter_name] = use_dora
if init_lora_weights == "loftq":
@ -1260,6 +1271,8 @@ class _ConvNd(nn.Module, LoraLayer):
else:
self.scaling[adapter_name] = lora_alpha / r
self.use_rslora[adapter_name] = use_rslora
self.use_dora[adapter_name] = use_dora
if init_lora_weights == "loftq":
@ -2033,6 +2046,8 @@ class ParamWrapper(nn.Module, LoraLayer):
else:
self.scaling[adapter_name] = lora_alpha / r
self.use_rslora[adapter_name] = use_rslora
self.use_dora[adapter_name] = use_dora
# for inits that require access to the base weight, use gather_param_ctx so that the weight is gathered when using DeepSpeed

View File

@ -14,6 +14,7 @@
import copy
import itertools
import math
import platform
import re
import warnings
@ -3891,6 +3892,44 @@ class TestScaling:
expected = [2.0] * n_layers
assert scalings == expected
def test_scaling_with_rslora(self, model):
n_layers = 5
rank, lora_alpha = 8, 16
config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
use_rslora=True,
target_modules=["k_proj"],
)
model = get_peft_model(model, config)
scalings = self.get_scalings(model)
expected = [lora_alpha / math.sqrt(rank)] * n_layers
assert scalings == expected
# double
self.scale_layer(model, 2)
scalings = self.get_scalings(model)
expected = [2 * lora_alpha / math.sqrt(rank)] * n_layers
assert scalings == expected
# back to original
self.unscale_layer(model, None)
scalings = self.get_scalings(model)
expected = [lora_alpha / math.sqrt(rank)] * n_layers
assert scalings == expected
# triple
self.set_scale(model, "default", 3)
scalings = self.get_scalings(model)
expected = [3 * lora_alpha / math.sqrt(rank)] * n_layers
assert scalings == expected
# back to original
self.unscale_layer(model, 3)
scalings = self.get_scalings(model)
expected = [lora_alpha / math.sqrt(rank)] * n_layers
assert scalings == expected
def test_scaling_rank_pattern_alpha_pattern(self, model):
# layer 0: 8 / 8
# layer 1: 8 / 16