mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[gpt-oss] flashinfer mxfp4 (#22339)
Signed-off-by: simon-mo <xmo@berkeley.edu> Signed-off-by: Yongye Zhu <zyy1102000@gmail.com> Co-authored-by: simon-mo <xmo@berkeley.edu>
This commit is contained in:
12
vllm/envs.py
12
vllm/envs.py
@ -154,6 +154,8 @@ if TYPE_CHECKING:
|
||||
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
|
||||
VLLM_USE_TRTLLM_CONTEXT_ATTENTION: bool = False
|
||||
VLLM_USE_TRTLLM_DECODE_ATTENTION: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -932,6 +934,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_USE_FLASHINFER_MOE_FP4":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP4", "0"))),
|
||||
|
||||
# If set to 1, use the FlashInfer
|
||||
# MXFP8 (activation) x MXFP4 (weight) MoE backend.
|
||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8", "0"))),
|
||||
|
||||
# If set to 1, use the FlashInfer
|
||||
# BF16 (activation) x MXFP4 (weight) MoE backend.
|
||||
"VLLM_USE_FLASHINFER_MOE_MXFP4_BF16":
|
||||
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "0"))),
|
||||
|
||||
# Control the cache sized used by the xgrammar compiler. The default
|
||||
# of 512 MB should be enough for roughly 1000 JSON schemas.
|
||||
# It can be changed with this variable if needed for some reason.
|
||||
|
@ -33,7 +33,8 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx
|
||||
from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx,
|
||||
round_up)
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
@ -719,6 +720,12 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
self.global_num_experts = num_experts + num_redundant_experts
|
||||
|
||||
# we padding globally so EP buffer allocation works
|
||||
if quant_config and quant_config.get_name() == "mxfp4" and (
|
||||
envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
|
||||
# For smuggling this layer into the fused moe custom op
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
@ -1064,6 +1071,18 @@ class FusedMoE(torch.nn.Module):
|
||||
shard_id: str,
|
||||
expert_id: int,
|
||||
return_success: bool = False) -> Optional[bool]:
|
||||
|
||||
if self.quant_config and self.quant_config.get_name() == "mxfp4":
|
||||
# (FIXME) for gpt-oss all experts are combined
|
||||
if "bias" in weight_name:
|
||||
dim1 = loaded_weight.shape[1]
|
||||
param.data[:, :dim1].copy_(loaded_weight)
|
||||
else:
|
||||
dim1 = loaded_weight.shape[1]
|
||||
dim2 = loaded_weight.shape[2]
|
||||
param.data[:, :dim1, :dim2].copy_(loaded_weight)
|
||||
return True if return_success else None
|
||||
|
||||
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
|
||||
if expert_id == -1:
|
||||
# Failed to load this param since it's not local to this rank
|
||||
@ -1476,13 +1495,20 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
og_hidden_states = hidden_states.shape[-1]
|
||||
if self.hidden_size != og_hidden_states:
|
||||
hidden_states = F.pad(hidden_states,
|
||||
(0, self.hidden_size - og_hidden_states),
|
||||
mode='constant',
|
||||
value=0.0)
|
||||
# TODO: Once the OOM issue for the TPU backend is resolved, we will
|
||||
# switch to using the moe_forward custom op.
|
||||
if current_platform.is_tpu():
|
||||
return self.forward_impl(hidden_states, router_logits)
|
||||
else:
|
||||
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
|
||||
self.layer_name)
|
||||
return torch.ops.vllm.moe_forward(
|
||||
hidden_states, router_logits,
|
||||
self.layer_name)[..., :og_hidden_states]
|
||||
|
||||
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
|
||||
full_router_logits: torch.Tensor):
|
||||
|
@ -37,6 +37,7 @@ QuantizationMethods = Literal[
|
||||
"auto-round",
|
||||
"rtn",
|
||||
"inc",
|
||||
"mxfp4",
|
||||
]
|
||||
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
|
||||
|
||||
@ -110,6 +111,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
from .marlin import MarlinConfig
|
||||
from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config
|
||||
from .moe_wna16 import MoeWNA16Config
|
||||
from .mxfp4 import Mxfp4Config
|
||||
from .neuron_quant import NeuronQuantConfig
|
||||
from .ptpc_fp8 import PTPCFp8Config
|
||||
from .qqq import QQQConfig
|
||||
@ -148,6 +150,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
"auto-round": AutoRoundConfig,
|
||||
"rtn": RTNConfig,
|
||||
"inc": INCConfig,
|
||||
"mxfp4": Mxfp4Config,
|
||||
}
|
||||
# Update the `method_to_config` with customized quantization methods.
|
||||
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
|
||||
|
387
vllm/model_executor/layers/quantization/mxfp4.py
Normal file
387
vllm/model_executor/layers/quantization/mxfp4.py
Normal file
@ -0,0 +1,387 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import envs
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
_can_support_mxfp4)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils import next_power_of_2, round_up
|
||||
|
||||
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
|
||||
# from flashinfer.fused_moe import cutlass_fused_moe
|
||||
from flashinfer import (mxfp8_quantize, shuffle_matrix_a,
|
||||
shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe)
|
||||
|
||||
|
||||
class Mxfp4Config(QuantizationConfig):
|
||||
|
||||
def __init__(self, ignored_layers: Optional[list[str]] = None):
|
||||
super().__init__()
|
||||
self.ignored_layers = ignored_layers
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls()
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 100
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "mxfp4"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if self.ignored_layers and is_layer_skipped(
|
||||
prefix=prefix,
|
||||
ignored_layers=self.ignored_layers,
|
||||
fused_mapping=self.packed_modules_mapping):
|
||||
return UnquantizedLinearMethod()
|
||||
raise NotImplementedError("Mxfp4 linear layer is not implemented")
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return Mxfp4MoEMethod(layer.moe_config)
|
||||
elif isinstance(layer, Attention):
|
||||
raise NotImplementedError(
|
||||
"Mxfp4 attention layer is not implemented")
|
||||
return None
|
||||
|
||||
|
||||
class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__()
|
||||
self.topk_indices_dtype = None
|
||||
self.moe = moe
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
self.num_experts = num_experts
|
||||
weight_dtype = torch.uint8
|
||||
scale_dtype = torch.uint8
|
||||
|
||||
# FIXME (zyongye): ship after torch and safetensors support mxfp4
|
||||
# is_torch_mxfp4_available = (
|
||||
# hasattr(torch, "float4_e2m1fn_x2") and
|
||||
# hasattr(torch, "float8_e8m0fnu"))
|
||||
# if is_torch_mxfp4_available:
|
||||
# weight_dtype = torch.float4_e2m1fn_x2
|
||||
# scale_dtype = torch.float8_e8m0fnu
|
||||
|
||||
mxfp4_block = 32
|
||||
|
||||
intermediate_size_per_partition_after_pad = \
|
||||
intermediate_size_per_partition
|
||||
# pad the intermediate size to be a multiple of 2 * mxfp4_block
|
||||
# for to hold non-uniform sharded tensor as well as swizzling
|
||||
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 256)
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
|
||||
self.intermediate_size = intermediate_size_per_partition_after_pad
|
||||
self.hidden_size = hidden_size
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
hidden_size // 2,
|
||||
dtype=weight_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w13_weight_scale = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
hidden_size // mxfp4_block,
|
||||
dtype=scale_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
|
||||
w13_bias = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition_after_pad,
|
||||
dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_bias", w13_bias)
|
||||
set_weight_attrs(w13_bias, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition_after_pad // 2,
|
||||
dtype=weight_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition_after_pad // mxfp4_block,
|
||||
dtype=scale_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
w2_bias = torch.nn.Parameter(torch.zeros(num_experts,
|
||||
hidden_size,
|
||||
dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_bias", w2_bias)
|
||||
set_weight_attrs(w2_bias, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
|
||||
layer.gemm1_alpha = Parameter(torch.tensor(
|
||||
[1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False)
|
||||
layer.gemm1_beta = Parameter(torch.tensor(
|
||||
[1.0] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False)
|
||||
layer.gemm1_clamp_limit = Parameter(torch.tensor(
|
||||
[7.0] * self.num_experts, dtype=torch.float32).cuda(),
|
||||
requires_grad=False)
|
||||
sf_block_size = 32 # mxfp4 block size
|
||||
|
||||
assert (layer.w13_weight.dim() == 3
|
||||
and layer.w13_weight.shape[0] == self.num_experts
|
||||
and layer.w13_weight.shape[1] == self.intermediate_size * 2
|
||||
and layer.w13_weight.shape[2] == self.hidden_size // 2)
|
||||
assert (layer.w13_weight_scale.dim() == 3
|
||||
and layer.w13_weight_scale.shape[0] == self.num_experts
|
||||
and layer.w13_weight_scale.shape[1]
|
||||
== self.intermediate_size * 2
|
||||
and layer.w13_weight_scale.shape[2]
|
||||
== self.hidden_size // sf_block_size)
|
||||
assert (layer.w2_weight.dim() == 3
|
||||
and layer.w2_weight.shape[0] == self.num_experts
|
||||
and layer.w2_weight.shape[1] == self.hidden_size and
|
||||
layer.w2_weight.shape[2] == self.intermediate_size // 2)
|
||||
assert (layer.w2_weight_scale.dim() == 3
|
||||
and layer.w2_weight_scale.shape[1] == self.hidden_size
|
||||
and layer.w2_weight_scale.shape[2]
|
||||
== self.intermediate_size // sf_block_size)
|
||||
assert (layer.w13_bias.dim() == 2
|
||||
and layer.w13_bias.shape[0] == self.num_experts
|
||||
and layer.w13_bias.shape[1] == self.intermediate_size * 2)
|
||||
assert (layer.w2_bias.dim() == 2
|
||||
and layer.w2_bias.shape[0] == self.num_experts
|
||||
and layer.w2_bias.shape[1] == self.hidden_size)
|
||||
|
||||
w13_weight_scale = layer.w13_weight_scale.data
|
||||
w2_weight_scale = layer.w2_weight_scale.data
|
||||
w13_weight = layer.w13_weight.data
|
||||
w2_weight = layer.w2_weight.data
|
||||
w13_bias = layer.w13_bias.data.to(torch.float32)
|
||||
w2_bias = layer.w2_bias.data.to(torch.float32)
|
||||
|
||||
# Swap w1 and w3 as the defenition of
|
||||
# swiglu is different in the trtllm-gen
|
||||
def swap_every_two_rows(x, axis=-1):
|
||||
shape = x.shape
|
||||
if axis < 0:
|
||||
axis = len(shape) + axis
|
||||
|
||||
# Create a new shape with pairs swapped along specified axis
|
||||
new_shape = list(shape)
|
||||
new_shape[axis] = shape[axis] // 2
|
||||
new_shape.insert(axis + 1, 2)
|
||||
|
||||
# Reshape to expose pairs, swap them, and reshape back
|
||||
x = x.reshape(*new_shape)
|
||||
x = x.flip(axis + 1)
|
||||
new_shape = list(shape)
|
||||
return x.reshape(*new_shape)
|
||||
|
||||
w13_weight_scale = swap_every_two_rows(w13_weight_scale, -2)
|
||||
w13_weight = swap_every_two_rows(w13_weight, -2)
|
||||
w13_bias = swap_every_two_rows(w13_bias, -1)
|
||||
|
||||
# Do not interleave as the checkpoint is already interleaved
|
||||
|
||||
# Shuffle weights and scaling factors for transposed mma output
|
||||
gemm1_weights_mxfp4_shuffled = []
|
||||
gemm1_scales_mxfp4_shuffled = []
|
||||
gemm2_weights_mxfp4_shuffled = []
|
||||
gemm2_scales_mxfp4_shuffled = []
|
||||
gemm1_bias_shuffled = []
|
||||
gemm2_bias_shuffled = []
|
||||
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
||||
for i in range(self.num_experts):
|
||||
gemm1_weights_mxfp4_shuffled.append(
|
||||
shuffle_matrix_a(w13_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
gemm1_scales_mxfp4_shuffled.append(
|
||||
shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
gemm1_bias_shuffled.append(
|
||||
shuffle_matrix_a(w13_bias[i].clone().reshape(-1, 1),
|
||||
epilogue_tile_m))
|
||||
|
||||
gemm2_weights_mxfp4_shuffled.append(
|
||||
shuffle_matrix_a(w2_weight[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
gemm2_scales_mxfp4_shuffled.append(
|
||||
shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8),
|
||||
epilogue_tile_m))
|
||||
gemm2_bias_shuffled.append(
|
||||
shuffle_matrix_a(w2_bias[i].clone().reshape(-1, 1),
|
||||
epilogue_tile_m))
|
||||
|
||||
w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled)
|
||||
w13_weight_scale = torch.stack(
|
||||
gemm1_scales_mxfp4_shuffled).reshape(
|
||||
self.num_experts, 2 * self.intermediate_size,
|
||||
self.hidden_size // sf_block_size).view(
|
||||
torch.float8_e4m3fn)
|
||||
|
||||
w2_weight = torch.stack(gemm2_weights_mxfp4_shuffled)
|
||||
w2_weight_scale = torch.stack(gemm2_scales_mxfp4_shuffled).reshape(
|
||||
self.num_experts, self.hidden_size, self.intermediate_size //
|
||||
sf_block_size).view(torch.float8_e4m3fn)
|
||||
|
||||
layer.w13_weight = Parameter(w13_weight, requires_grad=False)
|
||||
layer.w13_weight_scale = Parameter(w13_weight_scale,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = Parameter(w2_weight, requires_grad=False)
|
||||
layer.w2_weight_scale = Parameter(w2_weight_scale,
|
||||
requires_grad=False)
|
||||
layer.w13_bias = Parameter(
|
||||
torch.stack(gemm1_bias_shuffled).reshape(self.num_experts, -1),
|
||||
requires_grad=False)
|
||||
layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape(
|
||||
self.num_experts, -1),
|
||||
requires_grad=False)
|
||||
return
|
||||
|
||||
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
|
||||
# Number of tokens in the input tensor.
|
||||
num_tokens = x.shape[0]
|
||||
# Factor to account for the imbalance of the experts.
|
||||
# factor equals to the
|
||||
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
|
||||
# - 1.0 means perfect expert distribution.
|
||||
# - > 1.0 means some experts have more
|
||||
# tokens than the perfect distribution.
|
||||
# - < 1.0 does not make sense.
|
||||
imbalance_factor = 1.3
|
||||
# Calculate the number of tokens per expert
|
||||
# assuming perfect distribution.
|
||||
num_tokens_per_expert = (num_tokens * top_k) // self.num_experts
|
||||
# Apply the imbalance factor.
|
||||
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
|
||||
# And pad the number to the next power of 2.
|
||||
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
||||
# Cap to 8-64 tokens per CTA tile
|
||||
# as it's the range supported by the kernel.
|
||||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||
|
||||
return tile_tokens_dim
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError("EPLB is not supported for mxfp4")
|
||||
|
||||
assert _can_support_mxfp4(
|
||||
use_grouped_topk, topk_group, num_expert_group, expert_map,
|
||||
custom_routing_function, e_score_correction_bias,
|
||||
apply_router_weight_on_input, scoring_func, activation,
|
||||
expert_load_view, logical_to_physical_map,
|
||||
logical_replica_count), ("MXFP4 are not supported\
|
||||
with this configuration.")
|
||||
|
||||
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
|
||||
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
|
||||
assert not self.moe.use_ep, (
|
||||
"EP is not supported for flashinfer mxfp4 moe backend yet.")
|
||||
if envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16:
|
||||
assert x.dtype == torch.bfloat16
|
||||
x_quant = x
|
||||
x_scale = None
|
||||
else:
|
||||
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
|
||||
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
|
||||
trtllm_gen_output = trtllm_fp4_block_scale_moe(
|
||||
router_logits.to(torch.bfloat16),
|
||||
None, # routing_bias
|
||||
x_quant,
|
||||
x_scale,
|
||||
layer.w13_weight, # uint8 (e2m1 x 2)
|
||||
layer.w13_weight_scale, # uint8 (e4m3 x 2)
|
||||
layer.w13_bias, # fp32 per expert per channel
|
||||
layer.gemm1_alpha, # fp32 per expert
|
||||
layer.gemm1_beta, # fp32 per expert
|
||||
layer.gemm1_clamp_limit, # fp32 per expert
|
||||
layer.w2_weight, # uint8 (e2m1 x 2)
|
||||
layer.w2_weight_scale, # ue8m0
|
||||
layer.w2_bias, # fp32 per expert per channel
|
||||
None, # output1_scale_scalar
|
||||
None, # output1_scale_gate_scalar
|
||||
None, # output2_scale_scalar
|
||||
self.num_experts,
|
||||
top_k,
|
||||
None, # n_group
|
||||
None, # topk_group
|
||||
self.intermediate_size, # padded to multiple of 256
|
||||
0, # local_expert_offset
|
||||
self.num_experts, # local num experts
|
||||
None,
|
||||
self._get_tile_tokens_dim(x, top_k),
|
||||
1 if renormalize else 0, # routing_method_type, renormalize
|
||||
True, # do finalize
|
||||
)[0]
|
||||
return trtllm_gen_output
|
@ -1,5 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.utils import direct_register_custom_op
|
||||
@ -7,6 +9,26 @@ from vllm.utils import direct_register_custom_op
|
||||
OCP_MX_BLOCK_SIZE = 32
|
||||
|
||||
|
||||
def _can_support_mxfp4(use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
scoring_func: str = "softmax",
|
||||
activation: str = "silu",
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None):
|
||||
return not (use_grouped_topk or topk_group or num_expert_group
|
||||
or expert_map or custom_routing_function
|
||||
or e_score_correction_bias or apply_router_weight_on_input
|
||||
or scoring_func != "softmax" or activation != "silu"
|
||||
or expert_load_view or logical_to_physical_map
|
||||
or logical_replica_count)
|
||||
|
||||
|
||||
def _dequant_mxfp4(x: torch.Tensor, scale: torch.Tensor,
|
||||
float_dtype: torch.dtype) -> torch.Tensor:
|
||||
try:
|
||||
|
Reference in New Issue
Block a user