[Model] Interface to enable batch-level DP support (#23733)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Cyrus Leung
2025-08-27 21:41:22 +08:00
committed by GitHub
parent 16dc4052b0
commit fe8d7b6f03
8 changed files with 38 additions and 4 deletions

View File

@ -168,8 +168,11 @@ llm = LLM(
Batch-level DP is not to be confused with API request-level DP
(which is instead controlled by `data_parallel_size`).
The availability of batch-level DP is based on model implementation.
Currently, the following models support `mm_encoder_tp_mode="data"`:
Batch-level DP needs to be implemented on a per-model basis,
and enabled by setting `supports_encoder_tp_data = True` in the model class.
Regardless, you need to set `mm_encoder_tp_mode="data"` in engine arguments to use this feature.
Known supported models:
- Llama4 (<gh-pr:18368>)
- MiniCPM-V-4 (<gh-pr:23327>)

View File

@ -872,6 +872,13 @@ class ModelConfig:
def _init_multimodal_config(self) -> Optional["MultiModalConfig"]:
if self._model_info.supports_multimodal:
if (self.mm_encoder_tp_mode == "data" and
not self._model_info.supports_multimodal_encoder_tp_data):
logger.warning_once(
"This model does not support `--mm-encoder-tp-mode data`. "
"Falling back to `--mm-encoder-tp-mode weights`.")
self.mm_encoder_tp_mode = "weights"
return MultiModalConfig(
limit_per_prompt=self.limit_mm_per_prompt,
media_io_kwargs=self.media_io_kwargs,

View File

@ -52,6 +52,12 @@ class SupportsMultiModal(Protocol):
MRO of your model class.
"""
supports_encoder_tp_data: ClassVar[bool] = False
"""
A flag that indicates whether this model supports
`multimodal_config.mm_encoder_tp_mode="data"`.
"""
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
"""
@ -137,6 +143,11 @@ def supports_multimodal(
return getattr(model, "supports_multimodal", False)
def supports_multimodal_encoder_tp_data(
model: Union[type[object], object]) -> bool:
return getattr(model, "supports_encoder_tp_data", False)
@runtime_checkable
class SupportsMultiModalWithRawInput(SupportsMultiModal, Protocol):
"""The interface required for all multi-modal models."""

View File

@ -1521,6 +1521,8 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA):
],
}
supports_encoder_tp_data = True
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)
assert self.version == (4, 0)

View File

@ -716,6 +716,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
"gate_up_proj": ["gate_proj", "up_proj"],
}
supports_encoder_tp_data = True
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"):

View File

@ -868,6 +868,8 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
"model.": "language_model.model.",
})
supports_encoder_tp_data = True
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"):

View File

@ -27,8 +27,10 @@ from vllm.transformers_utils.dynamic_module import (
from .interfaces import (has_inner_state, has_noops, is_attention_free,
is_hybrid, supports_cross_encoding,
supports_multimodal, supports_multimodal_raw_input,
supports_pp, supports_transcription, supports_v0_only)
supports_multimodal,
supports_multimodal_encoder_tp_data,
supports_multimodal_raw_input, supports_pp,
supports_transcription, supports_v0_only)
from .interfaces_base import (get_default_pooling_type, is_pooling_model,
is_text_generation_model)
@ -324,6 +326,7 @@ class _ModelInfo:
supports_cross_encoding: bool
supports_multimodal: bool
supports_multimodal_raw_input: bool
supports_multimodal_encoder_tp_data: bool
supports_pp: bool
has_inner_state: bool
is_attention_free: bool
@ -343,6 +346,8 @@ class _ModelInfo:
supports_cross_encoding=supports_cross_encoding(model),
supports_multimodal=supports_multimodal(model),
supports_multimodal_raw_input=supports_multimodal_raw_input(model),
supports_multimodal_encoder_tp_data=
supports_multimodal_encoder_tp_data(model),
supports_pp=supports_pp(model),
has_inner_state=has_inner_state(model),
is_attention_free=is_attention_free(model),

View File

@ -867,6 +867,8 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
"lm_head.": "language_model.lm_head.",
})
supports_encoder_tp_data = True
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"):