mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Performance] Run shared_experts on a separate cuda stream (in parallel with the FusedMoE)
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
This commit is contained in:
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import os
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Callable, Iterable
|
from collections.abc import Callable, Iterable
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
@ -1073,6 +1074,20 @@ class FusedMoE(CustomOp):
|
|||||||
n_shared_experts: int | None = None,
|
n_shared_experts: int | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
# TODO: Allow disabling of the separate shared experts stream for
|
||||||
|
# debug purposes. Remove this after more extensive testings with
|
||||||
|
# TP/DP and other execution modes
|
||||||
|
disable_shared_experts_stream = os.environ.get(
|
||||||
|
"DISABLE_MOE_SHARED_EXPERTS_CUDA_STREAM", None
|
||||||
|
)
|
||||||
|
|
||||||
|
if disable_shared_experts_stream is not None:
|
||||||
|
logger.info_once("Disabling MoE shared_experts cuda stream")
|
||||||
|
self.shared_experts_stream = None
|
||||||
|
else:
|
||||||
|
self.shared_experts_stream = torch.cuda.Stream()
|
||||||
|
|
||||||
if params_dtype is None:
|
if params_dtype is None:
|
||||||
params_dtype = torch.get_default_dtype()
|
params_dtype = torch.get_default_dtype()
|
||||||
self.params_dtype = params_dtype
|
self.params_dtype = params_dtype
|
||||||
@ -1322,6 +1337,10 @@ class FusedMoE(CustomOp):
|
|||||||
def shared_experts(self) -> torch.nn.Module | None:
|
def shared_experts(self) -> torch.nn.Module | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def gate(self) -> torch.nn.Module | None:
|
||||||
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tp_size(self):
|
def tp_size(self):
|
||||||
return self.moe_parallel_config.tp_size
|
return self.moe_parallel_config.tp_size
|
||||||
@ -2144,6 +2163,7 @@ class FusedMoE(CustomOp):
|
|||||||
self,
|
self,
|
||||||
full_hidden_states: torch.Tensor,
|
full_hidden_states: torch.Tensor,
|
||||||
full_router_logits: torch.Tensor,
|
full_router_logits: torch.Tensor,
|
||||||
|
has_separate_shared_experts: bool,
|
||||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
assert self.batched_hidden_states is not None
|
assert self.batched_hidden_states is not None
|
||||||
assert self.batched_router_logits is not None
|
assert self.batched_router_logits is not None
|
||||||
@ -2192,11 +2212,24 @@ class FusedMoE(CustomOp):
|
|||||||
|
|
||||||
# If there are shared experts but we are not using a modular kernel,
|
# If there are shared experts but we are not using a modular kernel,
|
||||||
# the shared experts must be called here
|
# the shared experts must be called here
|
||||||
if (
|
if has_separate_shared_experts:
|
||||||
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
|
assert self.shared_experts is not None
|
||||||
and self.shared_experts is not None
|
|
||||||
):
|
if self.shared_experts_stream is not None:
|
||||||
|
# For chunked, we start the shared experts stream here
|
||||||
|
# (Note that no concurrency with the router/gate)
|
||||||
|
current_stream = torch.cuda.current_stream()
|
||||||
|
self.shared_experts_stream.wait_stream(current_stream)
|
||||||
|
|
||||||
|
with torch.cuda.stream(self.shared_experts_stream):
|
||||||
|
# Note that staged_hidden_states clone() is necessary
|
||||||
|
# here to avoid conflict with the main stream
|
||||||
|
shared_output = self.shared_experts(
|
||||||
|
staged_hidden_states.clone()
|
||||||
|
)
|
||||||
|
else:
|
||||||
shared_output = self.shared_experts(staged_hidden_states)
|
shared_output = self.shared_experts(staged_hidden_states)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
shared_output = None
|
shared_output = None
|
||||||
|
|
||||||
@ -2225,9 +2258,14 @@ class FusedMoE(CustomOp):
|
|||||||
logical_replica_count=self.logical_replica_count,
|
logical_replica_count=self.logical_replica_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
if shared_output is not None:
|
if has_separate_shared_experts:
|
||||||
assert not isinstance(final_hidden_states, tuple)
|
assert not isinstance(final_hidden_states, tuple)
|
||||||
assert self.shared_experts is not None
|
assert self.shared_experts is not None
|
||||||
|
|
||||||
|
# Here we finish the shared experts stream
|
||||||
|
if self.shared_experts_stream is not None:
|
||||||
|
current_stream.wait_stream(self.shared_experts_stream)
|
||||||
|
|
||||||
final_hidden_states = (
|
final_hidden_states = (
|
||||||
shared_output,
|
shared_output,
|
||||||
final_hidden_states,
|
final_hidden_states,
|
||||||
@ -2297,8 +2335,34 @@ class FusedMoE(CustomOp):
|
|||||||
|
|
||||||
self.ensure_moe_quant_config()
|
self.ensure_moe_quant_config()
|
||||||
|
|
||||||
if self.use_dp_chunking:
|
has_separate_shared_experts = (
|
||||||
return self.forward_impl_chunked(hidden_states, router_logits)
|
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
|
||||||
|
and self.shared_experts is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
use_chunked_impl = self.use_dp_chunking
|
||||||
|
|
||||||
|
if (
|
||||||
|
has_separate_shared_experts
|
||||||
|
and not use_chunked_impl
|
||||||
|
and self.shared_experts_stream is not None
|
||||||
|
):
|
||||||
|
# Start the separate shared experts stream here since we want
|
||||||
|
# to run in parallel with the router/gate (next op below)
|
||||||
|
current_stream = torch.cuda.current_stream()
|
||||||
|
self.shared_experts_stream.wait_stream(current_stream)
|
||||||
|
|
||||||
|
# If router/gate provided, then apply it here.
|
||||||
|
# (Note: This code runs only when "overlapped mode" is on to allow
|
||||||
|
# parallel execution of shared experts with the FusedMoE via
|
||||||
|
# separate cuda stream)
|
||||||
|
if self.gate is not None:
|
||||||
|
router_logits, _ = self.gate(hidden_states)
|
||||||
|
|
||||||
|
if use_chunked_impl:
|
||||||
|
return self.forward_impl_chunked(
|
||||||
|
hidden_states, router_logits, has_separate_shared_experts
|
||||||
|
)
|
||||||
|
|
||||||
do_naive_dispatch_combine: bool = (
|
do_naive_dispatch_combine: bool = (
|
||||||
self.dp_size > 1 and not self.quant_method.using_modular_kernel
|
self.dp_size > 1 and not self.quant_method.using_modular_kernel
|
||||||
@ -2306,10 +2370,16 @@ class FusedMoE(CustomOp):
|
|||||||
|
|
||||||
# If there are shared experts but we are not using a modular kernel, the
|
# If there are shared experts but we are not using a modular kernel, the
|
||||||
# shared experts must be called here
|
# shared experts must be called here
|
||||||
if (
|
if has_separate_shared_experts:
|
||||||
not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel)
|
assert self.shared_experts is not None
|
||||||
and self.shared_experts is not None
|
|
||||||
):
|
if self.shared_experts_stream is not None:
|
||||||
|
# Run shared experts in parallel on a separate stream
|
||||||
|
with torch.cuda.stream(self.shared_experts_stream):
|
||||||
|
# Note that hidden_states clone() is necessary here to avoid
|
||||||
|
# conflict with the main stream
|
||||||
|
shared_output = self.shared_experts(hidden_states.clone())
|
||||||
|
else:
|
||||||
shared_output = self.shared_experts(hidden_states)
|
shared_output = self.shared_experts(hidden_states)
|
||||||
else:
|
else:
|
||||||
shared_output = None
|
shared_output = None
|
||||||
@ -2353,9 +2423,14 @@ class FusedMoE(CustomOp):
|
|||||||
logical_replica_count=self.logical_replica_count,
|
logical_replica_count=self.logical_replica_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
if shared_output is not None:
|
if has_separate_shared_experts:
|
||||||
assert not isinstance(final_hidden_states, tuple)
|
assert not isinstance(final_hidden_states, tuple)
|
||||||
assert self.shared_experts is not None
|
assert self.shared_experts is not None
|
||||||
|
|
||||||
|
# Wait for the parallel shared experts stream to finish here
|
||||||
|
if self.shared_experts_stream is not None:
|
||||||
|
current_stream.wait_stream(self.shared_experts_stream)
|
||||||
|
|
||||||
final_hidden_states = (
|
final_hidden_states = (
|
||||||
shared_output,
|
shared_output,
|
||||||
final_hidden_states,
|
final_hidden_states,
|
||||||
|
@ -18,25 +18,36 @@ class SharedFusedMoE(FusedMoE):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
shared_experts: torch.nn.Module | None,
|
shared_experts: torch.nn.Module | None,
|
||||||
|
gate: torch.nn.Module | None = None,
|
||||||
use_overlapped: bool = True,
|
use_overlapped: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self._shared_experts = shared_experts
|
self._shared_experts = shared_experts
|
||||||
|
|
||||||
# Disable shared expert overlap if EP is disabled or we are not using
|
# Disable shared expert overlap if EP is disabled or we are not using
|
||||||
# flashinfer + DP since there is nothing to be gained in this case.
|
# flashinfer + DP since there is nothing to be gained in this case.
|
||||||
# Disabling the overlap optimization also prevents the shared experts
|
# Disabling the overlap optimization also prevents the shared experts
|
||||||
# from being hidden from torch.compile.
|
# from being hidden from torch.compile.
|
||||||
self.use_overlapped = (
|
self.use_overlapped = (
|
||||||
use_overlapped
|
use_overlapped
|
||||||
and not (self.use_ep or self.use_flashinfer_cutlass_kernels)
|
and not (
|
||||||
|
self.use_ep
|
||||||
|
or (self.use_flashinfer_cutlass_kernels and self.dp_size > 1)
|
||||||
|
)
|
||||||
and self._shared_experts is not None
|
and self._shared_experts is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._gate = gate
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shared_experts(self) -> torch.nn.Module | None:
|
def shared_experts(self) -> torch.nn.Module | None:
|
||||||
return self._shared_experts if self.use_overlapped else None
|
return self._shared_experts if self.use_overlapped else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def gate(self) -> torch.nn.Module | None:
|
||||||
|
return self._gate if self.use_overlapped else None
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
@ -227,6 +227,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
|
|
||||||
self.experts = SharedFusedMoE(
|
self.experts = SharedFusedMoE(
|
||||||
shared_experts=self.shared_experts,
|
shared_experts=self.shared_experts,
|
||||||
|
gate=self.gate,
|
||||||
num_experts=config.n_routed_experts,
|
num_experts=config.n_routed_experts,
|
||||||
top_k=config.num_experts_per_tok,
|
top_k=config.num_experts_per_tok,
|
||||||
hidden_size=config.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
@ -264,9 +265,13 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
if self.is_sequence_parallel:
|
if self.is_sequence_parallel:
|
||||||
hidden_states = sequence_parallel_chunk(hidden_states)
|
hidden_states = sequence_parallel_chunk(hidden_states)
|
||||||
|
|
||||||
|
if isinstance(self.experts, SharedFusedMoE) and self.experts.use_overlapped:
|
||||||
|
fused_moe_out = self.experts(
|
||||||
|
hidden_states=hidden_states, router_logits=hidden_states
|
||||||
|
)
|
||||||
|
else:
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
|
|
||||||
fused_moe_out = self.experts(
|
fused_moe_out = self.experts(
|
||||||
hidden_states=hidden_states, router_logits=router_logits
|
hidden_states=hidden_states, router_logits=router_logits
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user