diff --git a/examples/boft_controlnet/test_controlnet.py b/examples/boft_controlnet/test_controlnet.py index 2080deb0..9624b7c3 100644 --- a/examples/boft_controlnet/test_controlnet.py +++ b/examples/boft_controlnet/test_controlnet.py @@ -22,7 +22,6 @@ from pathlib import Path import numpy as np import torch -import torch.utils.checkpoint from accelerate import Accelerator from diffusers import DDIMScheduler from diffusers.utils import check_min_version diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index b01e87e6..a338fac0 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -1035,6 +1035,10 @@ class Embedding(nn.Module, LoraLayer): # extra argument that allows mixing different adapters in the same batch at inference time. result = self.base_layer(x, *args, **kwargs) + # Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method. + # Since base_layer(x) already includes this scaling, we need to apply it to LoRA contributions too. + embed_scale = self._get_embed_scale() + unique_adapters = set(adapter_names) sub_batch_indices_list = [] for adapter in unique_adapters: @@ -1054,7 +1058,13 @@ class Embedding(nn.Module, LoraLayer): # layer output sub_batch = x[sub_batch_indices_list[i]] after_A = self._embed(sub_batch, embedding_A) - result[sub_batch_indices_list[i]] += (after_A @ embedding_B) * scaling + adapter_output = (after_A @ embedding_B) * scaling + + # Apply embed_scale to match the base layer's scaling + if embed_scale is not None: + adapter_output = adapter_output * embed_scale.to(adapter_output.dtype) + + result[sub_batch_indices_list[i]] += adapter_output return result @@ -1086,6 +1096,11 @@ class Embedding(nn.Module, LoraLayer): else: result = self.base_layer(x, *args, **kwargs) torch_result_dtype = result.dtype + + # Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method. + # Since base_layer(x) already includes this scaling, we need to apply it to LoRA contributions too. + embed_scale = self._get_embed_scale() + for active_adapter in self.active_adapters: if active_adapter not in self.lora_embedding_A: continue @@ -1095,7 +1110,13 @@ class Embedding(nn.Module, LoraLayer): embedding_B = self.lora_embedding_B[active_adapter].T scaling = self.scaling[active_adapter] after_A = self._embed(x, embedding_A) - result = result + (after_A @ embedding_B) * scaling + adapter_output = (after_A @ embedding_B) * scaling + + # Apply embed_scale to match the base layer's scaling + if embed_scale is not None: + adapter_output = adapter_output * embed_scale.to(adapter_output.dtype) + + result = result + adapter_output else: result = self.lora_variant[active_adapter].forward( self, diff --git a/src/peft/tuners/trainable_tokens/layer.py b/src/peft/tuners/trainable_tokens/layer.py index 0f354622..da955a68 100644 --- a/src/peft/tuners/trainable_tokens/layer.py +++ b/src/peft/tuners/trainable_tokens/layer.py @@ -232,6 +232,11 @@ class TrainableTokensLayer(nn.Module, BaseTunerLayer): scale_grad_by_freq=self.base_layer.scale_grad_by_freq, sparse=self.base_layer.sparse, ) + # Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling in their forward method. + # Since we're using F.embedding directly, we need to apply this scaling manually. + embed_scale = self._get_embed_scale() + if embed_scale is not None: + result = result * embed_scale.to(result.dtype) elif isinstance(self.base_layer, torch.nn.Linear): # Probably a tied adapter that wraps an LM head. result = F.linear( diff --git a/src/peft/tuners/tuners_utils.py b/src/peft/tuners/tuners_utils.py index 66903296..f41d11d2 100644 --- a/src/peft/tuners/tuners_utils.py +++ b/src/peft/tuners/tuners_utils.py @@ -50,6 +50,7 @@ from peft.utils.other import ( set_additional_trainable_modules, ) from peft.utils.peft_types import PeftType, TaskType +from peft.utils.warning import PeftWarning from ..config import PeftConfig from ..utils import _get_submodules @@ -1214,6 +1215,43 @@ class BaseTunerLayer(ABC): base_layer = base_layer.base_layer return base_layer + def _get_embed_scale(self): + """ + Extract embed_scale from base layer if present and valid. + + Some embedding layers (e.g., Gemma3TextScaledWordEmbedding) apply scaling to embeddings in their forward + method. This method checks for the presence of an `embed_scale` attribute. If it exists, it is assumed to be a + scalar. Its shape is validated accordingly. + + Returns: + torch.Tensor or None: The embed_scale tensor if found and valid, None otherwise. + """ + base_layer = self.get_base_layer() + if not hasattr(base_layer, "embed_scale"): + return None + + embed_scale = base_layer.embed_scale + + # Convert scalar values to tensors + if isinstance(embed_scale, (int, float)): + return torch.tensor(embed_scale, device=base_layer.weight.device, dtype=base_layer.weight.dtype) + + # Validate tensor shape - must be scalar (0-d) or 1-element tensor for proper broadcasting + if isinstance(embed_scale, torch.Tensor): + if embed_scale.numel() == 1: + return embed_scale + else: + # Log warning but don't fail - this maintains backward compatibility + warnings.warn( + f"Found embed_scale attribute with shape {embed_scale.shape}, expected scalar. " + "Embedding scaling will not be applied. If this is unexpected, please open an issue at " + "https://github.com/huggingface/peft/issues", + PeftWarning, + ) + return None + + return None + @property def weight(self) -> torch.Tensor: # This is required for some transformers code, e.g. for T5, weight is accessed as: diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index c0e37101..ed7ff2b9 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -796,6 +796,66 @@ class TestDecoderModels(PeftCommonTester): else: assert not contains_embedding + def test_lora_embed_scale_is_applied(self): + """Test that LoRA correctly handles embeddings with scaling (e.g., Gemma3).""" + model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM" + with hub_online_once(model_id): + base_model = AutoModelForCausalLM.from_pretrained(model_id).to(self.torch_device) + orig_embedding = base_model.get_input_embeddings() + + peft_config = LoraConfig(target_modules=["embed_tokens"], init_lora_weights=False) + peft_model = get_peft_model(base_model, peft_config) + + x = torch.arange(10).to(self.torch_device) + peft_embedding = peft_model.base_model.model.get_input_embeddings() + embedding_output = peft_embedding(x) + max_embedding_output = embedding_output.abs().max(0)[0] + assert (max_embedding_output < 100.0).all() + peft_model.merge_adapter() + embedding_merged = peft_embedding(x) + assert torch.allclose(embedding_output, embedding_merged) + peft_model.unmerge_adapter() + + # set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high + # value + orig_embedding.embed_scale.fill_(10000.0) + max_embedding_output = peft_embedding(x).abs().max(0)[0] + assert (max_embedding_output > 100.0).all() + + # set embed_scale to zero, then check that the embedding output is also zero + orig_embedding.embed_scale.fill_(0) + embedding_output = peft_embedding(x) + assert (embedding_output == 0.0).all() + + def test_lora_embed_scale_is_applied_mixed_batch(self): + """Test that LoRA correctly handles embeddings with scaling in mixed batch mode.""" + model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM" + with hub_online_once(model_id): + base_model = AutoModelForCausalLM.from_pretrained(model_id) + orig_embedding = base_model.get_input_embeddings() + + peft_config = LoraConfig(target_modules=["embed_tokens"], init_lora_weights=False) + peft_model = get_peft_model(base_model, peft_config) + peft_model.add_adapter("adapter2", peft_config) + + # sanity check: with the default embed_scale, the embedding output should be reasonably sized + peft_embedding = peft_model.base_model.model.get_input_embeddings() + input_ids = torch.arange(10).unsqueeze(0).repeat(2, 1) + adapter_names = ["default", "adapter2"] + max_embedding_output = peft_embedding(input_ids, adapter_names=adapter_names).abs().max() + assert max_embedding_output < 100.0 + + # set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high + # value + orig_embedding.embed_scale.fill_(10000.0) + max_embedding_output = peft_embedding(input_ids, adapter_names=adapter_names).abs().max() + assert max_embedding_output > 100.0 + + # set embed_scale to zero, then check that the embedding output is also zero + orig_embedding.embed_scale.fill_(0) + embedding_output = peft_embedding(input_ids, adapter_names=adapter_names) + assert (embedding_output == 0.0).all() + @pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS) def test_set_requires_grad_prompt_learning_raises(self, config_cls, config_kwargs): # Test that for prompt learning, calling set_requires_grad raises an error with an appropriate error message. diff --git a/tests/test_trainable_tokens.py b/tests/test_trainable_tokens.py index 38b32b06..a642fe54 100644 --- a/tests/test_trainable_tokens.py +++ b/tests/test_trainable_tokens.py @@ -918,3 +918,63 @@ class TestTrainableTokens: assert contains_embedding else: assert not contains_embedding + + def test_embed_scale_is_applied(self): + """Test that TrainableTokens correctly handles embeddings with scaling (e.g., Gemma3).""" + model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM" + with hub_online_once(model_id): + base_model = AutoModelForCausalLM.from_pretrained(model_id) + orig_embedding = base_model.get_input_embeddings() + + peft_config = TrainableTokensConfig(target_modules=["embed_tokens"], token_indices=[0, 1, 3]) + peft_model = get_peft_model(base_model, peft_config) + + # sanity check: with the default embed_scale, the embedding output should be reasonably sized + peft_embedding = peft_model.base_model.model.get_input_embeddings() + max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0] + assert (max_embedding_output < 100.0).all() + + # set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high + # value + orig_embedding.embed_scale.fill_(10000.0) + max_embedding_output = peft_embedding(torch.arange(10)).abs().max(0)[0] + assert (max_embedding_output > 100.0).all() + + # set embed_scale to zero, then check that the embedding output is also zero + orig_embedding.embed_scale.fill_(0) + embedding_output = peft_embedding(torch.arange(10)) + assert (embedding_output == 0.0).all() + + def test_scaled_embedding_with_lora(self): + """ + Test that TrainableTokens works with LoRA on scaled embeddings when both are active simultaneously. + """ + model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM" + with hub_online_once(model_id): + base_model = AutoModelForCausalLM.from_pretrained(model_id) + orig_embedding = base_model.get_input_embeddings() + + # Apply both TrainableTokens and LoRA to the same model + peft_config = LoraConfig(target_modules=["q_proj"], trainable_token_indices={"embed_tokens": [0, 1, 3]}) + peft_model = get_peft_model(base_model, peft_config) + + x = torch.arange(10) + peft_embedding = peft_model.base_model.model.get_input_embeddings() + embedding_output = peft_embedding(x) + max_embedding_output = embedding_output.abs().max(0)[0] + assert (max_embedding_output < 100.0).all() + peft_model.merge_adapter() + embedding_merged = peft_embedding(x) + assert torch.allclose(embedding_output, embedding_merged) + peft_model.unmerge_adapter() + + # set embed_scale to an absurdly high value, then check that the embedding output is also scaled to a high + # value + orig_embedding.embed_scale.fill_(10000.0) + max_embedding_output = peft_embedding(x).abs().max(0)[0] + assert (max_embedding_output > 100.0).all() + + # set embed_scale to zero, then check that the embedding output is also zero + orig_embedding.embed_scale.fill_(0) + embedding_output = peft_embedding(x) + assert (embedding_output == 0.0).all()