[Performance] Run shared_experts on a separate cuda stream (in parallel with the FusedMoE)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
Alexander Matveev
2025-10-08 09:37:59 -07:00
parent 314fa8abbf
commit 6f30ab9ab3
3 changed files with 112 additions and 21 deletions

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from abc import abstractmethod
from collections.abc import Callable, Iterable
from contextlib import nullcontext
@ -1073,6 +1074,20 @@ class FusedMoE(CustomOp):
n_shared_experts: int | None = None,
):
super().__init__()
# TODO: Allow disabling of the separate shared experts stream for
# debug purposes. Remove this after more extensive testings with
# TP/DP and other execution modes
disable_shared_experts_stream = os.environ.get(
"DISABLE_MOE_SHARED_EXPERTS_CUDA_STREAM", None
)
if disable_shared_experts_stream is not None:
logger.info_once("Disabling MoE shared_experts cuda stream")
self.shared_experts_stream = None
else:
self.shared_experts_stream = torch.cuda.Stream()
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
@ -1322,6 +1337,10 @@ class FusedMoE(CustomOp):
def shared_experts(self) -> torch.nn.Module | None:
return None
@property
def gate(self) -> torch.nn.Module | None:
return None
@property
def tp_size(self):
return self.moe_parallel_config.tp_size
@ -2144,6 +2163,7 @@ class FusedMoE(CustomOp):
self,
full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor,
has_separate_shared_experts: bool,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None
@ -2192,11 +2212,24 @@ class FusedMoE(CustomOp):
# If there are shared experts but we are not using a modular kernel,
# the shared experts must be called here
if (
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
and self.shared_experts is not None
):
if has_separate_shared_experts:
assert self.shared_experts is not None
if self.shared_experts_stream is not None:
# For chunked, we start the shared experts stream here
# (Note that no concurrency with the router/gate)
current_stream = torch.cuda.current_stream()
self.shared_experts_stream.wait_stream(current_stream)
with torch.cuda.stream(self.shared_experts_stream):
# Note that staged_hidden_states clone() is necessary
# here to avoid conflict with the main stream
shared_output = self.shared_experts(
staged_hidden_states.clone()
)
else:
shared_output = self.shared_experts(staged_hidden_states)
else:
shared_output = None
@ -2225,9 +2258,14 @@ class FusedMoE(CustomOp):
logical_replica_count=self.logical_replica_count,
)
if shared_output is not None:
if has_separate_shared_experts:
assert not isinstance(final_hidden_states, tuple)
assert self.shared_experts is not None
# Here we finish the shared experts stream
if self.shared_experts_stream is not None:
current_stream.wait_stream(self.shared_experts_stream)
final_hidden_states = (
shared_output,
final_hidden_states,
@ -2297,8 +2335,34 @@ class FusedMoE(CustomOp):
self.ensure_moe_quant_config()
if self.use_dp_chunking:
return self.forward_impl_chunked(hidden_states, router_logits)
has_separate_shared_experts = (
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
and self.shared_experts is not None
)
use_chunked_impl = self.use_dp_chunking
if (
has_separate_shared_experts
and not use_chunked_impl
and self.shared_experts_stream is not None
):
# Start the separate shared experts stream here since we want
# to run in parallel with the router/gate (next op below)
current_stream = torch.cuda.current_stream()
self.shared_experts_stream.wait_stream(current_stream)
# If router/gate provided, then apply it here.
# (Note: This code runs only when "overlapped mode" is on to allow
# parallel execution of shared experts with the FusedMoE via
# separate cuda stream)
if self.gate is not None:
router_logits, _ = self.gate(hidden_states)
if use_chunked_impl:
return self.forward_impl_chunked(
hidden_states, router_logits, has_separate_shared_experts
)
do_naive_dispatch_combine: bool = (
self.dp_size > 1 and not self.quant_method.using_modular_kernel
@ -2306,10 +2370,16 @@ class FusedMoE(CustomOp):
# If there are shared experts but we are not using a modular kernel, the
# shared experts must be called here
if (
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
and self.shared_experts is not None
):
if has_separate_shared_experts:
assert self.shared_experts is not None
if self.shared_experts_stream is not None:
# Run shared experts in parallel on a separate stream
with torch.cuda.stream(self.shared_experts_stream):
# Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream
shared_output = self.shared_experts(hidden_states.clone())
else:
shared_output = self.shared_experts(hidden_states)
else:
shared_output = None
@ -2353,9 +2423,14 @@ class FusedMoE(CustomOp):
logical_replica_count=self.logical_replica_count,
)
if shared_output is not None:
if has_separate_shared_experts:
assert not isinstance(final_hidden_states, tuple)
assert self.shared_experts is not None
# Wait for the parallel shared experts stream to finish here
if self.shared_experts_stream is not None:
current_stream.wait_stream(self.shared_experts_stream)
final_hidden_states = (
shared_output,
final_hidden_states,

View File

@ -18,25 +18,36 @@ class SharedFusedMoE(FusedMoE):
def __init__(
self,
shared_experts: torch.nn.Module | None,
gate: torch.nn.Module | None = None,
use_overlapped: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self._shared_experts = shared_experts
# Disable shared expert overlap if EP is disabled or we are not using
# flashinfer + DP since there is nothing to be gained in this case.
# Disabling the overlap optimization also prevents the shared experts
# from being hidden from torch.compile.
self.use_overlapped = (
use_overlapped
and not (self.use_ep or self.use_flashinfer_cutlass_kernels)
and not (
self.use_ep
or (self.use_flashinfer_cutlass_kernels and self.dp_size > 1)
)
and self._shared_experts is not None
)
self._gate = gate
@property
def shared_experts(self) -> torch.nn.Module | None:
return self._shared_experts if self.use_overlapped else None
@property
def gate(self) -> torch.nn.Module | None:
return self._gate if self.use_overlapped else None
def forward(
self,
hidden_states: torch.Tensor,

View File

@ -227,6 +227,7 @@ class DeepseekV2MoE(nn.Module):
self.experts = SharedFusedMoE(
shared_experts=self.shared_experts,
gate=self.gate,
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
@ -264,9 +265,13 @@ class DeepseekV2MoE(nn.Module):
if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states)
if isinstance(self.experts, SharedFusedMoE) and self.experts.use_overlapped:
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=hidden_states
)
else:
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)