[Bugfix] Fix accuracy issue of TRTLLM FP8 MOE and improve logging (#25895)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
Pavani Majety
2025-09-30 07:51:31 -07:00
committed by GitHub
parent f4db5e6de1
commit ef283548f7
2 changed files with 29 additions and 17 deletions

View File

@ -434,14 +434,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant = self.weight_block_size is not None
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
self.fused_experts: Optional[
mk.FusedMoEModularKernel] = None # type: ignore
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
logger.info_once(
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
)
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = (not current_platform.has_device_capability(89)
@ -450,14 +445,27 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if current_platform.is_rocm():
self.use_marlin = False
# First check for Flashinfer MOE on Blackwell GPUs
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
if (current_platform.is_cuda()
and current_platform.is_device_capability(100)
and envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe()):
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
logger.info_once(
f"Detected Blackwell GPUs, using FlashInfer "
f"{self.flashinfer_moe_backend.value} kernels for FP8 MOE.")
# Check for DeepGemm support.
self.allow_deep_gemm = False
if envs.VLLM_USE_DEEP_GEMM:
if not has_deep_gemm():
logger.warning_once("Failed to import DeepGemm kernels.")
elif not self.block_quant:
logger.warning_once("Model is not block quantized. Not using "
"DeepGemm kernels")
logger.warning_once("Model is not block quantized. Not using"
" DeepGemm kernels")
elif self.flashinfer_moe_backend:
logger.info_once("DeepGemm disabled: FlashInfer MOE is"
" enabled.")
elif (is_deep_gemm_supported()):
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
self.allow_deep_gemm = True
@ -471,15 +479,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
logger.debug_once("Model is not block quantized. Not using "
"CutlassBlockScaledGroupedGemm kernels")
elif (current_platform.is_cuda()
and current_platform.is_device_capability(100)):
and current_platform.is_device_capability(100)
and not self.flashinfer_moe_backend):
logger.info_once(
"Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod."
)
"Using CutlassBlockScaledGroupedGemm kernels for Fp8 MOE "
"on SM100.")
self.allow_cutlass_block_scaled_grouped_gemm = True
else:
logger.warning_once(
"CutlassBlockScaledGroupedGemm not supported on the current "
"platform.")
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size_per_partition: int,
@ -934,7 +939,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
assert (renormalize and use_grouped_topk
and custom_routing_function is None)
result = torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
e_score_correction_bias = (e_score_correction_bias.to(
x.dtype) if e_score_correction_bias is not None else None)
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
routing_logits=router_logits.to(torch.float32),
routing_bias=e_score_correction_bias,
x=x,

View File

@ -27,7 +27,8 @@ def is_deep_gemm_supported() -> bool:
is_supported_arch = current_platform.is_cuda() and (
current_platform.is_device_capability(90)
or current_platform.is_device_capability(100))
return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch
return (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch
and not envs.VLLM_USE_FLASHINFER_MOE_FP8)
@functools.cache
@ -46,6 +47,10 @@ def is_deep_gemm_e8m0_used() -> bool:
logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found")
return False
if envs.VLLM_USE_FLASHINFER_MOE_FP8:
logger.info_once("DeepGEMM E8M0 disabled: FlashInfer MOE is enabled.")
return False
if current_platform.is_device_capability(100) and \
envs.VLLM_USE_DEEP_GEMM_E8M0:
logger.info_once("DeepGEMM E8M0 enabled on Blackwell GPU.")