mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[CLI][Doc] Formalize --mm-encoder-tp-mode
(#23190)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -129,6 +129,51 @@ Data parallelism replicates the entire model across multiple GPU sets and proces
|
||||
Data parallelism can be combined with the other parallelism strategies and is set by `data_parallel_size=N`.
|
||||
Note that MoE layers will be sharded according to the product of the tensor parallel size and data parallel size.
|
||||
|
||||
### Batch-level DP for Multi-Modal Encoders
|
||||
|
||||
By default, TP is used to shard the weights of multi-modal encoders just like for language decoders,
|
||||
in order to reduce the memory and compute load on each GPU.
|
||||
|
||||
However, since the size of multi-modal encoders is very small compared to language decoders,
|
||||
there is relatively little gain from TP. On the other hand, TP incurs significant communication
|
||||
overhead because of all-reduce being performed after every layer.
|
||||
|
||||
Given this, it may be advantageous to instead shard the batched input data using TP, essentially
|
||||
performing batch-level DP. This has been shown to improve the throughput by around 10% for
|
||||
`tensor_parallel_size=8`. For vision encoders that use hardware-unoptimized Conv3D operations,
|
||||
batch-level DP can provide another 40% increase to throughput compared to regular TP.
|
||||
|
||||
Nevertheless, since the weights of the multi-modal encoder are replicated across each TP rank,
|
||||
there will be a minor increase in memory consumption and may cause OOM if you can barely fit the model already.
|
||||
|
||||
You can enable batch-level DP by setting `mm_encoder_tp_mode="data"`, for example:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
||||
llm = LLM(
|
||||
model="Qwen/Qwen2.5-VL-72B-Instruct",
|
||||
# Create two EngineCore instances, one per DP rank
|
||||
data_parallel_size=2,
|
||||
# Within each EngineCore instance:
|
||||
# The vision encoder uses TP=4 (not DP=2) to shard the input data
|
||||
# The language decoder uses TP=4 to shard the weights as usual
|
||||
tensor_parallel_size=4,
|
||||
mm_encoder_tp_mode="data",
|
||||
)
|
||||
```
|
||||
|
||||
!! important
|
||||
Batch-level DP is not to be confused with API request-level DP
|
||||
(which is instead controlled by `data_parallel_size`).
|
||||
|
||||
The availablilty of batch-level DP is based on model implementation.
|
||||
Currently, the following models support `mm_encoder_tp_mode="data"`:
|
||||
|
||||
- Llama4 (<gh-pr:18368>)
|
||||
- Qwen2.5-VL (<gh-pr:22742>)
|
||||
- Step3 (<gh-pr:22697>)
|
||||
|
||||
## Input Processing
|
||||
|
||||
### Parallel Processing
|
||||
|
@ -258,6 +258,7 @@ TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
|
||||
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
|
||||
LogprobsMode = Literal["raw_logprobs", "raw_logits", "processed_logprobs",
|
||||
"processed_logits"]
|
||||
MMEncoderTPMode = Literal["weights", "data"]
|
||||
|
||||
|
||||
@config
|
||||
@ -438,6 +439,19 @@ class ModelConfig:
|
||||
`mm_processor_cache_gb * (api_server_count + data_parallel_size)`.
|
||||
|
||||
Set to `0` to disable this cache completely (not recommended)."""
|
||||
mm_encoder_tp_mode: MMEncoderTPMode = "weights"
|
||||
"""Indicates how to optimize multi-modal encoder inference using
|
||||
tensor parallelism (TP).
|
||||
|
||||
- `"weights"`: Within the same vLLM engine, split the weights of
|
||||
each layer across TP ranks. (default TP behavior)
|
||||
- `"data"`: Within the same vLLM engine, split the batched input data
|
||||
across TP ranks to process the data in parallel, while hosting
|
||||
the full weights on each TP rank.
|
||||
This batch-level DP is not to be confused with API request-level
|
||||
DP (which is controlled by `--data-parallel-size`).
|
||||
This is only supported on a per-model basis and falls back to
|
||||
`"weights"` if the encoder does not support DP."""
|
||||
override_neuron_config: dict[str, Any] = field(default_factory=dict)
|
||||
"""Initialize non-default neuron config or override default neuron config
|
||||
that are specific to Neuron devices, this argument will be used to
|
||||
@ -856,8 +870,10 @@ class ModelConfig:
|
||||
media_io_kwargs=self.media_io_kwargs,
|
||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||
mm_processor_cache_gb=self.mm_processor_cache_gb,
|
||||
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
|
||||
interleave_mm_strings=self.interleave_mm_strings,
|
||||
skip_mm_profiling=self.skip_mm_profiling)
|
||||
skip_mm_profiling=self.skip_mm_profiling,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@ -2547,6 +2563,22 @@ class MultiModalConfig:
|
||||
Set to `0` to disable this cache completely (not recommended).
|
||||
"""
|
||||
|
||||
mm_encoder_tp_mode: MMEncoderTPMode = "weights"
|
||||
"""
|
||||
Indicates how to optimize multi-modal encoder inference using
|
||||
tensor parallelism (TP).
|
||||
|
||||
- `"weights"`: Within the same vLLM engine, split the weights of
|
||||
each layer across TP ranks. (default TP behavior)
|
||||
- `"data"`: Within the same vLLM engine, split the batched input data
|
||||
across TP ranks to process the data in parallel, while hosting
|
||||
the full weights on each TP rank.
|
||||
This batch-level DP is not to be confused with API request-level
|
||||
DP (which is controlled by `--data-parallel-size`).
|
||||
This is only supported on a per-model basis and falls back to
|
||||
`"weights"` if the encoder does not support DP.
|
||||
"""
|
||||
|
||||
interleave_mm_strings: bool = False
|
||||
"""
|
||||
Enable fully interleaved support for multimodal prompts.
|
||||
|
@ -137,10 +137,6 @@ class ParallelConfig:
|
||||
rank: int = 0
|
||||
"""Global rank in distributed setup."""
|
||||
|
||||
enable_multimodal_encoder_data_parallel: bool = False
|
||||
""" Use data parallelism instead of tensor parallelism for vision encoder.
|
||||
Only support LLama4 for now"""
|
||||
|
||||
@property
|
||||
def world_size_across_dp(self) -> int:
|
||||
"""world_size_across_dp is TPxPPxDP, it is the size of the world
|
||||
|
@ -28,12 +28,12 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
||||
DeviceConfig, DistributedExecutorBackend,
|
||||
GuidedDecodingBackend, HfOverrides, KVEventsConfig,
|
||||
KVTransferConfig, LoadConfig, LogprobsMode,
|
||||
LoRAConfig, MambaDType, ModelConfig, ModelDType,
|
||||
ModelImpl, MultiModalConfig, ObservabilityConfig,
|
||||
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
|
||||
RunnerOption, SchedulerConfig, SchedulerPolicy,
|
||||
SpeculativeConfig, TaskOption, TokenizerMode,
|
||||
VllmConfig, get_attr_docs, get_field)
|
||||
LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig,
|
||||
ModelDType, ModelImpl, MultiModalConfig,
|
||||
ObservabilityConfig, ParallelConfig, PoolerConfig,
|
||||
PrefixCachingHashAlgo, RunnerOption, SchedulerConfig,
|
||||
SchedulerPolicy, SpeculativeConfig, TaskOption,
|
||||
TokenizerMode, VllmConfig, get_attr_docs, get_field)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.plugins import load_general_plugins
|
||||
@ -352,6 +352,7 @@ class EngineArgs:
|
||||
MultiModalConfig.mm_processor_kwargs
|
||||
disable_mm_preprocessor_cache: bool = False # DEPRECATED
|
||||
mm_processor_cache_gb: int = MultiModalConfig.mm_processor_cache_gb
|
||||
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
|
||||
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
|
||||
# LoRA fields
|
||||
enable_lora: bool = False
|
||||
@ -434,16 +435,14 @@ class EngineArgs:
|
||||
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
|
||||
pt_load_map_location: str = LoadConfig.pt_load_map_location
|
||||
|
||||
enable_multimodal_encoder_data_parallel: bool = \
|
||||
ParallelConfig.enable_multimodal_encoder_data_parallel
|
||||
# DEPRECATED
|
||||
enable_multimodal_encoder_data_parallel: bool = False
|
||||
|
||||
logits_processors: Optional[list[Union[
|
||||
str, type[LogitsProcessor]]]] = ModelConfig.logits_processors
|
||||
"""Custom logitproc types"""
|
||||
|
||||
async_scheduling: bool = SchedulerConfig.async_scheduling
|
||||
# DEPRECATED
|
||||
enable_prompt_adapter: bool = False
|
||||
|
||||
kv_sharing_fast_prefill: bool = \
|
||||
CacheConfig.kv_sharing_fast_prefill
|
||||
@ -685,7 +684,8 @@ class EngineArgs:
|
||||
**parallel_kwargs["worker_extension_cls"])
|
||||
parallel_group.add_argument(
|
||||
"--enable-multimodal-encoder-data-parallel",
|
||||
**parallel_kwargs["enable_multimodal_encoder_data_parallel"])
|
||||
action="store_true",
|
||||
deprecated=True)
|
||||
|
||||
# KV cache arguments
|
||||
cache_kwargs = get_kwargs(CacheConfig)
|
||||
@ -735,6 +735,8 @@ class EngineArgs:
|
||||
multimodal_group.add_argument("--disable-mm-preprocessor-cache",
|
||||
action="store_true",
|
||||
deprecated=True)
|
||||
multimodal_group.add_argument(
|
||||
"--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"])
|
||||
multimodal_group.add_argument(
|
||||
"--interleave-mm-strings",
|
||||
**multimodal_kwargs["interleave_mm_strings"])
|
||||
@ -909,6 +911,14 @@ class EngineArgs:
|
||||
|
||||
self.mm_processor_cache_gb = envs.VLLM_MM_INPUT_CACHE_GIB
|
||||
|
||||
if self.enable_multimodal_encoder_data_parallel:
|
||||
logger.warning(
|
||||
"--enable-multimodal-encoder-data-parallel` is deprecated "
|
||||
"and will be removed in v0.13. "
|
||||
"Please use `--mm-encoder-tp-mode data` instead.")
|
||||
|
||||
self.mm_encoder_tp_mode = "data"
|
||||
|
||||
return ModelConfig(
|
||||
model=self.model,
|
||||
hf_config_path=self.hf_config_path,
|
||||
@ -947,6 +957,7 @@ class EngineArgs:
|
||||
config_format=self.config_format,
|
||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||
mm_processor_cache_gb=self.mm_processor_cache_gb,
|
||||
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
|
||||
override_neuron_config=self.override_neuron_config,
|
||||
override_pooler_config=self.override_pooler_config,
|
||||
logits_processor_pattern=self.logits_processor_pattern,
|
||||
@ -1258,8 +1269,6 @@ class EngineArgs:
|
||||
distributed_executor_backend=self.distributed_executor_backend,
|
||||
worker_cls=self.worker_cls,
|
||||
worker_extension_cls=self.worker_extension_cls,
|
||||
enable_multimodal_encoder_data_parallel=self.
|
||||
enable_multimodal_encoder_data_parallel,
|
||||
)
|
||||
|
||||
if model_config.is_multimodal_model:
|
||||
|
@ -728,8 +728,8 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
self.use_data_parallel = (vllm_config.parallel_config.
|
||||
enable_multimodal_encoder_data_parallel)
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
@ -877,8 +877,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
|
||||
self.use_data_parallel = (vllm_config.parallel_config.
|
||||
enable_multimodal_encoder_data_parallel)
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
|
@ -882,8 +882,7 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.use_data_parallel = (vllm_config.parallel_config.
|
||||
enable_multimodal_encoder_data_parallel)
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
|
||||
if multimodal_config.get_limit_per_prompt("image"):
|
||||
self.vision_model = Step3VisionTransformer(
|
||||
|
Reference in New Issue
Block a user