[Performance] Run shared_experts on a separate cuda stream (in parallel with the FusedMoE)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
Alexander Matveev
2025-10-08 09:37:59 -07:00
parent 314fa8abbf
commit 6f30ab9ab3
3 changed files with 112 additions and 21 deletions

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from contextlib import nullcontext from contextlib import nullcontext
@ -1073,6 +1074,20 @@ class FusedMoE(CustomOp):
n_shared_experts: int | None = None, n_shared_experts: int | None = None,
): ):
super().__init__() super().__init__()
# TODO: Allow disabling of the separate shared experts stream for
# debug purposes. Remove this after more extensive testings with
# TP/DP and other execution modes
disable_shared_experts_stream = os.environ.get(
"DISABLE_MOE_SHARED_EXPERTS_CUDA_STREAM", None
)
if disable_shared_experts_stream is not None:
logger.info_once("Disabling MoE shared_experts cuda stream")
self.shared_experts_stream = None
else:
self.shared_experts_stream = torch.cuda.Stream()
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype self.params_dtype = params_dtype
@ -1322,6 +1337,10 @@ class FusedMoE(CustomOp):
def shared_experts(self) -> torch.nn.Module | None: def shared_experts(self) -> torch.nn.Module | None:
return None return None
@property
def gate(self) -> torch.nn.Module | None:
return None
@property @property
def tp_size(self): def tp_size(self):
return self.moe_parallel_config.tp_size return self.moe_parallel_config.tp_size
@ -2144,6 +2163,7 @@ class FusedMoE(CustomOp):
self, self,
full_hidden_states: torch.Tensor, full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor, full_router_logits: torch.Tensor,
has_separate_shared_experts: bool,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.batched_hidden_states is not None assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None assert self.batched_router_logits is not None
@ -2192,11 +2212,24 @@ class FusedMoE(CustomOp):
# If there are shared experts but we are not using a modular kernel, # If there are shared experts but we are not using a modular kernel,
# the shared experts must be called here # the shared experts must be called here
if ( if has_separate_shared_experts:
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel) assert self.shared_experts is not None
and self.shared_experts is not None
): if self.shared_experts_stream is not None:
# For chunked, we start the shared experts stream here
# (Note that no concurrency with the router/gate)
current_stream = torch.cuda.current_stream()
self.shared_experts_stream.wait_stream(current_stream)
with torch.cuda.stream(self.shared_experts_stream):
# Note that staged_hidden_states clone() is necessary
# here to avoid conflict with the main stream
shared_output = self.shared_experts(
staged_hidden_states.clone()
)
else:
shared_output = self.shared_experts(staged_hidden_states) shared_output = self.shared_experts(staged_hidden_states)
else: else:
shared_output = None shared_output = None
@ -2225,9 +2258,14 @@ class FusedMoE(CustomOp):
logical_replica_count=self.logical_replica_count, logical_replica_count=self.logical_replica_count,
) )
if shared_output is not None: if has_separate_shared_experts:
assert not isinstance(final_hidden_states, tuple) assert not isinstance(final_hidden_states, tuple)
assert self.shared_experts is not None assert self.shared_experts is not None
# Here we finish the shared experts stream
if self.shared_experts_stream is not None:
current_stream.wait_stream(self.shared_experts_stream)
final_hidden_states = ( final_hidden_states = (
shared_output, shared_output,
final_hidden_states, final_hidden_states,
@ -2297,8 +2335,34 @@ class FusedMoE(CustomOp):
self.ensure_moe_quant_config() self.ensure_moe_quant_config()
if self.use_dp_chunking: has_separate_shared_experts = (
return self.forward_impl_chunked(hidden_states, router_logits) not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
and self.shared_experts is not None
)
use_chunked_impl = self.use_dp_chunking
if (
has_separate_shared_experts
and not use_chunked_impl
and self.shared_experts_stream is not None
):
# Start the separate shared experts stream here since we want
# to run in parallel with the router/gate (next op below)
current_stream = torch.cuda.current_stream()
self.shared_experts_stream.wait_stream(current_stream)
# If router/gate provided, then apply it here.
# (Note: This code runs only when "overlapped mode" is on to allow
# parallel execution of shared experts with the FusedMoE via
# separate cuda stream)
if self.gate is not None:
router_logits, _ = self.gate(hidden_states)
if use_chunked_impl:
return self.forward_impl_chunked(
hidden_states, router_logits, has_separate_shared_experts
)
do_naive_dispatch_combine: bool = ( do_naive_dispatch_combine: bool = (
self.dp_size > 1 and not self.quant_method.using_modular_kernel self.dp_size > 1 and not self.quant_method.using_modular_kernel
@ -2306,10 +2370,16 @@ class FusedMoE(CustomOp):
# If there are shared experts but we are not using a modular kernel, the # If there are shared experts but we are not using a modular kernel, the
# shared experts must be called here # shared experts must be called here
if ( if has_separate_shared_experts:
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel) assert self.shared_experts is not None
and self.shared_experts is not None
): if self.shared_experts_stream is not None:
# Run shared experts in parallel on a separate stream
with torch.cuda.stream(self.shared_experts_stream):
# Note that hidden_states clone() is necessary here to avoid
# conflict with the main stream
shared_output = self.shared_experts(hidden_states.clone())
else:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
else: else:
shared_output = None shared_output = None
@ -2353,9 +2423,14 @@ class FusedMoE(CustomOp):
logical_replica_count=self.logical_replica_count, logical_replica_count=self.logical_replica_count,
) )
if shared_output is not None: if has_separate_shared_experts:
assert not isinstance(final_hidden_states, tuple) assert not isinstance(final_hidden_states, tuple)
assert self.shared_experts is not None assert self.shared_experts is not None
# Wait for the parallel shared experts stream to finish here
if self.shared_experts_stream is not None:
current_stream.wait_stream(self.shared_experts_stream)
final_hidden_states = ( final_hidden_states = (
shared_output, shared_output,
final_hidden_states, final_hidden_states,

View File

@ -18,25 +18,36 @@ class SharedFusedMoE(FusedMoE):
def __init__( def __init__(
self, self,
shared_experts: torch.nn.Module | None, shared_experts: torch.nn.Module | None,
gate: torch.nn.Module | None = None,
use_overlapped: bool = True, use_overlapped: bool = True,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
self._shared_experts = shared_experts self._shared_experts = shared_experts
# Disable shared expert overlap if EP is disabled or we are not using # Disable shared expert overlap if EP is disabled or we are not using
# flashinfer + DP since there is nothing to be gained in this case. # flashinfer + DP since there is nothing to be gained in this case.
# Disabling the overlap optimization also prevents the shared experts # Disabling the overlap optimization also prevents the shared experts
# from being hidden from torch.compile. # from being hidden from torch.compile.
self.use_overlapped = ( self.use_overlapped = (
use_overlapped use_overlapped
and not (self.use_ep or self.use_flashinfer_cutlass_kernels) and not (
self.use_ep
or (self.use_flashinfer_cutlass_kernels and self.dp_size > 1)
)
and self._shared_experts is not None and self._shared_experts is not None
) )
self._gate = gate
@property @property
def shared_experts(self) -> torch.nn.Module | None: def shared_experts(self) -> torch.nn.Module | None:
return self._shared_experts if self.use_overlapped else None return self._shared_experts if self.use_overlapped else None
@property
def gate(self) -> torch.nn.Module | None:
return self._gate if self.use_overlapped else None
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,

View File

@ -227,6 +227,7 @@ class DeepseekV2MoE(nn.Module):
self.experts = SharedFusedMoE( self.experts = SharedFusedMoE(
shared_experts=self.shared_experts, shared_experts=self.shared_experts,
gate=self.gate,
num_experts=config.n_routed_experts, num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
@ -264,9 +265,13 @@ class DeepseekV2MoE(nn.Module):
if self.is_sequence_parallel: if self.is_sequence_parallel:
hidden_states = sequence_parallel_chunk(hidden_states) hidden_states = sequence_parallel_chunk(hidden_states)
if isinstance(self.experts, SharedFusedMoE) and self.experts.use_overlapped:
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=hidden_states
)
else:
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
fused_moe_out = self.experts( fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits hidden_states=hidden_states, router_logits=router_logits
) )