[Misc] Embedding model support LoRA (#14935)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2025-03-18 20:07:00 +08:00
committed by GitHub
parent f863ffc965
commit db7c8ca910

View File

@ -30,6 +30,7 @@ from vllm.lora.utils import (from_layer, from_layer_logits_processor,
is_regex_target_modules,
parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
from vllm.model_executor.models.interfaces import is_pooling_model
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
from vllm.utils import is_pin_memory_available
@ -104,6 +105,9 @@ class LoRAModel(AdapterModel):
"""Get LoRA for a given module by name"""
return self.loras.get(module_name, None)
def check_lora_name(self, lora_name: str) -> bool:
return lora_name in self.loras
# (yard1): TODO see if we can derive target_embedding_padding automatically
@classmethod
def from_lora_tensors(
@ -335,6 +339,7 @@ class LoRAModelManager(AdapterModelManager):
# Used for long context lora.
self.scaling_factor_to_offset: Dict[float, int] = {}
super().__init__(model)
self.supported_lora_modules = get_supported_lora_modules(self.model)
assert self.supported_lora_modules, "No supported LoRA modules found in"
f"{self.model.__class__.__name__}."
@ -350,6 +355,7 @@ class LoRAModelManager(AdapterModelManager):
# In case the model only supports LoRA for
# text modules (e.g. ChatGLM)
and hasattr(self.model, "get_mm_mapping"))
self.is_pooling_model = is_pooling_model(self.model)
self.packed_modules: Dict[str, List[str]] = {}
self.modules: Dict[str, BaseLayerWithLoRA] = {}
# Dict instead of a Set for compatibility with LRUCache.
@ -389,7 +395,7 @@ class LoRAModelManager(AdapterModelManager):
lora_model.id, index)
self.lora_index_to_id[index] = lora_model.id
for module_name, module in self.modules.items():
module_lora = lora_model.get_lora(module_name)
module_lora = self._get_lora_layer_weights(lora_model, module_name)
if module_lora:
module_lora.optimize()
# Bias is not explicitly enabled with the flag enable_lora_bias.
@ -626,7 +632,7 @@ class LoRAModelManager(AdapterModelManager):
replaced_module: Set[str] = set()
has_replacement = False
for r in new_module_names:
lora = lora_model.get_lora(r)
lora = self._get_lora_layer_weights(lora_model, r)
replacement_loras.append(lora)
if lora:
has_replacement = True
@ -637,12 +643,34 @@ class LoRAModelManager(AdapterModelManager):
if replacement_loras[i]:
continue
replacement_loras[i] = None
# HACK Temporary solution for the pool model.
if self.is_pooling_model and not lora_model.check_lora_name(
module_name):
replaced_module_name = module_name.replace("model.", "")
if lora_model.check_lora_name(module_name):
module_name = replaced_module_name
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
replacement_loras)
# Remove the modules that have been replaced.
for module in replaced_module:
lora_model.loras.pop(module, None)
def _get_lora_layer_weights(
self, lora_model: LoRAModel,
module_name: str) -> Optional[LoRALayerWeights]:
org_module_name = module_name
if self.is_pooling_model and not lora_model.check_lora_name(
module_name):
# If it's a pool model, and the layer name is not found,
# remove the prefix 'model.' and search again.
module_name = module_name.replace("model.", "")
if lora_model.check_lora_name(module_name):
org_module_name = module_name
logger.info_once(
"For the pool model, successfully loaded the LoRA weights "
"after removing the prefix 'model.'.")
return lora_model.get_lora(org_module_name)
def deactivate_adapter(self, adapter_id: int) -> bool:
return deactivate_adapter(adapter_id, self._active_adapters,
self._deactivate_adapter)