mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Core][LoRA][1/N] Add LoRA for EncoderDecoderModelRunner (#15990)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@ -866,6 +866,11 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
|
||||
and len(packed_modules_list) == 3)
|
||||
|
||||
|
||||
#TODO: Implement this
|
||||
class QKVCrossParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
pass
|
||||
|
||||
|
||||
class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
|
||||
def __init__(self, base_layer: RowParallelLinear) -> None:
|
||||
|
@ -52,6 +52,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||
@ -1181,6 +1182,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
super().__init__()
|
||||
config: MllamaConfig = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
self.hidden_size = config.text_config.hidden_size
|
||||
@ -1517,6 +1519,15 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
updated_params.add(name)
|
||||
return updated_params
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
Get the module prefix in multimodal models
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="language_model",
|
||||
connector="multi_modal_projector",
|
||||
tower_model="vision_model")
|
||||
|
||||
|
||||
def skip_attention_mask(sparse_mask: List[List[int]]) -> bool:
|
||||
for mask in sparse_mask:
|
||||
|
@ -16,6 +16,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
|
||||
@ -34,6 +35,7 @@ from vllm.worker.model_runner_base import (
|
||||
from vllm.worker.utils import assert_enc_dec_mr_supported_scenario
|
||||
|
||||
logger = init_logger(__name__)
|
||||
LORA_WARMUP_RANK = 8
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -160,7 +162,11 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
if num_steps > 1:
|
||||
raise ValueError("num_steps > 1 is not supported in "
|
||||
"EncoderDecoderModelRunner")
|
||||
|
||||
if self.lora_config:
|
||||
assert model_input.lora_requests is not None
|
||||
assert model_input.lora_mapping is not None
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
if (model_input.attn_metadata is not None
|
||||
and model_input.attn_metadata.prefill_metadata is None
|
||||
and model_input.attn_metadata.decode_metadata.use_cuda_graph):
|
||||
@ -268,6 +274,22 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
max_num_seqs = self.scheduler_config.max_num_seqs
|
||||
|
||||
# This represents the maximum number of different requests
|
||||
# that will have unique loras, and therefore the max amount of
|
||||
# memory consumption. Create dummy lora request copies from the
|
||||
# lora request passed in, which contains a lora from the lora
|
||||
# warmup path.
|
||||
dummy_lora_requests: List[LoRARequest] = []
|
||||
dummy_lora_requests_per_seq: List[LoRARequest] = []
|
||||
if self.lora_config:
|
||||
dummy_lora_requests = self._add_dummy_loras(
|
||||
self.lora_config.max_loras)
|
||||
assert len(dummy_lora_requests) == self.lora_config.max_loras
|
||||
dummy_lora_requests_per_seq = [
|
||||
dummy_lora_requests[idx % len(dummy_lora_requests)]
|
||||
for idx in range(max_num_seqs)
|
||||
]
|
||||
|
||||
# Profile memory usage with max_num_sequences sequences and the total
|
||||
# number of tokens equal to max_num_batched_tokens.
|
||||
seqs: List[SequenceGroupMetadata] = []
|
||||
@ -315,6 +337,8 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
block_tables=None,
|
||||
encoder_seq_data=encoder_dummy_data.seq_data,
|
||||
cross_block_table=None,
|
||||
lora_request=dummy_lora_requests_per_seq[group_id]
|
||||
if dummy_lora_requests_per_seq else None,
|
||||
multi_modal_data=decoder_dummy_data.multi_modal_data
|
||||
or encoder_dummy_data.multi_modal_data,
|
||||
multi_modal_placeholders=decoder_dummy_data.
|
||||
|
Reference in New Issue
Block a user