[torch.compile] Unwrap fused_marlin_moe custom op (#26739)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-10-13 22:22:16 -04:00
committed by GitHub
parent 8a0af6a561
commit 8ae169286f
10 changed files with 22 additions and 52 deletions

View File

@ -26,6 +26,7 @@ from vllm.model_executor.layers.fused_moe.config import (
int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk,
modular_triton_fused_moe,
@ -724,7 +725,7 @@ def test_fused_marlin_moe(
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map)
marlin_output = torch.ops.vllm.fused_marlin_moe(
marlin_output = fused_marlin_moe(
a,
qweight1,
qweight2,
@ -837,7 +838,7 @@ def test_fused_marlin_moe_with_bias(m):
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1, b_bias2)
marlin_output = torch.ops.vllm.fused_marlin_moe(
marlin_output = fused_marlin_moe(
a,
qweight1,
qweight2,

View File

@ -51,7 +51,6 @@ __all__ = [
if HAS_TRITON:
# import to register the custom ops
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts,
)

View File

@ -19,7 +19,6 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
maybe_warn_marlin_atomic_add,
)
from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils import direct_register_custom_op
def fused_marlin_moe(
@ -241,44 +240,6 @@ def fused_marlin_moe(
return torch.sum(intermediate_cache3.view(-1, topk, K), dim=1, out=output)
def fused_marlin_moe_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: torch.Tensor | None,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
quant_type_id: int,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
global_scale1: torch.Tensor | None = None,
global_scale2: torch.Tensor | None = None,
expert_map: torch.Tensor | None = None,
g_idx1: torch.Tensor | None = None,
g_idx2: torch.Tensor | None = None,
sort_indices1: torch.Tensor | None = None,
sort_indices2: torch.Tensor | None = None,
w1_zeros: torch.Tensor | None = None,
w2_zeros: torch.Tensor | None = None,
workspace: torch.Tensor | None = None,
intermediate_cache13: torch.Tensor | None = None,
intermediate_cache2: torch.Tensor | None = None,
is_k_full: bool = True,
output: torch.Tensor | None = None,
inplace: bool = False,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
direct_register_custom_op(
op_name="fused_marlin_moe",
op_func=fused_marlin_moe,
fake_impl=fused_marlin_moe_fake,
)
class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self, quant_config: FusedMoEQuantConfig):
# TODO (varun) : Enable activation quantization

View File

@ -14,6 +14,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,
@ -604,7 +605,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
indices_type=self.topk_indices_dtype,
)
return torch.ops.vllm.fused_marlin_moe(
return fused_marlin_moe(
x,
layer.w13_qweight,
layer.w2_qweight,

View File

@ -34,6 +34,7 @@ from vllm.model_executor.layers.fused_moe.cpu_fused_moe import select_experts
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
is_valid_flashinfer_cutlass_fused_moe,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
WNA16_SUPPORTED_BITS,
WNA16_SUPPORTED_TYPES_MAP,
@ -462,7 +463,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
#
if self.use_marlin:
assert self.fused_experts is None
return torch.ops.vllm.fused_marlin_moe(
return fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
@ -1067,7 +1068,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
if self.use_marlin:
assert activation == "silu", f"{activation} not supported for Marlin MoE."
assert self.fused_experts is None
return torch.ops.vllm.fused_marlin_moe(
return fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,
@ -1654,7 +1655,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
indices_type=self.topk_indices_dtype,
)
return torch.ops.vllm.fused_marlin_moe(
return fused_marlin_moe(
x,
layer.w13_weight_packed,
layer.w2_weight_packed,

View File

@ -26,6 +26,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.linear import (
LinearBase,
@ -1196,7 +1197,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
elif self.use_marlin:
assert activation == "silu", f"{activation} not supported for Marlin MoE."
assert self.fused_experts is None
result = torch.ops.vllm.fused_marlin_moe(
result = fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,

View File

@ -15,6 +15,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,
@ -765,7 +766,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
indices_type=self.topk_indices_dtype,
)
return torch.ops.vllm.fused_marlin_moe(
return fused_marlin_moe(
x,
layer.w13_qweight,
layer.w2_qweight,

View File

@ -21,6 +21,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
is_valid_flashinfer_cutlass_fused_moe,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoEMethodBase,
@ -1701,7 +1702,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
#
if self.use_marlin:
assert self.fused_experts is None
return torch.ops.vllm.fused_marlin_moe(
return fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,

View File

@ -21,7 +21,10 @@ from vllm.model_executor.layers.fused_moe.config import (
mxfp4_w4a16_moe_quant_config,
ocp_mx_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import MarlinExperts
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
fused_marlin_moe,
)
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
OAITritonExperts,
)
@ -947,7 +950,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
e_score_correction_bias=e_score_correction_bias,
)
return torch.ops.vllm.fused_marlin_moe(
return fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,

View File

@ -20,6 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config,
ocp_mx_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled,
)
@ -402,7 +403,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
)
if self.use_marlin:
assert activation == "silu", f"{activation} not supported for Marlin MoE."
return torch.ops.vllm.fused_marlin_moe(
return fused_marlin_moe(
x,
layer.w13_weight,
layer.w2_weight,