[Distributed] Add enable_expert_parallel arg (#14305)

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith
2025-03-06 13:54:45 -05:00
committed by GitHub
parent cd579352bf
commit cc2f9b32c8
5 changed files with 27 additions and 21 deletions

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# usage:
# VLLM_TEST_ENABLE_EP=1 VLLM_USE_V1=1 \
# python examples/offline_inference/data_parallel.py
# VLLM_USE_V1=1 python examples/offline_inference/data_parallel.py
# we need to have a launcher to create multiple data parallel
# ranks. And each rank will create a vLLM instance to process its own prompts.
import os
@ -55,7 +54,8 @@ def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
# Create an LLM.
llm = LLM(model="ibm-research/PowerMoE-3b",
tensor_parallel_size=GPUs_per_dp_rank,
enforce_eager=True)
enforce_eager=True,
enable_expert_parallel=True)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:

View File

@ -754,7 +754,7 @@ class ModelConfig:
" must be divisible by tensor parallel size "
f"({tensor_parallel_size}).")
if envs.VLLM_TEST_ENABLE_EP:
if parallel_config.enable_expert_parallel:
self._verify_with_expert_parallelism()
pipeline_parallel_size = parallel_config.pipeline_parallel_size
@ -1334,6 +1334,7 @@ class ParallelConfig:
# IP of the data parallel master.
data_parallel_master_ip: str = "127.0.0.1"
data_parallel_master_port: int = 29500 # Port of the data parallel master.
enable_expert_parallel: bool = False # Use EP instead of TP for MoE layers.
# Maximum number of multiple batches
# when load model sequentially. To avoid RAM OOM when using tensor

View File

@ -114,6 +114,7 @@ class EngineArgs:
# number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
enable_expert_parallel: bool = False
max_parallel_loading_workers: Optional[int] = None
block_size: Optional[int] = None
enable_prefix_caching: Optional[bool] = None
@ -440,6 +441,11 @@ class EngineArgs:
type=int,
default=EngineArgs.tensor_parallel_size,
help='Number of tensor parallel replicas.')
parser.add_argument(
'--enable-expert-parallel',
action='store_true',
help='Use expert parallelism instead of tensor parallelism '
'for MoE layers.')
parser.add_argument(
'--max-parallel-loading-workers',
type=int,
@ -1207,6 +1213,7 @@ class EngineArgs:
parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,
enable_expert_parallel=self.enable_expert_parallel,
max_parallel_loading_workers=self.max_parallel_loading_workers,
disable_custom_all_reduce=self.disable_custom_all_reduce,
tokenizer_pool_config=TokenizerPoolConfig.create_config(

View File

@ -86,7 +86,6 @@ if TYPE_CHECKING:
VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True
VLLM_MLA_DISABLE_REQUANTIZATION: bool = False
VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE: bool = True
VLLM_TEST_ENABLE_EP: bool = False
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
VLLM_RAY_BUNDLE_INDICES: str = ""
@ -579,12 +578,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
),
# If set, vLLM will use the experimental expert parallel implementation on
# the FusedMoE layer, using tensor parallelism size as expert parallelism
# size.
"VLLM_TEST_ENABLE_EP":
lambda: bool(int(os.getenv("VLLM_TEST_ENABLE_EP", "0"))),
# Number of GPUs per worker in Ray, if it is set to be a fraction,
# it allows ray to schedule multiple actors on a single GPU,
# so that users can colocate other actors on the same GPUs as vLLM.

View File

@ -7,7 +7,6 @@ from typing import Callable, List, Optional, Tuple
import torch
from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs
from vllm.config import get_current_vllm_config
from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
@ -342,14 +341,6 @@ class FusedMoE(torch.nn.Module):
if params_dtype is None:
params_dtype = torch.get_default_dtype()
# For smuggling this layer into the fused moe custom op
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError("Duplicate layer name: {}".format(prefix))
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
self.use_direct_call = not envs.VLLM_TEST_ENABLE_EP
# Note: here we guard against accessing the TP and DP groups when
# uninitialized (this happens when testing)
self.tp_size = (tp_size if tp_size is not None else
@ -361,7 +352,21 @@ class FusedMoE(torch.nn.Module):
if self.dp_size == 1 else get_dp_group().rank_in_group)
self.global_num_experts = num_experts
if envs.VLLM_TEST_ENABLE_EP:
# Use expert parallelism instead of tensor parallelism?
vllm_config = get_current_vllm_config()
use_ep = (vllm_config.parallel_config.enable_expert_parallel
and self.tp_size > 1)
# For smuggling this layer into the fused moe custom op
self.use_direct_call = self.dp_size == 1
if self.use_direct_call:
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError("Duplicate layer name: {}".format(prefix))
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
if use_ep:
# Set TP size to 1 to adjust for EP and adjust EP size and rank
# for DP attention.
self.ep_rank = tp_rank + self.tp_size * self.dp_rank