Compare commits

...

3 Commits

Author SHA1 Message Date
f8aca0a0c2 ENH Merging LoRAs supports negative weights (#2811)
When using add_weighted_adapter, so far, there was an implicit
assumption that all weights are positive. This PR allows negative
weights to be passed.

---------

Co-authored-by: Valentin Teutschbein <valentin.teutschbein@student.hpi.uni-potsdam.de>
2025-10-09 13:53:08 +02:00
e9f5707e3f FIX X-LoRA scaling storage and per token normalization (#2793) 2025-10-09 13:36:54 +02:00
2c29cf7936 ENH Add sample vocab init to PromptEmbedding (#2815) 2025-10-09 12:21:40 +02:00
9 changed files with 201 additions and 6 deletions

View File

@ -540,7 +540,8 @@ class LoraModel(BaseTuner):
adapters (`list`):
List of adapter names to be merged.
weights (`list`):
List of weights for each adapter.
List of weights for each adapter. Weights can be positive or negative, allowing for both addition and
subtraction of adapter effects.
adapter_name (`str`):
Name of the new adapter.
combination_type (`str`):
@ -742,7 +743,10 @@ class LoraModel(BaseTuner):
current_adapter_lora_B = target.lora_embedding_B[adapter]
else:
continue
valid_weights.append(math.sqrt(weight * target.scaling[adapter]))
# Support negative weights: take absolute value for sqrt, then apply sign
weight_with_scaling = weight * target.scaling[adapter]
sign = 1 if weight_with_scaling >= 0 else -1
valid_weights.append(sign * math.sqrt(abs(weight_with_scaling)))
lora_A_deltas.append(current_adapter_lora_A.data)
lora_B_deltas.append(current_adapter_lora_B.data)
valid_weights = torch.tensor(valid_weights).to(lora_A_deltas[0].device)

View File

@ -22,6 +22,7 @@ from peft.utils import PeftType
class PromptTuningInit(str, enum.Enum):
TEXT = "TEXT"
SAMPLE_VOCAB = "SAMPLE_VOCAB"
RANDOM = "RANDOM"
@ -31,7 +32,10 @@ class PromptTuningConfig(PromptLearningConfig):
This is the configuration class to store the configuration of a [`PromptEmbedding`].
Args:
prompt_tuning_init (Union[[`PromptTuningInit`], `str`]): The initialization of the prompt embedding.
prompt_tuning_init (Union[[`PromptTuningInit`], `str`]):
The initialization of the prompt embedding. `TEXT` will initialize with your text. `SAMPLE_VOCAB` will
initialize with randomly sampled tokens from the model's vocabulary. `RANDOM` will initialize with randomly
sampled continuous, soft tokens (warning: sampled soft tokens may fall outside of embedding manifold)
prompt_tuning_init_text (`str`, *optional*):
The text to initialize the prompt embedding. Only used if `prompt_tuning_init` is `TEXT`.
tokenizer_name_or_path (`str`, *optional*):

View File

@ -64,7 +64,18 @@ class PromptEmbedding(torch.nn.Module):
total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules
self.embedding = torch.nn.Embedding(total_virtual_tokens, config.token_dim)
if config.prompt_tuning_init == PromptTuningInit.TEXT and not config.inference_mode:
if config.prompt_tuning_init == PromptTuningInit.SAMPLE_VOCAB and not config.inference_mode:
# Randomly sample tokens from the tokenizer's vocab
vocab_size = word_embeddings.num_embeddings
init_token_ids = torch.randint(0, vocab_size, (total_virtual_tokens,), dtype=torch.long).to(
word_embeddings.weight.device
)
with gather_params_ctx(word_embeddings.parameters()):
word_embedding_weights = word_embeddings(init_token_ids).detach().clone()
word_embedding_weights = word_embedding_weights.to(torch.float32)
self.embedding.weight = torch.nn.Parameter(word_embedding_weights)
elif config.prompt_tuning_init == PromptTuningInit.TEXT and not config.inference_mode:
from transformers import AutoTokenizer
tokenizer_kwargs = config.tokenizer_kwargs or {}

View File

@ -73,10 +73,12 @@ class XLoraLayer:
xlora_scalings = xlora_scalings * mask.to(xlora_scalings.dtype)
# Apply per-token normalization to the xLoRA scaling factors using a softmax
if self.config.enable_softmax_topk:
nonzero_mask = xlora_scalings != 0
softmax_res_nonzero = torch.softmax(xlora_scalings[nonzero_mask], dim=-1)
xlora_scalings[nonzero_mask] = softmax_res_nonzero
full = xlora_scalings.masked_fill(~nonzero_mask, float("-inf"))
new_scalings = torch.softmax(full, dim=-1)
xlora_scalings = new_scalings.masked_fill(~nonzero_mask, 0.0)
return xlora_scalings

View File

@ -368,6 +368,8 @@ class XLoraModel(BaseTuner):
self.lora_model.enable_adapter_layers()
xlora_scalings = self.internal_xlora_classifier(result=base_output, *args_real, **kwargs_real)
# Store computed scalings to fix get_latest_scalings() returning None
self.internal_xlora_scalings = xlora_scalings
# =========================== Real forward pass with calculated scalings ==================

View File

@ -2916,6 +2916,118 @@ class TestPeftCustomModel(PeftCommonTester):
["default", "other"], weights=[1.0, 1.0], adapter_name="merged", combination_type="cat"
)
def test_add_weighted_adapter_negative_weight_negates_adapter(self):
# Test that weight=-1.0 properly negates an adapter
torch.manual_seed(42)
model = MLP()
config = LoraConfig(target_modules=["lin0"], init_lora_weights=False)
model = get_peft_model(model, config, adapter_name="adapter1")
# Create merged adapter with weight=1.0
model.add_weighted_adapter(
adapters=["adapter1"],
weights=[1.0],
adapter_name="merged_positive",
combination_type="linear",
)
# Create merged adapter with weight=-1.0
model.add_weighted_adapter(
adapters=["adapter1"],
weights=[-1.0],
adapter_name="merged_negative",
combination_type="linear",
)
# Get the LoRA weights for comparison
for name, module in model.named_modules():
if hasattr(module, "lora_A") and "merged_positive" in module.lora_A:
pos_A = module.lora_A["merged_positive"].weight.data
neg_A = module.lora_A["merged_negative"].weight.data
pos_B = module.lora_B["merged_positive"].weight.data
neg_B = module.lora_B["merged_negative"].weight.data
# Check that negative adapter is negation of positive
# Since we apply sign to both A and B: sign * sqrt(|w|)
# For w=1: sqrt(1) = 1, for w=-1: -sqrt(1) = -1
assert torch.allclose(neg_A, -pos_A, atol=1e-6), "A matrices should be negated"
assert torch.allclose(neg_B, -pos_B, atol=1e-6), "B matrices should be negated"
def test_add_weighted_adapter_subtraction_with_negative_weights(self):
# Test that merging two identical adapters with weights [1.0, -1.0] results in approximately zero weights
model = MLP()
config = LoraConfig(target_modules=["lin0"], init_lora_weights=False)
# Create two identical adapters by using the same seed
torch.manual_seed(42)
model = get_peft_model(model, config, adapter_name="adapter1")
torch.manual_seed(42)
model.add_adapter("adapter2", config)
# Merge with weights [1.0, -1.0] - should cancel out exactly
model.add_weighted_adapter(
adapters=["adapter1", "adapter2"],
weights=[1.0, -1.0],
adapter_name="cancelled",
combination_type="linear",
)
# Verify the merged adapter has weights of approximately 0
for name, module in model.named_modules():
if hasattr(module, "lora_A") and "cancelled" in module.lora_A:
cancelled_A = module.lora_A["cancelled"].weight.data
cancelled_B = module.lora_B["cancelled"].weight.data
# The weights should be approximately zero (they cancel out)
assert torch.allclose(cancelled_A, torch.zeros_like(cancelled_A), atol=1e-5), (
f"Cancelled A should be ~0, got max abs value {cancelled_A.abs().max()}"
)
assert torch.allclose(cancelled_B, torch.zeros_like(cancelled_B), atol=1e-5), (
f"Cancelled B should be ~0, got max abs value {cancelled_B.abs().max()}"
)
def test_add_weighted_adapter_negative_weight_with_different_scaling(self):
# Test negative weights with different scaling factors (lora_alpha)
# This edge case ensures negative weights work correctly with different scaling values
torch.manual_seed(42)
model = MLP()
# Create two configs with different lora_alpha (different scaling factors)
config1 = LoraConfig(
r=8,
lora_alpha=16, # scaling = 16/8 = 2
target_modules=["lin0"],
lora_dropout=0.0,
bias="none",
init_lora_weights=False,
)
config2 = LoraConfig(
r=8,
lora_alpha=32, # scaling = 32/8 = 4
target_modules=["lin0"],
lora_dropout=0.0,
bias="none",
init_lora_weights=False,
)
model = get_peft_model(model, config1, adapter_name="adapter1")
model.add_adapter("adapter2", config2)
# Merge with negative weight - should handle different scalings correctly
model.add_weighted_adapter(
adapters=["adapter1", "adapter2"],
weights=[0.5, -0.3],
adapter_name="merged_diff_scaling",
combination_type="linear",
)
# Verify the merged adapter can run forward pass
model.set_adapter("merged_diff_scaling")
dummy_input = torch.randn(2, 10)
output = model(dummy_input)
assert output is not None
def test_multiple_adapters_no_needless_copy_modules_to_save(self):
# See 2206
# The problem was that we keep a "global" modules_to_save on the model which contains all possible

View File

@ -380,6 +380,18 @@ class TestDecoderModels(PeftCommonTester):
expected_call = call(model_id, trust_remote_code=True, foo="bar")
assert mock.call_args == expected_call
@pytest.mark.parametrize("model_id", PEFT_DECODER_MODELS_TO_TEST)
@pytest.mark.parametrize("config_cls,config_kwargs", ALL_CONFIGS)
def test_prompt_tuning_sample_vocab_prepare_for_training(self, model_id, config_cls, config_kwargs):
if config_cls != PromptTuningConfig:
pytest.skip(f"This test does not apply to {config_cls}")
config_kwargs = config_kwargs.copy()
config_kwargs["prompt_tuning_init"] = PromptTuningInit.SAMPLE_VOCAB
config_kwargs["tokenizer_name_or_path"] = model_id
self._test_prepare_for_training(model_id, config_cls, config_kwargs.copy())
def test_prompt_tuning_config_invalid_args(self):
# Raise an error when tokenizer_kwargs is used with prompt_tuning_init!='TEXT', because this argument has no
# function in that case

View File

@ -23,6 +23,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, PeftType, TaskType, XLoraConfig, get_peft_model
from peft.peft_model import PeftModel
from peft.tuners.xlora.layer import XLoraLayer
from peft.utils import infer_device
@ -381,3 +382,45 @@ class TestXlora:
w1 = sd["base_model.model.model.decoder.layers.0.self_attn.q_proj.lora_A.weight"]
assert torch.allclose(w0, w1)
def test_scalings_storage(self, tokenizer, model):
model.enable_scalings_logging()
inputs = tokenizer.encode("Python is a", add_special_tokens=False, return_tensors="pt")
outputs = model.generate(
input_ids=inputs.to(self.torch_device),
max_new_tokens=10,
)
latest_scalings = model.get_latest_scalings()
assert latest_scalings is not None, "get_latest_scalings() should not return None after generation"
assert isinstance(latest_scalings, torch.Tensor)
assert torch.isfinite(latest_scalings).all(), "Scalings should contain finite values"
def test_per_token_normalization_with_softmax_topk(self, tokenizer, model, monkeypatch):
model.internal_xlora_classifier.config.top_k_lora = 2
model.internal_xlora_classifier.config.enable_softmax = False
model.internal_xlora_classifier.config.enable_softmax_topk = True
captured_data = []
orig_get_maybe_topk_scalings = XLoraLayer.get_maybe_topk_scalings
def mock_get_maybe_topk_scalings(self, scalings):
result = orig_get_maybe_topk_scalings(self, scalings)
if getattr(model, "internal_xlora_scalings", None) is not None:
captured_data.append(result)
return result
monkeypatch.setattr(XLoraLayer, "get_maybe_topk_scalings", mock_get_maybe_topk_scalings)
model.enable_scalings_logging()
inputs = tokenizer.encode("Test per token normalization", add_special_tokens=False, return_tensors="pt")
outputs = model.generate(
input_ids=inputs.to(self.torch_device),
max_new_tokens=1,
)
for scaling in captured_data:
weight_sums = scaling.sum(dim=-1)
assert torch.allclose(weight_sums, torch.ones_like(weight_sums), atol=1e-5), (
"Per-token scaling weights are not normalized to sum to 1."
)

View File

@ -1775,6 +1775,8 @@ class PeftCommonTester:
if "single" in adapter_name:
new_delta_weight = target.get_delta_weight(adapter_name)
weighted_original_delta_weights = target.get_delta_weight(adapter_list[0]) * weight_list[0]
sign = 1 if weight_list[0] > 0 else -1
weighted_original_delta_weights = sign * weighted_original_delta_weights
assert torch.allclose(new_delta_weight, weighted_original_delta_weights, atol=1e-4, rtol=1e-4)
elif "svd" in adapter_name:
assert target.r[adapter_name] == 20
@ -1831,6 +1833,7 @@ class PeftCommonTester:
adapter_list = ["adapter1", "adapter_2", "adapter_3"]
weight_list = [0.5, 1.5, 1.5]
negative_weight_list = [-0.5, -0.8, -1.2]
# Initialize the config
config = config_cls(
base_model_name_or_path=model_id,
@ -1847,8 +1850,10 @@ class PeftCommonTester:
if isinstance(config, LoraConfig):
self._test_weighted_combination_of_adapters_lora(model, config, adapter_list, weight_list)
self._test_weighted_combination_of_adapters_lora(model, config, adapter_list, negative_weight_list)
elif isinstance(config, IA3Config):
self._test_weighted_combination_of_adapters_ia3(model, config, adapter_list, weight_list)
self._test_weighted_combination_of_adapters_ia3(model, config, adapter_list, negative_weight_list)
else:
pytest.skip(f"Test not applicable for {config}")