mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add scaled_grouped_mm_v2 and python API (#165154)
Summary: * Add `torch._scaled_grouped_mm_v2` with more functionality and extensibility for future formats * Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint * Test both original and v2 functionality Test Plan: ``` pytest -svv -k grouped test/test_scaled_matmul_cuda.py ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlayton@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/165154 Approved by: https://github.com/drisspg, https://github.com/danielvegamyhre
This commit is contained in:
committed by
PyTorch MergeBot
parent
b509fb9b5d
commit
7c6c5d04fe
@ -2578,7 +2578,9 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
|
|||||||
const Tensor& mat_a,
|
const Tensor& mat_a,
|
||||||
const Tensor& mat_b,
|
const Tensor& mat_b,
|
||||||
const Tensor& scale_a,
|
const Tensor& scale_a,
|
||||||
|
const SwizzleType& swizzle_a,
|
||||||
const Tensor& scale_b,
|
const Tensor& scale_b,
|
||||||
|
const SwizzleType& swizzle_b,
|
||||||
const std::optional<at::Tensor>& offs,
|
const std::optional<at::Tensor>& offs,
|
||||||
Tensor& out) {
|
Tensor& out) {
|
||||||
const bool a_is_2d = mat_a.dim() == 2;
|
const bool a_is_2d = mat_a.dim() == 2;
|
||||||
@ -2589,6 +2591,16 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
|
|||||||
TORCH_CHECK_VALUE(is_2d_2d || is_2d_3d, "MXFP8 grouped GEMM currently only supports 2d-2d and 2d-3d cases");
|
TORCH_CHECK_VALUE(is_2d_2d || is_2d_3d, "MXFP8 grouped GEMM currently only supports 2d-2d and 2d-3d cases");
|
||||||
TORCH_CHECK_VALUE(offs.has_value(), "MXFP8 2d-2d and 2d-3d grouped GEMMs requires offsets");
|
TORCH_CHECK_VALUE(offs.has_value(), "MXFP8 2d-2d and 2d-3d grouped GEMMs requires offsets");
|
||||||
TORCH_CHECK_VALUE(out.scalar_type() == at::kBFloat16, "Only bf16 out_dtype is supported for MXFP8 grouped gemm");
|
TORCH_CHECK_VALUE(out.scalar_type() == at::kBFloat16, "Only bf16 out_dtype is supported for MXFP8 grouped gemm");
|
||||||
|
// MXFP8 expects float8_e8m0fnu scales.
|
||||||
|
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu,
|
||||||
|
"For MXFP8 grouped gemm, both scales must be float8_e8m0fnu tensors.");
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE && swizzle_b == SwizzleType::NO_SWIZZLE,
|
||||||
|
"For ROCM MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_NONE");
|
||||||
|
#else
|
||||||
|
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4 && swizzle_b == SwizzleType::SWIZZLE_32_4_4,
|
||||||
|
"For CUDA MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_32_4_4");
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(USE_FBGEMM_GENAI) and !defined(USE_ROCM)
|
#if defined(USE_FBGEMM_GENAI) and !defined(USE_ROCM)
|
||||||
fbgemm_gpu::mx8mx8bf16_grouped_mm(
|
fbgemm_gpu::mx8mx8bf16_grouped_mm(
|
||||||
@ -2673,6 +2685,9 @@ _f8_f8_bf16_rowwise_grouped_mm(
|
|||||||
const std::optional<Tensor>& bias,
|
const std::optional<Tensor>& bias,
|
||||||
bool use_fast_accum,
|
bool use_fast_accum,
|
||||||
Tensor& out) {
|
Tensor& out) {
|
||||||
|
// FP8 per-tensor and per-row scaling expect fp32 scales.
|
||||||
|
TORCH_CHECK_VALUE(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
|
||||||
|
"For grouped FP8 rowwise, both scales must be float32 tensors");
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
return _f8_f8_bf16_rowwise_grouped_mm_cuda(
|
return _f8_f8_bf16_rowwise_grouped_mm_cuda(
|
||||||
mat_a,
|
mat_a,
|
||||||
@ -2772,11 +2787,15 @@ _scaled_grouped_mm_cuda(
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
if (is_mx8mx8bf16) {
|
if (is_mx8mx8bf16) {
|
||||||
|
// Note: Passing implied SwizzleType here, correctness of scale previously checked
|
||||||
|
// in `check_scale` call
|
||||||
return _mx8_mx8_bf16_grouped_mm_fbgemm(
|
return _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||||
mat_a,
|
mat_a,
|
||||||
mat_b,
|
mat_b,
|
||||||
scale_a,
|
scale_a,
|
||||||
|
SwizzleType::SWIZZLE_32_4_4,
|
||||||
scale_b,
|
scale_b,
|
||||||
|
SwizzleType::SWIZZLE_32_4_4,
|
||||||
offs.value(),
|
offs.value(),
|
||||||
out);
|
out);
|
||||||
}
|
}
|
||||||
@ -2793,6 +2812,140 @@ _scaled_grouped_mm_cuda(
|
|||||||
out);
|
out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2> scale_grouped_kernel_dispatch = {{
|
||||||
|
{ "rowwise_rowwise", check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
|
||||||
|
{ "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
Tensor
|
||||||
|
_scaled_grouped_mm_cuda_v2(
|
||||||
|
const Tensor& mat_a, const Tensor& mat_b,
|
||||||
|
ArrayRef<Tensor> scale_a,
|
||||||
|
IntArrayRef scale_recipe_a,
|
||||||
|
IntArrayRef swizzle_a,
|
||||||
|
ArrayRef<Tensor> scale_b,
|
||||||
|
IntArrayRef scale_recipe_b,
|
||||||
|
IntArrayRef swizzle_b,
|
||||||
|
const std::optional<Tensor>& offs,
|
||||||
|
const std::optional<Tensor>& bias,
|
||||||
|
const std::optional<c10::ScalarType> out_dtype,
|
||||||
|
IntArrayRef contraction_dim,
|
||||||
|
bool use_fast_accum) {
|
||||||
|
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
|
||||||
|
TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
|
||||||
|
|
||||||
|
TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
|
||||||
|
TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
|
||||||
|
TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
|
||||||
|
TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
|
||||||
|
const bool a_is_2d = mat_a.dim() == 2;
|
||||||
|
const bool b_is_2d = mat_b.dim() == 2;
|
||||||
|
|
||||||
|
// NOTE(slayton): For sub-1B formats want contraction_dim argument?
|
||||||
|
if (!a_is_2d || !b_is_2d) {
|
||||||
|
if (contraction_dim.size() > 0) {
|
||||||
|
const int dim_a = contraction_dim[0], dim_b = mat_b.size(contraction_dim[1]);
|
||||||
|
TORCH_CHECK_VALUE(mat_a.size(dim_a) == mat_b.size(dim_b),
|
||||||
|
"Contraction dimensions (", dim_a, ",", dim_b, ") of mat_a and mat_b must match, got: ", mat_a.size(dim_a), " and ",
|
||||||
|
mat_b.size(dim_b));
|
||||||
|
// Note: only (-1, -2) is currently supported
|
||||||
|
TORCH_CHECK_VALUE(dim_a == -1 && dim_b == -2, "Curently contraction dims must be (-1, -2) only");
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TORCH_CHECK_VALUE(
|
||||||
|
mat_a.size(-1) % 16 == 0,
|
||||||
|
"Expected trailing dimension of mat_a to be divisible by 16 ",
|
||||||
|
"but got mat1 shape: (",
|
||||||
|
mat_a.sizes(),
|
||||||
|
").");
|
||||||
|
TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0,
|
||||||
|
"Expected mat_b shape to be divisible by 16 ",
|
||||||
|
"but got mat_b shape: (",
|
||||||
|
mat_b.sizes(),
|
||||||
|
").");
|
||||||
|
|
||||||
|
TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet");
|
||||||
|
TORCH_CHECK_VALUE(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
|
||||||
|
|
||||||
|
// NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present.
|
||||||
|
// for rowwise, no offsets implies 3d-3d and is handled by lower-level
|
||||||
|
// routines
|
||||||
|
if (offs.has_value()) {
|
||||||
|
TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D");
|
||||||
|
TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32");
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
|
||||||
|
TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
|
||||||
|
|
||||||
|
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||||
|
|
||||||
|
// Conversion of implicitly-defined enums to explicit
|
||||||
|
auto scale_recipe_a_enum = convert_int_to_enum<ScalingType>(scale_recipe_a);
|
||||||
|
auto swizzle_a_enum = convert_int_to_enum<SwizzleType>(swizzle_a);
|
||||||
|
auto scale_recipe_b_enum = convert_int_to_enum<ScalingType>(scale_recipe_b);
|
||||||
|
auto swizzle_b_enum = convert_int_to_enum<SwizzleType>(swizzle_b);
|
||||||
|
|
||||||
|
// at this point we can start working out what we want to be doing
|
||||||
|
// Try to do as few steps as possible.
|
||||||
|
// NOTE: support is deliberately sparse, can explicitly enumerate all combinations allowed.
|
||||||
|
// Do this via a list of defined (name, acceptance, concrete_impl) tuples.
|
||||||
|
ScaledGemmImplementation gemm_impl = ScaledGemmImplementation::NONE;
|
||||||
|
for (const auto& fn_entry : scale_grouped_kernel_dispatch) {
|
||||||
|
const auto [name, accept_fn, scaled_gemm_impl] = fn_entry;
|
||||||
|
bool ok = accept_fn(mat_a.scalar_type(),
|
||||||
|
scale_recipe_a_enum,
|
||||||
|
scale_a,
|
||||||
|
mat_b.scalar_type(),
|
||||||
|
scale_recipe_b_enum,
|
||||||
|
scale_b);
|
||||||
|
if (ok) {
|
||||||
|
gemm_impl = scaled_gemm_impl;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TORCH_CHECK_VALUE(gemm_impl != ScaledGemmImplementation::NONE,
|
||||||
|
"No gemm implementation was found");
|
||||||
|
|
||||||
|
switch (gemm_impl) {
|
||||||
|
case ScaledGemmImplementation::ROWWISE_ROWWISE: {
|
||||||
|
const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
|
||||||
|
_check_scales_fp8_rowwise(mat_a, scale_a[0], 0 /* dim */ , 0 /* arg_idx */, scale_multiplier);
|
||||||
|
_check_scales_fp8_rowwise(mat_b, scale_b[0], 1 /* dim */ , 1 /* arg_idx */, scale_multiplier);
|
||||||
|
return _f8_f8_bf16_rowwise_grouped_mm(
|
||||||
|
mat_a,
|
||||||
|
mat_b,
|
||||||
|
scale_a[0],
|
||||||
|
scale_b[0],
|
||||||
|
offs,
|
||||||
|
bias,
|
||||||
|
use_fast_accum,
|
||||||
|
out);
|
||||||
|
}
|
||||||
|
case ScaledGemmImplementation::MXFP8_MXFP8: {
|
||||||
|
_check_scales_mxfp8(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
|
||||||
|
_check_scales_mxfp8(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
|
||||||
|
return _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||||
|
mat_a,
|
||||||
|
mat_b,
|
||||||
|
scale_a[0],
|
||||||
|
swizzle_a_enum[0],
|
||||||
|
scale_b[0],
|
||||||
|
swizzle_b_enum[0],
|
||||||
|
offs.value(),
|
||||||
|
out);
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||||
|
"_scaled_grouped_mm_cuda_v2 is in an inconsistent state - should never reach here");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
|
Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
|
||||||
const std::optional<at::Tensor>& offs,
|
const std::optional<at::Tensor>& offs,
|
||||||
const std::optional<at::Tensor>& bias,
|
const std::optional<at::Tensor>& bias,
|
||||||
|
@ -7183,6 +7183,12 @@
|
|||||||
CUDA: _scaled_grouped_mm_cuda
|
CUDA: _scaled_grouped_mm_cuda
|
||||||
tags: needs_exact_strides
|
tags: needs_exact_strides
|
||||||
|
|
||||||
|
- func: _scaled_grouped_mm_v2(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None, int[] contraction_dim=[], bool use_fast_accum=False) -> Tensor
|
||||||
|
variants: function
|
||||||
|
dispatch:
|
||||||
|
CUDA: _scaled_grouped_mm_cuda_v2
|
||||||
|
tags: needs_exact_strides
|
||||||
|
|
||||||
- func: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor
|
- func: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor
|
||||||
variants: function
|
variants: function
|
||||||
dispatch:
|
dispatch:
|
||||||
|
@ -228,3 +228,4 @@ Low-Precision functions
|
|||||||
ScalingType
|
ScalingType
|
||||||
SwizzleType
|
SwizzleType
|
||||||
scaled_mm
|
scaled_mm
|
||||||
|
scaled_grouped_mm
|
||||||
|
@ -524,6 +524,7 @@ aten::_scaled_dot_product_flash_attention_for_cpu_backward
|
|||||||
aten::_scaled_dot_product_fused_attention_overrideable
|
aten::_scaled_dot_product_fused_attention_overrideable
|
||||||
aten::_scaled_dot_product_fused_attention_overrideable_backward
|
aten::_scaled_dot_product_fused_attention_overrideable_backward
|
||||||
aten::_scaled_grouped_mm
|
aten::_scaled_grouped_mm
|
||||||
|
aten::_scaled_grouped_mm_v2
|
||||||
aten::_scaled_mm
|
aten::_scaled_mm
|
||||||
aten::_scaled_mm.out
|
aten::_scaled_mm.out
|
||||||
aten::_scaled_mm_v2
|
aten::_scaled_mm_v2
|
||||||
|
@ -11,7 +11,7 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
from torch.nn.functional import scaled_mm, ScalingType, SwizzleType
|
from torch.nn.functional import scaled_mm, scaled_grouped_mm, ScalingType, SwizzleType
|
||||||
from torch.testing._internal.common_cuda import (
|
from torch.testing._internal.common_cuda import (
|
||||||
IS_SM90,
|
IS_SM90,
|
||||||
_get_torch_cuda_version,
|
_get_torch_cuda_version,
|
||||||
@ -215,6 +215,49 @@ def scaled_mm_wrap(
|
|||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def scaled_grouped_mm_wrap(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
scale_recipe_a,
|
||||||
|
scale_recipe_b,
|
||||||
|
swizzle_a=SwizzleType.NO_SWIZZLE,
|
||||||
|
swizzle_b=SwizzleType.NO_SWIZZLE,
|
||||||
|
scale_result=None,
|
||||||
|
out_dtype=torch.bfloat16,
|
||||||
|
use_fast_accum=False,
|
||||||
|
offs=None,
|
||||||
|
bias=None,
|
||||||
|
wrap_v2=True,
|
||||||
|
):
|
||||||
|
if not wrap_v2:
|
||||||
|
return torch._scaled_grouped_mm(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
out_dtype=out_dtype,
|
||||||
|
bias=bias,
|
||||||
|
offs=offs,
|
||||||
|
use_fast_accum=use_fast_accum)
|
||||||
|
else:
|
||||||
|
return scaled_grouped_mm(
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
scale_a,
|
||||||
|
scale_recipe_a,
|
||||||
|
scale_b,
|
||||||
|
scale_recipe_b,
|
||||||
|
swizzle_a=swizzle_a,
|
||||||
|
swizzle_b=swizzle_b,
|
||||||
|
offs=offs,
|
||||||
|
bias=bias,
|
||||||
|
output_dtype=out_dtype,
|
||||||
|
use_fast_accum=use_fast_accum)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
|
def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
|
||||||
# naive implementation: dq -> op -> q
|
# naive implementation: dq -> op -> q
|
||||||
x_fp32 = x.to(torch.float) / x_scale
|
x_fp32 = x.to(torch.float) / x_scale
|
||||||
@ -444,7 +487,8 @@ class TestFP8Matmul(TestCase):
|
|||||||
@parametrize("M", [2048, 2049])
|
@parametrize("M", [2048, 2049])
|
||||||
@parametrize("N", [8192])
|
@parametrize("N", [8192])
|
||||||
@parametrize("K", [16640])
|
@parametrize("K", [16640])
|
||||||
def test_mxfp8_scaled_grouped_mm_2d_2d(self, G, M, N, K):
|
@parametrize("wrap_v2", [True, False])
|
||||||
|
def test_mxfp8_scaled_grouped_mm_2d_2d(self, G, M, N, K, wrap_v2):
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
total_K = K # Alias for clarity, communicating this consists of several groups along this dim
|
total_K = K # Alias for clarity, communicating this consists of several groups along this dim
|
||||||
input_group_end_offsets = generate_jagged_offs(
|
input_group_end_offsets = generate_jagged_offs(
|
||||||
@ -510,13 +554,18 @@ class TestFP8Matmul(TestCase):
|
|||||||
w_blocked_scales = w_blocked_scales.reshape(N_rounded, -1)
|
w_blocked_scales = w_blocked_scales.reshape(N_rounded, -1)
|
||||||
|
|
||||||
# Compute mxfp8 grouped mm output
|
# Compute mxfp8 grouped mm output
|
||||||
y_mxfp8 = torch._scaled_grouped_mm(
|
y_mxfp8 = scaled_grouped_mm_wrap(
|
||||||
xq, # (M, total_K)
|
xq, # (M, total_K)
|
||||||
wq.transpose(-2, -1), # (total_K, N)
|
wq.transpose(-2, -1), # (total_K, N)
|
||||||
x_blocked_scales, # to_blocked_per_group(M, total_K//32)
|
x_blocked_scales, # to_blocked_per_group(M, total_K//32)
|
||||||
w_blocked_scales, # to_blocked_per_group(N, total_K//32)
|
w_blocked_scales, # to_blocked_per_group(N, total_K//32)
|
||||||
|
scale_recipe_a=ScalingType.BlockWise1x32,
|
||||||
|
scale_recipe_b=ScalingType.BlockWise1x32,
|
||||||
|
swizzle_a=SwizzleType.SWIZZLE_32_4_4,
|
||||||
|
swizzle_b=SwizzleType.SWIZZLE_32_4_4,
|
||||||
offs=input_group_end_offsets, # (G,)
|
offs=input_group_end_offsets, # (G,)
|
||||||
out_dtype=torch.bfloat16,
|
out_dtype=torch.bfloat16,
|
||||||
|
wrap_v2=wrap_v2
|
||||||
)
|
)
|
||||||
|
|
||||||
# bf16 reference output
|
# bf16 reference output
|
||||||
@ -535,7 +584,8 @@ class TestFP8Matmul(TestCase):
|
|||||||
@parametrize("M", [16640])
|
@parametrize("M", [16640])
|
||||||
@parametrize("N", [8192])
|
@parametrize("N", [8192])
|
||||||
@parametrize("K", [4096])
|
@parametrize("K", [4096])
|
||||||
def test_mxfp8_scaled_grouped_mm_2d_3d(self, G, M, N, K):
|
@parametrize("wrap_v2", [True, False])
|
||||||
|
def test_mxfp8_scaled_grouped_mm_2d_3d(self, G, M, N, K, wrap_v2):
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
# Simulate 2d-3d grouped gemm `out = input @ weight.t()`
|
# Simulate 2d-3d grouped gemm `out = input @ weight.t()`
|
||||||
# 2D inputs with groups along M, 3D weights.
|
# 2D inputs with groups along M, 3D weights.
|
||||||
@ -579,14 +629,19 @@ class TestFP8Matmul(TestCase):
|
|||||||
xq = xq.view(-1, xq.shape[-1])
|
xq = xq.view(-1, xq.shape[-1])
|
||||||
|
|
||||||
# Compute mxfp8 grouped gemm.
|
# Compute mxfp8 grouped gemm.
|
||||||
y_mxfp8 = torch._scaled_grouped_mm(
|
y_mxfp8 = scaled_grouped_mm_wrap(
|
||||||
xq,
|
xq,
|
||||||
wq.transpose(-2, -1),
|
wq.transpose(-2, -1),
|
||||||
x_scale,
|
x_scale,
|
||||||
w_scale,
|
w_scale,
|
||||||
offs=input_group_end_offsets,
|
offs=input_group_end_offsets,
|
||||||
out_dtype=torch.bfloat16,
|
out_dtype=torch.bfloat16,
|
||||||
)
|
scale_recipe_a=ScalingType.BlockWise1x32,
|
||||||
|
scale_recipe_b=ScalingType.BlockWise1x32,
|
||||||
|
swizzle_a=SwizzleType.SWIZZLE_32_4_4,
|
||||||
|
swizzle_b=SwizzleType.SWIZZLE_32_4_4,
|
||||||
|
wrap_v2=wrap_v2)
|
||||||
|
|
||||||
|
|
||||||
# Compute reference bf16 grouped gemm.
|
# Compute reference bf16 grouped gemm.
|
||||||
y_bf16 = torch._grouped_mm(
|
y_bf16 = torch._grouped_mm(
|
||||||
@ -1536,7 +1591,8 @@ class TestFP8Matmul(TestCase):
|
|||||||
@parametrize("fast_accum", [False, True])
|
@parametrize("fast_accum", [False, True])
|
||||||
# AMD does not support non-contiguous inputs yet
|
# AMD does not support non-contiguous inputs yet
|
||||||
@parametrize("strided", [False] + ([True] if torch.version.cuda else []))
|
@parametrize("strided", [False] + ([True] if torch.version.cuda else []))
|
||||||
def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided):
|
@parametrize("wrap_v2", [True, False])
|
||||||
|
def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided, wrap_v2):
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
fp8_dtype = e4m3_type
|
fp8_dtype = e4m3_type
|
||||||
m, n, k, n_groups = 16, 32, 64, 4
|
m, n, k, n_groups = 16, 32, 64, 4
|
||||||
@ -1545,9 +1601,16 @@ class TestFP8Matmul(TestCase):
|
|||||||
scale_a = torch.rand(m * n_groups, device=device, dtype=torch.float32)
|
scale_a = torch.rand(m * n_groups, device=device, dtype=torch.float32)
|
||||||
scale_b = torch.rand(n * n_groups, device=device, dtype=torch.float32)
|
scale_b = torch.rand(n * n_groups, device=device, dtype=torch.float32)
|
||||||
offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32)
|
offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32)
|
||||||
f = torch._scaled_grouped_mm
|
f = scaled_grouped_mm_wrap
|
||||||
out = f(a, b.t(), scale_a, scale_b, offs=offs,
|
out = f(a, b.t(),
|
||||||
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
scale_recipe_a=ScalingType.RowWise,
|
||||||
|
scale_recipe_b=ScalingType.RowWise,
|
||||||
|
offs=offs,
|
||||||
|
out_dtype=torch.bfloat16,
|
||||||
|
use_fast_accum=fast_accum,
|
||||||
|
wrap_v2=wrap_v2)
|
||||||
offs_cpu = offs.cpu()
|
offs_cpu = offs.cpu()
|
||||||
alist, blist, ascalelist, bscalelist = [], [], [], []
|
alist, blist, ascalelist, bscalelist = [], [], [], []
|
||||||
start = 0
|
start = 0
|
||||||
@ -1564,7 +1627,8 @@ class TestFP8Matmul(TestCase):
|
|||||||
@parametrize("fast_accum", [False, True])
|
@parametrize("fast_accum", [False, True])
|
||||||
# AMD does not support non-contiguous inputs yet
|
# AMD does not support non-contiguous inputs yet
|
||||||
@parametrize("strided", [False] + ([True] if torch.version.cuda else []))
|
@parametrize("strided", [False] + ([True] if torch.version.cuda else []))
|
||||||
def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided):
|
@parametrize("wrap_v2", [True, False])
|
||||||
|
def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, wrap_v2):
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
fp8_dtype = e4m3_type
|
fp8_dtype = e4m3_type
|
||||||
m, n, k, n_groups = 16, 32, 64, 4
|
m, n, k, n_groups = 16, 32, 64, 4
|
||||||
@ -1582,9 +1646,16 @@ class TestFP8Matmul(TestCase):
|
|||||||
offs[0] = offs[1]
|
offs[0] = offs[1]
|
||||||
scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32)
|
scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32)
|
||||||
scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
|
scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
|
||||||
f = torch._scaled_grouped_mm
|
f = scaled_grouped_mm_wrap
|
||||||
out = f(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs,
|
out = f(a, b.transpose(-2, -1),
|
||||||
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
scale_recipe_a=ScalingType.RowWise,
|
||||||
|
scale_recipe_b=ScalingType.RowWise,
|
||||||
|
offs=offs,
|
||||||
|
out_dtype=torch.bfloat16,
|
||||||
|
use_fast_accum=fast_accum,
|
||||||
|
wrap_v2=wrap_v2)
|
||||||
|
|
||||||
offs_cpu = offs.cpu()
|
offs_cpu = offs.cpu()
|
||||||
alist, ascalelist, outlist = [], [], []
|
alist, ascalelist, outlist = [], [], []
|
||||||
|
@ -6711,3 +6711,90 @@ def scaled_mm(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def scaled_grouped_mm(
|
||||||
|
mat_a: Tensor,
|
||||||
|
mat_b: Tensor,
|
||||||
|
scale_a: Tensor | list[Tensor],
|
||||||
|
scale_recipe_a: ScalingType | list[ScalingType],
|
||||||
|
scale_b: Tensor | list[Tensor],
|
||||||
|
scale_recipe_b: ScalingType | list[ScalingType],
|
||||||
|
swizzle_a: SwizzleType | list[SwizzleType] | None = None,
|
||||||
|
swizzle_b: SwizzleType | list[SwizzleType] | None = None,
|
||||||
|
bias: Optional[Tensor] = None,
|
||||||
|
offs: Optional[Tensor] = None,
|
||||||
|
output_dtype: Optional[torch.dtype] = torch.bfloat16,
|
||||||
|
contraction_dim: list[int] | tuple[int] = (),
|
||||||
|
use_fast_accum: bool = False,
|
||||||
|
) -> Tensor:
|
||||||
|
r"""
|
||||||
|
scaled_grouped_mm(mat_a, mat_b, scale_a, scale_recipe_a, scale_b, scale_recipe_b, swizzle_a, swizzle_b, bias, offs,
|
||||||
|
output_dtype, use_fast_accum)
|
||||||
|
|
||||||
|
Applies a grouped scaled matrix-multiply, grouped_mm(mat_a, mat_b) where the scaling of mat_a and mat_b are described by
|
||||||
|
scale_recipe_a and scale_recipe_b respectively.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scale_a: Tensor containing decoding scaling factors for mat_a
|
||||||
|
scale_recipe_a: Enum describing how mat_a has been scaled
|
||||||
|
scale_b: Tensor containing decoding scaling factors for mat_b
|
||||||
|
scale_recipe_b: Enum describing how mat_b has been scaled
|
||||||
|
swizzle_a: Enum describing the swizzling pattern (if any) of scale_a
|
||||||
|
swizzle_b: Enum describing the swizzling pattern (if any) of scale_b
|
||||||
|
bias: optional bias term to be added to the output
|
||||||
|
offs: optional offsets into the source tensors denoting group start indices
|
||||||
|
output_dtype: dtype used for the output tensor
|
||||||
|
contraction_dim: describe which dimensions are :math:`K` in the matmul.
|
||||||
|
use_fast_accum: enable/disable tensor-core fast accumulation (Hopper-GPUs only)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def expand_single_value(v: _Any | list[_Any] | None) -> list[_Any]:
|
||||||
|
if v is None:
|
||||||
|
return []
|
||||||
|
elif not isinstance(v, (list)):
|
||||||
|
return [
|
||||||
|
v,
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
return v
|
||||||
|
|
||||||
|
scale_a = expand_single_value(scale_a)
|
||||||
|
scale_recipe_a = expand_single_value(scale_recipe_a)
|
||||||
|
scale_b = expand_single_value(scale_b)
|
||||||
|
scale_recipe_b = expand_single_value(scale_recipe_b)
|
||||||
|
swizzle_a = expand_single_value(swizzle_a)
|
||||||
|
swizzle_b = expand_single_value(swizzle_b)
|
||||||
|
|
||||||
|
# native_functions has restrictions on what can be defined
|
||||||
|
# & passed through - std::optional<ArrayRef<Tensor>> for instance
|
||||||
|
# *cannot* be passed, but an empty vector (list) can.
|
||||||
|
# So, we need to convert None arguments for lists in python
|
||||||
|
# explicitly into empty lists.
|
||||||
|
def list_or_empty(l: list[_Any] | None) -> list[_Any]:
|
||||||
|
return [] if not l else l
|
||||||
|
|
||||||
|
def enum_list_as_int_list(l: _Any | list[_Any]) -> list[_Any]:
|
||||||
|
if not isinstance(l, list):
|
||||||
|
l = [
|
||||||
|
l,
|
||||||
|
]
|
||||||
|
return [li.value for li in l]
|
||||||
|
|
||||||
|
out = torch._scaled_grouped_mm_v2(
|
||||||
|
mat_a,
|
||||||
|
mat_b,
|
||||||
|
scale_a,
|
||||||
|
enum_list_as_int_list(scale_recipe_a),
|
||||||
|
enum_list_as_int_list(list_or_empty(swizzle_a)),
|
||||||
|
scale_b,
|
||||||
|
enum_list_as_int_list(scale_recipe_b),
|
||||||
|
enum_list_as_int_list(list_or_empty(swizzle_b)),
|
||||||
|
offs,
|
||||||
|
bias,
|
||||||
|
output_dtype,
|
||||||
|
contraction_dim,
|
||||||
|
use_fast_accum,
|
||||||
|
)
|
||||||
|
|
||||||
|
return out
|
||||||
|
@ -249,6 +249,7 @@ def get_ignored_functions() -> set[Callable]:
|
|||||||
torch.nn.functional.has_torch_function_unary,
|
torch.nn.functional.has_torch_function_unary,
|
||||||
torch.nn.functional.has_torch_function_variadic,
|
torch.nn.functional.has_torch_function_variadic,
|
||||||
torch.nn.functional.handle_torch_function,
|
torch.nn.functional.handle_torch_function,
|
||||||
|
torch.nn.functional.scaled_grouped_mm,
|
||||||
torch.nn.functional.scaled_mm,
|
torch.nn.functional.scaled_mm,
|
||||||
torch.nn.functional.sigmoid,
|
torch.nn.functional.sigmoid,
|
||||||
torch.nn.functional.hardsigmoid,
|
torch.nn.functional.hardsigmoid,
|
||||||
|
Reference in New Issue
Block a user