mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[torch.compile] Unwrap fused_marlin_moe custom op (#26739)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
8a0af6a561
commit
8ae169286f
@ -26,6 +26,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
int4_w4a16_moe_quant_config,
|
||||
int8_w8a16_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk,
|
||||
modular_triton_fused_moe,
|
||||
@ -724,7 +725,7 @@ def test_fused_marlin_moe(
|
||||
with set_current_vllm_config(vllm_config):
|
||||
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map)
|
||||
|
||||
marlin_output = torch.ops.vllm.fused_marlin_moe(
|
||||
marlin_output = fused_marlin_moe(
|
||||
a,
|
||||
qweight1,
|
||||
qweight2,
|
||||
@ -837,7 +838,7 @@ def test_fused_marlin_moe_with_bias(m):
|
||||
with set_current_vllm_config(vllm_config):
|
||||
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1, b_bias2)
|
||||
|
||||
marlin_output = torch.ops.vllm.fused_marlin_moe(
|
||||
marlin_output = fused_marlin_moe(
|
||||
a,
|
||||
qweight1,
|
||||
qweight2,
|
||||
|
@ -51,7 +51,6 @@ __all__ = [
|
||||
|
||||
if HAS_TRITON:
|
||||
# import to register the custom ops
|
||||
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts,
|
||||
)
|
||||
|
@ -19,7 +19,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
maybe_warn_marlin_atomic_add,
|
||||
)
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
|
||||
def fused_marlin_moe(
|
||||
@ -241,44 +240,6 @@ def fused_marlin_moe(
|
||||
return torch.sum(intermediate_cache3.view(-1, topk, K), dim=1, out=output)
|
||||
|
||||
|
||||
def fused_marlin_moe_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
gating_output: torch.Tensor | None,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
quant_type_id: int,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
global_scale1: torch.Tensor | None = None,
|
||||
global_scale2: torch.Tensor | None = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
g_idx1: torch.Tensor | None = None,
|
||||
g_idx2: torch.Tensor | None = None,
|
||||
sort_indices1: torch.Tensor | None = None,
|
||||
sort_indices2: torch.Tensor | None = None,
|
||||
w1_zeros: torch.Tensor | None = None,
|
||||
w2_zeros: torch.Tensor | None = None,
|
||||
workspace: torch.Tensor | None = None,
|
||||
intermediate_cache13: torch.Tensor | None = None,
|
||||
intermediate_cache2: torch.Tensor | None = None,
|
||||
is_k_full: bool = True,
|
||||
output: torch.Tensor | None = None,
|
||||
inplace: bool = False,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="fused_marlin_moe",
|
||||
op_func=fused_marlin_moe,
|
||||
fake_impl=fused_marlin_moe_fake,
|
||||
)
|
||||
|
||||
|
||||
class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def __init__(self, quant_config: FusedMoEQuantConfig):
|
||||
# TODO (varun) : Enable activation quantization
|
||||
|
@ -14,6 +14,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
@ -604,7 +605,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
|
@ -34,6 +34,7 @@ from vllm.model_executor.layers.fused_moe.cpu_fused_moe import select_experts
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
is_valid_flashinfer_cutlass_fused_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
|
||||
WNA16_SUPPORTED_BITS,
|
||||
WNA16_SUPPORTED_TYPES_MAP,
|
||||
@ -462,7 +463,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
#
|
||||
if self.use_marlin:
|
||||
assert self.fused_experts is None
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
@ -1067,7 +1068,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
if self.use_marlin:
|
||||
assert activation == "silu", f"{activation} not supported for Marlin MoE."
|
||||
assert self.fused_experts is None
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
@ -1654,7 +1655,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight_packed,
|
||||
layer.w2_weight_packed,
|
||||
|
@ -26,6 +26,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
@ -1196,7 +1197,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
elif self.use_marlin:
|
||||
assert activation == "silu", f"{activation} not supported for Marlin MoE."
|
||||
assert self.fused_experts is None
|
||||
result = torch.ops.vllm.fused_marlin_moe(
|
||||
result = fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
|
@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
@ -765,7 +766,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
|
@ -21,6 +21,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
is_valid_flashinfer_cutlass_fused_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
@ -1701,7 +1702,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
#
|
||||
if self.use_marlin:
|
||||
assert self.fused_experts is None
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
|
@ -21,7 +21,10 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
mxfp4_w4a16_moe_quant_config,
|
||||
ocp_mx_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import MarlinExperts
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
MarlinExperts,
|
||||
fused_marlin_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
OAITritonExperts,
|
||||
)
|
||||
@ -947,7 +950,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
|
@ -20,6 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
fp8_w8a8_moe_quant_config,
|
||||
ocp_mx_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled,
|
||||
)
|
||||
@ -402,7 +403,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
)
|
||||
if self.use_marlin:
|
||||
assert activation == "silu", f"{activation} not supported for Marlin MoE."
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
|
Reference in New Issue
Block a user