From 0dc9532065c5f98952cb82d4c497e49ca09400bf Mon Sep 17 00:00:00 2001 From: yzds <41983536+youzhedian@users.noreply.github.com> Date: Sat, 30 Aug 2025 00:36:39 +0800 Subject: [PATCH] [BUGFIX ] fix undefined silu_and_mul_nvfp4_quant (#23929) Signed-off-by: hongchao Signed-off-by: Richard Zou Co-authored-by: hongchao Co-authored-by: Richard Zou Co-authored-by: Richard Zou --- csrc/ops.h | 4 ++-- csrc/torch_bindings.cpp | 3 ++- vllm/compilation/fix_functionalization.py | 4 +++- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 78a487201b..7a176a5c00 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -130,8 +130,8 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input); void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& scale); -#ifndef USE_ROCM - +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) void silu_and_mul_nvfp4_quant(torch::Tensor& out, torch::Tensor& output_block_scale, torch::Tensor& input, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index b769c09adc..56626a02c0 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -115,7 +115,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant); -#ifndef USE_ROCM +#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ + (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) ops.def( "silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, " "Tensor input, Tensor input_global_scale) -> ()"); diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index a36dd8b845..6bc721eec3 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -97,7 +97,9 @@ class FixFunctionalizationPass(VllmInductorPass): node, mutated_args, args=('result', 'input', 'scale')) - elif at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default: + elif hasattr( + torch.ops._C, "silu_and_mul_nvfp4_quant" + ) and at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default: mutated_args = {1: 'result', 2: 'result_block_scale'} self.defunctionalize(graph, node,