FIX X-LoRA embed_scale support #2830 (#2831)

This commit is contained in:
Sambhav Dixit
2025-10-14 19:24:15 +05:30
committed by GitHub
parent 9b8cf2a0c3
commit ec5a1b2ce6
2 changed files with 59 additions and 1 deletions

View File

@ -153,6 +153,10 @@ class XLoraEmbeddingLayer(XLoraLayer):
result = self.target.base_layer(x, *args, **kwargs)
# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method.
# Since base_layer(x) already includes this scaling, we need to apply it to X-LoRA contributions too.
embed_scale = self.target._get_embed_scale()
# Ignore if disabled. We want to make sure this is always run.
if not self.target.merged:
for adapter_n, active_adapter in enumerate(self.target.active_adapters):
@ -171,7 +175,14 @@ class XLoraEmbeddingLayer(XLoraLayer):
else:
after_A_mod = after_A
scaling_weight = 1
result += (after_A_mod @ embedding_B) * scaling * scaling_weight
adapter_output = (after_A_mod @ embedding_B) * scaling * scaling_weight
# Apply embed_scale to match the base layer's scaling
if embed_scale is not None:
adapter_output = adapter_output * embed_scale.to(adapter_output.dtype)
result += adapter_output
return result

View File

@ -26,6 +26,8 @@ from peft.peft_model import PeftModel
from peft.tuners.xlora.layer import XLoraLayer
from peft.utils import infer_device
from .testing_utils import hub_online_once
def flaky(num_tries: int):
"""Decorator for test functions that are flaky"""
@ -424,3 +426,48 @@ class TestXlora:
assert torch.allclose(weight_sums, torch.ones_like(weight_sums), atol=1e-5), (
"Per-token scaling weights are not normalized to sum to 1."
)
def test_xlora_embed_scale_is_applied(self, tmp_path):
"""Test that X-LoRA correctly handles embeddings with scaling (e.g., Gemma3)."""
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
with hub_online_once(model_id):
# Create and save Gemma3-compatible LoRA adapters
adapters = {}
for i in range(2):
torch.manual_seed(i + 1)
lora_config = LoraConfig(
task_type="CAUSAL_LM", init_lora_weights=False, target_modules=["embed_tokens"]
)
model = AutoModelForCausalLM.from_pretrained(model_id)
peft_model = get_peft_model(model, lora_config)
adapter_path = os.path.join(tmp_path, f"checkpoint-{i + 1}")
peft_model.save_pretrained(adapter_path)
adapters[str(i)] = adapter_path
# Load base model and test X-LoRA with embed_scale
base_model = AutoModelForCausalLM.from_pretrained(model_id).to(self.torch_device)
base_model.config.use_cache = False
orig_embedding = base_model.get_input_embeddings()
xlora_config = XLoraConfig(
task_type=TaskType.CAUSAL_LM,
hidden_size=base_model.config.hidden_size,
adapters=adapters,
)
xlora_model = get_peft_model(base_model, xlora_config)
x = torch.arange(10).to(self.torch_device)
xlora_embedding = xlora_model.base_model.model.get_input_embeddings()
max_embedding_output = xlora_embedding(x).abs().max(0)[0]
assert (max_embedding_output < 100.0).all()
# set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high
# value
orig_embedding.embed_scale.fill_(10000.0)
max_embedding_output = xlora_embedding(x).abs().max(0)[0]
assert (max_embedding_output > 100.0).all()
# set embed_scale to zero, then check that the embedding output is also zero
orig_embedding.embed_scale.fill_(0)
embedding_output = xlora_embedding(x)
assert (embedding_output == 0.0).all()