mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 15:33:48 +08:00
@ -153,6 +153,10 @@ class XLoraEmbeddingLayer(XLoraLayer):
|
||||
|
||||
result = self.target.base_layer(x, *args, **kwargs)
|
||||
|
||||
# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method.
|
||||
# Since base_layer(x) already includes this scaling, we need to apply it to X-LoRA contributions too.
|
||||
embed_scale = self.target._get_embed_scale()
|
||||
|
||||
# Ignore if disabled. We want to make sure this is always run.
|
||||
if not self.target.merged:
|
||||
for adapter_n, active_adapter in enumerate(self.target.active_adapters):
|
||||
@ -171,7 +175,14 @@ class XLoraEmbeddingLayer(XLoraLayer):
|
||||
else:
|
||||
after_A_mod = after_A
|
||||
scaling_weight = 1
|
||||
result += (after_A_mod @ embedding_B) * scaling * scaling_weight
|
||||
|
||||
adapter_output = (after_A_mod @ embedding_B) * scaling * scaling_weight
|
||||
|
||||
# Apply embed_scale to match the base layer's scaling
|
||||
if embed_scale is not None:
|
||||
adapter_output = adapter_output * embed_scale.to(adapter_output.dtype)
|
||||
|
||||
result += adapter_output
|
||||
|
||||
return result
|
||||
|
||||
|
@ -26,6 +26,8 @@ from peft.peft_model import PeftModel
|
||||
from peft.tuners.xlora.layer import XLoraLayer
|
||||
from peft.utils import infer_device
|
||||
|
||||
from .testing_utils import hub_online_once
|
||||
|
||||
|
||||
def flaky(num_tries: int):
|
||||
"""Decorator for test functions that are flaky"""
|
||||
@ -424,3 +426,48 @@ class TestXlora:
|
||||
assert torch.allclose(weight_sums, torch.ones_like(weight_sums), atol=1e-5), (
|
||||
"Per-token scaling weights are not normalized to sum to 1."
|
||||
)
|
||||
|
||||
def test_xlora_embed_scale_is_applied(self, tmp_path):
|
||||
"""Test that X-LoRA correctly handles embeddings with scaling (e.g., Gemma3)."""
|
||||
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
|
||||
with hub_online_once(model_id):
|
||||
# Create and save Gemma3-compatible LoRA adapters
|
||||
adapters = {}
|
||||
for i in range(2):
|
||||
torch.manual_seed(i + 1)
|
||||
lora_config = LoraConfig(
|
||||
task_type="CAUSAL_LM", init_lora_weights=False, target_modules=["embed_tokens"]
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
peft_model = get_peft_model(model, lora_config)
|
||||
adapter_path = os.path.join(tmp_path, f"checkpoint-{i + 1}")
|
||||
peft_model.save_pretrained(adapter_path)
|
||||
adapters[str(i)] = adapter_path
|
||||
|
||||
# Load base model and test X-LoRA with embed_scale
|
||||
base_model = AutoModelForCausalLM.from_pretrained(model_id).to(self.torch_device)
|
||||
base_model.config.use_cache = False
|
||||
orig_embedding = base_model.get_input_embeddings()
|
||||
|
||||
xlora_config = XLoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
hidden_size=base_model.config.hidden_size,
|
||||
adapters=adapters,
|
||||
)
|
||||
xlora_model = get_peft_model(base_model, xlora_config)
|
||||
|
||||
x = torch.arange(10).to(self.torch_device)
|
||||
xlora_embedding = xlora_model.base_model.model.get_input_embeddings()
|
||||
max_embedding_output = xlora_embedding(x).abs().max(0)[0]
|
||||
assert (max_embedding_output < 100.0).all()
|
||||
|
||||
# set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high
|
||||
# value
|
||||
orig_embedding.embed_scale.fill_(10000.0)
|
||||
max_embedding_output = xlora_embedding(x).abs().max(0)[0]
|
||||
assert (max_embedding_output > 100.0).all()
|
||||
|
||||
# set embed_scale to zero, then check that the embedding output is also zero
|
||||
orig_embedding.embed_scale.fill_(0)
|
||||
embedding_output = xlora_embedding(x)
|
||||
assert (embedding_output == 0.0).all()
|
||||
|
Reference in New Issue
Block a user