FIX X-LoRA scaling storage and per token normalization (#2793)

This commit is contained in:
Che-Xu
2025-10-09 19:36:54 +08:00
committed by GitHub
parent 2c29cf7936
commit e9f5707e3f
3 changed files with 49 additions and 2 deletions

View File

@ -73,10 +73,12 @@ class XLoraLayer:
xlora_scalings = xlora_scalings * mask.to(xlora_scalings.dtype) xlora_scalings = xlora_scalings * mask.to(xlora_scalings.dtype)
# Apply per-token normalization to the xLoRA scaling factors using a softmax
if self.config.enable_softmax_topk: if self.config.enable_softmax_topk:
nonzero_mask = xlora_scalings != 0 nonzero_mask = xlora_scalings != 0
softmax_res_nonzero = torch.softmax(xlora_scalings[nonzero_mask], dim=-1) full = xlora_scalings.masked_fill(~nonzero_mask, float("-inf"))
xlora_scalings[nonzero_mask] = softmax_res_nonzero new_scalings = torch.softmax(full, dim=-1)
xlora_scalings = new_scalings.masked_fill(~nonzero_mask, 0.0)
return xlora_scalings return xlora_scalings

View File

@ -368,6 +368,8 @@ class XLoraModel(BaseTuner):
self.lora_model.enable_adapter_layers() self.lora_model.enable_adapter_layers()
xlora_scalings = self.internal_xlora_classifier(result=base_output, *args_real, **kwargs_real) xlora_scalings = self.internal_xlora_classifier(result=base_output, *args_real, **kwargs_real)
# Store computed scalings to fix get_latest_scalings() returning None
self.internal_xlora_scalings = xlora_scalings
# =========================== Real forward pass with calculated scalings ================== # =========================== Real forward pass with calculated scalings ==================

View File

@ -23,6 +23,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, PeftType, TaskType, XLoraConfig, get_peft_model from peft import LoraConfig, PeftType, TaskType, XLoraConfig, get_peft_model
from peft.peft_model import PeftModel from peft.peft_model import PeftModel
from peft.tuners.xlora.layer import XLoraLayer
from peft.utils import infer_device from peft.utils import infer_device
@ -381,3 +382,45 @@ class TestXlora:
w1 = sd["base_model.model.model.decoder.layers.0.self_attn.q_proj.lora_A.weight"] w1 = sd["base_model.model.model.decoder.layers.0.self_attn.q_proj.lora_A.weight"]
assert torch.allclose(w0, w1) assert torch.allclose(w0, w1)
def test_scalings_storage(self, tokenizer, model):
model.enable_scalings_logging()
inputs = tokenizer.encode("Python is a", add_special_tokens=False, return_tensors="pt")
outputs = model.generate(
input_ids=inputs.to(self.torch_device),
max_new_tokens=10,
)
latest_scalings = model.get_latest_scalings()
assert latest_scalings is not None, "get_latest_scalings() should not return None after generation"
assert isinstance(latest_scalings, torch.Tensor)
assert torch.isfinite(latest_scalings).all(), "Scalings should contain finite values"
def test_per_token_normalization_with_softmax_topk(self, tokenizer, model, monkeypatch):
model.internal_xlora_classifier.config.top_k_lora = 2
model.internal_xlora_classifier.config.enable_softmax = False
model.internal_xlora_classifier.config.enable_softmax_topk = True
captured_data = []
orig_get_maybe_topk_scalings = XLoraLayer.get_maybe_topk_scalings
def mock_get_maybe_topk_scalings(self, scalings):
result = orig_get_maybe_topk_scalings(self, scalings)
if getattr(model, "internal_xlora_scalings", None) is not None:
captured_data.append(result)
return result
monkeypatch.setattr(XLoraLayer, "get_maybe_topk_scalings", mock_get_maybe_topk_scalings)
model.enable_scalings_logging()
inputs = tokenizer.encode("Test per token normalization", add_special_tokens=False, return_tensors="pt")
outputs = model.generate(
input_ids=inputs.to(self.torch_device),
max_new_tokens=1,
)
for scaling in captured_data:
weight_sums = scaling.sum(dim=-1)
assert torch.allclose(weight_sums, torch.ones_like(weight_sums), atol=1e-5), (
"Per-token scaling weights are not normalized to sum to 1."
)