Update flashinfer CUTLASS MoE Kernel (#21408)

Signed-off-by: Shu Wang. <shuw@nvidia.com>
This commit is contained in:
Shu Wang
2025-07-24 10:13:31 -05:00
committed by GitHub
parent e8cb0d0495
commit 1b25f1fe75
3 changed files with 8 additions and 8 deletions

View File

@ -11,7 +11,7 @@ from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import (
extract_required_args, moe_kernel_quantize_input)
from vllm.utils.flashinfer import block_scale_interleave
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
def get_local_sizes(local_tokens):
@ -92,7 +92,7 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
dim=0,
sizes=get_local_sizes(local_tokens))
a1_m, a1_n = a1q.shape
a1q_scale = block_scale_interleave(a1q_scale)
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
return a1q, a1q_scale, None, topk_ids, topk_weights

View File

@ -1254,8 +1254,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
x, layer.w13_weight, layer.w2_weight), (
"Flashinfer CUTLASS Fused MoE not applicable!")
a1_gscale = torch.min(layer.w13_input_scale_quant)
a2_gscale = torch.min(layer.w2_input_scale_quant)
a1_gscale = layer.w13_input_scale_quant
a2_gscale = layer.w2_input_scale_quant
extra_expert_args = {
'g1_alphas': layer.g1_alphas,
'g2_alphas': layer.g2_alphas,

View File

@ -69,8 +69,8 @@ flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper(
flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe",
"cutlass_fused_moe")
fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
block_scale_interleave = _lazy_import_wrapper("flashinfer",
"block_scale_interleave")
nvfp4_block_scale_interleave = _lazy_import_wrapper(
"flashinfer", "nvfp4_block_scale_interleave")
# Special case for autotune since it returns a context manager
autotune = _lazy_import_wrapper(
@ -95,7 +95,7 @@ def has_flashinfer_cutlass_fused_moe() -> bool:
required_functions = [
("flashinfer.fused_moe", "cutlass_fused_moe"),
("flashinfer", "fp4_quantize"),
("flashinfer", "block_scale_interleave"),
("flashinfer", "nvfp4_block_scale_interleave"),
]
for module_name, attr_name in required_functions:
@ -110,7 +110,7 @@ __all__ = [
"flashinfer_trtllm_fp8_block_scale_moe",
"flashinfer_cutlass_fused_moe",
"fp4_quantize",
"block_scale_interleave",
"nvfp4_block_scale_interleave",
"autotune",
"has_flashinfer_moe",
"has_flashinfer_cutlass_fused_moe",