mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Add scaled_grouped_mm_v2 and python API (#165154)
Summary: * Add `torch._scaled_grouped_mm_v2` with more functionality and extensibility for future formats * Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint * Test both original and v2 functionality Test Plan: ``` pytest -svv -k grouped test/test_scaled_matmul_cuda.py ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlayton@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/165154 Approved by: https://github.com/drisspg, https://github.com/danielvegamyhre
This commit is contained in:
committed by
PyTorch MergeBot
parent
b509fb9b5d
commit
7c6c5d04fe
@ -2578,7 +2578,9 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
const Tensor& mat_a,
|
||||
const Tensor& mat_b,
|
||||
const Tensor& scale_a,
|
||||
const SwizzleType& swizzle_a,
|
||||
const Tensor& scale_b,
|
||||
const SwizzleType& swizzle_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
Tensor& out) {
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
@ -2589,6 +2591,16 @@ _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
TORCH_CHECK_VALUE(is_2d_2d || is_2d_3d, "MXFP8 grouped GEMM currently only supports 2d-2d and 2d-3d cases");
|
||||
TORCH_CHECK_VALUE(offs.has_value(), "MXFP8 2d-2d and 2d-3d grouped GEMMs requires offsets");
|
||||
TORCH_CHECK_VALUE(out.scalar_type() == at::kBFloat16, "Only bf16 out_dtype is supported for MXFP8 grouped gemm");
|
||||
// MXFP8 expects float8_e8m0fnu scales.
|
||||
TORCH_CHECK_VALUE(scale_a.scalar_type() == at::kFloat8_e8m0fnu && scale_b.scalar_type() == at::kFloat8_e8m0fnu,
|
||||
"For MXFP8 grouped gemm, both scales must be float8_e8m0fnu tensors.");
|
||||
#ifdef USE_ROCM
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE && swizzle_b == SwizzleType::NO_SWIZZLE,
|
||||
"For ROCM MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_NONE");
|
||||
#else
|
||||
TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4 && swizzle_b == SwizzleType::SWIZZLE_32_4_4,
|
||||
"For CUDA MXFP8 grouped gemm, both scale swizzle types must be SWIZZLE_32_4_4");
|
||||
#endif
|
||||
|
||||
#if defined(USE_FBGEMM_GENAI) and !defined(USE_ROCM)
|
||||
fbgemm_gpu::mx8mx8bf16_grouped_mm(
|
||||
@ -2673,6 +2685,9 @@ _f8_f8_bf16_rowwise_grouped_mm(
|
||||
const std::optional<Tensor>& bias,
|
||||
bool use_fast_accum,
|
||||
Tensor& out) {
|
||||
// FP8 per-tensor and per-row scaling expect fp32 scales.
|
||||
TORCH_CHECK_VALUE(scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
|
||||
"For grouped FP8 rowwise, both scales must be float32 tensors");
|
||||
#ifndef USE_ROCM
|
||||
return _f8_f8_bf16_rowwise_grouped_mm_cuda(
|
||||
mat_a,
|
||||
@ -2772,11 +2787,15 @@ _scaled_grouped_mm_cuda(
|
||||
#endif
|
||||
|
||||
if (is_mx8mx8bf16) {
|
||||
// Note: Passing implied SwizzleType here, correctness of scale previously checked
|
||||
// in `check_scale` call
|
||||
return _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a,
|
||||
SwizzleType::SWIZZLE_32_4_4,
|
||||
scale_b,
|
||||
SwizzleType::SWIZZLE_32_4_4,
|
||||
offs.value(),
|
||||
out);
|
||||
}
|
||||
@ -2793,6 +2812,140 @@ _scaled_grouped_mm_cuda(
|
||||
out);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
std::array<std::tuple<std::string, acceptance_fn, ScaledGemmImplementation>, 2> scale_grouped_kernel_dispatch = {{
|
||||
{ "rowwise_rowwise", check_rowwise_recipe, ScaledGemmImplementation::ROWWISE_ROWWISE},
|
||||
{ "mxfp8_mxfp8", check_mxfp8_recipe, ScaledGemmImplementation::MXFP8_MXFP8}}};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
Tensor
|
||||
_scaled_grouped_mm_cuda_v2(
|
||||
const Tensor& mat_a, const Tensor& mat_b,
|
||||
ArrayRef<Tensor> scale_a,
|
||||
IntArrayRef scale_recipe_a,
|
||||
IntArrayRef swizzle_a,
|
||||
ArrayRef<Tensor> scale_b,
|
||||
IntArrayRef scale_recipe_b,
|
||||
IntArrayRef swizzle_b,
|
||||
const std::optional<Tensor>& offs,
|
||||
const std::optional<Tensor>& bias,
|
||||
const std::optional<c10::ScalarType> out_dtype,
|
||||
IntArrayRef contraction_dim,
|
||||
bool use_fast_accum) {
|
||||
bool allowed_device = _scaled_mm_allowed_device(/*sm90_only*/true, /*sm100_only*/true);
|
||||
TORCH_CHECK_VALUE(allowed_device, "torch._scaled_grouped_mm is only supported on CUDA devices with compute capability = [9.0, 10.0], or ROCm MI300+");
|
||||
|
||||
TORCH_CHECK_VALUE(!check_valid_strides_and_return_transposed(mat_a), "Expected mat1 to not be transposed");
|
||||
TORCH_CHECK_VALUE(check_valid_strides_and_return_transposed(mat_b), "Expected mat2 to be transposed");
|
||||
TORCH_CHECK_VALUE(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
|
||||
TORCH_CHECK_VALUE(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
|
||||
const bool a_is_2d = mat_a.dim() == 2;
|
||||
const bool b_is_2d = mat_b.dim() == 2;
|
||||
|
||||
// NOTE(slayton): For sub-1B formats want contraction_dim argument?
|
||||
if (!a_is_2d || !b_is_2d) {
|
||||
if (contraction_dim.size() > 0) {
|
||||
const int dim_a = contraction_dim[0], dim_b = mat_b.size(contraction_dim[1]);
|
||||
TORCH_CHECK_VALUE(mat_a.size(dim_a) == mat_b.size(dim_b),
|
||||
"Contraction dimensions (", dim_a, ",", dim_b, ") of mat_a and mat_b must match, got: ", mat_a.size(dim_a), " and ",
|
||||
mat_b.size(dim_b));
|
||||
// Note: only (-1, -2) is currently supported
|
||||
TORCH_CHECK_VALUE(dim_a == -1 && dim_b == -2, "Curently contraction dims must be (-1, -2) only");
|
||||
} else {
|
||||
TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
|
||||
}
|
||||
}
|
||||
TORCH_CHECK_VALUE(
|
||||
mat_a.size(-1) % 16 == 0,
|
||||
"Expected trailing dimension of mat_a to be divisible by 16 ",
|
||||
"but got mat1 shape: (",
|
||||
mat_a.sizes(),
|
||||
").");
|
||||
TORCH_CHECK_VALUE(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0,
|
||||
"Expected mat_b shape to be divisible by 16 ",
|
||||
"but got mat_b shape: (",
|
||||
mat_b.sizes(),
|
||||
").");
|
||||
|
||||
TORCH_CHECK_VALUE(!bias.has_value(), "Bias not supported yet");
|
||||
TORCH_CHECK_VALUE(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
|
||||
|
||||
// NOTE: mxfp8 x mxfp8 requires (and asserts later) that offsets is present.
|
||||
// for rowwise, no offsets implies 3d-3d and is handled by lower-level
|
||||
// routines
|
||||
if (offs.has_value()) {
|
||||
TORCH_CHECK_VALUE(offs->dim() == 1, "offs has to be 1D");
|
||||
TORCH_CHECK_VALUE(offs->dtype() == at::kInt, "Offsets have to be int32");
|
||||
}
|
||||
|
||||
const auto out_dtype_ = out_dtype.value_or(kBFloat16);
|
||||
TORCH_CHECK_VALUE(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
|
||||
|
||||
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
|
||||
|
||||
// Conversion of implicitly-defined enums to explicit
|
||||
auto scale_recipe_a_enum = convert_int_to_enum<ScalingType>(scale_recipe_a);
|
||||
auto swizzle_a_enum = convert_int_to_enum<SwizzleType>(swizzle_a);
|
||||
auto scale_recipe_b_enum = convert_int_to_enum<ScalingType>(scale_recipe_b);
|
||||
auto swizzle_b_enum = convert_int_to_enum<SwizzleType>(swizzle_b);
|
||||
|
||||
// at this point we can start working out what we want to be doing
|
||||
// Try to do as few steps as possible.
|
||||
// NOTE: support is deliberately sparse, can explicitly enumerate all combinations allowed.
|
||||
// Do this via a list of defined (name, acceptance, concrete_impl) tuples.
|
||||
ScaledGemmImplementation gemm_impl = ScaledGemmImplementation::NONE;
|
||||
for (const auto& fn_entry : scale_grouped_kernel_dispatch) {
|
||||
const auto [name, accept_fn, scaled_gemm_impl] = fn_entry;
|
||||
bool ok = accept_fn(mat_a.scalar_type(),
|
||||
scale_recipe_a_enum,
|
||||
scale_a,
|
||||
mat_b.scalar_type(),
|
||||
scale_recipe_b_enum,
|
||||
scale_b);
|
||||
if (ok) {
|
||||
gemm_impl = scaled_gemm_impl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
TORCH_CHECK_VALUE(gemm_impl != ScaledGemmImplementation::NONE,
|
||||
"No gemm implementation was found");
|
||||
|
||||
switch (gemm_impl) {
|
||||
case ScaledGemmImplementation::ROWWISE_ROWWISE: {
|
||||
const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
|
||||
_check_scales_fp8_rowwise(mat_a, scale_a[0], 0 /* dim */ , 0 /* arg_idx */, scale_multiplier);
|
||||
_check_scales_fp8_rowwise(mat_b, scale_b[0], 1 /* dim */ , 1 /* arg_idx */, scale_multiplier);
|
||||
return _f8_f8_bf16_rowwise_grouped_mm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a[0],
|
||||
scale_b[0],
|
||||
offs,
|
||||
bias,
|
||||
use_fast_accum,
|
||||
out);
|
||||
}
|
||||
case ScaledGemmImplementation::MXFP8_MXFP8: {
|
||||
_check_scales_mxfp8(mat_a, scale_a[0], 0 /* dim */, 0 /* arg_idx */);
|
||||
_check_scales_mxfp8(mat_b, scale_b[0], 1 /* dim */, 1 /* arg_idx */);
|
||||
return _mx8_mx8_bf16_grouped_mm_fbgemm(
|
||||
mat_a,
|
||||
mat_b,
|
||||
scale_a[0],
|
||||
swizzle_a_enum[0],
|
||||
scale_b[0],
|
||||
swizzle_b_enum[0],
|
||||
offs.value(),
|
||||
out);
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"_scaled_grouped_mm_cuda_v2 is in an inconsistent state - should never reach here");
|
||||
}
|
||||
}
|
||||
|
||||
Tensor _grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
|
||||
const std::optional<at::Tensor>& offs,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
|
@ -7183,6 +7183,12 @@
|
||||
CUDA: _scaled_grouped_mm_cuda
|
||||
tags: needs_exact_strides
|
||||
|
||||
- func: _scaled_grouped_mm_v2(Tensor self, Tensor mat2, Tensor[] scale_a, int[] recipe_a, int[] swizzle_a, Tensor[] scale_b, int[] recipe_b, int[] swizzle_b, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None, int[] contraction_dim=[], bool use_fast_accum=False) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CUDA: _scaled_grouped_mm_cuda_v2
|
||||
tags: needs_exact_strides
|
||||
|
||||
- func: _grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
|
Reference in New Issue
Block a user