DP/EP Support for gpt-oss with deepep-ht comm kernel on SM100 (#23608)
This commit is contained in:
@ -255,7 +255,7 @@ class DeviceCommunicatorBase:
|
||||
if module.__class__.__name__ == "FusedMoE"
|
||||
]
|
||||
for module in moe_modules:
|
||||
module.quant_method.init_prepare_finalize()
|
||||
module.quant_method.init_prepare_finalize(module)
|
||||
|
||||
def dispatch(
|
||||
self, hidden_states: torch.Tensor,
|
||||
|
@ -450,6 +450,12 @@ class FusedMoEConfig:
|
||||
if quant_dtype is None and isinstance(quant_config, Fp8Config):
|
||||
quant_dtype = torch.float8_e4m3fn
|
||||
|
||||
from vllm.model_executor.layers.quantization.mxfp4 import (
|
||||
Mxfp4Config)
|
||||
if (quant_dtype is None and isinstance(quant_config, Mxfp4Config)
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8):
|
||||
quant_dtype = "mxfp8"
|
||||
|
||||
from vllm.model_executor.layers.quantization.modelopt import (
|
||||
ModelOptNvFp4Config)
|
||||
if quant_dtype is None and isinstance(quant_config,
|
||||
|
@ -200,7 +200,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
# Note: init_prepare_finalize should only be called by
|
||||
# prepare_communication_buffer_for_model.
|
||||
def init_prepare_finalize(self):
|
||||
def init_prepare_finalize(self, layer: torch.nn.Module):
|
||||
assert self.moe is not None
|
||||
prepare_finalize = self.maybe_make_prepare_finalize(self.moe)
|
||||
|
||||
@ -211,7 +211,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
assert self.fused_experts is None, \
|
||||
f"Attempt to override experts for {id(self)}!"
|
||||
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
|
||||
experts = self.select_gemm_impl(prepare_finalize, self.moe)
|
||||
experts = self.select_gemm_impl(prepare_finalize, self.moe, layer)
|
||||
self.fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
experts,
|
||||
@ -221,6 +221,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
# based on the all2all implementation, select the appropriate
|
||||
# gemm implementation
|
||||
@ -273,6 +274,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
# TODO(bnell): Remove. Every layer should have an moe config object.
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
if (prepare_finalize.activation_format ==
|
||||
FusedMoEActivationFormat.BatchedExperts):
|
||||
|
197
vllm/model_executor/layers/fused_moe/trtllm_moe.py
Normal file
197
vllm/model_executor/layers/fused_moe/trtllm_moe.py
Normal file
@ -0,0 +1,197 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP)
|
||||
from vllm.utils import next_power_of_2
|
||||
|
||||
|
||||
class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
moe: FusedMoEConfig,
|
||||
gemm1_alpha,
|
||||
gemm1_beta,
|
||||
gemm1_clamp_limit,
|
||||
w13_bias,
|
||||
w2_bias,
|
||||
max_capture_size,
|
||||
):
|
||||
super().__init__(moe.quant_config)
|
||||
self.moe = moe
|
||||
self.gemm1_alpha = gemm1_alpha
|
||||
self.gemm1_beta = gemm1_beta
|
||||
self.gemm1_clamp_limit = gemm1_clamp_limit
|
||||
self.w13_bias = w13_bias
|
||||
self.w2_bias = w2_bias
|
||||
self.max_capture_size = max_capture_size
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
return (mk.FusedMoEActivationFormat.Standard,
|
||||
mk.FusedMoEActivationFormat.Standard)
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
return TopKWeightAndReduceNoOP()
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
aq: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
# The workspaces for this implementation are managed by flashinfer.
|
||||
# TODO(varun) : workspace1 is could be used as the output tensor. This
|
||||
# is error-prone. Allow the `workspace_shapes` to return None workspaces
|
||||
workspace1 = (M, K)
|
||||
workspace2 = (0, 0)
|
||||
output = (M, K)
|
||||
return (workspace1, workspace2, output, a.dtype)
|
||||
|
||||
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int,
|
||||
local_num_experts: int):
|
||||
# Number of tokens in the input tensor.
|
||||
num_tokens = x.shape[0]
|
||||
# Factor to account for the imbalance of the experts.
|
||||
# factor equals to the
|
||||
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
|
||||
# 1.0 means perfect expert distribution.
|
||||
# > 1.0 means some experts have more tokens than the perfect
|
||||
# distribution.
|
||||
# < 1.0 does not make sense.
|
||||
imbalance_factor = 1.3
|
||||
# Calculate the number of tokens per expert assuming perfect
|
||||
# distribution.
|
||||
num_tokens_per_expert = (num_tokens * top_k) // local_num_experts
|
||||
# Apply the imbalance factor.
|
||||
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
|
||||
# And pad the number to the next power of 2.
|
||||
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
||||
# Cap to 8-64 tokens per CTA tile as it's the range supported by the
|
||||
# kernel.
|
||||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||
|
||||
return tile_tokens_dim
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
topk = topk_ids.size(-1)
|
||||
local_num_experts = w1.size(0)
|
||||
intermediate_size = w2.size(1)
|
||||
local_expert_offset = self.moe.ep_rank * local_num_experts
|
||||
|
||||
x_quant = hidden_states
|
||||
x_scale = a1q_scale
|
||||
if x_scale is not None:
|
||||
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(
|
||||
*x_quant.shape[:-1], -1)
|
||||
|
||||
packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
|
||||
torch.bfloat16).view(torch.int16)
|
||||
|
||||
assert w1_scale is not None
|
||||
assert w2_scale is not None
|
||||
kwargs = {
|
||||
"topk_ids":
|
||||
packed_tensor,
|
||||
"routing_bias":
|
||||
None,
|
||||
"hidden_states":
|
||||
x_quant,
|
||||
"hidden_states_scale":
|
||||
x_scale,
|
||||
"gemm1_weights":
|
||||
w1,
|
||||
"gemm1_weights_scale":
|
||||
w1_scale,
|
||||
"gemm1_bias":
|
||||
self.w13_bias,
|
||||
"gemm1_alpha":
|
||||
self.gemm1_alpha,
|
||||
"gemm1_beta":
|
||||
self.gemm1_beta,
|
||||
"gemm1_clamp_limit":
|
||||
self.gemm1_clamp_limit,
|
||||
"gemm2_weights":
|
||||
w2,
|
||||
"gemm2_weights_scale":
|
||||
w2_scale,
|
||||
"gemm2_bias":
|
||||
self.w2_bias,
|
||||
"output1_scale_scalar":
|
||||
None,
|
||||
"output1_scale_gate_scalar":
|
||||
None,
|
||||
"output2_scale_scalar":
|
||||
None,
|
||||
"num_experts":
|
||||
global_num_experts,
|
||||
"top_k":
|
||||
topk,
|
||||
"n_group":
|
||||
None,
|
||||
"topk_group":
|
||||
None,
|
||||
"intermediate_size":
|
||||
intermediate_size,
|
||||
"local_expert_offset":
|
||||
local_expert_offset,
|
||||
"local_num_experts":
|
||||
local_num_experts,
|
||||
"routed_scaling_factor":
|
||||
None,
|
||||
"tile_tokens_dim":
|
||||
self._get_tile_tokens_dim(x_quant, topk, local_num_experts),
|
||||
"routing_method_type":
|
||||
1,
|
||||
"do_finalize":
|
||||
True,
|
||||
"output":
|
||||
output,
|
||||
"tune_max_num_tokens":
|
||||
self.max_capture_size,
|
||||
}
|
||||
|
||||
from flashinfer import trtllm_fp4_block_scale_routed_moe
|
||||
trtllm_fp4_block_scale_routed_moe(**kwargs)
|
||||
return output
|
@ -12,6 +12,8 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
per_token_group_quant_int8, per_token_quant_int8)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
quant_dequant_mxfp4)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
|
||||
mxfp8_quantize)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import cdiv
|
||||
@ -177,6 +179,18 @@ def _mxfp4_quantize(
|
||||
return A, None
|
||||
|
||||
|
||||
def _mxfp8_quantize(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert A_scale is None
|
||||
assert not per_act_token_quant
|
||||
assert block_shape is None
|
||||
return mxfp8_quantize(A)
|
||||
|
||||
|
||||
def moe_kernel_quantize_input(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
@ -195,6 +209,8 @@ def moe_kernel_quantize_input(
|
||||
is_sf_swizzled_layout=is_fp4_scale_swizzled)
|
||||
elif quant_dtype == "mxfp4":
|
||||
return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape)
|
||||
elif quant_dtype == "mxfp8":
|
||||
return _mxfp8_quantize(A, A_scale, per_act_token_quant, block_shape)
|
||||
else:
|
||||
return A, A_scale
|
||||
|
||||
|
@ -322,6 +322,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
"""Return the appropriate GEMM experts implementation."""
|
||||
experts = select_nvfp4_gemm_impl(
|
||||
@ -719,10 +720,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
dtype=torch.int64)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
self, prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module) -> FusedMoEPermuteExpertsUnpermute:
|
||||
# cutlass path
|
||||
if self.use_cutlass:
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
|
@ -897,6 +897,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
BatchedTritonOrDeepGemmExperts, TritonOrDeepGemmExperts)
|
||||
|
@ -311,6 +311,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
experts = select_cutlass_fp8_gemm_impl(
|
||||
moe,
|
||||
@ -1032,6 +1033,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
experts = select_nvfp4_gemm_impl(
|
||||
moe,
|
||||
|
@ -10,6 +10,8 @@ from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.trtllm_moe import TrtLlmGenExperts
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@ -445,6 +447,91 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
return tile_tokens_dim
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
if (prepare_finalize.activation_format ==
|
||||
mk.FusedMoEActivationFormat.BatchedExperts):
|
||||
raise NotImplementedError(
|
||||
"Mxfp4 does not support batched experts format for EP")
|
||||
else:
|
||||
if should_use_flashinfer_mxfp4():
|
||||
# B200 code-path
|
||||
kwargs = {
|
||||
"gemm1_alpha": layer.gemm1_alpha,
|
||||
"gemm1_beta": layer.gemm1_beta,
|
||||
"gemm1_clamp_limit": layer.gemm1_clamp_limit,
|
||||
"w13_bias": layer.w13_bias,
|
||||
"w2_bias": layer.w2_bias,
|
||||
"max_capture_size": self.max_capture_size,
|
||||
}
|
||||
return TrtLlmGenExperts(moe, **kwargs)
|
||||
else:
|
||||
# Use matmul_ogs from triton_kernels here!
|
||||
raise NotImplementedError(
|
||||
"Mxfp4 does not support non-batched experts format for EP")
|
||||
|
||||
def _route_and_experts(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert isinstance(self.fused_experts, mk.FusedMoEModularKernel)
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype,
|
||||
enable_eplb=enable_eplb,
|
||||
expert_map=expert_map,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count)
|
||||
|
||||
return self.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@ -503,6 +590,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
activation=activation,
|
||||
expert_map=expert_map)
|
||||
|
||||
if self.fused_experts is not None:
|
||||
return self._route_and_experts(
|
||||
layer,
|
||||
x,
|
||||
router_logits,
|
||||
top_k,
|
||||
renormalize,
|
||||
use_grouped_topk,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
custom_routing_function,
|
||||
scoring_func,
|
||||
e_score_correction_bias,
|
||||
apply_router_weight_on_input,
|
||||
activation,
|
||||
enable_eplb,
|
||||
expert_load_view,
|
||||
logical_to_physical_map,
|
||||
logical_replica_count,
|
||||
)
|
||||
|
||||
assert _can_support_mxfp4(
|
||||
use_grouped_topk, topk_group, num_expert_group, expert_map,
|
||||
custom_routing_function, e_score_correction_bias,
|
||||
|
@ -66,11 +66,10 @@ def _can_support_mxfp4(use_grouped_topk: bool = False,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None):
|
||||
return not (use_grouped_topk or topk_group or num_expert_group
|
||||
or expert_map or custom_routing_function
|
||||
or e_score_correction_bias or apply_router_weight_on_input
|
||||
or scoring_func != "softmax" or activation != "swigluoai"
|
||||
or expert_load_view or logical_to_physical_map
|
||||
or logical_replica_count)
|
||||
or custom_routing_function or e_score_correction_bias
|
||||
or apply_router_weight_on_input or scoring_func != "softmax"
|
||||
or activation != "swigluoai" or expert_load_view
|
||||
or logical_to_physical_map or logical_replica_count)
|
||||
|
||||
|
||||
def _dequant_mxfp4(x: torch.Tensor, scale: torch.Tensor,
|
||||
|
20
vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
Normal file
20
vllm/model_executor/layers/quantization/utils/mxfp8_utils.py
Normal file
@ -0,0 +1,20 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def mxfp8_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
try:
|
||||
from flashinfer import mxfp8_quantize
|
||||
except ImportError as err:
|
||||
raise ImportError("The package `flashinfer` is required to do "
|
||||
"MX-FP8 quantization. Please install it with" \
|
||||
"`pip install flashinfer`") from err
|
||||
|
||||
return mxfp8_quantize(x, is_sf_swizzled_layout=False)
|
Reference in New Issue
Block a user