DP/EP Support for gpt-oss with deepep-ht comm kernel on SM100 (#23608)

This commit is contained in:
Yongye Zhu
2025-08-27 17:33:21 -04:00
committed by GitHub
parent 853c371fc3
commit 082cc07ef8
11 changed files with 365 additions and 12 deletions

View File

@ -255,7 +255,7 @@ class DeviceCommunicatorBase:
if module.__class__.__name__ == "FusedMoE"
]
for module in moe_modules:
module.quant_method.init_prepare_finalize()
module.quant_method.init_prepare_finalize(module)
def dispatch(
self, hidden_states: torch.Tensor,

View File

@ -450,6 +450,12 @@ class FusedMoEConfig:
if quant_dtype is None and isinstance(quant_config, Fp8Config):
quant_dtype = torch.float8_e4m3fn
from vllm.model_executor.layers.quantization.mxfp4 import (
Mxfp4Config)
if (quant_dtype is None and isinstance(quant_config, Mxfp4Config)
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8):
quant_dtype = "mxfp8"
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptNvFp4Config)
if quant_dtype is None and isinstance(quant_config,

View File

@ -200,7 +200,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
# Note: init_prepare_finalize should only be called by
# prepare_communication_buffer_for_model.
def init_prepare_finalize(self):
def init_prepare_finalize(self, layer: torch.nn.Module):
assert self.moe is not None
prepare_finalize = self.maybe_make_prepare_finalize(self.moe)
@ -211,7 +211,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
assert self.fused_experts is None, \
f"Attempt to override experts for {id(self)}!"
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
experts = self.select_gemm_impl(prepare_finalize, self.moe)
experts = self.select_gemm_impl(prepare_finalize, self.moe, layer)
self.fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
@ -221,6 +221,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
# based on the all2all implementation, select the appropriate
# gemm implementation
@ -273,6 +274,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
prepare_finalize: FusedMoEPrepareAndFinalize,
# TODO(bnell): Remove. Every layer should have an moe config object.
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):

View File

@ -0,0 +1,197 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP)
from vllm.utils import next_power_of_2
class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
moe: FusedMoEConfig,
gemm1_alpha,
gemm1_beta,
gemm1_clamp_limit,
w13_bias,
w2_bias,
max_capture_size,
):
super().__init__(moe.quant_config)
self.moe = moe
self.gemm1_alpha = gemm1_alpha
self.gemm1_beta = gemm1_beta
self.gemm1_clamp_limit = gemm1_clamp_limit
self.w13_bias = w13_bias
self.w2_bias = w2_bias
self.max_capture_size = max_capture_size
@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard)
def supports_chunking(self) -> bool:
return True
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# The workspaces for this implementation are managed by flashinfer.
# TODO(varun) : workspace1 is could be used as the output tensor. This
# is error-prone. Allow the `workspace_shapes` to return None workspaces
workspace1 = (M, K)
workspace2 = (0, 0)
output = (M, K)
return (workspace1, workspace2, output, a.dtype)
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int,
local_num_experts: int):
# Number of tokens in the input tensor.
num_tokens = x.shape[0]
# Factor to account for the imbalance of the experts.
# factor equals to the
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
# 1.0 means perfect expert distribution.
# > 1.0 means some experts have more tokens than the perfect
# distribution.
# < 1.0 does not make sense.
imbalance_factor = 1.3
# Calculate the number of tokens per expert assuming perfect
# distribution.
num_tokens_per_expert = (num_tokens * top_k) // local_num_experts
# Apply the imbalance factor.
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
# And pad the number to the next power of 2.
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the
# kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
return tile_tokens_dim
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
topk = topk_ids.size(-1)
local_num_experts = w1.size(0)
intermediate_size = w2.size(1)
local_expert_offset = self.moe.ep_rank * local_num_experts
x_quant = hidden_states
x_scale = a1q_scale
if x_scale is not None:
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(
*x_quant.shape[:-1], -1)
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
torch.bfloat16).view(torch.int16)
assert w1_scale is not None
assert w2_scale is not None
kwargs = {
"topk_ids":
packed_tensor,
"routing_bias":
None,
"hidden_states":
x_quant,
"hidden_states_scale":
x_scale,
"gemm1_weights":
w1,
"gemm1_weights_scale":
w1_scale,
"gemm1_bias":
self.w13_bias,
"gemm1_alpha":
self.gemm1_alpha,
"gemm1_beta":
self.gemm1_beta,
"gemm1_clamp_limit":
self.gemm1_clamp_limit,
"gemm2_weights":
w2,
"gemm2_weights_scale":
w2_scale,
"gemm2_bias":
self.w2_bias,
"output1_scale_scalar":
None,
"output1_scale_gate_scalar":
None,
"output2_scale_scalar":
None,
"num_experts":
global_num_experts,
"top_k":
topk,
"n_group":
None,
"topk_group":
None,
"intermediate_size":
intermediate_size,
"local_expert_offset":
local_expert_offset,
"local_num_experts":
local_num_experts,
"routed_scaling_factor":
None,
"tile_tokens_dim":
self._get_tile_tokens_dim(x_quant, topk, local_num_experts),
"routing_method_type":
1,
"do_finalize":
True,
"output":
output,
"tune_max_num_tokens":
self.max_capture_size,
}
from flashinfer import trtllm_fp4_block_scale_routed_moe
trtllm_fp4_block_scale_routed_moe(**kwargs)
return output

View File

@ -12,6 +12,8 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
per_token_group_quant_int8, per_token_quant_int8)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
quant_dequant_mxfp4)
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
mxfp8_quantize)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv
@ -177,6 +179,18 @@ def _mxfp4_quantize(
return A, None
def _mxfp8_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
per_act_token_quant: bool,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
assert A_scale is None
assert not per_act_token_quant
assert block_shape is None
return mxfp8_quantize(A)
def moe_kernel_quantize_input(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
@ -195,6 +209,8 @@ def moe_kernel_quantize_input(
is_sf_swizzled_layout=is_fp4_scale_swizzled)
elif quant_dtype == "mxfp4":
return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == "mxfp8":
return _mxfp8_quantize(A, A_scale, per_act_token_quant, block_shape)
else:
return A, A_scale

View File

@ -322,6 +322,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
"""Return the appropriate GEMM experts implementation."""
experts = select_nvfp4_gemm_impl(
@ -719,10 +720,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
dtype=torch.int64)
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute:
self, prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module) -> FusedMoEPermuteExpertsUnpermute:
# cutlass path
if self.use_cutlass:
from vllm.model_executor.layers.fused_moe import (

View File

@ -897,6 +897,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
from vllm.model_executor.layers.fused_moe import (
BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts)

View File

@ -311,6 +311,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
experts = select_cutlass_fp8_gemm_impl(
moe,
@ -1032,6 +1033,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
experts = select_nvfp4_gemm_impl(
moe,

View File

@ -10,6 +10,8 @@ from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
@ -445,6 +447,91 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
return tile_tokens_dim
def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
if (prepare_finalize.activation_format ==
mk.FusedMoEActivationFormat.BatchedExperts):
raise NotImplementedError(
"Mxfp4 does not support batched experts format for EP")
else:
if should_use_flashinfer_mxfp4():
# B200 code-path
kwargs = {
"gemm1_alpha": layer.gemm1_alpha,
"gemm1_beta": layer.gemm1_beta,
"gemm1_clamp_limit": layer.gemm1_clamp_limit,
"w13_bias": layer.w13_bias,
"w2_bias": layer.w2_bias,
"max_capture_size": self.max_capture_size,
}
return TrtLlmGenExperts(moe, **kwargs)
else:
# Use matmul_ogs from triton_kernels here!
raise NotImplementedError(
"Mxfp4 does not support non-batched experts format for EP")
def _route_and_experts(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None
) -> torch.Tensor:
assert isinstance(self.fused_experts, mk.FusedMoEModularKernel)
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype,
enable_eplb=enable_eplb,
expert_map=expert_map,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count)
return self.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
)
def apply(
self,
layer: torch.nn.Module,
@ -503,6 +590,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
activation=activation,
expert_map=expert_map)
if self.fused_experts is not None:
return self._route_and_experts(
layer,
x,
router_logits,
top_k,
renormalize,
use_grouped_topk,
topk_group,
num_expert_group,
global_num_experts,
expert_map,
custom_routing_function,
scoring_func,
e_score_correction_bias,
apply_router_weight_on_input,
activation,
enable_eplb,
expert_load_view,
logical_to_physical_map,
logical_replica_count,
)
assert _can_support_mxfp4(
use_grouped_topk, topk_group, num_expert_group, expert_map,
custom_routing_function, e_score_correction_bias,

View File

@ -66,11 +66,10 @@ def _can_support_mxfp4(use_grouped_topk: bool = False,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None):
return not (use_grouped_topk or topk_group or num_expert_group
or expert_map or custom_routing_function
or e_score_correction_bias or apply_router_weight_on_input
or scoring_func != "softmax" or activation != "swigluoai"
or expert_load_view or logical_to_physical_map
or logical_replica_count)
or custom_routing_function or e_score_correction_bias
or apply_router_weight_on_input or scoring_func != "softmax"
or activation != "swigluoai" or expert_load_view
or logical_to_physical_map or logical_replica_count)
def _dequant_mxfp4(x: torch.Tensor, scale: torch.Tensor,

View File

@ -0,0 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
def mxfp8_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
try:
from flashinfer import mxfp8_quantize
except ImportError as err:
raise ImportError("The package `flashinfer` is required to do "
"MX-FP8 quantization. Please install it with" \
"`pip install flashinfer`") from err
return mxfp8_quantize(x, is_sf_swizzled_layout=False)