From 7c6c5d04fe3c82ec010ae7f636f35e359d13d226 Mon Sep 17 00:00:00 2001 From: Simon Layton Date: Tue, 14 Oct 2025 14:36:57 +0000 Subject: [PATCH] Add scaled_grouped_mm_v2 and python API (#165154) Summary: * Add `torch._scaled_grouped_mm_v2` with more functionality and extensibility for future formats * Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint * Test both original and v2 functionality Test Plan: ``` pytest -svv -k grouped test/test_scaled_matmul_cuda.py ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton Pull Request resolved: https://github.com/pytorch/pytorch/pull/165154 Approved by: https://github.com/drisspg, https://github.com/danielvegamyhre --- aten/src/ATen/native/cuda/Blas.cpp | 153 ++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 6 + docs/source/nn.functional.rst | 1 + ...asDecompTest.test_has_decomposition.expect | 1 + test/test_scaled_matmul_cuda.py | 99 ++++++++++-- torch/nn/functional.py | 87 ++++++++++ torch/overrides.py | 1 + 7 files changed, 334 insertions(+), 14 deletions(-) diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 48b49c3c597d..c95145f0dd1b 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -2578,7 +2578,9 @@ _mx8_mx8_bf16_grouped_mm_fbgemm( const Tensor& mat_a, const Tensor& mat_b, const Tensor& scale_a, + const SwizzleType& swizzle_a, const Tensor& scale_b, + const SwizzleType& swizzle_b, const std::optional& offs, Tensor& out) { const bool a_is_2d = mat_a.dim() == 2; @@ -2589,6 +2591,16 @@ _mx8_mx8_bf16_grouped_mm_fbgemm( TORCH_CHECK_VALUE(is_2d_2d || is_2d_3d, "MXFP8 grouped GEMM currently only supports 2d-2d and 2d-3d cases"); TORCH_CHECK_VALUE(offs.has_value(), "MXFP8 2d-2d and 2d-3d grouped GEMMs requires offsets"); TORCH_CHECK_VALUE(out.scalar_type() == at::kBFloat16, "Only bf16 out_dtype is supported for MXFP8 grouped gemm"); + // MXFP8 expects float8_e8m0fnu scales. + TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu, + "For MXFP8 grouped gemm, both scales must be float8_e8m0fnu tensors."); +#ifdef USE_ROCM + TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE && swizzle_b == SwizzleType::NO_SWIZZLE, + "For ROCM MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_NONE"); +#else + TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4 && swizzle_b == SwizzleType::SWIZZLE_32_4_4, + "For CUDA MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_32_4_4"); +#endif #if defined(USE_FBGEMM_GENAI) and !defined(USE_ROCM) fbgemm_gpu::mx8mx8bf16_grouped_mm( @@ -2673,6 +2685,9 @@ _f8_f8_bf16_rowwise_grouped_mm( const std::optional& bias, bool use_fast_accum, Tensor& out) { + // FP8 per-tensor and per-row scaling expect fp32 scales. + TORCH_CHECK_VALUE(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat, + "For grouped FP8 rowwise, both scales must be float32 tensors"); #ifndef USE_ROCM return _f8_f8_bf16_rowwise_grouped_mm_cuda( mat_a, @@ -2772,11 +2787,15 @@ _scaled_grouped_mm_cuda( #endif if (is_mx8mx8bf16) { + // Note: Passing implied SwizzleType here, correctness of scale previously checked + // in `check_scale` call return _mx8_mx8_bf16_grouped_mm_fbgemm( mat_a, mat_b, scale_a, + SwizzleType::SWIZZLE_32_4_4, scale_b, + SwizzleType::SWIZZLE_32_4_4, offs.value(), out); } @@ -2793,6 +2812,140 @@ _scaled_grouped_mm_cuda( out); } +namespace { + +std::array, 2> scale_grouped_kernel_dispatch = {{ + { "rowwise_rowwise", check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE}, + { "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}}; + +} // anonymous namespace + +Tensor +_scaled_grouped_mm_cuda_v2( + const Tensor& mat_a, const Tensor& mat_b, + ArrayRef scale_a, + IntArrayRef scale_recipe_a, + IntArrayRef swizzle_a, + ArrayRef scale_b, + IntArrayRef scale_recipe_b, + IntArrayRef swizzle_b, + const std::optional& offs, + const std::optional& bias, + const std::optional out_dtype, + IntArrayRef contraction_dim, + bool use_fast_accum) { + bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true); + TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+"); + + TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed"); + TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed"); + TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d"); + TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d"); + const bool a_is_2d = mat_a.dim() == 2; + const bool b_is_2d = mat_b.dim() == 2; + + // NOTE(slayton): For sub-1B formats want contraction_dim argument? + if (!a_is_2d || !b_is_2d) { + if (contraction_dim.size() > 0) { + const int dim_a = contraction_dim[0], dim_b = mat_b.size(contraction_dim[1]); + TORCH_CHECK_VALUE(mat_a.size(dim_a) == mat_b.size(dim_b), + "Contraction dimensions (", dim_a, ",", dim_b, ") of mat_a and mat_b must match, got: ", mat_a.size(dim_a), " and ", + mat_b.size(dim_b)); + // Note: only (-1, -2) is currently supported + TORCH_CHECK_VALUE(dim_a == -1 && dim_b == -2, "Curently contraction dims must be (-1, -2) only"); + } else { + TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match"); + } + } + TORCH_CHECK_VALUE( + mat_a.size(-1) % 16 == 0, + "Expected trailing dimension of mat_a to be divisible by 16 ", + "but got mat1 shape: (", + mat_a.sizes(), + ")."); + TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0, + "Expected mat_b shape to be divisible by 16 ", + "but got mat_b shape: (", + mat_b.sizes(), + ")."); + + TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet"); + TORCH_CHECK_VALUE(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix"); + + // NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present. + // for rowwise, no offsets implies 3d-3d and is handled by lower-level + // routines + if (offs.has_value()) { + TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D"); + TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32"); + } + + const auto out_dtype_ = out_dtype.value_or(kBFloat16); + TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm"); + + Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_); + + // Conversion of implicitly-defined enums to explicit + auto scale_recipe_a_enum = convert_int_to_enum(scale_recipe_a); + auto swizzle_a_enum = convert_int_to_enum(swizzle_a); + auto scale_recipe_b_enum = convert_int_to_enum(scale_recipe_b); + auto swizzle_b_enum = convert_int_to_enum(swizzle_b); + + // at this point we can start working out what we want to be doing + // Try to do as few steps as possible. + // NOTE: support is deliberately sparse, can explicitly enumerate all combinations allowed. + // Do this via a list of defined (name, acceptance, concrete_impl) tuples. + ScaledGemmImplementation gemm_impl = ScaledGemmImplementation::NONE; + for (const auto& fn_entry : scale_grouped_kernel_dispatch) { + const auto [name, accept_fn, scaled_gemm_impl] = fn_entry; + bool ok = accept_fn(mat_a.scalar_type(), + scale_recipe_a_enum, + scale_a, + mat_b.scalar_type(), + scale_recipe_b_enum, + scale_b); + if (ok) { + gemm_impl = scaled_gemm_impl; + break; + } + } + TORCH_CHECK_VALUE(gemm_impl != ScaledGemmImplementation::NONE, + "No gemm implementation was found"); + + switch (gemm_impl) { + case ScaledGemmImplementation::ROWWISE_ROWWISE: { + const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1; + _check_scales_fp8_rowwise(mat_a, scale_a[0], 0 /* dim */ , 0 /* arg_idx */, scale_multiplier); + _check_scales_fp8_rowwise(mat_b, scale_b[0], 1 /* dim */ , 1 /* arg_idx */, scale_multiplier); + return _f8_f8_bf16_rowwise_grouped_mm( + mat_a, + mat_b, + scale_a[0], + scale_b[0], + offs, + bias, + use_fast_accum, + out); + } + case ScaledGemmImplementation::MXFP8_MXFP8: { + _check_scales_mxfp8(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */); + _check_scales_mxfp8(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */); + return _mx8_mx8_bf16_grouped_mm_fbgemm( + mat_a, + mat_b, + scale_a[0], + swizzle_a_enum[0], + scale_b[0], + swizzle_b_enum[0], + offs.value(), + out); + } + default: + TORCH_CHECK_NOT_IMPLEMENTED(false, + "_scaled_grouped_mm_cuda_v2 is in an inconsistent state - should never reach here"); + } +} + Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b, const std::optional& offs, const std::optional& bias, diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 9b3c75b13e9d..db788c6e3e66 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -7183,6 +7183,12 @@ CUDA: _scaled_grouped_mm_cuda tags: needs_exact_strides +- func: _scaled_grouped_mm_v2(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None, int[] contraction_dim=[], bool use_fast_accum=False) -> Tensor + variants: function + dispatch: + CUDA: _scaled_grouped_mm_cuda_v2 + tags: needs_exact_strides + - func: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor variants: function dispatch: diff --git a/docs/source/nn.functional.rst b/docs/source/nn.functional.rst index c34c351937b2..015d1d9ffda1 100644 --- a/docs/source/nn.functional.rst +++ b/docs/source/nn.functional.rst @@ -228,3 +228,4 @@ Low-Precision functions ScalingType SwizzleType scaled_mm + scaled_grouped_mm diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 936a90938292..42c63ad8706f 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -524,6 +524,7 @@ aten::_scaled_dot_product_flash_attention_for_cpu_backward aten::_scaled_dot_product_fused_attention_overrideable aten::_scaled_dot_product_fused_attention_overrideable_backward aten::_scaled_grouped_mm +aten::_scaled_grouped_mm_v2 aten::_scaled_mm aten::_scaled_mm.out aten::_scaled_mm_v2 diff --git a/test/test_scaled_matmul_cuda.py b/test/test_scaled_matmul_cuda.py index e694b836ede7..e58f3ea8d960 100644 --- a/test/test_scaled_matmul_cuda.py +++ b/test/test_scaled_matmul_cuda.py @@ -11,7 +11,7 @@ from typing import Optional import torch -from torch.nn.functional import scaled_mm, ScalingType, SwizzleType +from torch.nn.functional import scaled_mm, scaled_grouped_mm, ScalingType, SwizzleType from torch.testing._internal.common_cuda import ( IS_SM90, _get_torch_cuda_version, @@ -215,6 +215,49 @@ def scaled_mm_wrap( ) return out +def scaled_grouped_mm_wrap( + a, + b, + scale_a, + scale_b, + scale_recipe_a, + scale_recipe_b, + swizzle_a=SwizzleType.NO_SWIZZLE, + swizzle_b=SwizzleType.NO_SWIZZLE, + scale_result=None, + out_dtype=torch.bfloat16, + use_fast_accum=False, + offs=None, + bias=None, + wrap_v2=True, +): + if not wrap_v2: + return torch._scaled_grouped_mm( + a, + b, + scale_a, + scale_b, + out_dtype=out_dtype, + bias=bias, + offs=offs, + use_fast_accum=use_fast_accum) + else: + return scaled_grouped_mm( + a, + b, + scale_a, + scale_recipe_a, + scale_b, + scale_recipe_b, + swizzle_a=swizzle_a, + swizzle_b=swizzle_b, + offs=offs, + bias=bias, + output_dtype=out_dtype, + use_fast_accum=use_fast_accum) + + + def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor: # naive implementation: dq -> op -> q x_fp32 = x.to(torch.float) / x_scale @@ -444,7 +487,8 @@ class TestFP8Matmul(TestCase): @parametrize("M", [2048, 2049]) @parametrize("N", [8192]) @parametrize("K", [16640]) - def test_mxfp8_scaled_grouped_mm_2d_2d(self, G, M, N, K): + @parametrize("wrap_v2", [True, False]) + def test_mxfp8_scaled_grouped_mm_2d_2d(self, G, M, N, K, wrap_v2): torch.manual_seed(42) total_K = K # Alias for clarity, communicating this consists of several groups along this dim input_group_end_offsets = generate_jagged_offs( @@ -510,13 +554,18 @@ class TestFP8Matmul(TestCase): w_blocked_scales = w_blocked_scales.reshape(N_rounded, -1) # Compute mxfp8 grouped mm output - y_mxfp8 = torch._scaled_grouped_mm( + y_mxfp8 = scaled_grouped_mm_wrap( xq, # (M, total_K) wq.transpose(-2, -1), # (total_K, N) x_blocked_scales, # to_blocked_per_group(M, total_K//32) w_blocked_scales, # to_blocked_per_group(N, total_K//32) + scale_recipe_a=ScalingType.BlockWise1x32, + scale_recipe_b=ScalingType.BlockWise1x32, + swizzle_a=SwizzleType.SWIZZLE_32_4_4, + swizzle_b=SwizzleType.SWIZZLE_32_4_4, offs=input_group_end_offsets, # (G,) out_dtype=torch.bfloat16, + wrap_v2=wrap_v2 ) # bf16 reference output @@ -535,7 +584,8 @@ class TestFP8Matmul(TestCase): @parametrize("M", [16640]) @parametrize("N", [8192]) @parametrize("K", [4096]) - def test_mxfp8_scaled_grouped_mm_2d_3d(self, G, M, N, K): + @parametrize("wrap_v2", [True, False]) + def test_mxfp8_scaled_grouped_mm_2d_3d(self, G, M, N, K, wrap_v2): torch.manual_seed(42) # Simulate 2d-3d grouped gemm `out = input @ weight.t()` # 2D inputs with groups along M, 3D weights. @@ -579,14 +629,19 @@ class TestFP8Matmul(TestCase): xq = xq.view(-1, xq.shape[-1]) # Compute mxfp8 grouped gemm. - y_mxfp8 = torch._scaled_grouped_mm( + y_mxfp8 = scaled_grouped_mm_wrap( xq, wq.transpose(-2, -1), x_scale, w_scale, offs=input_group_end_offsets, out_dtype=torch.bfloat16, - ) + scale_recipe_a=ScalingType.BlockWise1x32, + scale_recipe_b=ScalingType.BlockWise1x32, + swizzle_a=SwizzleType.SWIZZLE_32_4_4, + swizzle_b=SwizzleType.SWIZZLE_32_4_4, + wrap_v2=wrap_v2) + # Compute reference bf16 grouped gemm. y_bf16 = torch._grouped_mm( @@ -1536,7 +1591,8 @@ class TestFP8Matmul(TestCase): @parametrize("fast_accum", [False, True]) # AMD does not support non-contiguous inputs yet @parametrize("strided", [False] + ([True] if torch.version.cuda else [])) - def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided): + @parametrize("wrap_v2", [True, False]) + def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided, wrap_v2): device = "cuda" fp8_dtype = e4m3_type m, n, k, n_groups = 16, 32, 64, 4 @@ -1545,9 +1601,16 @@ class TestFP8Matmul(TestCase): scale_a = torch.rand(m * n_groups, device=device, dtype=torch.float32) scale_b = torch.rand(n * n_groups, device=device, dtype=torch.float32) offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32) - f = torch._scaled_grouped_mm - out = f(a, b.t(), scale_a, scale_b, offs=offs, - out_dtype=torch.bfloat16, use_fast_accum=fast_accum) + f = scaled_grouped_mm_wrap + out = f(a, b.t(), + scale_a, + scale_b, + scale_recipe_a=ScalingType.RowWise, + scale_recipe_b=ScalingType.RowWise, + offs=offs, + out_dtype=torch.bfloat16, + use_fast_accum=fast_accum, + wrap_v2=wrap_v2) offs_cpu = offs.cpu() alist, blist, ascalelist, bscalelist = [], [], [], [] start = 0 @@ -1564,7 +1627,8 @@ class TestFP8Matmul(TestCase): @parametrize("fast_accum", [False, True]) # AMD does not support non-contiguous inputs yet @parametrize("strided", [False] + ([True] if torch.version.cuda else [])) - def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided): + @parametrize("wrap_v2", [True, False]) + def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, wrap_v2): device = "cuda" fp8_dtype = e4m3_type m, n, k, n_groups = 16, 32, 64, 4 @@ -1582,9 +1646,16 @@ class TestFP8Matmul(TestCase): offs[0] = offs[1] scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32) scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n) - f = torch._scaled_grouped_mm - out = f(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs, - out_dtype=torch.bfloat16, use_fast_accum=fast_accum) + f = scaled_grouped_mm_wrap + out = f(a, b.transpose(-2, -1), + scale_a, + scale_b, + scale_recipe_a=ScalingType.RowWise, + scale_recipe_b=ScalingType.RowWise, + offs=offs, + out_dtype=torch.bfloat16, + use_fast_accum=fast_accum, + wrap_v2=wrap_v2) offs_cpu = offs.cpu() alist, ascalelist, outlist = [], [], [] diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 8c4c958b7476..ef4ed35008cc 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -6711,3 +6711,90 @@ def scaled_mm( ) return out + + +def scaled_grouped_mm( + mat_a: Tensor, + mat_b: Tensor, + scale_a: Tensor | list[Tensor], + scale_recipe_a: ScalingType | list[ScalingType], + scale_b: Tensor | list[Tensor], + scale_recipe_b: ScalingType | list[ScalingType], + swizzle_a: SwizzleType | list[SwizzleType] | None = None, + swizzle_b: SwizzleType | list[SwizzleType] | None = None, + bias: Optional[Tensor] = None, + offs: Optional[Tensor] = None, + output_dtype: Optional[torch.dtype] = torch.bfloat16, + contraction_dim: list[int] | tuple[int] = (), + use_fast_accum: bool = False, +) -> Tensor: + r""" + scaled_grouped_mm(mat_a, mat_b, scale_a, scale_recipe_a, scale_b, scale_recipe_b, swizzle_a, swizzle_b, bias, offs, + output_dtype, use_fast_accum) + + Applies a grouped scaled matrix-multiply, grouped_mm(mat_a, mat_b) where the scaling of mat_a and mat_b are described by + scale_recipe_a and scale_recipe_b respectively. + + Args: + scale_a: Tensor containing decoding scaling factors for mat_a + scale_recipe_a: Enum describing how mat_a has been scaled + scale_b: Tensor containing decoding scaling factors for mat_b + scale_recipe_b: Enum describing how mat_b has been scaled + swizzle_a: Enum describing the swizzling pattern (if any) of scale_a + swizzle_b: Enum describing the swizzling pattern (if any) of scale_b + bias: optional bias term to be added to the output + offs: optional offsets into the source tensors denoting group start indices + output_dtype: dtype used for the output tensor + contraction_dim: describe which dimensions are :math:`K` in the matmul. + use_fast_accum: enable/disable tensor-core fast accumulation (Hopper-GPUs only) + """ + + def expand_single_value(v: _Any | list[_Any] | None) -> list[_Any]: + if v is None: + return [] + elif not isinstance(v, (list)): + return [ + v, + ] + else: + return v + + scale_a = expand_single_value(scale_a) + scale_recipe_a = expand_single_value(scale_recipe_a) + scale_b = expand_single_value(scale_b) + scale_recipe_b = expand_single_value(scale_recipe_b) + swizzle_a = expand_single_value(swizzle_a) + swizzle_b = expand_single_value(swizzle_b) + + # native_functions has restrictions on what can be defined + # & passed through - std::optional> for instance + # *cannot* be passed, but an empty vector (list) can. + # So, we need to convert None arguments for lists in python + # explicitly into empty lists. + def list_or_empty(l: list[_Any] | None) -> list[_Any]: + return [] if not l else l + + def enum_list_as_int_list(l: _Any | list[_Any]) -> list[_Any]: + if not isinstance(l, list): + l = [ + l, + ] + return [li.value for li in l] + + out = torch._scaled_grouped_mm_v2( + mat_a, + mat_b, + scale_a, + enum_list_as_int_list(scale_recipe_a), + enum_list_as_int_list(list_or_empty(swizzle_a)), + scale_b, + enum_list_as_int_list(scale_recipe_b), + enum_list_as_int_list(list_or_empty(swizzle_b)), + offs, + bias, + output_dtype, + contraction_dim, + use_fast_accum, + ) + + return out diff --git a/torch/overrides.py b/torch/overrides.py index b02301db1f17..264edf07b918 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -249,6 +249,7 @@ def get_ignored_functions() -> set[Callable]: torch.nn.functional.has_torch_function_unary, torch.nn.functional.has_torch_function_variadic, torch.nn.functional.handle_torch_function, + torch.nn.functional.scaled_grouped_mm, torch.nn.functional.scaled_mm, torch.nn.functional.sigmoid, torch.nn.functional.hardsigmoid,