mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[ROCM] MoE fp4 CK kernel (#26545)
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
This commit is contained in:
committed by
GitHub
parent
99722d5f0e
commit
0925b28a8e
@ -46,6 +46,11 @@ def is_rocm_aiter_moe_enabled() -> bool:
|
||||
)
|
||||
|
||||
|
||||
@cache
|
||||
def use_mxfp4_aiter_moe() -> bool:
|
||||
return current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER
|
||||
|
||||
|
||||
@cache
|
||||
def is_rocm_aiter_fusion_shared_expert_enabled() -> bool:
|
||||
return (
|
||||
|
@ -23,6 +23,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled,
|
||||
use_mxfp4_aiter_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
prepare_moe_fp8_layer_for_marlin,
|
||||
@ -472,22 +473,22 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
"not implemented. Please open an issue."
|
||||
)
|
||||
|
||||
if not current_platform.supports_mx():
|
||||
self.emulate = True
|
||||
self.emulate = not current_platform.supports_mx() or not (
|
||||
use_mxfp4_aiter_moe() and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4"
|
||||
)
|
||||
if self.emulate:
|
||||
logger.warning_once(
|
||||
"The current platform does not support native MXFP4/MXFP6 "
|
||||
f"The current mode (supports_mx={current_platform.supports_mx()}, "
|
||||
f"use_mxfp4_aiter_moe={use_mxfp4_aiter_moe()}, "
|
||||
f"ocp_mx_scheme={self.ocp_mx_scheme}) "
|
||||
"does not support native MXFP4/MXFP6 "
|
||||
"computation. Simulated weight dequantization and activation "
|
||||
"QDQ (quantize and dequantize) will be used, with the linear "
|
||||
"layers computed in high precision."
|
||||
)
|
||||
else:
|
||||
self.emulate = True
|
||||
logger.warning_once(
|
||||
"The current platform supports native MXFP4/MXFP6 "
|
||||
"computation, but kernels are not yet integrated in vLLM. "
|
||||
"Simulated weight dequantization and activation "
|
||||
"QDQ (quantize and dequantize) will be used, with the linear "
|
||||
"layers computed in high precision."
|
||||
"The current mode supports native MoE MXFP4 computation"
|
||||
)
|
||||
|
||||
def get_packed_dim(self, dim: int, quant_dtype: str):
|
||||
@ -568,6 +569,24 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.emulate:
|
||||
return
|
||||
|
||||
from aiter.utility.fp4_utils import e8m0_shuffle
|
||||
|
||||
# Pre-shuffle weight scales
|
||||
s0, s1, _ = layer.w13_weight_scale.shape
|
||||
w13_weight_scale = layer.w13_weight_scale.view(s0 * s1, -1)
|
||||
w13_weight_scale = e8m0_shuffle(w13_weight_scale)
|
||||
layer.w13_weight_scale.data = w13_weight_scale.view(s0, s1, -1)
|
||||
|
||||
s0, s1, _ = layer.w2_weight_scale.shape
|
||||
w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1)
|
||||
w2_weight_scale = e8m0_shuffle(w2_weight_scale)
|
||||
layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
@ -611,8 +630,6 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
"EPLB not supported for `QuarkOCP_MX_MoEMethod` yet."
|
||||
)
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
@ -628,17 +645,44 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
|
||||
out = fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
if not self.emulate:
|
||||
from aiter import ActivationType, QuantType
|
||||
from aiter.fused_moe import fused_moe
|
||||
|
||||
aiter_acts = {
|
||||
ActivationType.No.name.lower(): ActivationType.No,
|
||||
ActivationType.Silu.name.lower(): ActivationType.Silu,
|
||||
ActivationType.Gelu.name.lower(): ActivationType.Gelu,
|
||||
}
|
||||
assert activation in aiter_acts, (
|
||||
f"Aiter CK fp4 MoE doesn't support activation {activation}"
|
||||
)
|
||||
out = fused_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type=QuantType.per_1x32,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
activation=aiter_acts[activation],
|
||||
doweight_stage1=False,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
out = fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
return out
|
||||
|
Reference in New Issue
Block a user