[UX] Replace VLLM_ALL2ALL_BACKEND with --all2all-backend (#26732)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2025-10-13 21:12:52 -04:00
committed by GitHub
parent 8317f72354
commit 3e051bda82
12 changed files with 90 additions and 51 deletions

View File

@ -34,10 +34,10 @@ To enable the DBO system pass in the `--enable-dbo` argument to your vllm serve
* `--dbo-decode-token-threshold` the minimum number of tokens in a decode-only batch required to enable DBO for that batch
* `--dbo-prefill-token-threshold` the minimum number of tokens in a batch containing at least one prefill required to enable DBO for that batch
Currently, DBO is only supported with DeepEP, so DeepEP must be installed and the `VLLM_ALL2ALL_BACKEND` environment variable must be set to `deepep_low_latency` if your workload is primarily decode requests, or `deepep_high_throughput` if your workload is primarily prefill requests.
Currently, DBO is only supported with DeepEP, so DeepEP must be installed and the `--all2all-backend` argument must be set to `deepep_low_latency` if your workload is primarily decode requests, or `deepep_high_throughput` if your workload is primarily prefill requests.
Below is a command that will spin up a two DP rank server with expert parallelism and DBO enabled.
EX: `VLLM_ALL2ALL_BACKEND=deepep_low_latency vllm serve --model="deepseek-ai/DeepSeek-V2-Lite" --trust-remote-code --data-parallel-size 2 --enable-expert-parallel --enable-dbo`
EX: `vllm serve deepseek-ai/DeepSeek-V2-Lite --trust-remote-code --data-parallel-size 2 --enable-expert-parallel --enable-dbo --all2all-backend deepep_low_latency`
Note that there must be at least two GPUs visible in `CUDA_VISIBLE_DEVICES`

View File

@ -14,13 +14,16 @@ Before using EP, you need to install the necessary dependencies. We are actively
### Backend Selection Guide
vLLM provides three communication backends for EP:
vLLM provides multiple communication backends for EP. Use `--all2all-backend` to select one:
| Backend | Use Case | Features | Best For |
|---------|----------|----------|----------|
| `pplx` | Single node | Chunked prefill support | Development, best for intra-node deployments |
| `deepep_high_throughput` | Multi-node prefill | Grouped GEMM with continuous layout | High-throughput scenarios, prefill-dominated workloads |
| `deepep_low_latency` | Multi-node decode | CUDA graph support, masked layout | Low-latency scenarios, decode-dominated workloads |
| `allgather_reducescatter` | Default backend | Standard all2all using allgather/reducescatter primitives | General purpose, works with any EP+DP configuration |
| `pplx` | Single node | Chunked prefill support, efficient intra-node communication | Single-node deployments, development |
| `deepep_high_throughput` | Multi-node prefill | Grouped GEMM with continuous layout, optimized for prefill | Prefill-dominated workloads, high-throughput scenarios |
| `deepep_low_latency` | Multi-node decode | CUDA graph support, masked layout, optimized for decode | Decode-dominated workloads, low-latency scenarios |
| `flashinfer_all2allv` | MNNVL systems | FlashInfer alltoallv kernels for multi-node NVLink | Systems with NVLink across nodes |
| `naive` | Testing/debugging | Simple broadcast-based implementation | Debugging, not recommended for production |
## Single Node Deployment
@ -47,11 +50,11 @@ The following command serves a `DeepSeek-V3-0324` model with 1-way tensor parall
```bash
# Single node EP deployment with pplx backend
VLLM_ALL2ALL_BACKEND=pplx VLLM_USE_DEEP_GEMM=1 \
vllm serve deepseek-ai/DeepSeek-V3-0324 \
--tensor-parallel-size 1 \ # Tensor parallelism across 1 GPU
vllm serve deepseek-ai/DeepSeek-V3-0324 \
--tensor-parallel-size 1 \ # Tensor parallelism across 1 GPU
--data-parallel-size 8 \ # Data parallelism across 8 processes
--enable-expert-parallel # Enable expert parallelism
--enable-expert-parallel \ # Enable expert parallelism
--all2all-backend pplx # Use pplx communication backend
```
## Multi-Node Deployment
@ -70,8 +73,8 @@ The following example deploys `DeepSeek-V3-0324` across 2 nodes using `deepep_lo
```bash
# Node 1 (Primary - handles incoming requests)
VLLM_ALL2ALL_BACKEND=deepep_low_latency VLLM_USE_DEEP_GEMM=1 \
vllm serve deepseek-ai/DeepSeek-V3-0324 \
vllm serve deepseek-ai/DeepSeek-V3-0324 \
--all2all-backend deepep_low_latency \
--tensor-parallel-size 1 \ # TP size per node
--enable-expert-parallel \ # Enable EP
--data-parallel-size 16 \ # Total DP size across all nodes
@ -81,8 +84,8 @@ VLLM_ALL2ALL_BACKEND=deepep_low_latency VLLM_USE_DEEP_GEMM=1 \
--api-server-count=8 # Number of API servers for load handling (scaling this out to total ranks are recommended)
# Node 2 (Secondary - headless mode, no API server)
VLLM_ALL2ALL_BACKEND=deepep_low_latency VLLM_USE_DEEP_GEMM=1 \
vllm serve deepseek-ai/DeepSeek-V3-0324 \
vllm serve deepseek-ai/DeepSeek-V3-0324 \
--all2all-backend deepep_low_latency \
--tensor-parallel-size 1 \ # TP size per node
--enable-expert-parallel \ # Enable EP
--data-parallel-size 16 \ # Total DP size across all nodes
@ -169,11 +172,12 @@ Single node deployment with EPLB enabled:
```bash
# Single node with EPLB load balancing
VLLM_ALL2ALL_BACKEND=pplx VLLM_USE_DEEP_GEMM=1 vllm serve deepseek-ai/DeepSeek-V3-0324 \
--tensor-parallel-size 1 \ # Tensor parallelism
--data-parallel-size 8 \ # Data parallelism
--enable-expert-parallel \ # Enable EP
--enable-eplb \ # Enable load balancer
vllm serve deepseek-ai/DeepSeek-V3-0324 \
--tensor-parallel-size 1 \ # Tensor parallelism
--data-parallel-size 8 \ # Data parallelism
--enable-expert-parallel \ # Enable EP
--all2all-backend pplx \ # Use pplx communication backend
--enable-eplb \ # Enable load balancer
--eplb-config '{"window_size":1000,"step_interval":3000,"num_redundant_experts":2,"log_balancedness":true}'
```

View File

@ -113,6 +113,25 @@ class ParallelConfig:
with 4 experts and 2 ranks, rank 0 will have experts [0, 2] and rank 1
will have experts [1, 3]. This strategy can help improve load balancing
for grouped expert models with no redundant experts."""
all2all_backend: (
Literal[
"naive",
"pplx",
"deepep_high_throughput",
"deepep_low_latency",
"allgather_reducescatter",
"flashinfer_all2allv",
]
| None
) = None
"""All2All backend for MoE expert parallel communication. If not set, uses
the value from VLLM_ALL2ALL_BACKEND environment variable. Available options:
- "naive": Naive all2all implementation using broadcasts
- "allgather_reducescatter": All2all based on allgather and reducescatter
- "pplx": Use pplx kernels
- "deepep_high_throughput": Use deepep high-throughput kernels
- "deepep_low_latency": Use deepep low-latency kernels
- "flashinfer_all2allv": Use flashinfer alltoallv kernels for mnnvl"""
num_redundant_experts: int | None = None
"""`num_redundant_experts` is deprecated and has been replaced with
`eplb_config.num_redundant_experts`. This will be removed in v0.12.0.
@ -341,7 +360,7 @@ class ParallelConfig:
@property
def use_sequence_parallel_moe(self) -> bool:
return (
envs.VLLM_ALL2ALL_BACKEND
self.all2all_backend
in (
"allgather_reducescatter",
"naive",
@ -390,7 +409,7 @@ class ParallelConfig:
factors.append(self.tensor_parallel_size)
factors.append(self.enable_expert_parallel)
factors.append(self.data_parallel_size)
factors.append(envs.VLLM_ALL2ALL_BACKEND)
factors.append(self.all2all_backend)
factors.append(self.enable_eplb)
if self.enable_eplb:
factors.append(self.eplb_config.log_balancedness)
@ -400,6 +419,16 @@ class ParallelConfig:
return hashlib.sha256(str(factors).encode()).hexdigest()
def __post_init__(self) -> None:
# Set all2all_backend from env var if not specified, with deprecation warning
if self.all2all_backend is None:
self.all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if envs.is_set("VLLM_ALL2ALL_BACKEND"):
logger.warning_once(
"VLLM_ALL2ALL_BACKEND environment variable is deprecated and "
"will be removed in a future release. Please use the "
"--all2all-backend command-line argument instead."
)
# Forward deprecated fields to their new location
if self.num_redundant_experts is not None:
self.eplb_config.num_redundant_experts = self.num_redundant_experts

View File

@ -523,13 +523,13 @@ class VllmConfig:
)
if self.parallel_config.enable_dbo:
a2a_backend = envs.VLLM_ALL2ALL_BACKEND
a2a_backend = self.parallel_config.all2all_backend
assert a2a_backend in ["deepep_low_latency", "deepep_high_throughput"], (
"Microbatching currently only supports the deepep_low_latency and "
f"deepep_high_throughput all2all backend. {a2a_backend} is not "
"supported. To fix set the VLLM_ALL2ALL_BACKEND environment "
"variable to deepep_low_latency or deepep_high_throughput and "
"install the DeepEP kernels."
"supported. To fix use --all2all-backend=deepep_low_latency or "
"--all2all-backend=deepep_high_throughput and install the DeepEP"
" kernels."
)
if not self.model_config.disable_cascade_attn:

View File

@ -111,6 +111,7 @@ class DeviceCommunicatorBase:
self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)
use_ep = False
all2all_backend = None
from vllm.config import get_current_vllm_config
config = get_current_vllm_config()
@ -119,9 +120,11 @@ class DeviceCommunicatorBase:
# where all data parallel ranks execute forward together),
# we initialize the all2all manager used in expert parallel.
use_ep = config.parallel_config.data_parallel_size > 1
all2all_backend = config.parallel_config.all2all_backend
self.is_ep_communicator = "ep" in unique_name
self.use_all2all = self.is_ep_communicator and use_ep
self.all2all_backend = all2all_backend
self.all2all_manager: All2AllManagerBase | None = None
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:

View File

@ -91,33 +91,32 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.qr_comm = QuickAllReduce(group=self.cpu_group, device=self.device)
if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend == "naive":
if self.all2all_backend == "naive":
from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
elif all2all_backend == "allgather_reducescatter":
elif self.all2all_backend == "allgather_reducescatter":
from .all2all import AgRsAll2AllManager
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
elif all2all_backend == "pplx":
elif self.all2all_backend == "pplx":
from .all2all import PPLXAll2AllManager
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
elif all2all_backend == "deepep_high_throughput":
elif self.all2all_backend == "deepep_high_throughput":
from .all2all import DeepEPHTAll2AllManager
self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
elif all2all_backend == "deepep_low_latency":
elif self.all2all_backend == "deepep_low_latency":
from .all2all import DeepEPLLAll2AllManager
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
elif all2all_backend == "flashinfer_all2allv":
elif self.all2all_backend == "flashinfer_all2allv":
from .all2all import FlashInferAllToAllManager
self.all2all_manager = FlashInferAllToAllManager(self.cpu_group)
else:
raise ValueError(f"Unknown all2all backend: {all2all_backend}")
raise ValueError(f"Unknown all2all backend: {self.all2all_backend}")
if is_global_first_rank():
logger.info(

View File

@ -6,7 +6,6 @@ import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
import vllm.envs as envs
from vllm.logger import init_logger
from .base_device_communicator import DeviceCommunicatorBase
@ -24,15 +23,14 @@ class XpuCommunicator(DeviceCommunicatorBase):
):
super().__init__(cpu_group, device, device_group, unique_name)
if self.use_all2all:
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
if all2all_backend != "naive":
if self.all2all_backend != "naive":
logger.warning(
"`%s` all2all manager is not supported on XPU."
"`%s` all2all manager is not supported on XPU. "
"Falling back to `naive` all2all manager for XPU.",
all2all_backend,
self.all2all_backend,
)
all2all_backend = "naive"
if all2all_backend == "naive":
self.all2all_backend = "naive"
if self.all2all_backend == "naive":
from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)

View File

@ -371,6 +371,7 @@ class EngineArgs:
data_parallel_hybrid_lb: bool = False
data_parallel_backend: str = ParallelConfig.data_parallel_backend
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
all2all_backend: str | None = ParallelConfig.all2all_backend
enable_dbo: bool = ParallelConfig.enable_dbo
dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold
@ -763,6 +764,9 @@ class EngineArgs:
parallel_group.add_argument(
"--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]
)
parallel_group.add_argument(
"--all2all-backend", **parallel_kwargs["all2all_backend"]
)
parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"])
parallel_group.add_argument(
"--dbo-decode-token-threshold",
@ -1461,6 +1465,7 @@ class EngineArgs:
data_parallel_backend=self.data_parallel_backend,
data_parallel_hybrid_lb=self.data_parallel_hybrid_lb,
enable_expert_parallel=self.enable_expert_parallel,
all2all_backend=self.all2all_backend,
enable_dbo=self.enable_dbo,
dbo_decode_token_threshold=self.dbo_decode_token_threshold,
dbo_prefill_token_threshold=self.dbo_prefill_token_threshold,

View File

@ -641,6 +641,7 @@ class FusedMoEParallelConfig:
ep_rank: int
use_ep: bool # whether to use EP or not
all2all_backend: str # all2all backend for MoE communication
@property
def use_all2all_kernels(self):
@ -648,21 +649,18 @@ class FusedMoEParallelConfig:
@property
def use_pplx_kernels(self):
return self.use_all2all_kernels and envs.VLLM_ALL2ALL_BACKEND == "pplx"
return self.use_all2all_kernels and self.all2all_backend == "pplx"
@property
def use_deepep_ht_kernels(self):
return (
self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
and self.all2all_backend == "deepep_high_throughput"
)
@property
def use_deepep_ll_kernels(self):
return (
self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency"
)
return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency"
@staticmethod
def make(
@ -762,6 +760,7 @@ class FusedMoEParallelConfig:
ep_size=1,
ep_rank=0,
use_ep=False,
all2all_backend=vllm_parallel_config.all2all_backend,
)
# DP + EP / TP + EP / DP + TP + EP
assert use_ep
@ -777,6 +776,7 @@ class FusedMoEParallelConfig:
ep_size=ep_size,
ep_rank=ep_rank,
use_ep=True,
all2all_backend=vllm_parallel_config.all2all_backend,
)

View File

@ -58,7 +58,7 @@ def build_flashinfer_fp4_cutlass_moe_prepare_finalize(
) -> mk.FusedMoEPrepareAndFinalize:
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
use_dp = moe.moe_parallel_config.dp_size > 1
enable_alltoallv = envs.VLLM_ALL2ALL_BACKEND == "flashinfer_all2allv"
enable_alltoallv = moe.moe_parallel_config.all2all_backend == "flashinfer_all2allv"
return create_flashinfer_prepare_finalize(
use_dp=use_dp, use_nvfp4=True, enable_alltoallv=enable_alltoallv
)

View File

@ -192,7 +192,7 @@ class CudaPlatformBase(Platform):
compilation_config = vllm_config.compilation_config
if (
envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
parallel_config.all2all_backend == "deepep_high_throughput"
and parallel_config.data_parallel_size > 1
and compilation_config.cudagraph_mode != CUDAGraphMode.NONE
):
@ -204,7 +204,7 @@ class CudaPlatformBase(Platform):
"kernels are optimized for prefill and are incompatible with "
"CUDA Graphs. "
"In order to use CUDA Graphs for decode-optimized workloads, "
"set VLLM_ALL2ALL_BACKEND to another option, such as "
"use --all2all-backend with another option, such as "
"deepep_low_latency, pplx, or allgather_reducescatter."
)
compilation_config.cudagraph_mode = CUDAGraphMode.NONE

View File

@ -356,9 +356,10 @@ class CoreEngineActorManager:
)
device_str = current_platform.ray_device_key
all2all_backend = vllm_config.parallel_config.all2all_backend
if envs.VLLM_RAY_DP_PACK_STRATEGY == "fill" and (
envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
or envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency"
all2all_backend == "deepep_high_throughput"
or all2all_backend == "deepep_low_latency"
):
raise ValueError(
"DeepEP kernels require EP ranks [0,7] (same for [8,15], ...) "