mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Distributed] Add enable_expert_parallel arg (#14305)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
committed by
GitHub
parent
cd579352bf
commit
cc2f9b32c8
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# usage:
|
||||
# VLLM_TEST_ENABLE_EP=1 VLLM_USE_V1=1 \
|
||||
# python examples/offline_inference/data_parallel.py
|
||||
# VLLM_USE_V1=1 python examples/offline_inference/data_parallel.py
|
||||
# we need to have a launcher to create multiple data parallel
|
||||
# ranks. And each rank will create a vLLM instance to process its own prompts.
|
||||
import os
|
||||
@ -55,7 +54,8 @@ def main(dp_size, dp_rank, dp_master_ip, dp_master_port, GPUs_per_dp_rank):
|
||||
# Create an LLM.
|
||||
llm = LLM(model="ibm-research/PowerMoE-3b",
|
||||
tensor_parallel_size=GPUs_per_dp_rank,
|
||||
enforce_eager=True)
|
||||
enforce_eager=True,
|
||||
enable_expert_parallel=True)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
|
@ -754,7 +754,7 @@ class ModelConfig:
|
||||
" must be divisible by tensor parallel size "
|
||||
f"({tensor_parallel_size}).")
|
||||
|
||||
if envs.VLLM_TEST_ENABLE_EP:
|
||||
if parallel_config.enable_expert_parallel:
|
||||
self._verify_with_expert_parallelism()
|
||||
|
||||
pipeline_parallel_size = parallel_config.pipeline_parallel_size
|
||||
@ -1334,6 +1334,7 @@ class ParallelConfig:
|
||||
# IP of the data parallel master.
|
||||
data_parallel_master_ip: str = "127.0.0.1"
|
||||
data_parallel_master_port: int = 29500 # Port of the data parallel master.
|
||||
enable_expert_parallel: bool = False # Use EP instead of TP for MoE layers.
|
||||
|
||||
# Maximum number of multiple batches
|
||||
# when load model sequentially. To avoid RAM OOM when using tensor
|
||||
|
@ -114,6 +114,7 @@ class EngineArgs:
|
||||
# number of P/D disaggregation (or other disaggregation) workers
|
||||
pipeline_parallel_size: int = 1
|
||||
tensor_parallel_size: int = 1
|
||||
enable_expert_parallel: bool = False
|
||||
max_parallel_loading_workers: Optional[int] = None
|
||||
block_size: Optional[int] = None
|
||||
enable_prefix_caching: Optional[bool] = None
|
||||
@ -440,6 +441,11 @@ class EngineArgs:
|
||||
type=int,
|
||||
default=EngineArgs.tensor_parallel_size,
|
||||
help='Number of tensor parallel replicas.')
|
||||
parser.add_argument(
|
||||
'--enable-expert-parallel',
|
||||
action='store_true',
|
||||
help='Use expert parallelism instead of tensor parallelism '
|
||||
'for MoE layers.')
|
||||
parser.add_argument(
|
||||
'--max-parallel-loading-workers',
|
||||
type=int,
|
||||
@ -1207,6 +1213,7 @@ class EngineArgs:
|
||||
parallel_config = ParallelConfig(
|
||||
pipeline_parallel_size=self.pipeline_parallel_size,
|
||||
tensor_parallel_size=self.tensor_parallel_size,
|
||||
enable_expert_parallel=self.enable_expert_parallel,
|
||||
max_parallel_loading_workers=self.max_parallel_loading_workers,
|
||||
disable_custom_all_reduce=self.disable_custom_all_reduce,
|
||||
tokenizer_pool_config=TokenizerPoolConfig.create_config(
|
||||
|
@ -86,7 +86,6 @@ if TYPE_CHECKING:
|
||||
VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True
|
||||
VLLM_MLA_DISABLE_REQUANTIZATION: bool = False
|
||||
VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE: bool = True
|
||||
VLLM_TEST_ENABLE_EP: bool = False
|
||||
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
|
||||
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
|
||||
VLLM_RAY_BUNDLE_INDICES: str = ""
|
||||
@ -579,12 +578,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
||||
),
|
||||
|
||||
# If set, vLLM will use the experimental expert parallel implementation on
|
||||
# the FusedMoE layer, using tensor parallelism size as expert parallelism
|
||||
# size.
|
||||
"VLLM_TEST_ENABLE_EP":
|
||||
lambda: bool(int(os.getenv("VLLM_TEST_ENABLE_EP", "0"))),
|
||||
|
||||
# Number of GPUs per worker in Ray, if it is set to be a fraction,
|
||||
# it allows ray to schedule multiple actors on a single GPU,
|
||||
# so that users can colocate other actors on the same GPUs as vLLM.
|
||||
|
@ -7,7 +7,6 @@ from typing import Callable, List, Optional, Tuple
|
||||
import torch
|
||||
from torch.nn.parameter import UninitializedParameter
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@ -342,14 +341,6 @@ class FusedMoE(torch.nn.Module):
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
# For smuggling this layer into the fused moe custom op
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError("Duplicate layer name: {}".format(prefix))
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
self.layer_name = prefix
|
||||
self.use_direct_call = not envs.VLLM_TEST_ENABLE_EP
|
||||
|
||||
# Note: here we guard against accessing the TP and DP groups when
|
||||
# uninitialized (this happens when testing)
|
||||
self.tp_size = (tp_size if tp_size is not None else
|
||||
@ -361,7 +352,21 @@ class FusedMoE(torch.nn.Module):
|
||||
if self.dp_size == 1 else get_dp_group().rank_in_group)
|
||||
self.global_num_experts = num_experts
|
||||
|
||||
if envs.VLLM_TEST_ENABLE_EP:
|
||||
# Use expert parallelism instead of tensor parallelism?
|
||||
vllm_config = get_current_vllm_config()
|
||||
use_ep = (vllm_config.parallel_config.enable_expert_parallel
|
||||
and self.tp_size > 1)
|
||||
|
||||
# For smuggling this layer into the fused moe custom op
|
||||
self.use_direct_call = self.dp_size == 1
|
||||
if self.use_direct_call:
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError("Duplicate layer name: {}".format(prefix))
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
self.layer_name = prefix
|
||||
|
||||
if use_ep:
|
||||
# Set TP size to 1 to adjust for EP and adjust EP size and rank
|
||||
# for DP attention.
|
||||
self.ep_rank = tp_rank + self.tp_size * self.dp_rank
|
||||
|
Reference in New Issue
Block a user