mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add scaled_grouped_mm_v2 and python API (#165154)
Summary: * Add `torch._scaled_grouped_mm_v2` with more functionality and extensibility for future formats * Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint * Test both original and v2 functionality Test Plan: ``` pytest -svv -k grouped test/test_scaled_matmul_cuda.py ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlayton@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/165154 Approved by: https://github.com/drisspg, https://github.com/danielvegamyhre
This commit is contained in:
committed by
PyTorch MergeBot
parent
b509fb9b5d
commit
7c6c5d04fe
@ -2578,7 +2578,9 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const SwizzleType& swizzle_a,
|
||||
const Tensor& scale_b,
|
||||
const SwizzleType& swizzle_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
Tensor& out) {
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
@ -2589,6 +2591,16 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
TORCH_CHECK_VALUE(is_2d_2d || is_2d_3d, "MXFP8 grouped GEMM currently only supports 2d-2d and 2d-3d cases");
|
||||
TORCH_CHECK_VALUE(offs.has_value(), "MXFP8 2d-2d and 2d-3d grouped GEMMs requires offsets");
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == at::kBFloat16, "Only bf16 out_dtype is supported for MXFP8 grouped gemm");
|
||||
// MXFP8 expects float8_e8m0fnu scales.
|
||||
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu,
|
||||
"For MXFP8 grouped gemm, both scales must be float8_e8m0fnu tensors.");
|
||||
#ifdef USE_ROCM
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE && swizzle_b == SwizzleType::NO_SWIZZLE,
|
||||
"For ROCM MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_NONE");
|
||||
#else
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4 && swizzle_b == SwizzleType::SWIZZLE_32_4_4,
|
||||
"For CUDA MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_32_4_4");
|
||||
#endif
|
||||
|
||||
#if defined(USE_FBGEMM_GENAI) and !defined(USE_ROCM)
|
||||
fbgemm_gpu::mx8mx8bf16_grouped_mm(
|
||||
@ -2673,6 +2685,9 @@ _f8_f8_bf16_rowwise_grouped_mm(
|
||||
const std::optional<Tensor>& bias,
|
||||
bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
// FP8 per-tensor and per-row scaling expect fp32 scales.
|
||||
TORCH_CHECK_VALUE(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
|
||||
"For grouped FP8 rowwise, both scales must be float32 tensors");
|
||||
#ifndef USE_ROCM
|
||||
return _f8_f8_bf16_rowwise_grouped_mm_cuda(
|
||||
mat_a,
|
||||
@ -2772,11 +2787,15 @@ _scaled_grouped_mm_cuda(
|
||||
#endif
|
||||
|
||||
if (is_mx8mx8bf16) {
|
||||
// Note: Passing implied SwizzleType here, correctness of scale previously checked
|
||||
// in `check_scale` call
|
||||
return _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
SwizzleType::SWIZZLE_32_4_4,
|
||||
scale_b,
|
||||
SwizzleType::SWIZZLE_32_4_4,
|
||||
offs.value(),
|
||||
out);
|
||||
}
|
||||
@ -2793,6 +2812,140 @@ _scaled_grouped_mm_cuda(
|
||||
out);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2> scale_grouped_kernel_dispatch = {{
|
||||
{ "rowwise_rowwise", check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
|
||||
{ "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
Tensor
|
||||
_scaled_grouped_mm_cuda_v2(
|
||||
const Tensor& mat_a, const Tensor& mat_b,
|
||||
ArrayRef<Tensor> scale_a,
|
||||
IntArrayRef scale_recipe_a,
|
||||
IntArrayRef swizzle_a,
|
||||
ArrayRef<Tensor> scale_b,
|
||||
IntArrayRef scale_recipe_b,
|
||||
IntArrayRef swizzle_b,
|
||||
const std::optional<Tensor>& offs,
|
||||
const std::optional<Tensor>& bias,
|
||||
const std::optional<c10::ScalarType> out_dtype,
|
||||
IntArrayRef contraction_dim,
|
||||
bool use_fast_accum) {
|
||||
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
|
||||
TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
|
||||
|
||||
TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
|
||||
TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
|
||||
TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
|
||||
TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
const bool b_is_2d = mat_b.dim() == 2;
|
||||
|
||||
// NOTE(slayton): For sub-1B formats want contraction_dim argument?
|
||||
if (!a_is_2d || !b_is_2d) {
|
||||
if (contraction_dim.size() > 0) {
|
||||
const int dim_a = contraction_dim[0], dim_b = mat_b.size(contraction_dim[1]);
|
||||
TORCH_CHECK_VALUE(mat_a.size(dim_a) == mat_b.size(dim_b),
|
||||
"Contraction dimensions (", dim_a, ",", dim_b, ") of mat_a and mat_b must match, got: ", mat_a.size(dim_a), " and ",
|
||||
mat_b.size(dim_b));
|
||||
// Note: only (-1, -2) is currently supported
|
||||
TORCH_CHECK_VALUE(dim_a == -1 && dim_b == -2, "Curently contraction dims must be (-1, -2) only");
|
||||
} else {
|
||||
TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
|
||||
}
|
||||
}
|
||||
TORCH_CHECK_VALUE(
|
||||
mat_a.size(-1) % 16 == 0,
|
||||
"Expected trailing dimension of mat_a to be divisible by 16 ",
|
||||
"but got mat1 shape: (",
|
||||
mat_a.sizes(),
|
||||
").");
|
||||
TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0,
|
||||
"Expected mat_b shape to be divisible by 16 ",
|
||||
"but got mat_b shape: (",
|
||||
mat_b.sizes(),
|
||||
").");
|
||||
|
||||
TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet");
|
||||
TORCH_CHECK_VALUE(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
|
||||
|
||||
// NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present.
|
||||
// for rowwise, no offsets implies 3d-3d and is handled by lower-level
|
||||
// routines
|
||||
if (offs.has_value()) {
|
||||
TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D");
|
||||
TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32");
|
||||
}
|
||||
|
||||
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
|
||||
TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
|
||||
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
|
||||
// Conversion of implicitly-defined enums to explicit
|
||||
auto scale_recipe_a_enum = convert_int_to_enum<ScalingType>(scale_recipe_a);
|
||||
auto swizzle_a_enum = convert_int_to_enum<SwizzleType>(swizzle_a);
|
||||
auto scale_recipe_b_enum = convert_int_to_enum<ScalingType>(scale_recipe_b);
|
||||
auto swizzle_b_enum = convert_int_to_enum<SwizzleType>(swizzle_b);
|
||||
|
||||
// at this point we can start working out what we want to be doing
|
||||
// Try to do as few steps as possible.
|
||||
// NOTE: support is deliberately sparse, can explicitly enumerate all combinations allowed.
|
||||
// Do this via a list of defined (name, acceptance, concrete_impl) tuples.
|
||||
ScaledGemmImplementation gemm_impl = ScaledGemmImplementation::NONE;
|
||||
for (const auto& fn_entry : scale_grouped_kernel_dispatch) {
|
||||
const auto [name, accept_fn, scaled_gemm_impl] = fn_entry;
|
||||
bool ok = accept_fn(mat_a.scalar_type(),
|
||||
scale_recipe_a_enum,
|
||||
scale_a,
|
||||
mat_b.scalar_type(),
|
||||
scale_recipe_b_enum,
|
||||
scale_b);
|
||||
if (ok) {
|
||||
gemm_impl = scaled_gemm_impl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
TORCH_CHECK_VALUE(gemm_impl != ScaledGemmImplementation::NONE,
|
||||
"No gemm implementation was found");
|
||||
|
||||
switch (gemm_impl) {
|
||||
case ScaledGemmImplementation::ROWWISE_ROWWISE: {
|
||||
const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
|
||||
_check_scales_fp8_rowwise(mat_a, scale_a[0], 0 /* dim */ , 0 /* arg_idx */, scale_multiplier);
|
||||
_check_scales_fp8_rowwise(mat_b, scale_b[0], 1 /* dim */ , 1 /* arg_idx */, scale_multiplier);
|
||||
return _f8_f8_bf16_rowwise_grouped_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a[0],
|
||||
scale_b[0],
|
||||
offs,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
}
|
||||
case ScaledGemmImplementation::MXFP8_MXFP8: {
|
||||
_check_scales_mxfp8(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
|
||||
_check_scales_mxfp8(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
|
||||
return _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a[0],
|
||||
swizzle_a_enum[0],
|
||||
scale_b[0],
|
||||
swizzle_b_enum[0],
|
||||
offs.value(),
|
||||
out);
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"_scaled_grouped_mm_cuda_v2 is in an inconsistent state - should never reach here");
|
||||
}
|
||||
}
|
||||
|
||||
Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
|
@ -7183,6 +7183,12 @@
|
||||
CUDA: _scaled_grouped_mm_cuda
|
||||
tags: needs_exact_strides
|
||||
|
||||
- func: _scaled_grouped_mm_v2(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None, int[] contraction_dim=[], bool use_fast_accum=False) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CUDA: _scaled_grouped_mm_cuda_v2
|
||||
tags: needs_exact_strides
|
||||
|
||||
- func: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
|
@ -228,3 +228,4 @@ Low-Precision functions
|
||||
ScalingType
|
||||
SwizzleType
|
||||
scaled_mm
|
||||
scaled_grouped_mm
|
||||
|
@ -524,6 +524,7 @@ aten::_scaled_dot_product_flash_attention_for_cpu_backward
|
||||
aten::_scaled_dot_product_fused_attention_overrideable
|
||||
aten::_scaled_dot_product_fused_attention_overrideable_backward
|
||||
aten::_scaled_grouped_mm
|
||||
aten::_scaled_grouped_mm_v2
|
||||
aten::_scaled_mm
|
||||
aten::_scaled_mm.out
|
||||
aten::_scaled_mm_v2
|
||||
|
@ -11,7 +11,7 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
|
||||
from torch.nn.functional import scaled_mm, ScalingType, SwizzleType
|
||||
from torch.nn.functional import scaled_mm, scaled_grouped_mm, ScalingType, SwizzleType
|
||||
from torch.testing._internal.common_cuda import (
|
||||
IS_SM90,
|
||||
_get_torch_cuda_version,
|
||||
@ -215,6 +215,49 @@ def scaled_mm_wrap(
|
||||
)
|
||||
return out
|
||||
|
||||
def scaled_grouped_mm_wrap(
|
||||
a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
scale_recipe_a,
|
||||
scale_recipe_b,
|
||||
swizzle_a=SwizzleType.NO_SWIZZLE,
|
||||
swizzle_b=SwizzleType.NO_SWIZZLE,
|
||||
scale_result=None,
|
||||
out_dtype=torch.bfloat16,
|
||||
use_fast_accum=False,
|
||||
offs=None,
|
||||
bias=None,
|
||||
wrap_v2=True,
|
||||
):
|
||||
if not wrap_v2:
|
||||
return torch._scaled_grouped_mm(
|
||||
a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_b,
|
||||
out_dtype=out_dtype,
|
||||
bias=bias,
|
||||
offs=offs,
|
||||
use_fast_accum=use_fast_accum)
|
||||
else:
|
||||
return scaled_grouped_mm(
|
||||
a,
|
||||
b,
|
||||
scale_a,
|
||||
scale_recipe_a,
|
||||
scale_b,
|
||||
scale_recipe_b,
|
||||
swizzle_a=swizzle_a,
|
||||
swizzle_b=swizzle_b,
|
||||
offs=offs,
|
||||
bias=bias,
|
||||
output_dtype=out_dtype,
|
||||
use_fast_accum=use_fast_accum)
|
||||
|
||||
|
||||
|
||||
def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
|
||||
# naive implementation: dq -> op -> q
|
||||
x_fp32 = x.to(torch.float) / x_scale
|
||||
@ -444,7 +487,8 @@ class TestFP8Matmul(TestCase):
|
||||
@parametrize("M", [2048, 2049])
|
||||
@parametrize("N", [8192])
|
||||
@parametrize("K", [16640])
|
||||
def test_mxfp8_scaled_grouped_mm_2d_2d(self, G, M, N, K):
|
||||
@parametrize("wrap_v2", [True, False])
|
||||
def test_mxfp8_scaled_grouped_mm_2d_2d(self, G, M, N, K, wrap_v2):
|
||||
torch.manual_seed(42)
|
||||
total_K = K # Alias for clarity, communicating this consists of several groups along this dim
|
||||
input_group_end_offsets = generate_jagged_offs(
|
||||
@ -510,13 +554,18 @@ class TestFP8Matmul(TestCase):
|
||||
w_blocked_scales = w_blocked_scales.reshape(N_rounded, -1)
|
||||
|
||||
# Compute mxfp8 grouped mm output
|
||||
y_mxfp8 = torch._scaled_grouped_mm(
|
||||
y_mxfp8 = scaled_grouped_mm_wrap(
|
||||
xq, # (M, total_K)
|
||||
wq.transpose(-2, -1), # (total_K, N)
|
||||
x_blocked_scales, # to_blocked_per_group(M, total_K//32)
|
||||
w_blocked_scales, # to_blocked_per_group(N, total_K//32)
|
||||
scale_recipe_a=ScalingType.BlockWise1x32,
|
||||
scale_recipe_b=ScalingType.BlockWise1x32,
|
||||
swizzle_a=SwizzleType.SWIZZLE_32_4_4,
|
||||
swizzle_b=SwizzleType.SWIZZLE_32_4_4,
|
||||
offs=input_group_end_offsets, # (G,)
|
||||
out_dtype=torch.bfloat16,
|
||||
wrap_v2=wrap_v2
|
||||
)
|
||||
|
||||
# bf16 reference output
|
||||
@ -535,7 +584,8 @@ class TestFP8Matmul(TestCase):
|
||||
@parametrize("M", [16640])
|
||||
@parametrize("N", [8192])
|
||||
@parametrize("K", [4096])
|
||||
def test_mxfp8_scaled_grouped_mm_2d_3d(self, G, M, N, K):
|
||||
@parametrize("wrap_v2", [True, False])
|
||||
def test_mxfp8_scaled_grouped_mm_2d_3d(self, G, M, N, K, wrap_v2):
|
||||
torch.manual_seed(42)
|
||||
# Simulate 2d-3d grouped gemm `out = input @ weight.t()`
|
||||
# 2D inputs with groups along M, 3D weights.
|
||||
@ -579,14 +629,19 @@ class TestFP8Matmul(TestCase):
|
||||
xq = xq.view(-1, xq.shape[-1])
|
||||
|
||||
# Compute mxfp8 grouped gemm.
|
||||
y_mxfp8 = torch._scaled_grouped_mm(
|
||||
y_mxfp8 = scaled_grouped_mm_wrap(
|
||||
xq,
|
||||
wq.transpose(-2, -1),
|
||||
x_scale,
|
||||
w_scale,
|
||||
offs=input_group_end_offsets,
|
||||
out_dtype=torch.bfloat16,
|
||||
)
|
||||
scale_recipe_a=ScalingType.BlockWise1x32,
|
||||
scale_recipe_b=ScalingType.BlockWise1x32,
|
||||
swizzle_a=SwizzleType.SWIZZLE_32_4_4,
|
||||
swizzle_b=SwizzleType.SWIZZLE_32_4_4,
|
||||
wrap_v2=wrap_v2)
|
||||
|
||||
|
||||
# Compute reference bf16 grouped gemm.
|
||||
y_bf16 = torch._grouped_mm(
|
||||
@ -1536,7 +1591,8 @@ class TestFP8Matmul(TestCase):
|
||||
@parametrize("fast_accum", [False, True])
|
||||
# AMD does not support non-contiguous inputs yet
|
||||
@parametrize("strided", [False] + ([True] if torch.version.cuda else []))
|
||||
def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided):
|
||||
@parametrize("wrap_v2", [True, False])
|
||||
def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided, wrap_v2):
|
||||
device = "cuda"
|
||||
fp8_dtype = e4m3_type
|
||||
m, n, k, n_groups = 16, 32, 64, 4
|
||||
@ -1545,9 +1601,16 @@ class TestFP8Matmul(TestCase):
|
||||
scale_a = torch.rand(m * n_groups, device=device, dtype=torch.float32)
|
||||
scale_b = torch.rand(n * n_groups, device=device, dtype=torch.float32)
|
||||
offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32)
|
||||
f = torch._scaled_grouped_mm
|
||||
out = f(a, b.t(), scale_a, scale_b, offs=offs,
|
||||
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
|
||||
f = scaled_grouped_mm_wrap
|
||||
out = f(a, b.t(),
|
||||
scale_a,
|
||||
scale_b,
|
||||
scale_recipe_a=ScalingType.RowWise,
|
||||
scale_recipe_b=ScalingType.RowWise,
|
||||
offs=offs,
|
||||
out_dtype=torch.bfloat16,
|
||||
use_fast_accum=fast_accum,
|
||||
wrap_v2=wrap_v2)
|
||||
offs_cpu = offs.cpu()
|
||||
alist, blist, ascalelist, bscalelist = [], [], [], []
|
||||
start = 0
|
||||
@ -1564,7 +1627,8 @@ class TestFP8Matmul(TestCase):
|
||||
@parametrize("fast_accum", [False, True])
|
||||
# AMD does not support non-contiguous inputs yet
|
||||
@parametrize("strided", [False] + ([True] if torch.version.cuda else []))
|
||||
def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided):
|
||||
@parametrize("wrap_v2", [True, False])
|
||||
def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, wrap_v2):
|
||||
device = "cuda"
|
||||
fp8_dtype = e4m3_type
|
||||
m, n, k, n_groups = 16, 32, 64, 4
|
||||
@ -1582,9 +1646,16 @@ class TestFP8Matmul(TestCase):
|
||||
offs[0] = offs[1]
|
||||
scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
|
||||
f = torch._scaled_grouped_mm
|
||||
out = f(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs,
|
||||
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
|
||||
f = scaled_grouped_mm_wrap
|
||||
out = f(a, b.transpose(-2, -1),
|
||||
scale_a,
|
||||
scale_b,
|
||||
scale_recipe_a=ScalingType.RowWise,
|
||||
scale_recipe_b=ScalingType.RowWise,
|
||||
offs=offs,
|
||||
out_dtype=torch.bfloat16,
|
||||
use_fast_accum=fast_accum,
|
||||
wrap_v2=wrap_v2)
|
||||
|
||||
offs_cpu = offs.cpu()
|
||||
alist, ascalelist, outlist = [], [], []
|
||||
|
@ -6711,3 +6711,90 @@ def scaled_mm(
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def scaled_grouped_mm(
|
||||
mat_a: Tensor,
|
||||
mat_b: Tensor,
|
||||
scale_a: Tensor | list[Tensor],
|
||||
scale_recipe_a: ScalingType | list[ScalingType],
|
||||
scale_b: Tensor | list[Tensor],
|
||||
scale_recipe_b: ScalingType | list[ScalingType],
|
||||
swizzle_a: SwizzleType | list[SwizzleType] | None = None,
|
||||
swizzle_b: SwizzleType | list[SwizzleType] | None = None,
|
||||
bias: Optional[Tensor] = None,
|
||||
offs: Optional[Tensor] = None,
|
||||
output_dtype: Optional[torch.dtype] = torch.bfloat16,
|
||||
contraction_dim: list[int] | tuple[int] = (),
|
||||
use_fast_accum: bool = False,
|
||||
) -> Tensor:
|
||||
r"""
|
||||
scaled_grouped_mm(mat_a, mat_b, scale_a, scale_recipe_a, scale_b, scale_recipe_b, swizzle_a, swizzle_b, bias, offs,
|
||||
output_dtype, use_fast_accum)
|
||||
|
||||
Applies a grouped scaled matrix-multiply, grouped_mm(mat_a, mat_b) where the scaling of mat_a and mat_b are described by
|
||||
scale_recipe_a and scale_recipe_b respectively.
|
||||
|
||||
Args:
|
||||
scale_a: Tensor containing decoding scaling factors for mat_a
|
||||
scale_recipe_a: Enum describing how mat_a has been scaled
|
||||
scale_b: Tensor containing decoding scaling factors for mat_b
|
||||
scale_recipe_b: Enum describing how mat_b has been scaled
|
||||
swizzle_a: Enum describing the swizzling pattern (if any) of scale_a
|
||||
swizzle_b: Enum describing the swizzling pattern (if any) of scale_b
|
||||
bias: optional bias term to be added to the output
|
||||
offs: optional offsets into the source tensors denoting group start indices
|
||||
output_dtype: dtype used for the output tensor
|
||||
contraction_dim: describe which dimensions are :math:`K` in the matmul.
|
||||
use_fast_accum: enable/disable tensor-core fast accumulation (Hopper-GPUs only)
|
||||
"""
|
||||
|
||||
def expand_single_value(v: _Any | list[_Any] | None) -> list[_Any]:
|
||||
if v is None:
|
||||
return []
|
||||
elif not isinstance(v, (list)):
|
||||
return [
|
||||
v,
|
||||
]
|
||||
else:
|
||||
return v
|
||||
|
||||
scale_a = expand_single_value(scale_a)
|
||||
scale_recipe_a = expand_single_value(scale_recipe_a)
|
||||
scale_b = expand_single_value(scale_b)
|
||||
scale_recipe_b = expand_single_value(scale_recipe_b)
|
||||
swizzle_a = expand_single_value(swizzle_a)
|
||||
swizzle_b = expand_single_value(swizzle_b)
|
||||
|
||||
# native_functions has restrictions on what can be defined
|
||||
# & passed through - std::optional<ArrayRef<Tensor>> for instance
|
||||
# *cannot* be passed, but an empty vector (list) can.
|
||||
# So, we need to convert None arguments for lists in python
|
||||
# explicitly into empty lists.
|
||||
def list_or_empty(l: list[_Any] | None) -> list[_Any]:
|
||||
return [] if not l else l
|
||||
|
||||
def enum_list_as_int_list(l: _Any | list[_Any]) -> list[_Any]:
|
||||
if not isinstance(l, list):
|
||||
l = [
|
||||
l,
|
||||
]
|
||||
return [li.value for li in l]
|
||||
|
||||
out = torch._scaled_grouped_mm_v2(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
enum_list_as_int_list(scale_recipe_a),
|
||||
enum_list_as_int_list(list_or_empty(swizzle_a)),
|
||||
scale_b,
|
||||
enum_list_as_int_list(scale_recipe_b),
|
||||
enum_list_as_int_list(list_or_empty(swizzle_b)),
|
||||
offs,
|
||||
bias,
|
||||
output_dtype,
|
||||
contraction_dim,
|
||||
use_fast_accum,
|
||||
)
|
||||
|
||||
return out
|
||||
|
@ -249,6 +249,7 @@ def get_ignored_functions() -> set[Callable]:
|
||||
torch.nn.functional.has_torch_function_unary,
|
||||
torch.nn.functional.has_torch_function_variadic,
|
||||
torch.nn.functional.handle_torch_function,
|
||||
torch.nn.functional.scaled_grouped_mm,
|
||||
torch.nn.functional.scaled_mm,
|
||||
torch.nn.functional.sigmoid,
|
||||
torch.nn.functional.hardsigmoid,
|
||||
|
Reference in New Issue
Block a user