Add scaled_grouped_mm_v2 and python API (#165154)

Summary:

* Add `torch._scaled_grouped_mm_v2` with more functionality and
  extensibility for future formats
* Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint
* Test both original and v2 functionality

Test Plan:

```
pytest -svv -k grouped test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165154
Approved by: https://github.com/drisspg, https://github.com/danielvegamyhre
This commit is contained in:
Simon Layton
2025-10-14 14:36:57 +00:00
committed by PyTorch MergeBot
parent b509fb9b5d
commit 7c6c5d04fe
7 changed files with 334 additions and 14 deletions

View File

@ -2578,7 +2578,9 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
const Tensor& mat_a, const Tensor& mat_a,
const Tensor& mat_b, const Tensor& mat_b,
const Tensor& scale_a, const Tensor& scale_a,
const SwizzleType& swizzle_a,
const Tensor& scale_b, const Tensor& scale_b,
const SwizzleType& swizzle_b,
const std::optional<at::Tensor>& offs, const std::optional<at::Tensor>& offs,
Tensor& out) { Tensor& out) {
const bool a_is_2d = mat_a.dim() == 2; const bool a_is_2d = mat_a.dim() == 2;
@ -2589,6 +2591,16 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
TORCH_CHECK_VALUE(is_2d_2d || is_2d_3d, "MXFP8 grouped GEMM currently only supports 2d-2d and 2d-3d cases"); TORCH_CHECK_VALUE(is_2d_2d || is_2d_3d, "MXFP8 grouped GEMM currently only supports 2d-2d and 2d-3d cases");
TORCH_CHECK_VALUE(offs.has_value(), "MXFP8 2d-2d and 2d-3d grouped GEMMs requires offsets"); TORCH_CHECK_VALUE(offs.has_value(), "MXFP8 2d-2d and 2d-3d grouped GEMMs requires offsets");
TORCH_CHECK_VALUE(out.scalar_type() == at::kBFloat16, "Only bf16 out_dtype is supported for MXFP8 grouped gemm"); TORCH_CHECK_VALUE(out.scalar_type() == at::kBFloat16, "Only bf16 out_dtype is supported for MXFP8 grouped gemm");
// MXFP8 expects float8_e8m0fnu scales.
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu,
"For MXFP8 grouped gemm, both scales must be float8_e8m0fnu tensors.");
#ifdef USE_ROCM
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE && swizzle_b == SwizzleType::NO_SWIZZLE,
"For ROCM MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_NONE");
#else
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4 && swizzle_b == SwizzleType::SWIZZLE_32_4_4,
"For CUDA MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_32_4_4");
#endif
#if defined(USE_FBGEMM_GENAI) and !defined(USE_ROCM) #if defined(USE_FBGEMM_GENAI) and !defined(USE_ROCM)
fbgemm_gpu::mx8mx8bf16_grouped_mm( fbgemm_gpu::mx8mx8bf16_grouped_mm(
@ -2673,6 +2685,9 @@ _f8_f8_bf16_rowwise_grouped_mm(
const std::optional<Tensor>& bias, const std::optional<Tensor>& bias,
bool use_fast_accum, bool use_fast_accum,
Tensor& out) { Tensor& out) {
// FP8 per-tensor and per-row scaling expect fp32 scales.
TORCH_CHECK_VALUE(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
"For grouped FP8 rowwise, both scales must be float32 tensors");
#ifndef USE_ROCM #ifndef USE_ROCM
return _f8_f8_bf16_rowwise_grouped_mm_cuda( return _f8_f8_bf16_rowwise_grouped_mm_cuda(
mat_a, mat_a,
@ -2772,11 +2787,15 @@ _scaled_grouped_mm_cuda(
#endif #endif
if (is_mx8mx8bf16) { if (is_mx8mx8bf16) {
// Note: Passing implied SwizzleType here, correctness of scale previously checked
// in `check_scale` call
return _mx8_mx8_bf16_grouped_mm_fbgemm( return _mx8_mx8_bf16_grouped_mm_fbgemm(
mat_a, mat_a,
mat_b, mat_b,
scale_a, scale_a,
SwizzleType::SWIZZLE_32_4_4,
scale_b, scale_b,
SwizzleType::SWIZZLE_32_4_4,
offs.value(), offs.value(),
out); out);
} }
@ -2793,6 +2812,140 @@ _scaled_grouped_mm_cuda(
out); out);
} }
namespace {
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2> scale_grouped_kernel_dispatch = {{
{ "rowwise_rowwise", check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
{ "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
} // anonymous namespace
Tensor
_scaled_grouped_mm_cuda_v2(
const Tensor& mat_a, const Tensor& mat_b,
ArrayRef<Tensor> scale_a,
IntArrayRef scale_recipe_a,
IntArrayRef swizzle_a,
ArrayRef<Tensor> scale_b,
IntArrayRef scale_recipe_b,
IntArrayRef swizzle_b,
const std::optional<Tensor>& offs,
const std::optional<Tensor>& bias,
const std::optional<c10::ScalarType> out_dtype,
IntArrayRef contraction_dim,
bool use_fast_accum) {
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
const bool a_is_2d = mat_a.dim() == 2;
const bool b_is_2d = mat_b.dim() == 2;
// NOTE(slayton): For sub-1B formats want contraction_dim argument?
if (!a_is_2d || !b_is_2d) {
if (contraction_dim.size() > 0) {
const int dim_a = contraction_dim[0], dim_b = mat_b.size(contraction_dim[1]);
TORCH_CHECK_VALUE(mat_a.size(dim_a) == mat_b.size(dim_b),
"Contraction dimensions (", dim_a, ",", dim_b, ") of mat_a and mat_b must match, got: ", mat_a.size(dim_a), " and ",
mat_b.size(dim_b));
// Note: only (-1, -2) is currently supported
TORCH_CHECK_VALUE(dim_a == -1 && dim_b == -2, "Curently contraction dims must be (-1, -2) only");
} else {
TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
}
}
TORCH_CHECK_VALUE(
mat_a.size(-1) % 16 == 0,
"Expected trailing dimension of mat_a to be divisible by 16 ",
"but got mat1 shape: (",
mat_a.sizes(),
").");
TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0,
"Expected mat_b shape to be divisible by 16 ",
"but got mat_b shape: (",
mat_b.sizes(),
").");
TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet");
TORCH_CHECK_VALUE(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
// NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present.
// for rowwise, no offsets implies 3d-3d and is handled by lower-level
// routines
if (offs.has_value()) {
TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D");
TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32");
}
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
// Conversion of implicitly-defined enums to explicit
auto scale_recipe_a_enum = convert_int_to_enum<ScalingType>(scale_recipe_a);
auto swizzle_a_enum = convert_int_to_enum<SwizzleType>(swizzle_a);
auto scale_recipe_b_enum = convert_int_to_enum<ScalingType>(scale_recipe_b);
auto swizzle_b_enum = convert_int_to_enum<SwizzleType>(swizzle_b);
// at this point we can start working out what we want to be doing
// Try to do as few steps as possible.
// NOTE: support is deliberately sparse, can explicitly enumerate all combinations allowed.
// Do this via a list of defined (name, acceptance, concrete_impl) tuples.
ScaledGemmImplementation gemm_impl = ScaledGemmImplementation::NONE;
for (const auto& fn_entry : scale_grouped_kernel_dispatch) {
const auto [name, accept_fn, scaled_gemm_impl] = fn_entry;
bool ok = accept_fn(mat_a.scalar_type(),
scale_recipe_a_enum,
scale_a,
mat_b.scalar_type(),
scale_recipe_b_enum,
scale_b);
if (ok) {
gemm_impl = scaled_gemm_impl;
break;
}
}
TORCH_CHECK_VALUE(gemm_impl != ScaledGemmImplementation::NONE,
"No gemm implementation was found");
switch (gemm_impl) {
case ScaledGemmImplementation::ROWWISE_ROWWISE: {
const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
_check_scales_fp8_rowwise(mat_a, scale_a[0], 0 /* dim */ , 0 /* arg_idx */, scale_multiplier);
_check_scales_fp8_rowwise(mat_b, scale_b[0], 1 /* dim */ , 1 /* arg_idx */, scale_multiplier);
return _f8_f8_bf16_rowwise_grouped_mm(
mat_a,
mat_b,
scale_a[0],
scale_b[0],
offs,
bias,
use_fast_accum,
out);
}
case ScaledGemmImplementation::MXFP8_MXFP8: {
_check_scales_mxfp8(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
_check_scales_mxfp8(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
return _mx8_mx8_bf16_grouped_mm_fbgemm(
mat_a,
mat_b,
scale_a[0],
swizzle_a_enum[0],
scale_b[0],
swizzle_b_enum[0],
offs.value(),
out);
}
default:
TORCH_CHECK_NOT_IMPLEMENTED(false,
"_scaled_grouped_mm_cuda_v2 is in an inconsistent state - should never reach here");
}
}
Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b, Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
const std::optional<at::Tensor>& offs, const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& bias, const std::optional<at::Tensor>& bias,

View File

@ -7183,6 +7183,12 @@
CUDA: _scaled_grouped_mm_cuda CUDA: _scaled_grouped_mm_cuda
tags: needs_exact_strides tags: needs_exact_strides
- func: _scaled_grouped_mm_v2(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None, int[] contraction_dim=[], bool use_fast_accum=False) -> Tensor
variants: function
dispatch:
CUDA: _scaled_grouped_mm_cuda_v2
tags: needs_exact_strides
- func: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor - func: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor
variants: function variants: function
dispatch: dispatch:

View File

@ -228,3 +228,4 @@ Low-Precision functions
ScalingType ScalingType
SwizzleType SwizzleType
scaled_mm scaled_mm
scaled_grouped_mm

View File

@ -524,6 +524,7 @@ aten::_scaled_dot_product_flash_attention_for_cpu_backward
aten::_scaled_dot_product_fused_attention_overrideable aten::_scaled_dot_product_fused_attention_overrideable
aten::_scaled_dot_product_fused_attention_overrideable_backward aten::_scaled_dot_product_fused_attention_overrideable_backward
aten::_scaled_grouped_mm aten::_scaled_grouped_mm
aten::_scaled_grouped_mm_v2
aten::_scaled_mm aten::_scaled_mm
aten::_scaled_mm.out aten::_scaled_mm.out
aten::_scaled_mm_v2 aten::_scaled_mm_v2

View File

@ -11,7 +11,7 @@ from typing import Optional
import torch import torch
from torch.nn.functional import scaled_mm, ScalingType, SwizzleType from torch.nn.functional import scaled_mm, scaled_grouped_mm, ScalingType, SwizzleType
from torch.testing._internal.common_cuda import ( from torch.testing._internal.common_cuda import (
IS_SM90, IS_SM90,
_get_torch_cuda_version, _get_torch_cuda_version,
@ -215,6 +215,49 @@ def scaled_mm_wrap(
) )
return out return out
def scaled_grouped_mm_wrap(
a,
b,
scale_a,
scale_b,
scale_recipe_a,
scale_recipe_b,
swizzle_a=SwizzleType.NO_SWIZZLE,
swizzle_b=SwizzleType.NO_SWIZZLE,
scale_result=None,
out_dtype=torch.bfloat16,
use_fast_accum=False,
offs=None,
bias=None,
wrap_v2=True,
):
if not wrap_v2:
return torch._scaled_grouped_mm(
a,
b,
scale_a,
scale_b,
out_dtype=out_dtype,
bias=bias,
offs=offs,
use_fast_accum=use_fast_accum)
else:
return scaled_grouped_mm(
a,
b,
scale_a,
scale_recipe_a,
scale_b,
scale_recipe_b,
swizzle_a=swizzle_a,
swizzle_b=swizzle_b,
offs=offs,
bias=bias,
output_dtype=out_dtype,
use_fast_accum=use_fast_accum)
def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor: def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
# naive implementation: dq -> op -> q # naive implementation: dq -> op -> q
x_fp32 = x.to(torch.float) / x_scale x_fp32 = x.to(torch.float) / x_scale
@ -444,7 +487,8 @@ class TestFP8Matmul(TestCase):
@parametrize("M", [2048, 2049]) @parametrize("M", [2048, 2049])
@parametrize("N", [8192]) @parametrize("N", [8192])
@parametrize("K", [16640]) @parametrize("K", [16640])
def test_mxfp8_scaled_grouped_mm_2d_2d(self, G, M, N, K): @parametrize("wrap_v2", [True, False])
def test_mxfp8_scaled_grouped_mm_2d_2d(self, G, M, N, K, wrap_v2):
torch.manual_seed(42) torch.manual_seed(42)
total_K = K # Alias for clarity, communicating this consists of several groups along this dim total_K = K # Alias for clarity, communicating this consists of several groups along this dim
input_group_end_offsets = generate_jagged_offs( input_group_end_offsets = generate_jagged_offs(
@ -510,13 +554,18 @@ class TestFP8Matmul(TestCase):
w_blocked_scales = w_blocked_scales.reshape(N_rounded, -1) w_blocked_scales = w_blocked_scales.reshape(N_rounded, -1)
# Compute mxfp8 grouped mm output # Compute mxfp8 grouped mm output
y_mxfp8 = torch._scaled_grouped_mm( y_mxfp8 = scaled_grouped_mm_wrap(
xq, # (M, total_K) xq, # (M, total_K)
wq.transpose(-2, -1), # (total_K, N) wq.transpose(-2, -1), # (total_K, N)
x_blocked_scales, # to_blocked_per_group(M, total_K//32) x_blocked_scales, # to_blocked_per_group(M, total_K//32)
w_blocked_scales, # to_blocked_per_group(N, total_K//32) w_blocked_scales, # to_blocked_per_group(N, total_K//32)
scale_recipe_a=ScalingType.BlockWise1x32,
scale_recipe_b=ScalingType.BlockWise1x32,
swizzle_a=SwizzleType.SWIZZLE_32_4_4,
swizzle_b=SwizzleType.SWIZZLE_32_4_4,
offs=input_group_end_offsets, # (G,) offs=input_group_end_offsets, # (G,)
out_dtype=torch.bfloat16, out_dtype=torch.bfloat16,
wrap_v2=wrap_v2
) )
# bf16 reference output # bf16 reference output
@ -535,7 +584,8 @@ class TestFP8Matmul(TestCase):
@parametrize("M", [16640]) @parametrize("M", [16640])
@parametrize("N", [8192]) @parametrize("N", [8192])
@parametrize("K", [4096]) @parametrize("K", [4096])
def test_mxfp8_scaled_grouped_mm_2d_3d(self, G, M, N, K): @parametrize("wrap_v2", [True, False])
def test_mxfp8_scaled_grouped_mm_2d_3d(self, G, M, N, K, wrap_v2):
torch.manual_seed(42) torch.manual_seed(42)
# Simulate 2d-3d grouped gemm `out = input @ weight.t()` # Simulate 2d-3d grouped gemm `out = input @ weight.t()`
# 2D inputs with groups along M, 3D weights. # 2D inputs with groups along M, 3D weights.
@ -579,14 +629,19 @@ class TestFP8Matmul(TestCase):
xq = xq.view(-1, xq.shape[-1]) xq = xq.view(-1, xq.shape[-1])
# Compute mxfp8 grouped gemm. # Compute mxfp8 grouped gemm.
y_mxfp8 = torch._scaled_grouped_mm( y_mxfp8 = scaled_grouped_mm_wrap(
xq, xq,
wq.transpose(-2, -1), wq.transpose(-2, -1),
x_scale, x_scale,
w_scale, w_scale,
offs=input_group_end_offsets, offs=input_group_end_offsets,
out_dtype=torch.bfloat16, out_dtype=torch.bfloat16,
) scale_recipe_a=ScalingType.BlockWise1x32,
scale_recipe_b=ScalingType.BlockWise1x32,
swizzle_a=SwizzleType.SWIZZLE_32_4_4,
swizzle_b=SwizzleType.SWIZZLE_32_4_4,
wrap_v2=wrap_v2)
# Compute reference bf16 grouped gemm. # Compute reference bf16 grouped gemm.
y_bf16 = torch._grouped_mm( y_bf16 = torch._grouped_mm(
@ -1536,7 +1591,8 @@ class TestFP8Matmul(TestCase):
@parametrize("fast_accum", [False, True]) @parametrize("fast_accum", [False, True])
# AMD does not support non-contiguous inputs yet # AMD does not support non-contiguous inputs yet
@parametrize("strided", [False] + ([True] if torch.version.cuda else [])) @parametrize("strided", [False] + ([True] if torch.version.cuda else []))
def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided): @parametrize("wrap_v2", [True, False])
def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided, wrap_v2):
device = "cuda" device = "cuda"
fp8_dtype = e4m3_type fp8_dtype = e4m3_type
m, n, k, n_groups = 16, 32, 64, 4 m, n, k, n_groups = 16, 32, 64, 4
@ -1545,9 +1601,16 @@ class TestFP8Matmul(TestCase):
scale_a = torch.rand(m * n_groups, device=device, dtype=torch.float32) scale_a = torch.rand(m * n_groups, device=device, dtype=torch.float32)
scale_b = torch.rand(n * n_groups, device=device, dtype=torch.float32) scale_b = torch.rand(n * n_groups, device=device, dtype=torch.float32)
offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32) offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32)
f = torch._scaled_grouped_mm f = scaled_grouped_mm_wrap
out = f(a, b.t(), scale_a, scale_b, offs=offs, out = f(a, b.t(),
out_dtype=torch.bfloat16, use_fast_accum=fast_accum) scale_a,
scale_b,
scale_recipe_a=ScalingType.RowWise,
scale_recipe_b=ScalingType.RowWise,
offs=offs,
out_dtype=torch.bfloat16,
use_fast_accum=fast_accum,
wrap_v2=wrap_v2)
offs_cpu = offs.cpu() offs_cpu = offs.cpu()
alist, blist, ascalelist, bscalelist = [], [], [], [] alist, blist, ascalelist, bscalelist = [], [], [], []
start = 0 start = 0
@ -1564,7 +1627,8 @@ class TestFP8Matmul(TestCase):
@parametrize("fast_accum", [False, True]) @parametrize("fast_accum", [False, True])
# AMD does not support non-contiguous inputs yet # AMD does not support non-contiguous inputs yet
@parametrize("strided", [False] + ([True] if torch.version.cuda else [])) @parametrize("strided", [False] + ([True] if torch.version.cuda else []))
def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided): @parametrize("wrap_v2", [True, False])
def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, wrap_v2):
device = "cuda" device = "cuda"
fp8_dtype = e4m3_type fp8_dtype = e4m3_type
m, n, k, n_groups = 16, 32, 64, 4 m, n, k, n_groups = 16, 32, 64, 4
@ -1582,9 +1646,16 @@ class TestFP8Matmul(TestCase):
offs[0] = offs[1] offs[0] = offs[1]
scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32) scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32)
scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n) scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
f = torch._scaled_grouped_mm f = scaled_grouped_mm_wrap
out = f(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs, out = f(a, b.transpose(-2, -1),
out_dtype=torch.bfloat16, use_fast_accum=fast_accum) scale_a,
scale_b,
scale_recipe_a=ScalingType.RowWise,
scale_recipe_b=ScalingType.RowWise,
offs=offs,
out_dtype=torch.bfloat16,
use_fast_accum=fast_accum,
wrap_v2=wrap_v2)
offs_cpu = offs.cpu() offs_cpu = offs.cpu()
alist, ascalelist, outlist = [], [], [] alist, ascalelist, outlist = [], [], []

View File

@ -6711,3 +6711,90 @@ def scaled_mm(
) )
return out return out
def scaled_grouped_mm(
mat_a: Tensor,
mat_b: Tensor,
scale_a: Tensor | list[Tensor],
scale_recipe_a: ScalingType | list[ScalingType],
scale_b: Tensor | list[Tensor],
scale_recipe_b: ScalingType | list[ScalingType],
swizzle_a: SwizzleType | list[SwizzleType] | None = None,
swizzle_b: SwizzleType | list[SwizzleType] | None = None,
bias: Optional[Tensor] = None,
offs: Optional[Tensor] = None,
output_dtype: Optional[torch.dtype] = torch.bfloat16,
contraction_dim: list[int] | tuple[int] = (),
use_fast_accum: bool = False,
) -> Tensor:
r"""
scaled_grouped_mm(mat_a, mat_b, scale_a, scale_recipe_a, scale_b, scale_recipe_b, swizzle_a, swizzle_b, bias, offs,
output_dtype, use_fast_accum)
Applies a grouped scaled matrix-multiply, grouped_mm(mat_a, mat_b) where the scaling of mat_a and mat_b are described by
scale_recipe_a and scale_recipe_b respectively.
Args:
scale_a: Tensor containing decoding scaling factors for mat_a
scale_recipe_a: Enum describing how mat_a has been scaled
scale_b: Tensor containing decoding scaling factors for mat_b
scale_recipe_b: Enum describing how mat_b has been scaled
swizzle_a: Enum describing the swizzling pattern (if any) of scale_a
swizzle_b: Enum describing the swizzling pattern (if any) of scale_b
bias: optional bias term to be added to the output
offs: optional offsets into the source tensors denoting group start indices
output_dtype: dtype used for the output tensor
contraction_dim: describe which dimensions are :math:`K` in the matmul.
use_fast_accum: enable/disable tensor-core fast accumulation (Hopper-GPUs only)
"""
def expand_single_value(v: _Any | list[_Any] | None) -> list[_Any]:
if v is None:
return []
elif not isinstance(v, (list)):
return [
v,
]
else:
return v
scale_a = expand_single_value(scale_a)
scale_recipe_a = expand_single_value(scale_recipe_a)
scale_b = expand_single_value(scale_b)
scale_recipe_b = expand_single_value(scale_recipe_b)
swizzle_a = expand_single_value(swizzle_a)
swizzle_b = expand_single_value(swizzle_b)
# native_functions has restrictions on what can be defined
# & passed through - std::optional<ArrayRef<Tensor>> for instance
# *cannot* be passed, but an empty vector (list) can.
# So, we need to convert None arguments for lists in python
# explicitly into empty lists.
def list_or_empty(l: list[_Any] | None) -> list[_Any]:
return [] if not l else l
def enum_list_as_int_list(l: _Any | list[_Any]) -> list[_Any]:
if not isinstance(l, list):
l = [
l,
]
return [li.value for li in l]
out = torch._scaled_grouped_mm_v2(
mat_a,
mat_b,
scale_a,
enum_list_as_int_list(scale_recipe_a),
enum_list_as_int_list(list_or_empty(swizzle_a)),
scale_b,
enum_list_as_int_list(scale_recipe_b),
enum_list_as_int_list(list_or_empty(swizzle_b)),
offs,
bias,
output_dtype,
contraction_dim,
use_fast_accum,
)
return out

View File

@ -249,6 +249,7 @@ def get_ignored_functions() -> set[Callable]:
torch.nn.functional.has_torch_function_unary, torch.nn.functional.has_torch_function_unary,
torch.nn.functional.has_torch_function_variadic, torch.nn.functional.has_torch_function_variadic,
torch.nn.functional.handle_torch_function, torch.nn.functional.handle_torch_function,
torch.nn.functional.scaled_grouped_mm,
torch.nn.functional.scaled_mm, torch.nn.functional.scaled_mm,
torch.nn.functional.sigmoid, torch.nn.functional.sigmoid,
torch.nn.functional.hardsigmoid, torch.nn.functional.hardsigmoid,