mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[UX] Replace VLLM_ALL2ALL_BACKEND with --all2all-backend (#26732)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@ -34,10 +34,10 @@ To enable the DBO system pass in the `--enable-dbo` argument to your vllm serve
|
||||
* `--dbo-decode-token-threshold` the minimum number of tokens in a decode-only batch required to enable DBO for that batch
|
||||
* `--dbo-prefill-token-threshold` the minimum number of tokens in a batch containing at least one prefill required to enable DBO for that batch
|
||||
|
||||
Currently, DBO is only supported with DeepEP, so DeepEP must be installed and the `VLLM_ALL2ALL_BACKEND` environment variable must be set to `deepep_low_latency` if your workload is primarily decode requests, or `deepep_high_throughput` if your workload is primarily prefill requests.
|
||||
Currently, DBO is only supported with DeepEP, so DeepEP must be installed and the `--all2all-backend` argument must be set to `deepep_low_latency` if your workload is primarily decode requests, or `deepep_high_throughput` if your workload is primarily prefill requests.
|
||||
|
||||
Below is a command that will spin up a two DP rank server with expert parallelism and DBO enabled.
|
||||
EX: `VLLM_ALL2ALL_BACKEND=deepep_low_latency vllm serve --model="deepseek-ai/DeepSeek-V2-Lite" --trust-remote-code --data-parallel-size 2 --enable-expert-parallel --enable-dbo`
|
||||
EX: `vllm serve deepseek-ai/DeepSeek-V2-Lite --trust-remote-code --data-parallel-size 2 --enable-expert-parallel --enable-dbo --all2all-backend deepep_low_latency`
|
||||
|
||||
Note that there must be at least two GPUs visible in `CUDA_VISIBLE_DEVICES`
|
||||
|
||||
|
@ -14,13 +14,16 @@ Before using EP, you need to install the necessary dependencies. We are actively
|
||||
|
||||
### Backend Selection Guide
|
||||
|
||||
vLLM provides three communication backends for EP:
|
||||
vLLM provides multiple communication backends for EP. Use `--all2all-backend` to select one:
|
||||
|
||||
| Backend | Use Case | Features | Best For |
|
||||
|---------|----------|----------|----------|
|
||||
| `pplx` | Single node | Chunked prefill support | Development, best for intra-node deployments |
|
||||
| `deepep_high_throughput` | Multi-node prefill | Grouped GEMM with continuous layout | High-throughput scenarios, prefill-dominated workloads |
|
||||
| `deepep_low_latency` | Multi-node decode | CUDA graph support, masked layout | Low-latency scenarios, decode-dominated workloads |
|
||||
| `allgather_reducescatter` | Default backend | Standard all2all using allgather/reducescatter primitives | General purpose, works with any EP+DP configuration |
|
||||
| `pplx` | Single node | Chunked prefill support, efficient intra-node communication | Single-node deployments, development |
|
||||
| `deepep_high_throughput` | Multi-node prefill | Grouped GEMM with continuous layout, optimized for prefill | Prefill-dominated workloads, high-throughput scenarios |
|
||||
| `deepep_low_latency` | Multi-node decode | CUDA graph support, masked layout, optimized for decode | Decode-dominated workloads, low-latency scenarios |
|
||||
| `flashinfer_all2allv` | MNNVL systems | FlashInfer alltoallv kernels for multi-node NVLink | Systems with NVLink across nodes |
|
||||
| `naive` | Testing/debugging | Simple broadcast-based implementation | Debugging, not recommended for production |
|
||||
|
||||
## Single Node Deployment
|
||||
|
||||
@ -47,11 +50,11 @@ The following command serves a `DeepSeek-V3-0324` model with 1-way tensor parall
|
||||
|
||||
```bash
|
||||
# Single node EP deployment with pplx backend
|
||||
VLLM_ALL2ALL_BACKEND=pplx VLLM_USE_DEEP_GEMM=1 \
|
||||
vllm serve deepseek-ai/DeepSeek-V3-0324 \
|
||||
--tensor-parallel-size 1 \ # Tensor parallelism across 1 GPU
|
||||
vllm serve deepseek-ai/DeepSeek-V3-0324 \
|
||||
--tensor-parallel-size 1 \ # Tensor parallelism across 1 GPU
|
||||
--data-parallel-size 8 \ # Data parallelism across 8 processes
|
||||
--enable-expert-parallel # Enable expert parallelism
|
||||
--enable-expert-parallel \ # Enable expert parallelism
|
||||
--all2all-backend pplx # Use pplx communication backend
|
||||
```
|
||||
|
||||
## Multi-Node Deployment
|
||||
@ -70,8 +73,8 @@ The following example deploys `DeepSeek-V3-0324` across 2 nodes using `deepep_lo
|
||||
|
||||
```bash
|
||||
# Node 1 (Primary - handles incoming requests)
|
||||
VLLM_ALL2ALL_BACKEND=deepep_low_latency VLLM_USE_DEEP_GEMM=1 \
|
||||
vllm serve deepseek-ai/DeepSeek-V3-0324 \
|
||||
vllm serve deepseek-ai/DeepSeek-V3-0324 \
|
||||
--all2all-backend deepep_low_latency \
|
||||
--tensor-parallel-size 1 \ # TP size per node
|
||||
--enable-expert-parallel \ # Enable EP
|
||||
--data-parallel-size 16 \ # Total DP size across all nodes
|
||||
@ -81,8 +84,8 @@ VLLM_ALL2ALL_BACKEND=deepep_low_latency VLLM_USE_DEEP_GEMM=1 \
|
||||
--api-server-count=8 # Number of API servers for load handling (scaling this out to total ranks are recommended)
|
||||
|
||||
# Node 2 (Secondary - headless mode, no API server)
|
||||
VLLM_ALL2ALL_BACKEND=deepep_low_latency VLLM_USE_DEEP_GEMM=1 \
|
||||
vllm serve deepseek-ai/DeepSeek-V3-0324 \
|
||||
vllm serve deepseek-ai/DeepSeek-V3-0324 \
|
||||
--all2all-backend deepep_low_latency \
|
||||
--tensor-parallel-size 1 \ # TP size per node
|
||||
--enable-expert-parallel \ # Enable EP
|
||||
--data-parallel-size 16 \ # Total DP size across all nodes
|
||||
@ -169,11 +172,12 @@ Single node deployment with EPLB enabled:
|
||||
|
||||
```bash
|
||||
# Single node with EPLB load balancing
|
||||
VLLM_ALL2ALL_BACKEND=pplx VLLM_USE_DEEP_GEMM=1 vllm serve deepseek-ai/DeepSeek-V3-0324 \
|
||||
--tensor-parallel-size 1 \ # Tensor parallelism
|
||||
--data-parallel-size 8 \ # Data parallelism
|
||||
--enable-expert-parallel \ # Enable EP
|
||||
--enable-eplb \ # Enable load balancer
|
||||
vllm serve deepseek-ai/DeepSeek-V3-0324 \
|
||||
--tensor-parallel-size 1 \ # Tensor parallelism
|
||||
--data-parallel-size 8 \ # Data parallelism
|
||||
--enable-expert-parallel \ # Enable EP
|
||||
--all2all-backend pplx \ # Use pplx communication backend
|
||||
--enable-eplb \ # Enable load balancer
|
||||
--eplb-config '{"window_size":1000,"step_interval":3000,"num_redundant_experts":2,"log_balancedness":true}'
|
||||
```
|
||||
|
||||
|
@ -113,6 +113,25 @@ class ParallelConfig:
|
||||
with 4 experts and 2 ranks, rank 0 will have experts [0, 2] and rank 1
|
||||
will have experts [1, 3]. This strategy can help improve load balancing
|
||||
for grouped expert models with no redundant experts."""
|
||||
all2all_backend: (
|
||||
Literal[
|
||||
"naive",
|
||||
"pplx",
|
||||
"deepep_high_throughput",
|
||||
"deepep_low_latency",
|
||||
"allgather_reducescatter",
|
||||
"flashinfer_all2allv",
|
||||
]
|
||||
| None
|
||||
) = None
|
||||
"""All2All backend for MoE expert parallel communication. If not set, uses
|
||||
the value from VLLM_ALL2ALL_BACKEND environment variable. Available options:
|
||||
- "naive": Naive all2all implementation using broadcasts
|
||||
- "allgather_reducescatter": All2all based on allgather and reducescatter
|
||||
- "pplx": Use pplx kernels
|
||||
- "deepep_high_throughput": Use deepep high-throughput kernels
|
||||
- "deepep_low_latency": Use deepep low-latency kernels
|
||||
- "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl"""
|
||||
num_redundant_experts: int | None = None
|
||||
"""`num_redundant_experts` is deprecated and has been replaced with
|
||||
`eplb_config.num_redundant_experts`. This will be removed in v0.12.0.
|
||||
@ -341,7 +360,7 @@ class ParallelConfig:
|
||||
@property
|
||||
def use_sequence_parallel_moe(self) -> bool:
|
||||
return (
|
||||
envs.VLLM_ALL2ALL_BACKEND
|
||||
self.all2all_backend
|
||||
in (
|
||||
"allgather_reducescatter",
|
||||
"naive",
|
||||
@ -390,7 +409,7 @@ class ParallelConfig:
|
||||
factors.append(self.tensor_parallel_size)
|
||||
factors.append(self.enable_expert_parallel)
|
||||
factors.append(self.data_parallel_size)
|
||||
factors.append(envs.VLLM_ALL2ALL_BACKEND)
|
||||
factors.append(self.all2all_backend)
|
||||
factors.append(self.enable_eplb)
|
||||
if self.enable_eplb:
|
||||
factors.append(self.eplb_config.log_balancedness)
|
||||
@ -400,6 +419,16 @@ class ParallelConfig:
|
||||
return hashlib.sha256(str(factors).encode()).hexdigest()
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Set all2all_backend from env var if not specified, with deprecation warning
|
||||
if self.all2all_backend is None:
|
||||
self.all2all_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||
if envs.is_set("VLLM_ALL2ALL_BACKEND"):
|
||||
logger.warning_once(
|
||||
"VLLM_ALL2ALL_BACKEND environment variable is deprecated and "
|
||||
"will be removed in a future release. Please use the "
|
||||
"--all2all-backend command-line argument instead."
|
||||
)
|
||||
|
||||
# Forward deprecated fields to their new location
|
||||
if self.num_redundant_experts is not None:
|
||||
self.eplb_config.num_redundant_experts = self.num_redundant_experts
|
||||
|
@ -523,13 +523,13 @@ class VllmConfig:
|
||||
)
|
||||
|
||||
if self.parallel_config.enable_dbo:
|
||||
a2a_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||
a2a_backend = self.parallel_config.all2all_backend
|
||||
assert a2a_backend in ["deepep_low_latency", "deepep_high_throughput"], (
|
||||
"Microbatching currently only supports the deepep_low_latency and "
|
||||
f"deepep_high_throughput all2all backend. {a2a_backend} is not "
|
||||
"supported. To fix set the VLLM_ALL2ALL_BACKEND environment "
|
||||
"variable to deepep_low_latency or deepep_high_throughput and "
|
||||
"install the DeepEP kernels."
|
||||
"supported. To fix use --all2all-backend=deepep_low_latency or "
|
||||
"--all2all-backend=deepep_high_throughput and install the DeepEP"
|
||||
" kernels."
|
||||
)
|
||||
|
||||
if not self.model_config.disable_cascade_attn:
|
||||
|
@ -111,6 +111,7 @@ class DeviceCommunicatorBase:
|
||||
self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)
|
||||
|
||||
use_ep = False
|
||||
all2all_backend = None
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
config = get_current_vllm_config()
|
||||
@ -119,9 +120,11 @@ class DeviceCommunicatorBase:
|
||||
# where all data parallel ranks execute forward together),
|
||||
# we initialize the all2all manager used in expert parallel.
|
||||
use_ep = config.parallel_config.data_parallel_size > 1
|
||||
all2all_backend = config.parallel_config.all2all_backend
|
||||
|
||||
self.is_ep_communicator = "ep" in unique_name
|
||||
self.use_all2all = self.is_ep_communicator and use_ep
|
||||
self.all2all_backend = all2all_backend
|
||||
self.all2all_manager: All2AllManagerBase | None = None
|
||||
|
||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -91,33 +91,32 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
self.qr_comm = QuickAllReduce(group=self.cpu_group, device=self.device)
|
||||
|
||||
if self.use_all2all:
|
||||
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||
if all2all_backend == "naive":
|
||||
if self.all2all_backend == "naive":
|
||||
from .all2all import NaiveAll2AllManager
|
||||
|
||||
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
||||
elif all2all_backend == "allgather_reducescatter":
|
||||
elif self.all2all_backend == "allgather_reducescatter":
|
||||
from .all2all import AgRsAll2AllManager
|
||||
|
||||
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
|
||||
elif all2all_backend == "pplx":
|
||||
elif self.all2all_backend == "pplx":
|
||||
from .all2all import PPLXAll2AllManager
|
||||
|
||||
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
|
||||
elif all2all_backend == "deepep_high_throughput":
|
||||
elif self.all2all_backend == "deepep_high_throughput":
|
||||
from .all2all import DeepEPHTAll2AllManager
|
||||
|
||||
self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
|
||||
elif all2all_backend == "deepep_low_latency":
|
||||
elif self.all2all_backend == "deepep_low_latency":
|
||||
from .all2all import DeepEPLLAll2AllManager
|
||||
|
||||
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
|
||||
elif all2all_backend == "flashinfer_all2allv":
|
||||
elif self.all2all_backend == "flashinfer_all2allv":
|
||||
from .all2all import FlashInferAllToAllManager
|
||||
|
||||
self.all2all_manager = FlashInferAllToAllManager(self.cpu_group)
|
||||
else:
|
||||
raise ValueError(f"Unknown all2all backend: {all2all_backend}")
|
||||
raise ValueError(f"Unknown all2all backend: {self.all2all_backend}")
|
||||
|
||||
if is_global_first_rank():
|
||||
logger.info(
|
||||
|
@ -6,7 +6,6 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .base_device_communicator import DeviceCommunicatorBase
|
||||
@ -24,15 +23,14 @@ class XpuCommunicator(DeviceCommunicatorBase):
|
||||
):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
if self.use_all2all:
|
||||
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
|
||||
if all2all_backend != "naive":
|
||||
if self.all2all_backend != "naive":
|
||||
logger.warning(
|
||||
"`%s` all2all manager is not supported on XPU."
|
||||
"`%s` all2all manager is not supported on XPU. "
|
||||
"Falling back to `naive` all2all manager for XPU.",
|
||||
all2all_backend,
|
||||
self.all2all_backend,
|
||||
)
|
||||
all2all_backend = "naive"
|
||||
if all2all_backend == "naive":
|
||||
self.all2all_backend = "naive"
|
||||
if self.all2all_backend == "naive":
|
||||
from .all2all import NaiveAll2AllManager
|
||||
|
||||
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
|
||||
|
@ -371,6 +371,7 @@ class EngineArgs:
|
||||
data_parallel_hybrid_lb: bool = False
|
||||
data_parallel_backend: str = ParallelConfig.data_parallel_backend
|
||||
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
||||
all2all_backend: str | None = ParallelConfig.all2all_backend
|
||||
enable_dbo: bool = ParallelConfig.enable_dbo
|
||||
dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
|
||||
dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold
|
||||
@ -763,6 +764,9 @@ class EngineArgs:
|
||||
parallel_group.add_argument(
|
||||
"--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--all2all-backend", **parallel_kwargs["all2all_backend"]
|
||||
)
|
||||
parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"])
|
||||
parallel_group.add_argument(
|
||||
"--dbo-decode-token-threshold",
|
||||
@ -1461,6 +1465,7 @@ class EngineArgs:
|
||||
data_parallel_backend=self.data_parallel_backend,
|
||||
data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
|
||||
enable_expert_parallel=self.enable_expert_parallel,
|
||||
all2all_backend=self.all2all_backend,
|
||||
enable_dbo=self.enable_dbo,
|
||||
dbo_decode_token_threshold=self.dbo_decode_token_threshold,
|
||||
dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,
|
||||
|
@ -641,6 +641,7 @@ class FusedMoEParallelConfig:
|
||||
ep_rank: int
|
||||
|
||||
use_ep: bool # whether to use EP or not
|
||||
all2all_backend: str # all2all backend for MoE communication
|
||||
|
||||
@property
|
||||
def use_all2all_kernels(self):
|
||||
@ -648,21 +649,18 @@ class FusedMoEParallelConfig:
|
||||
|
||||
@property
|
||||
def use_pplx_kernels(self):
|
||||
return self.use_all2all_kernels and envs.VLLM_ALL2ALL_BACKEND == "pplx"
|
||||
return self.use_all2all_kernels and self.all2all_backend == "pplx"
|
||||
|
||||
@property
|
||||
def use_deepep_ht_kernels(self):
|
||||
return (
|
||||
self.use_all2all_kernels
|
||||
and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
|
||||
and self.all2all_backend == "deepep_high_throughput"
|
||||
)
|
||||
|
||||
@property
|
||||
def use_deepep_ll_kernels(self):
|
||||
return (
|
||||
self.use_all2all_kernels
|
||||
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency"
|
||||
)
|
||||
return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency"
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
@ -762,6 +760,7 @@ class FusedMoEParallelConfig:
|
||||
ep_size=1,
|
||||
ep_rank=0,
|
||||
use_ep=False,
|
||||
all2all_backend=vllm_parallel_config.all2all_backend,
|
||||
)
|
||||
# DP + EP / TP + EP / DP + TP + EP
|
||||
assert use_ep
|
||||
@ -777,6 +776,7 @@ class FusedMoEParallelConfig:
|
||||
ep_size=ep_size,
|
||||
ep_rank=ep_rank,
|
||||
use_ep=True,
|
||||
all2all_backend=vllm_parallel_config.all2all_backend,
|
||||
)
|
||||
|
||||
|
||||
|
@ -58,7 +58,7 @@ def build_flashinfer_fp4_cutlass_moe_prepare_finalize(
|
||||
) -> mk.FusedMoEPrepareAndFinalize:
|
||||
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
|
||||
use_dp = moe.moe_parallel_config.dp_size > 1
|
||||
enable_alltoallv = envs.VLLM_ALL2ALL_BACKEND == "flashinfer_all2allv"
|
||||
enable_alltoallv = moe.moe_parallel_config.all2all_backend == "flashinfer_all2allv"
|
||||
return create_flashinfer_prepare_finalize(
|
||||
use_dp=use_dp, use_nvfp4=True, enable_alltoallv=enable_alltoallv
|
||||
)
|
||||
|
@ -192,7 +192,7 @@ class CudaPlatformBase(Platform):
|
||||
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if (
|
||||
envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
|
||||
parallel_config.all2all_backend == "deepep_high_throughput"
|
||||
and parallel_config.data_parallel_size > 1
|
||||
and compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
):
|
||||
@ -204,7 +204,7 @@ class CudaPlatformBase(Platform):
|
||||
"kernels are optimized for prefill and are incompatible with "
|
||||
"CUDA Graphs. "
|
||||
"In order to use CUDA Graphs for decode-optimized workloads, "
|
||||
"set VLLM_ALL2ALL_BACKEND to another option, such as "
|
||||
"use --all2all-backend with another option, such as "
|
||||
"deepep_low_latency, pplx, or allgather_reducescatter."
|
||||
)
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
@ -356,9 +356,10 @@ class CoreEngineActorManager:
|
||||
)
|
||||
device_str = current_platform.ray_device_key
|
||||
|
||||
all2all_backend = vllm_config.parallel_config.all2all_backend
|
||||
if envs.VLLM_RAY_DP_PACK_STRATEGY == "fill" and (
|
||||
envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
|
||||
or envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency"
|
||||
all2all_backend == "deepep_high_throughput"
|
||||
or all2all_backend == "deepep_low_latency"
|
||||
):
|
||||
raise ValueError(
|
||||
"DeepEP kernels require EP ranks [0,7] (same for [8,15], ...) "
|
||||
|
Reference in New Issue
Block a user