mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] Fix accuracy issue of TRTLLM FP8 MOE and improve logging (#25895)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
@ -434,14 +434,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self.weight_block_size = self.quant_config.weight_block_size
|
||||
self.block_quant = self.weight_block_size is not None
|
||||
|
||||
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
|
||||
self.fused_experts: Optional[
|
||||
mk.FusedMoEModularKernel] = None # type: ignore
|
||||
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
|
||||
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
||||
logger.info_once(
|
||||
f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels"
|
||||
)
|
||||
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = (not current_platform.has_device_capability(89)
|
||||
@ -450,14 +445,27 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
|
||||
# First check for Flashinfer MOE on Blackwell GPUs
|
||||
self.flashinfer_moe_backend: Optional[FlashinferMoeBackend] = None
|
||||
if (current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(100)
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe()):
|
||||
self.flashinfer_moe_backend = get_flashinfer_moe_backend()
|
||||
logger.info_once(
|
||||
f"Detected Blackwell GPUs, using FlashInfer "
|
||||
f"{self.flashinfer_moe_backend.value} kernels for FP8 MOE.")
|
||||
|
||||
# Check for DeepGemm support.
|
||||
self.allow_deep_gemm = False
|
||||
if envs.VLLM_USE_DEEP_GEMM:
|
||||
if not has_deep_gemm():
|
||||
logger.warning_once("Failed to import DeepGemm kernels.")
|
||||
elif not self.block_quant:
|
||||
logger.warning_once("Model is not block quantized. Not using "
|
||||
"DeepGemm kernels")
|
||||
logger.warning_once("Model is not block quantized. Not using"
|
||||
" DeepGemm kernels")
|
||||
elif self.flashinfer_moe_backend:
|
||||
logger.info_once("DeepGemm disabled: FlashInfer MOE is"
|
||||
" enabled.")
|
||||
elif (is_deep_gemm_supported()):
|
||||
logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.")
|
||||
self.allow_deep_gemm = True
|
||||
@ -471,15 +479,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
logger.debug_once("Model is not block quantized. Not using "
|
||||
"CutlassBlockScaledGroupedGemm kernels")
|
||||
elif (current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(100)):
|
||||
and current_platform.is_device_capability(100)
|
||||
and not self.flashinfer_moe_backend):
|
||||
logger.info_once(
|
||||
"Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod."
|
||||
)
|
||||
"Using CutlassBlockScaledGroupedGemm kernels for Fp8 MOE "
|
||||
"on SM100.")
|
||||
self.allow_cutlass_block_scaled_grouped_gemm = True
|
||||
else:
|
||||
logger.warning_once(
|
||||
"CutlassBlockScaledGroupedGemm not supported on the current "
|
||||
"platform.")
|
||||
|
||||
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
@ -934,7 +939,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
|
||||
assert (renormalize and use_grouped_topk
|
||||
and custom_routing_function is None)
|
||||
result = torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
|
||||
e_score_correction_bias = (e_score_correction_bias.to(
|
||||
x.dtype) if e_score_correction_bias is not None else None)
|
||||
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
|
||||
routing_logits=router_logits.to(torch.float32),
|
||||
routing_bias=e_score_correction_bias,
|
||||
x=x,
|
||||
|
@ -27,7 +27,8 @@ def is_deep_gemm_supported() -> bool:
|
||||
is_supported_arch = current_platform.is_cuda() and (
|
||||
current_platform.is_device_capability(90)
|
||||
or current_platform.is_device_capability(100))
|
||||
return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch
|
||||
return (envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch
|
||||
and not envs.VLLM_USE_FLASHINFER_MOE_FP8)
|
||||
|
||||
|
||||
@functools.cache
|
||||
@ -46,6 +47,10 @@ def is_deep_gemm_e8m0_used() -> bool:
|
||||
logger.info_once("DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found")
|
||||
return False
|
||||
|
||||
if envs.VLLM_USE_FLASHINFER_MOE_FP8:
|
||||
logger.info_once("DeepGEMM E8M0 disabled: FlashInfer MOE is enabled.")
|
||||
return False
|
||||
|
||||
if current_platform.is_device_capability(100) and \
|
||||
envs.VLLM_USE_DEEP_GEMM_E8M0:
|
||||
logger.info_once("DeepGEMM E8M0 enabled on Blackwell GPU.")
|
||||
|
Reference in New Issue
Block a user