From fe8d7b6f03e7d8a36ffb6931397fc81ee594dd64 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 27 Aug 2025 21:41:22 +0800 Subject: [PATCH] [Model] Interface to enable batch-level DP support (#23733) Signed-off-by: DarkLight1337 Signed-off-by: Cyrus Leung Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- docs/configuration/optimization.md | 7 +++++-- vllm/config/__init__.py | 7 +++++++ vllm/model_executor/models/interfaces.py | 11 +++++++++++ vllm/model_executor/models/minicpmv.py | 2 ++ vllm/model_executor/models/mllama4.py | 2 ++ vllm/model_executor/models/qwen2_5_vl.py | 2 ++ vllm/model_executor/models/registry.py | 9 +++++++-- vllm/model_executor/models/step3_vl.py | 2 ++ 8 files changed, 38 insertions(+), 4 deletions(-) diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index a8eab9985c..b11ccb5c00 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -168,8 +168,11 @@ llm = LLM( Batch-level DP is not to be confused with API request-level DP (which is instead controlled by `data_parallel_size`). -The availability of batch-level DP is based on model implementation. -Currently, the following models support `mm_encoder_tp_mode="data"`: +Batch-level DP needs to be implemented on a per-model basis, +and enabled by setting `supports_encoder_tp_data = True` in the model class. +Regardless, you need to set `mm_encoder_tp_mode="data"` in engine arguments to use this feature. + +Known supported models: - Llama4 () - MiniCPM-V-4 () diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index ac6f51df95..e3fb6d796d 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -872,6 +872,13 @@ class ModelConfig: def _init_multimodal_config(self) -> Optional["MultiModalConfig"]: if self._model_info.supports_multimodal: + if (self.mm_encoder_tp_mode == "data" and + not self._model_info.supports_multimodal_encoder_tp_data): + logger.warning_once( + "This model does not support `--mm-encoder-tp-mode data`. " + "Falling back to `--mm-encoder-tp-mode weights`.") + self.mm_encoder_tp_mode = "weights" + return MultiModalConfig( limit_per_prompt=self.limit_mm_per_prompt, media_io_kwargs=self.media_io_kwargs, diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 22f005849e..506732fed3 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -52,6 +52,12 @@ class SupportsMultiModal(Protocol): MRO of your model class. """ + supports_encoder_tp_data: ClassVar[bool] = False + """ + A flag that indicates whether this model supports + `multimodal_config.mm_encoder_tp_mode="data"`. + """ + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: """ @@ -137,6 +143,11 @@ def supports_multimodal( return getattr(model, "supports_multimodal", False) +def supports_multimodal_encoder_tp_data( + model: Union[type[object], object]) -> bool: + return getattr(model, "supports_encoder_tp_data", False) + + @runtime_checkable class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol): """The interface required for all multi-modal models.""" diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 2d785c30fd..0181bfeebd 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -1521,6 +1521,8 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA): ], } + supports_encoder_tp_data = True + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) assert self.version == (4, 0) diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index 595bdd17cf..ac9b968f7a 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -716,6 +716,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, "gate_up_proj": ["gate_proj", "up_proj"], } + supports_encoder_tp_data = True + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 648ba81eb3..b528083b7c 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -868,6 +868,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, "model.": "language_model.model.", }) + supports_encoder_tp_data = True + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"): diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 196b5f35e1..80eac78cdf 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -27,8 +27,10 @@ from vllm.transformers_utils.dynamic_module import ( from .interfaces import (has_inner_state, has_noops, is_attention_free, is_hybrid, supports_cross_encoding, - supports_multimodal, supports_multimodal_raw_input, - supports_pp, supports_transcription, supports_v0_only) + supports_multimodal, + supports_multimodal_encoder_tp_data, + supports_multimodal_raw_input, supports_pp, + supports_transcription, supports_v0_only) from .interfaces_base import (get_default_pooling_type, is_pooling_model, is_text_generation_model) @@ -324,6 +326,7 @@ class _ModelInfo: supports_cross_encoding: bool supports_multimodal: bool supports_multimodal_raw_input: bool + supports_multimodal_encoder_tp_data: bool supports_pp: bool has_inner_state: bool is_attention_free: bool @@ -343,6 +346,8 @@ class _ModelInfo: supports_cross_encoding=supports_cross_encoding(model), supports_multimodal=supports_multimodal(model), supports_multimodal_raw_input=supports_multimodal_raw_input(model), + supports_multimodal_encoder_tp_data= + supports_multimodal_encoder_tp_data(model), supports_pp=supports_pp(model), has_inner_state=has_inner_state(model), is_attention_free=is_attention_free(model), diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index f8877b584b..f379d2c15f 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -867,6 +867,8 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, "lm_head.": "language_model.lm_head.", }) + supports_encoder_tp_data = True + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: if modality.startswith("image"):