mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Compare commits
3 Commits
copilot/fi
...
v0.10.1
Author | SHA1 | Date | |
---|---|---|---|
aab549870d | |||
ba6928cf13 | |||
befedf86a8 |
@ -20,7 +20,15 @@ from openai.types.chat.chat_completion_message import (
|
||||
from openai.types.responses import (ResponseFunctionToolCall,
|
||||
ResponseInputItemParam, ResponseOutputItem,
|
||||
ResponsePrompt, ResponseReasoningItem,
|
||||
ResponseStatus, ResponseTextConfig)
|
||||
ResponseStatus)
|
||||
|
||||
# Backward compatibility for OpenAI client versions
|
||||
try: # For older openai versions (< 1.100.0)
|
||||
from openai.types.responses import ResponseTextConfig
|
||||
except ImportError: # For newer openai versions (>= 1.100.0)
|
||||
from openai.types.responses import (ResponseFormatTextConfig as
|
||||
ResponseTextConfig)
|
||||
|
||||
from openai.types.responses.response import ToolChoice
|
||||
from openai.types.responses.tool import Tool
|
||||
from openai.types.shared import Metadata, Reasoning
|
||||
|
@ -762,11 +762,11 @@ class FusedMoE(CustomOp):
|
||||
self.global_num_experts = num_experts + num_redundant_experts
|
||||
|
||||
# we padding globally so EP buffer allocation works
|
||||
if (quant_config and quant_config.get_name() == "mxfp4"
|
||||
and (current_platform.is_rocm()
|
||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16)):
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
if quant_config and quant_config.get_name() == "mxfp4":
|
||||
from vllm.model_executor.layers.quantization.mxfp4 import ( # noqa: E501
|
||||
should_use_flashinfer_mxfp4)
|
||||
if current_platform.is_rocm() or should_use_flashinfer_mxfp4():
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
|
||||
# For smuggling this layer into the fused moe custom op
|
||||
compilation_config = vllm_config.compilation_config
|
||||
|
@ -6,6 +6,7 @@ import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
@ -26,12 +27,38 @@ from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils import (has_triton_kernels, is_torch_equal_or_newer,
|
||||
next_power_of_2, round_up)
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
|
||||
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
|
||||
# from flashinfer.fused_moe import cutlass_fused_moe
|
||||
from flashinfer import (mxfp8_quantize, shuffle_matrix_a,
|
||||
shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe)
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _should_use_flashinfer_mxfp4_bf16():
|
||||
"""Determine if FlashInfer MXFP4 BF16 should be used."""
|
||||
# If explicitly set, respect the setting
|
||||
if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"):
|
||||
return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
|
||||
|
||||
# Enable by default on SM100 if MXFP8 is not explicitly enabled
|
||||
if (current_platform.is_device_capability(100) and has_flashinfer()
|
||||
and not envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")):
|
||||
logger.info_once(
|
||||
"Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell. "
|
||||
"For faster performance, consider setting "
|
||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, "
|
||||
"though this may impact accuracy.")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _should_use_flashinfer_mxfp4_mxfp8():
|
||||
"""Determine if FlashInfer MXFP4 MXFP8 should be used."""
|
||||
return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||
|
||||
|
||||
def should_use_flashinfer_mxfp4():
|
||||
return (_should_use_flashinfer_mxfp4_mxfp8()
|
||||
or _should_use_flashinfer_mxfp4_bf16())
|
||||
|
||||
|
||||
class Mxfp4Config(QuantizationConfig):
|
||||
@ -87,12 +114,18 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
self.moe = moe
|
||||
self.use_marlin = self._should_use_marlin()
|
||||
|
||||
if current_platform.is_device_capability(100) and not has_flashinfer():
|
||||
logger.warning_once(
|
||||
"MXFP4 MoE is enabled on Blackwell but FlashInfer "
|
||||
"is not available. This may result in degraded performance. "
|
||||
"Please `pip install vllm[flashinfer]` for best results.")
|
||||
|
||||
def _should_use_marlin(self):
|
||||
if envs.VLLM_MXFP4_USE_MARLIN is not None:
|
||||
return envs.VLLM_MXFP4_USE_MARLIN
|
||||
if current_platform.is_cuda() and \
|
||||
not current_platform.has_device_capability(100):
|
||||
if not current_platform.is_device_capability(90):
|
||||
not current_platform.is_device_capability(100):
|
||||
if not current_platform.has_device_capability(90):
|
||||
# marlin kernel has better performance on ampere
|
||||
return True
|
||||
if not has_triton_kernels():
|
||||
@ -138,8 +171,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
layer.hidden_size = hidden_size
|
||||
layer.intermediate_size_per_partition = \
|
||||
intermediate_size_per_partition_after_pad
|
||||
elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
|
||||
elif should_use_flashinfer_mxfp4():
|
||||
# pad the intermediate size to be a multiple of 2 * mxfp4_block
|
||||
# for to hold non-uniform sharded tensor as well as swizzling
|
||||
# other padding to increase performance
|
||||
@ -230,8 +262,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.use_marlin:
|
||||
prepare_moe_fp4_layer_for_marlin(layer)
|
||||
elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
|
||||
elif should_use_flashinfer_mxfp4():
|
||||
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
|
||||
layer.gemm1_alpha = Parameter(torch.tensor(
|
||||
[1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False)
|
||||
@ -478,11 +510,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
logical_replica_count), (
|
||||
"MXFP4 are not supported with this configuration.")
|
||||
|
||||
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
|
||||
if should_use_flashinfer_mxfp4():
|
||||
from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe
|
||||
assert not self.moe.use_ep, (
|
||||
"EP is not supported for flashinfer mxfp4 moe backend yet.")
|
||||
if envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16:
|
||||
if _should_use_flashinfer_mxfp4_bf16():
|
||||
assert x.dtype == torch.bfloat16
|
||||
x_quant = x
|
||||
x_scale = None
|
||||
|
Reference in New Issue
Block a user