mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Misc] Embedding model support LoRA (#14935)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@ -30,6 +30,7 @@ from vllm.lora.utils import (from_layer, from_layer_logits_processor,
|
||||
is_regex_target_modules,
|
||||
parse_fine_tuned_lora_name, replace_submodule)
|
||||
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
|
||||
from vllm.model_executor.models.interfaces import is_pooling_model
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
|
||||
from vllm.utils import is_pin_memory_available
|
||||
@ -104,6 +105,9 @@ class LoRAModel(AdapterModel):
|
||||
"""Get LoRA for a given module by name"""
|
||||
return self.loras.get(module_name, None)
|
||||
|
||||
def check_lora_name(self, lora_name: str) -> bool:
|
||||
return lora_name in self.loras
|
||||
|
||||
# (yard1): TODO see if we can derive target_embedding_padding automatically
|
||||
@classmethod
|
||||
def from_lora_tensors(
|
||||
@ -335,6 +339,7 @@ class LoRAModelManager(AdapterModelManager):
|
||||
# Used for long context lora.
|
||||
self.scaling_factor_to_offset: Dict[float, int] = {}
|
||||
super().__init__(model)
|
||||
|
||||
self.supported_lora_modules = get_supported_lora_modules(self.model)
|
||||
assert self.supported_lora_modules, "No supported LoRA modules found in"
|
||||
f"{self.model.__class__.__name__}."
|
||||
@ -350,6 +355,7 @@ class LoRAModelManager(AdapterModelManager):
|
||||
# In case the model only supports LoRA for
|
||||
# text modules (e.g. ChatGLM)
|
||||
and hasattr(self.model, "get_mm_mapping"))
|
||||
self.is_pooling_model = is_pooling_model(self.model)
|
||||
self.packed_modules: Dict[str, List[str]] = {}
|
||||
self.modules: Dict[str, BaseLayerWithLoRA] = {}
|
||||
# Dict instead of a Set for compatibility with LRUCache.
|
||||
@ -389,7 +395,7 @@ class LoRAModelManager(AdapterModelManager):
|
||||
lora_model.id, index)
|
||||
self.lora_index_to_id[index] = lora_model.id
|
||||
for module_name, module in self.modules.items():
|
||||
module_lora = lora_model.get_lora(module_name)
|
||||
module_lora = self._get_lora_layer_weights(lora_model, module_name)
|
||||
if module_lora:
|
||||
module_lora.optimize()
|
||||
# Bias is not explicitly enabled with the flag enable_lora_bias.
|
||||
@ -626,7 +632,7 @@ class LoRAModelManager(AdapterModelManager):
|
||||
replaced_module: Set[str] = set()
|
||||
has_replacement = False
|
||||
for r in new_module_names:
|
||||
lora = lora_model.get_lora(r)
|
||||
lora = self._get_lora_layer_weights(lora_model, r)
|
||||
replacement_loras.append(lora)
|
||||
if lora:
|
||||
has_replacement = True
|
||||
@ -637,12 +643,34 @@ class LoRAModelManager(AdapterModelManager):
|
||||
if replacement_loras[i]:
|
||||
continue
|
||||
replacement_loras[i] = None
|
||||
# HACK Temporary solution for the pool model.
|
||||
if self.is_pooling_model and not lora_model.check_lora_name(
|
||||
module_name):
|
||||
replaced_module_name = module_name.replace("model.", "")
|
||||
if lora_model.check_lora_name(module_name):
|
||||
module_name = replaced_module_name
|
||||
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
|
||||
replacement_loras)
|
||||
# Remove the modules that have been replaced.
|
||||
for module in replaced_module:
|
||||
lora_model.loras.pop(module, None)
|
||||
|
||||
def _get_lora_layer_weights(
|
||||
self, lora_model: LoRAModel,
|
||||
module_name: str) -> Optional[LoRALayerWeights]:
|
||||
org_module_name = module_name
|
||||
if self.is_pooling_model and not lora_model.check_lora_name(
|
||||
module_name):
|
||||
# If it's a pool model, and the layer name is not found,
|
||||
# remove the prefix 'model.' and search again.
|
||||
module_name = module_name.replace("model.", "")
|
||||
if lora_model.check_lora_name(module_name):
|
||||
org_module_name = module_name
|
||||
logger.info_once(
|
||||
"For the pool model, successfully loaded the LoRA weights "
|
||||
"after removing the prefix 'model.'.")
|
||||
return lora_model.get_lora(org_module_name)
|
||||
|
||||
def deactivate_adapter(self, adapter_id: int) -> bool:
|
||||
return deactivate_adapter(adapter_id, self._active_adapters,
|
||||
self._deactivate_adapter)
|
||||
|
Reference in New Issue
Block a user