mirror of
https://github.com/huggingface/peft.git
synced 2025-10-20 15:33:48 +08:00
FIX Handle embed scale for trainable tokens, LoRA (#2825)
Resolves #2809 Some models like Gemma3 apply a scalar to the embedding output. It needs to be taken into account when using trainable tokens or LoRA applied to the embedding layer.
This commit is contained in:
@ -22,7 +22,6 @@ from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from accelerate import Accelerator
|
||||
from diffusers import DDIMScheduler
|
||||
from diffusers.utils import check_min_version
|
||||
|
@ -1035,6 +1035,10 @@ class Embedding(nn.Module, LoraLayer):
|
||||
# extra argument that allows mixing different adapters in the same batch at inference time.
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
|
||||
# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method.
|
||||
# Since base_layer(x) already includes this scaling, we need to apply it to LoRA contributions too.
|
||||
embed_scale = self._get_embed_scale()
|
||||
|
||||
unique_adapters = set(adapter_names)
|
||||
sub_batch_indices_list = []
|
||||
for adapter in unique_adapters:
|
||||
@ -1054,7 +1058,13 @@ class Embedding(nn.Module, LoraLayer):
|
||||
# layer output
|
||||
sub_batch = x[sub_batch_indices_list[i]]
|
||||
after_A = self._embed(sub_batch, embedding_A)
|
||||
result[sub_batch_indices_list[i]] += (after_A @ embedding_B) * scaling
|
||||
adapter_output = (after_A @ embedding_B) * scaling
|
||||
|
||||
# Apply embed_scale to match the base layer's scaling
|
||||
if embed_scale is not None:
|
||||
adapter_output = adapter_output * embed_scale.to(adapter_output.dtype)
|
||||
|
||||
result[sub_batch_indices_list[i]] += adapter_output
|
||||
|
||||
return result
|
||||
|
||||
@ -1086,6 +1096,11 @@ class Embedding(nn.Module, LoraLayer):
|
||||
else:
|
||||
result = self.base_layer(x, *args, **kwargs)
|
||||
torch_result_dtype = result.dtype
|
||||
|
||||
# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method.
|
||||
# Since base_layer(x) already includes this scaling, we need to apply it to LoRA contributions too.
|
||||
embed_scale = self._get_embed_scale()
|
||||
|
||||
for active_adapter in self.active_adapters:
|
||||
if active_adapter not in self.lora_embedding_A:
|
||||
continue
|
||||
@ -1095,7 +1110,13 @@ class Embedding(nn.Module, LoraLayer):
|
||||
embedding_B = self.lora_embedding_B[active_adapter].T
|
||||
scaling = self.scaling[active_adapter]
|
||||
after_A = self._embed(x, embedding_A)
|
||||
result = result + (after_A @ embedding_B) * scaling
|
||||
adapter_output = (after_A @ embedding_B) * scaling
|
||||
|
||||
# Apply embed_scale to match the base layer's scaling
|
||||
if embed_scale is not None:
|
||||
adapter_output = adapter_output * embed_scale.to(adapter_output.dtype)
|
||||
|
||||
result = result + adapter_output
|
||||
else:
|
||||
result = self.lora_variant[active_adapter].forward(
|
||||
self,
|
||||
|
@ -232,6 +232,11 @@ class TrainableTokensLayer(nn.Module, BaseTunerLayer):
|
||||
scale_grad_by_freq=self.base_layer.scale_grad_by_freq,
|
||||
sparse=self.base_layer.sparse,
|
||||
)
|
||||
# Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method.
|
||||
# Since we're using F.embedding directly, we need to apply this scaling manually.
|
||||
embed_scale = self._get_embed_scale()
|
||||
if embed_scale is not None:
|
||||
result = result * embed_scale.to(result.dtype)
|
||||
elif isinstance(self.base_layer, torch.nn.Linear):
|
||||
# Probably a tied adapter that wraps an LM head.
|
||||
result = F.linear(
|
||||
|
@ -50,6 +50,7 @@ from peft.utils.other import (
|
||||
set_additional_trainable_modules,
|
||||
)
|
||||
from peft.utils.peft_types import PeftType, TaskType
|
||||
from peft.utils.warning import PeftWarning
|
||||
|
||||
from ..config import PeftConfig
|
||||
from ..utils import _get_submodules
|
||||
@ -1214,6 +1215,43 @@ class BaseTunerLayer(ABC):
|
||||
base_layer = base_layer.base_layer
|
||||
return base_layer
|
||||
|
||||
def _get_embed_scale(self):
|
||||
"""
|
||||
Extract embed_scale from base layer if present and valid.
|
||||
|
||||
Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling to embeddings in their forward
|
||||
method. This method checks for the presence of an `embed_scale` attribute. If it exists, it is assumed to be a
|
||||
scalar. Its shape is validated accordingly.
|
||||
|
||||
Returns:
|
||||
torch.Tensor or None: The embed_scale tensor if found and valid, None otherwise.
|
||||
"""
|
||||
base_layer = self.get_base_layer()
|
||||
if not hasattr(base_layer, "embed_scale"):
|
||||
return None
|
||||
|
||||
embed_scale = base_layer.embed_scale
|
||||
|
||||
# Convert scalar values to tensors
|
||||
if isinstance(embed_scale, (int, float)):
|
||||
return torch.tensor(embed_scale, device=base_layer.weight.device, dtype=base_layer.weight.dtype)
|
||||
|
||||
# Validate tensor shape - must be scalar (0-d) or 1-element tensor for proper broadcasting
|
||||
if isinstance(embed_scale, torch.Tensor):
|
||||
if embed_scale.numel() == 1:
|
||||
return embed_scale
|
||||
else:
|
||||
# Log warning but don't fail - this maintains backward compatibility
|
||||
warnings.warn(
|
||||
f"Found embed_scale attribute with shape {embed_scale.shape}, expected scalar. "
|
||||
"Embedding scaling will not be applied. If this is unexpected, please open an issue at "
|
||||
"https://github.com/huggingface/peft/issues",
|
||||
PeftWarning,
|
||||
)
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def weight(self) -> torch.Tensor:
|
||||
# This is required for some transformers code, e.g. for T5, weight is accessed as:
|
||||
|
@ -796,6 +796,66 @@ class TestDecoderModels(PeftCommonTester):
|
||||
else:
|
||||
assert not contains_embedding
|
||||
|
||||
def test_lora_embed_scale_is_applied(self):
|
||||
"""Test that LoRA correctly handles embeddings with scaling (e.g., Gemma3)."""
|
||||
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
|
||||
with hub_online_once(model_id):
|
||||
base_model = AutoModelForCausalLM.from_pretrained(model_id).to(self.torch_device)
|
||||
orig_embedding = base_model.get_input_embeddings()
|
||||
|
||||
peft_config = LoraConfig(target_modules=["embed_tokens"], init_lora_weights=False)
|
||||
peft_model = get_peft_model(base_model, peft_config)
|
||||
|
||||
x = torch.arange(10).to(self.torch_device)
|
||||
peft_embedding = peft_model.base_model.model.get_input_embeddings()
|
||||
embedding_output = peft_embedding(x)
|
||||
max_embedding_output = embedding_output.abs().max(0)[0]
|
||||
assert (max_embedding_output < 100.0).all()
|
||||
peft_model.merge_adapter()
|
||||
embedding_merged = peft_embedding(x)
|
||||
assert torch.allclose(embedding_output, embedding_merged)
|
||||
peft_model.unmerge_adapter()
|
||||
|
||||
# set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high
|
||||
# value
|
||||
orig_embedding.embed_scale.fill_(10000.0)
|
||||
max_embedding_output = peft_embedding(x).abs().max(0)[0]
|
||||
assert (max_embedding_output > 100.0).all()
|
||||
|
||||
# set embed_scale to zero, then check that the embedding output is also zero
|
||||
orig_embedding.embed_scale.fill_(0)
|
||||
embedding_output = peft_embedding(x)
|
||||
assert (embedding_output == 0.0).all()
|
||||
|
||||
def test_lora_embed_scale_is_applied_mixed_batch(self):
|
||||
"""Test that LoRA correctly handles embeddings with scaling in mixed batch mode."""
|
||||
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
|
||||
with hub_online_once(model_id):
|
||||
base_model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
orig_embedding = base_model.get_input_embeddings()
|
||||
|
||||
peft_config = LoraConfig(target_modules=["embed_tokens"], init_lora_weights=False)
|
||||
peft_model = get_peft_model(base_model, peft_config)
|
||||
peft_model.add_adapter("adapter2", peft_config)
|
||||
|
||||
# sanity check: with the default embed_scale, the embedding output should be reasonably sized
|
||||
peft_embedding = peft_model.base_model.model.get_input_embeddings()
|
||||
input_ids = torch.arange(10).unsqueeze(0).repeat(2, 1)
|
||||
adapter_names = ["default", "adapter2"]
|
||||
max_embedding_output = peft_embedding(input_ids, adapter_names=adapter_names).abs().max()
|
||||
assert max_embedding_output < 100.0
|
||||
|
||||
# set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high
|
||||
# value
|
||||
orig_embedding.embed_scale.fill_(10000.0)
|
||||
max_embedding_output = peft_embedding(input_ids, adapter_names=adapter_names).abs().max()
|
||||
assert max_embedding_output > 100.0
|
||||
|
||||
# set embed_scale to zero, then check that the embedding output is also zero
|
||||
orig_embedding.embed_scale.fill_(0)
|
||||
embedding_output = peft_embedding(input_ids, adapter_names=adapter_names)
|
||||
assert (embedding_output == 0.0).all()
|
||||
|
||||
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
|
||||
def test_set_requires_grad_prompt_learning_raises(self, config_cls, config_kwargs):
|
||||
# Test that for prompt learning, calling set_requires_grad raises an error with an appropriate error message.
|
||||
|
@ -918,3 +918,63 @@ class TestTrainableTokens:
|
||||
assert contains_embedding
|
||||
else:
|
||||
assert not contains_embedding
|
||||
|
||||
def test_embed_scale_is_applied(self):
|
||||
"""Test that TrainableTokens correctly handles embeddings with scaling (e.g., Gemma3)."""
|
||||
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
|
||||
with hub_online_once(model_id):
|
||||
base_model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
orig_embedding = base_model.get_input_embeddings()
|
||||
|
||||
peft_config = TrainableTokensConfig(target_modules=["embed_tokens"], token_indices=[0, 1, 3])
|
||||
peft_model = get_peft_model(base_model, peft_config)
|
||||
|
||||
# sanity check: with the default embed_scale, the embedding output should be reasonably sized
|
||||
peft_embedding = peft_model.base_model.model.get_input_embeddings()
|
||||
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0]
|
||||
assert (max_embedding_output < 100.0).all()
|
||||
|
||||
# set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high
|
||||
# value
|
||||
orig_embedding.embed_scale.fill_(10000.0)
|
||||
max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0]
|
||||
assert (max_embedding_output > 100.0).all()
|
||||
|
||||
# set embed_scale to zero, then check that the embedding output is also zero
|
||||
orig_embedding.embed_scale.fill_(0)
|
||||
embedding_output = peft_embedding(torch.arange(10))
|
||||
assert (embedding_output == 0.0).all()
|
||||
|
||||
def test_scaled_embedding_with_lora(self):
|
||||
"""
|
||||
Test that TrainableTokens works with LoRA on scaled embeddings when both are active simultaneously.
|
||||
"""
|
||||
model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
|
||||
with hub_online_once(model_id):
|
||||
base_model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
orig_embedding = base_model.get_input_embeddings()
|
||||
|
||||
# Apply both TrainableTokens and LoRA to the same model
|
||||
peft_config = LoraConfig(target_modules=["q_proj"], trainable_token_indices={"embed_tokens": [0, 1, 3]})
|
||||
peft_model = get_peft_model(base_model, peft_config)
|
||||
|
||||
x = torch.arange(10)
|
||||
peft_embedding = peft_model.base_model.model.get_input_embeddings()
|
||||
embedding_output = peft_embedding(x)
|
||||
max_embedding_output = embedding_output.abs().max(0)[0]
|
||||
assert (max_embedding_output < 100.0).all()
|
||||
peft_model.merge_adapter()
|
||||
embedding_merged = peft_embedding(x)
|
||||
assert torch.allclose(embedding_output, embedding_merged)
|
||||
peft_model.unmerge_adapter()
|
||||
|
||||
# set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high
|
||||
# value
|
||||
orig_embedding.embed_scale.fill_(10000.0)
|
||||
max_embedding_output = peft_embedding(x).abs().max(0)[0]
|
||||
assert (max_embedding_output > 100.0).all()
|
||||
|
||||
# set embed_scale to zero, then check that the embedding output is also zero
|
||||
orig_embedding.embed_scale.fill_(0)
|
||||
embedding_output = peft_embedding(x)
|
||||
assert (embedding_output == 0.0).all()
|
||||
|
Reference in New Issue
Block a user