[WIP] Initial implementation of Grouped Gemm API (#148531)

This PR provides initial cutlass implementation of grouped gemm api as described in this [document](https://docs.google.com/document/d/1985La6wUUVH1AGBkNhaGKUXzx-9ybtbUp567-vYVOM4/edit?tab=t.0#heading=h.g8lzbjnyzzx9). Any combination of 2d and 3d inputs is supported, with 2d input being jagged, and the offsets of the jagged input being given by device tensor `offs`. Only H100 is supported, and only fp8_e4m3 with bf16 output and rowwise scaling. All the dimensions of each individual gemm have to be multiple of 16, that's cutlass limitation.
I'll need to add those checks, for dynamic dimensions unfortunately the checks will have to be a device assert.
I had to copy-paste cutlass's `Sm90RowBroadcast` and `Sm90ColBroadcast` structs with minor changes to enable scales given as pointer arrays, ideally those should be part of cutlass itself.
I copied the schedules from the similar grouped gemm in FBGEMM, but there's a lot of room to improve perf, especially for `fast_accum=False`.
Next steps would be perf tuning and increasing coverage to B100, I don't know how cutlass grouped gemm example handles blockwise scaling on B100.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148531
Approved by: https://github.com/drisspg
This commit is contained in:
Natalia Gimelshein
2025-03-11 21:49:46 +00:00
committed by PyTorch MergeBot
parent b98af95401
commit 53a1a022a9
9 changed files with 1610 additions and 5 deletions

View File

@ -1,6 +1,7 @@
#include <cstdint>
#include <c10/util/typeid.h>
#include <c10/util/Exception.h>
#include <c10/util/SmallVector.h>
#include <c10/core/Scalar.h>
#include <c10/core/ScalarType.h>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
@ -16,6 +17,7 @@
#include <ATen/native/Resize.h>
#include <c10/util/MaybeOwned.h>
#include <ATen/native/cuda/RowwiseScaledMM.h>
#include <ATen/native/cuda/ScaledGroupMM.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -1363,6 +1365,84 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
return out;
}
namespace {
c10::SmallVector<int64_t, 3> compute_grouped_gemm_output_size(const Tensor& mat_a,
const Tensor& mat_b,
const std::optional<at::Tensor>& offs
) {
const bool a_is_2d = mat_a.dim() == 2;
const bool b_is_2d = mat_b.dim() == 2;
if (a_is_2d) {
if (b_is_2d) {
return {offs->size(0), mat_a.size(0), mat_b.size(1)};
} else {
TORCH_CHECK(offs->size(0) == mat_b.size(0), "matrix batch sizes have to match");
return {mat_a.size(0), mat_b.size(-1)};
}
} else {
if (b_is_2d) {
// this case is not actually encountered for MoE gemms
TORCH_CHECK(offs->size(0) == mat_a.size(0), "matrix batch sizes have to match");
return {mat_a.size(1), mat_b.size(1)};
} else { // regular bmm
TORCH_CHECK(mat_a.size(0) == mat_b.size(0), "batched dimension has to match");
return {mat_a.size(0), mat_a.size(1), mat_b.size(-1)};
}
}
}
bool transposed(const Tensor& mat) {
IntArrayRef tensor_strides = mat.strides();
IntArrayRef tensor_sizes = mat.sizes();
int end_dim = mat.dim() - 1;
if ((tensor_strides[end_dim - 1] == 1) && (tensor_strides[end_dim] >= std::max<int64_t>(1, tensor_sizes[end_dim - 1]))) {
return true;
} else if ((tensor_strides[end_dim] == 1) && (tensor_strides[end_dim - 1] >= std::max<int64_t>(1, tensor_sizes[end_dim]))) {
return false;
} else {
TORCH_CHECK(false, "Tensor should not be self-overlapping");
}
}
void check_scale(const Tensor& mat, const Tensor& scale, const int dim, const int arg_idx, const int scale_multiplier=1) {
if (mat.dim() == 2) {
TORCH_CHECK(
scale.dim() == 1,
"scale must be a 1D tensor, but got ",
scale.dim(),
"D, arg ",
arg_idx);
TORCH_CHECK(
scale.is_contiguous(), "scale_a must be contiguous for arg ", arg_idx);
TORCH_CHECK(
scale.size(0) == mat.size(dim) * scale_multiplier,
"scale must have the same length as mat for arg ",
arg_idx);
} else {
TORCH_CHECK(
scale.dim() == 2,
"scale must be a 2D tensor, but got ",
scale.dim(),
"D for arg ",
arg_idx);
TORCH_CHECK(
scale.stride(1),
"scale_a must be contiguous in the last dimension for arg ",
arg_idx);
TORCH_CHECK(
scale.size(0) == mat.size(0),
"scale must have the same batch dimension as mat for arg ",
arg_idx);
TORCH_CHECK(
scale.size(1) == mat.size(1 + dim),
"scale must have the same first dimension as mat for arg ",
arg_idx);
}
}
}
Tensor
_scaled_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
const Tensor& scale_a,
@ -1376,4 +1456,82 @@ _scaled_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
return _scaled_mm_out_cuda(mat_a, mat_b, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out);
}
Tensor
_scaled_grouped_mm_cuda(const Tensor& mat_a, const Tensor& mat_b,
const Tensor& scale_a, const Tensor& scale_b,
const std::optional<at::Tensor>& offs,
const std::optional<at::Tensor>& bias,
const std::optional<at::Tensor>& scale_result,
std::optional<c10::ScalarType> out_dtype,
bool use_fast_accum) {
#ifndef USE_ROCM
bool allowed_device = _scaled_mm_allowed_device();
TORCH_CHECK(allowed_device, "torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+");
TORCH_CHECK(mat_a.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_a.scalar_type());
TORCH_CHECK(mat_b.dtype() == at::kFloat8_e4m3fn, "Expected mat_a to be Float8_e4m3 matrix got ", mat_b.scalar_type());
TORCH_CHECK(!transposed(mat_a), "Expected mat1 to not be transposed");
TORCH_CHECK(transposed(mat_b), "Expected mat2 to be transposed");
TORCH_CHECK(mat_a.dim() == 2 || mat_a.dim() == 3, "mat_a has to be 2 or 3d");
TORCH_CHECK(mat_b.dim() == 2 || mat_b.dim() == 3, "mat_b has to be 2 or 3d");
const bool a_is_2d = mat_a.dim() == 2;
const bool b_is_2d = mat_b.dim() == 2;
TORCH_CHECK(
mat_a.size(-1) % 16 == 0,
"Expected trailing dimension of mat_a to be divisible by 16 ",
"but got mat1 shape: (",
mat_a.sizes(),
").");
TORCH_CHECK(mat_b.size(-2) % 16 == 0 && mat_b.size(-1) % 16 == 0,
"Expected mat_b shape to be divisible by 16 ",
"but got mat_b shape: (",
mat_b.sizes(),
").");
TORCH_CHECK(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
if (offs.has_value()) {
TORCH_CHECK(offs->dim() == 1, "offs has to be 1D");
TORCH_CHECK(offs->dtype() == at::kInt, "Offsets have to be int32");
}
// Both Per-Tensor and Row-wise scaling expect fp32 tensors
TORCH_CHECK(
scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat,
"Both scale_a and scale_b must be float (fp32) tensors.");
const int scale_multiplier = (mat_a.dim() == 2 && mat_b.dim() == 2) ? offs->size(0) : 1;
check_scale(mat_a, scale_a, 0 ,0, scale_multiplier);
check_scale(mat_b, scale_b, 1, 1, scale_multiplier);
const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type());
TORCH_CHECK(out_dtype_ == kBFloat16, "Only bf16 high precision output types are supported for grouped gemm");
const auto out_size = compute_grouped_gemm_output_size(mat_a, mat_b, offs);
Tensor out = at::empty(out_size, mat_a.options().dtype(out_dtype_));
at::cuda::detail::f8f8bf16_grouped_mm(
mat_a,
mat_b,
scale_a,
scale_b,
offs,
bias,
use_fast_accum,
out);
return out;
#else
TORCH_CHECK(false, "grouped gemm is not supported on ROCM")
#endif
}
} // namespace at::native

View File

@ -946,7 +946,6 @@ void dispatch_fp8_rowwise_kernel_on_input_dtypes(
}
}
template <typename... Types>
void dispatch_fp8_rowwise_kernel_on_bias_dtype(
at::Tensor XQ,
at::Tensor WQ,
@ -957,12 +956,13 @@ void dispatch_fp8_rowwise_kernel_on_bias_dtype(
at::Tensor out) {
if (bias.has_value() && bias->dtype() == at::kBFloat16) {
dispatch_fp8_rowwise_kernel_on_input_dtypes<
cutlass::bfloat16_t,
Types...>(XQ, WQ, x_scale, w_scale, bias, use_fast_accum, out);
cutlass::bfloat16_t>
(XQ, WQ, x_scale, w_scale, bias, use_fast_accum, out);
} else {
dispatch_fp8_rowwise_kernel_on_input_dtypes<
float,
Types...>(XQ, WQ, x_scale, w_scale, bias, use_fast_accum, out);
float>
//Types...>
(XQ, WQ, x_scale, w_scale, bias, use_fast_accum, out);
}
}

View File

@ -0,0 +1,640 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/Dispatch.h>
#include <ATen/core/Tensor.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/macros/Macros.h>
#include <c10/util/irange.h>
// Two warninngs in Cutlass included header files
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wset-but-not-used")
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-but-set-parameter")
// Determine if the architecture supports rowwise scaled mm
// Currently failing on windows with:
// https://github.com/NVIDIA/cutlass/issues/1571
#if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDA_VERSION) && \
CUDA_VERSION >= 12000
#define BUILD_ROWWISE_FP8_KERNEL
#endif
#if defined(BUILD_ROWWISE_FP8_KERNEL)
#include <ATen/native/cuda/cutlass_utils.cuh>
#include <cute/tensor.hpp>
#include <cutlass/core_io.h>
#include <cutlass/cutlass.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/half.h>
#include <cutlass/numeric_types.h>
#include <cutlass/trace.h>
#include <cutlass/util/host_tensor.h>
#include <cutlass/version.h>
#include <cutlass/epilogue/collective/collective_builder.hpp>
#include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
#include <cutlass/gemm/collective/collective_builder.hpp>
#include <cutlass/gemm/device/gemm_universal.h>
#include <cutlass/gemm/device/gemm_universal_adapter.h>
#include <cutlass/gemm/kernel/default_gemm_universal_with_visitor.h>
#include <cute/atom/mma_atom.hpp>
#include <cutlass/gemm/dispatch_policy.hpp>
#include <cutlass/gemm/kernel/gemm_universal.hpp>
#include <cutlass/util/packed_stride.hpp>
C10_DIAGNOSTIC_POP()
C10_DIAGNOSTIC_POP()
namespace {
using Strides = std::array<int64_t, 3>;
template <
typename DtypeA,
typename DtypeB,
typename DtypeOutput,
typename DtypeScale,
typename ProblemShape,
typename StrideA,
typename StrideB,
typename StrideOutput>
__global__ void prepare_gemm_data(
DtypeA* A,
DtypeB* B,
DtypeOutput* output,
DtypeScale* scale_A,
DtypeScale* scale_B,
DtypeA** A_ptrs,
DtypeB** B_ptrs,
DtypeOutput** output_ptrs,
DtypeScale** inputA_scale_ptrs,
DtypeScale** inputB_scale_ptrs,
ProblemShape* problem_sizes,
// Strides for cutlass, cute::Stride
StrideA* stride_A,
StrideB* stride_B,
StrideOutput* stride_output,
const int32_t* offs,
int32_t M,
int32_t N,
int32_t K,
// Original strides of the input tensors
Strides tensor_StrideA,
Strides tensor_StrideB,
Strides tensor_StrideOutput,
int64_t a_scale_stride,
int64_t b_scale_stride) {
int32_t tid = threadIdx.x;
int32_t delta = 0;
if (offs != nullptr) {
int32_t start = tid == 0 ? 0 : offs[tid - 1];
delta = offs[tid] - start;
CUDA_KERNEL_ASSERT(delta % 16 == 0 && "expected dynamic dimension to be multiple of 16\n");
}
int64_t lda, ldb, ldoutput;
if (M < 0) {
// A and output is 2d
M = delta;
lda = tensor_StrideA[0];
ldb = tensor_StrideB[2]; // B is transposed
ldoutput = tensor_StrideOutput[0];
A_ptrs[tid] = tid == 0 ? A : A + offs[tid - 1] * lda;
inputA_scale_ptrs[tid] = tid == 0 ? scale_A : scale_A + offs[tid - 1];
output_ptrs[tid] = tid == 0 ? output : output + offs[tid - 1] * ldoutput;
B_ptrs[tid] = B + tid * tensor_StrideB[0];
inputB_scale_ptrs[tid] = scale_B + tid * b_scale_stride;
} else if (N < 0) {
N = delta;
lda = tensor_StrideA[1];
ldb = tensor_StrideB[1]; // B is transposed
ldoutput = tensor_StrideOutput[0];
A_ptrs[tid] = A + tid * tensor_StrideA[0];
inputA_scale_ptrs[tid] = scale_A + tid * a_scale_stride;
output_ptrs[tid] = tid == 0 ? output : output + offs[tid - 1];
B_ptrs[tid] = tid == 0 ? B : B + offs[tid - 1] * ldb;
inputB_scale_ptrs[tid] = tid == 0 ? scale_B : scale_B + offs[tid - 1];
} else if (K < 0) {
// A, B is 2d, output is 3d
K = delta;
lda = tensor_StrideA[0];
ldb = tensor_StrideB[1]; // B is transposed
ldoutput = tensor_StrideOutput[1];
A_ptrs[tid] = tid == 0 ? A : A + offs[tid - 1];
B_ptrs[tid] = tid == 0 ? B : B + offs[tid - 1];
inputA_scale_ptrs[tid] = scale_A + tid * M;
inputB_scale_ptrs[tid] = scale_B + tid * N;
output_ptrs[tid] = output + tid * tensor_StrideOutput[0];
} else {
// A, B, output are 3D
lda = tensor_StrideA[1];
ldb = tensor_StrideB[2];
ldoutput = tensor_StrideOutput[1];
A_ptrs[tid] = A + tid * tensor_StrideA[0];
B_ptrs[tid] = B + tid * tensor_StrideB[0];
inputA_scale_ptrs[tid] = scale_A + tid * a_scale_stride;
inputB_scale_ptrs[tid] = scale_B + tid * b_scale_stride;
output_ptrs[tid] = output + tid * tensor_StrideOutput[0];
}
problem_sizes[tid] = ProblemShape(M, N, K);
stride_A[tid] = cutlass::make_cute_packed_stride(StrideA{}, {M, lda, 1});
stride_B[tid] = cutlass::make_cute_packed_stride(StrideB{}, {N, ldb, 1});
stride_output[tid] =
cutlass::make_cute_packed_stride(StrideOutput{}, {M, ldoutput, 1});
}
using DtypeScale = float;
using DtypeAccum = float;
using DtypeEpilogue = float;
using DtypeOutput = cutlass::bfloat16_t;
using Multiply = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies,
DtypeEpilogue,
DtypeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
using Add = cutlass::epilogue::fusion::Sm90Compute<
cutlass::plus,
DtypeEpilogue,
DtypeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
using Cast = cutlass::epilogue::fusion::Sm90Compute<
cutlass::epilogue::thread::Identity,
DtypeOutput,
DtypeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
using ProblemShape = cutlass::gemm::GroupProblemShape<
cute::Shape<int32_t, int32_t, int32_t>>; // <M,N,K> per
// group
template <
bool FastAccum,
bool PONG,
typename TB_M,
typename TB_N,
typename TB_K>
struct Schedule {
using FastCooperativeSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum;
using CooperativeSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative;
using FastPongSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
using PongSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong;
using CooperativeEpilogueSchedule =
cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative;
using PongEpilogueSchedule =
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
using KernelSchedule = cute::conditional_t<
PONG,
cute::conditional_t<FastAccum, FastPongSchedule, PongSchedule>,
cute::conditional_t<
FastAccum,
FastCooperativeSchedule,
CooperativeSchedule>>;
using EpilogueSchedule = cute::
conditional_t<PONG, PongEpilogueSchedule, CooperativeEpilogueSchedule>;
using TileShape = cute::Shape<TB_M, TB_N, TB_K>;
using ClusterShape = cute::Shape<cute::_2, cute::_2, cute::_1>;
};
int ceildiv(int a, int b) {
return (a + b - 1) / b;
}
int round_up_to_nearest_multiple(int a, int b) {
return ceildiv(a, b) * b;
}
template <
typename FastAccum,
typename BiasType,
typename Pong,
typename TB_M,
typename TB_N,
typename TB_K>
void f8f8bf16_grouped_gemm_impl_sm90(
at::Tensor mat_a, // FP8
at::Tensor mat_b, // FP8
at::Tensor scale_a, // FP32
at::Tensor scale_b, // FP32
std::optional<at::Tensor> offs,
std::optional<at::Tensor> bias, // BF16
bool use_fast_accum,
at::Tensor& out) {
using DtypeA = cutlass::float_e4m3_t;
using DtypeB = cutlass::float_e4m3_t;
using DtypeOutput = cutlass::bfloat16_t;
using LayoutA = cutlass::layout::RowMajor;
constexpr int AlignmentA = 16 / sizeof(DtypeA);
using LayoutB = cutlass::layout::ColumnMajor;
constexpr int AlignmentB = 16 / sizeof(DtypeB);
using LayoutOutput = cutlass::layout::RowMajor;
constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput);
// Tag indicating the minimum SM that supports the intended feature
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using TileShape =
typename Schedule<FastAccum::value, Pong::value, TB_M, TB_N, TB_K>::
TileShape;
using ClusterShape =
typename Schedule<FastAccum::value, Pong::value, TB_M, TB_N, TB_K>::
ClusterShape;
using KernelSchedule =
typename Schedule<FastAccum::value, Pong::value, TB_M, TB_N, TB_K>::
KernelSchedule;
using EpilogueSchedule =
typename Schedule<FastAccum::value, Pong::value, TB_M, TB_N, TB_K>::
EpilogueSchedule;
// TODO remove *BroadcastPtrArrays and replace with just Broadcast
// when https://github.com/NVIDIA/cutlass/pull/2120/ is in the tagged cutlass version
// Implement rowwise scaling epilogue.
using ScaleA = cutlass::epilogue::fusion::Sm90ColBroadcastPtrArray<
0,
TileShape,
DtypeScale,
DtypeScale,
cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;
using ScaleB = cutlass::epilogue::fusion::Sm90RowBroadcastPtrArray<
0,
TileShape,
DtypeScale,
DtypeScale,
cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using AccumScale = cutlass::epilogue::fusion::Sm90EVT<
Multiply,
ScaleB,
cutlass::epilogue::fusion::Sm90EVT<Multiply, ScaleA, Accum>>;
using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT<Cast, AccumScale>;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
TileShape,
ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
DtypeAccum,
DtypeAccum,
DtypeOutput,
LayoutOutput*,
AlignmentOutput,
DtypeOutput,
LayoutOutput*,
AlignmentOutput,
EpilogueSchedule,
EpilogueEVT>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
DtypeA,
LayoutA*,
AlignmentA,
DtypeB,
LayoutB*,
AlignmentB,
DtypeAccum,
TileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::
GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideOutput = typename Gemm::GemmKernel::InternalStrideD;
int32_t M, N, K, group_count;
M = mat_a.size(-2);
K = mat_a.size(-1);
N = mat_b.size(-1);
if (mat_a.dim() == 2 && mat_b.dim() == 2) {
// if both inputs are ragged, K is dynamic, M and N come from inputs
group_count = offs->size(0);
K = -1;
} else if (mat_a.dim() == 2) {
group_count = mat_b.size(0);
M = -1;
} else if (mat_b.dim() == 2) {
group_count = mat_a.size(0);
N = -1;
} else {
// regular bmm
group_count = mat_a.size(0);
}
const int64_t problem_shape_size =
group_count * ((int64_t)sizeof(ProblemShape::UnderlyingProblemShape));
const int64_t stride_size = 3 * group_count * ((int64_t)sizeof(StrideA));
// dummy tmas are created based on these pointer-to-pointers
// the actual values are never used, they are replaced
// by real addresses, but for dummy tma creation to succeed
// due to bug in cuda < 12.4 the pointers have to be aligned to 128 bits
const int group_alignment = 16 / sizeof(void*);
const int aligned_group_count =
round_up_to_nearest_multiple(group_count, group_alignment);
int64_t input_args_size = aligned_group_count * 5 * sizeof(void*) +
problem_shape_size + stride_size;
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
auto input_buf = allocator.allocate(input_args_size);
void* buf_ptr = input_buf.get();
DtypeA** inputA_ptrs = reinterpret_cast<DtypeA**>(buf_ptr);
DtypeB** inputB_ptrs =
reinterpret_cast<DtypeB**>(inputA_ptrs + aligned_group_count);
DtypeOutput** output_ptrs =
reinterpret_cast<DtypeOutput**>(inputB_ptrs + aligned_group_count);
DtypeScale** inputA_scale_ptrs =
reinterpret_cast<DtypeScale**>(output_ptrs + aligned_group_count);
DtypeScale** inputB_scale_ptrs =
reinterpret_cast<DtypeScale**>(inputA_scale_ptrs + aligned_group_count);
static_assert(
sizeof(StrideA) == 8, "expected StrideA to be 8 bytes for alignment");
StrideA* stride_A =
reinterpret_cast<StrideA*>(inputB_scale_ptrs + aligned_group_count);
StrideB* stride_B = reinterpret_cast<StrideB*>(stride_A + group_count);
StrideOutput* stride_output =
reinterpret_cast<StrideOutput*>(stride_B + group_count);
ProblemShape::UnderlyingProblemShape* problem_sizes =
reinterpret_cast<ProblemShape::UnderlyingProblemShape*>(
stride_output + group_count);
TORCH_CHECK(group_count < 1024, "Can't process more than 1024 groups");
auto stream = at::cuda::getCurrentCUDAStream().stream();
auto make_strides = [](at::IntArrayRef strides) -> Strides {
Strides out;
std::copy(strides.begin(), strides.end(), out.begin());
return out;
};
Strides tensor_StrideA = make_strides(mat_a.strides());
Strides tensor_StrideB = make_strides(mat_b.strides());
Strides tensor_StrideOutput = make_strides(out.strides());
// scale stride will be used inside the kernel only if needed,
// so for 1d scales the "1" assigned here won't be used
int64_t a_scale_stride = scale_a.stride(0);
int64_t b_scale_stride = scale_b.stride(0);
prepare_gemm_data<<<1, group_count, 0, stream>>>(
reinterpret_cast<DtypeA*>(mat_a.data_ptr()),
reinterpret_cast<DtypeB*>(mat_b.data_ptr()),
reinterpret_cast<DtypeOutput*>(out.data_ptr()),
scale_a.data_ptr<DtypeScale>(),
scale_b.data_ptr<DtypeScale>(),
inputA_ptrs,
inputB_ptrs,
output_ptrs,
inputA_scale_ptrs,
inputB_scale_ptrs,
problem_sizes,
stride_A,
stride_B,
stride_output,
offs.has_value() ? offs->const_data_ptr<int32_t>() : nullptr,
M,
N,
K,
tensor_StrideA,
tensor_StrideB,
tensor_StrideOutput,
a_scale_stride,
b_scale_stride);
C10_CUDA_KERNEL_LAUNCH_CHECK();
// auto buf_cpu = mat_a.new_empty(
// input_args_size, at::TensorOptions().dtype(at::kByte).device(at::kCPU));
// AT_CUDA_CHECK(cudaMemcpy(
// (char*)buf_cpu.data_ptr(),
// buf_ptr,
// input_args_size,
// cudaMemcpyDeviceToHost));
// char* buf_ptr_cpu = (char*)buf_cpu.data_ptr();
// DtypeA** inputA_ptrs_h = reinterpret_cast<DtypeA**>(buf_ptr_cpu);
// DtypeB** inputB_ptrs_h =
// reinterpret_cast<DtypeB**>(inputA_ptrs_h + aligned_group_count);
// DtypeOutput** output_ptrs_h =
// reinterpret_cast<DtypeOutput**>(inputB_ptrs_h + aligned_group_count);
// DtypeScale** inputA_scale_ptrs_h =
// reinterpret_cast<DtypeScale**>(output_ptrs_h + aligned_group_count);
// DtypeScale** inputB_scale_ptrs_h =
// reinterpret_cast<DtypeScale**>(inputA_scale_ptrs_h + aligned_group_count);
// StrideA* stride_A_h =
// reinterpret_cast<StrideA*>(inputB_scale_ptrs_h + aligned_group_count);
// StrideB* stride_B_h = reinterpret_cast<StrideB*>(stride_A_h + group_count);
// StrideOutput* stride_output_h =
// reinterpret_cast<StrideOutput*>(stride_B_h + group_count);
// ProblemShape::UnderlyingProblemShape* problem_sizes_h =
// reinterpret_cast<ProblemShape::UnderlyingProblemShape*>(
// stride_output_h + group_count);
// std::cout << "PTRS " << mat_a.data_ptr() << " " << mat_b.data_ptr() << " "
// << out.data_ptr() << " " << scale_a.data_ptr() << " "
// << scale_b.data_ptr() << "\n";
// for (int i = 0; i < group_count; i++) {
// std::cout << "A " << (void*)inputA_ptrs_h[i] << "\n";
// std::cout << "B " << (void*)inputB_ptrs_h[i] << "\n";
// std::cout << "O " << (void*)output_ptrs_h[i] << "\n";
// std::cout << "A_scale " << (void*)inputA_scale_ptrs_h[i] << "\n";
// std::cout << "B_scale " << (void*)inputB_scale_ptrs_h[i] << "\n";
// std::cout << "sizes " << problem_sizes_h[i] << "\n";
// std::cout << "strideA" << stride_A_h[i] << "\n";
// std::cout << "strideB" << stride_B_h[i] << "\n";
// std::cout << "stride_output" << stride_output_h[i] << "\n";
// }
// int device_id = 0;
// cutlass::KernelHardwareInfo kernel_hw_info =
// cutlass::KernelHardwareInfo::make_kernel_hardware_info<Gemm::GemmKernel>(device_id);
typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGrouped,
{group_count, problem_sizes, nullptr},
{(const DtypeA**)inputA_ptrs,
stride_A,
(const DtypeB**)inputB_ptrs,
stride_B},
{{{{inputB_scale_ptrs}, {inputA_scale_ptrs}}},
(const DtypeOutput**)output_ptrs,
stride_output,
output_ptrs,
stride_output}};
int sm_count = at::cuda::getDeviceProperties(out.device().index())->multiProcessorCount;
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
sm_count -= at::globalContext()._SMCarveout_EXPERIMENTAL().value();
}
arguments.hw_info.sm_count = sm_count;
size_t workspace_size = Gemm::get_workspace_size(arguments);
auto workspace = allocator.allocate(workspace_size);
Gemm gemm;
TORCH_CHECK(
gemm.can_implement(arguments) == cutlass::Status::kSuccess,
"cutlass cannot implement");
TORCH_CHECK(
gemm.initialize(arguments, workspace.get()) == cutlass::Status::kSuccess,
"cutlass cannot initialize");
auto status = gemm(at::cuda::getCurrentCUDAStream());
TORCH_CHECK(
status == cutlass::Status::kSuccess,
"cutlass cannot run, error ",
int(status));
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template <typename FastAccum, typename BiasType>
void dispatch_fp8_grouped_gemm_on_tile_size(
at::Tensor mat_a, // FP8
at::Tensor mat_b, // FP8
at::Tensor scale_a, // FP32
at::Tensor scale_b, // FP32
std::optional<at::Tensor> offs,
std::optional<at::Tensor> bias, // BF16
bool use_fast_accum,
at::Tensor& out) {
int32_t M, N, K, group_count;
M = mat_a.size(-2);
K = mat_a.size(-1);
N = mat_b.size(-1);
// below we assume that gemms are approx same size
if (mat_a.dim() == 2 && mat_b.dim() == 2) {
// if both inputs are ragged, K is dynamic, M and N come from inputs
group_count = offs->size(0);
K = K / group_count;
} else if (mat_a.dim() == 2) {
group_count = mat_b.size(0);
M = M / group_count;
} else if (mat_b.dim() == 2) {
group_count = mat_a.size(0);
N = N / group_count;
}
bool large =
((M >= 2048 && K >= 2048) || (M >= 2048 && N >= 2048) ||
(K >= 2048 && N >= 2048));
bool small = (M <= 128 || N <= 128);
if (small) {
f8f8bf16_grouped_gemm_impl_sm90<
FastAccum,
BiasType,
/*Pong*/ std::true_type,
cute::_64,
cute::_128,
cute::_128>(
mat_a, mat_b, scale_a, scale_b, offs, bias, use_fast_accum, out);
} else if (large && FastAccum::value) {
f8f8bf16_grouped_gemm_impl_sm90<
FastAccum,
BiasType,
/*Pong*/ std::false_type,
cute::_256,
cute::_128,
cute::_128>(
mat_a, mat_b, scale_a, scale_b, offs, bias, use_fast_accum, out);
} else if (large) { // use smaller tile for slow accum to avoid spilling
f8f8bf16_grouped_gemm_impl_sm90<
FastAccum,
BiasType,
/*Pong*/ std::false_type,
cute::_128,
cute::_128,
cute::_128>(
mat_a, mat_b, scale_a, scale_b, offs, bias, use_fast_accum, out);
} else
f8f8bf16_grouped_gemm_impl_sm90<
FastAccum,
BiasType,
/*Pong*/ std::false_type,
cute::_128,
cute::_256,
cute::_64>(
mat_a, mat_b, scale_a, scale_b, offs, bias, use_fast_accum, out);
}
template <typename BiasType>
void dispatch_fp8_grouped_gemm_on_fast_accum(
at::Tensor mat_a, // FP8
at::Tensor mat_b, // FP8
at::Tensor scale_a, // FP32
at::Tensor scale_b, // FP32
std::optional<at::Tensor> offs,
std::optional<at::Tensor> bias, // BF16
bool use_fast_accum,
at::Tensor& out) {
if (use_fast_accum) {
dispatch_fp8_grouped_gemm_on_tile_size<std::true_type, BiasType>(
mat_a, mat_b, scale_a, scale_b, offs, bias, use_fast_accum, out);
} else {
dispatch_fp8_grouped_gemm_on_tile_size<std::false_type, BiasType>(
mat_a, mat_b, scale_a, scale_b, offs, bias, use_fast_accum, out);
}
}
void dispatch_fp8_grouped_gemm_on_bias_dtype(
at::Tensor mat_a, // FP8
at::Tensor mat_b, // FP8
at::Tensor scale_a, // FP32
at::Tensor scale_b, // FP32
std::optional<at::Tensor> offs,
std::optional<at::Tensor> bias, // BF16
bool use_fast_accum,
at::Tensor& out) {
if (bias.has_value() && bias->dtype() == at::kBFloat16) {
dispatch_fp8_grouped_gemm_on_fast_accum<cutlass::bfloat16_t>(
mat_a, mat_b, scale_a, scale_b, offs, bias, use_fast_accum, out);
} else {
dispatch_fp8_grouped_gemm_on_fast_accum<float>(
mat_a, mat_b, scale_a, scale_b, offs, bias, use_fast_accum, out);
}
}
} // namespace
#endif
namespace at::cuda::detail {
void f8f8bf16_grouped_mm(
at::Tensor mat_a, // FP8
at::Tensor mat_b, // FP8
at::Tensor scale_a, // FP32
at::Tensor scale_b, // FP32
std::optional<at::Tensor> offs,
std::optional<at::Tensor> bias, // BF16
bool use_fast_accum,
at::Tensor& out) {
#if defined(BUILD_ROWWISE_FP8_KERNEL)
dispatch_fp8_grouped_gemm_on_bias_dtype(
mat_a, mat_b, scale_a, scale_b, offs, bias, use_fast_accum, out);
#else
TORCH_CHECK(false, "grouped mm is not supported on your system");
#endif
}
} // namespace at::cuda::detail

View File

@ -0,0 +1,15 @@
#pragma once
#include <ATen/core/TensorBase.h>
#include <optional>
namespace at::cuda::detail {
TORCH_API void f8f8bf16_grouped_mm(
at::Tensor mat_a, // FP8
at::Tensor mat_b, // FP8
at::Tensor scale_a, // FP32
at::Tensor scale_b, // FP32
std::optional<at::Tensor> offs,
std::optional<at::Tensor> bias, // BF16
bool use_fast_accum,
at::Tensor& out);
} // namespace at::cuda::detail

View File

@ -0,0 +1,654 @@
#pragma once
#include <cutlass/cutlass.h>
#include <cutlass/arch/barrier.h>
#include <cutlass/epilogue/collective/detail.hpp>
#include <cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp>
// TODO remove *BroadcastPtrArrays and replace with just Broadcast
// when https://github.com/NVIDIA/cutlass/pull/2120/ is in the tagged cutlass version
namespace cutlass::epilogue::fusion {
using namespace cute;
using namespace detail;
// Row vector broadcast with grouping.
template<
int Stages,
class CtaTileShapeMNK,
class ElementInput,
class ElementCompute = ElementInput,
class StrideMNL_ = Stride<_0,_1,_0>,
int Alignment = 128 / sizeof_bits_v<ElementInput>,
bool EnableNullptr = true // Fallback scalar broadcast for nullptr params
>
struct Sm90RowBroadcastPtrArray {
using StrideMNL = StrideMNL_;
static_assert(Stages == 0, "Row broadcast doesn't support smem pipelining");
static constexpr bool IsDynamicBroadcast = is_same_v<remove_cvref_t<decltype(get<1>(StrideMNL{}))>, bool>; // row vector or scalar broadcast
static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))> || IsDynamicBroadcast); // batch stride can be dynamic or static
static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{} || IsDynamicBroadcast);
struct SharedStorage {
array_aligned<ElementInput, size<1>(CtaTileShapeMNK{})> smem;
};
struct Arguments {
ElementInput const* const* ptr_row_array = nullptr;
ElementInput null_default = ElementInput(0);
StrideMNL dRow = {};
};
struct Params {
ElementInput const* const* ptr_row_array = nullptr;
ElementCompute null_default = ElementCompute(0);
StrideMNL dRow = {};
};
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return {args.ptr_row_array, ElementCompute(args.null_default), args.dRow};
}
template <class ProblemShape>
static bool
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
return true;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
template <class ProblemShape>
static cutlass::Status
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
CudaHostAdapter* cuda_adapter = nullptr) {
return cutlass::Status::kSuccess;
}
CUTLASS_HOST_DEVICE
Sm90RowBroadcastPtrArray() { }
CUTLASS_HOST_DEVICE
Sm90RowBroadcastPtrArray(Params const& params, SharedStorage const& shared_storage)
: params(params), is_zero_(false),
smem(const_cast<ElementInput*>(shared_storage.smem.data())) {
auto const& [stride_M, stride_N, stride_L] = params.dRow;
// Nullptr default
if (EnableNullptr && params.ptr_row_array == nullptr) {
is_zero_ = params.null_default == ElementCompute(0);
}
// Dynamic non-batched scalar broadcast
else if (IsDynamicBroadcast && stride_N == bool(0) && stride_L == repeat_like(stride_L, 0)) {
is_zero_ = params.ptr_row_array[0][0] == ElementInput(0);
}
}
Params params;
bool is_zero_ = false;
ElementInput *smem = nullptr;
CUTLASS_DEVICE bool
is_producer_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_C_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_zero() const {
return is_zero_;
}
template <class... Args>
CUTLASS_DEVICE auto
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
return EmptyProducerLoadCallbacks{};
}
template <class GS_GTensor, class GS_STensor, class GS_CTensor, class Tiled_G2S, class SR_STensor, class SR_RTensor, class Residue, class ThrNum>
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
CUTLASS_DEVICE
ConsumerStoreCallbacks(
GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
Residue residue_cRow_, ThrNum thr_num_, Params const& params_)
: tGS_gRow(tGS_gRow_)
, tGS_sRow(tGS_sRow_)
, tGS_cRow(tGS_cRow_)
, tiled_G2S(tiled_g2s_)
, tSR_sRow(tSR_sRow_)
, tSR_rRow(tSR_rRow_)
, residue_cRow(residue_cRow_)
, params(params_)
, is_nullptr(EnableNullptr && params_.ptr_row_array == nullptr) {
if (is_nullptr) {
fill(tSR_rRow, params.null_default);
}
}
GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
Tiled_G2S tiled_G2S;
SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Residue residue_cRow; // (m, n)
ThrNum thr_num;
Params const& params;
bool is_nullptr;
CUTLASS_DEVICE void
begin() {
if (is_nullptr) {
return;
}
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
Tensor tGS_cRow_flt = filter_zeros(tGS_cRow, tGS_gRow.stride());
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
continue; // OOB of SMEM,
}
if (elem_less(tGS_cRow_flt(i), residue_cRow)) {
tGS_sRow_flt(i) = tGS_gRow_flt(i);
}
else {
tGS_sRow_flt(i) = ElementInput(0); // Set to Zero when OOB so LDS can be issued without any preds.
}
}
synchronize();
}
CUTLASS_DEVICE void
begin_loop(int epi_m, int epi_n) {
if (epi_m == 0 and not is_nullptr) { // Assumes M-major subtile loop
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
Tensor tSR_rRow_flt = make_tensor_like<ElementInput>(tSR_sRow_flt);
copy_aligned(tSR_sRow_flt, tSR_rRow_flt);
constexpr int FrgSize = size(tSR_rRow_flt);
using FrgInput = Array<ElementInput, FrgSize>;
using FrgCompute = Array<ElementCompute, FrgSize>;
using ConvertInput = NumericArrayConverter<ElementCompute, ElementInput, FrgSize>;
Tensor tSR_rRow_input_frg = recast<FrgInput>(coalesce(tSR_rRow_flt));
Tensor tSR_rRow_compute_frg = recast<FrgCompute>(filter(tSR_rRow));
ConvertInput convert_input{};
tSR_rRow_compute_frg(_0{}) = convert_input(tSR_rRow_input_frg(_0{}));
}
}
template <typename ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE Array<ElementCompute, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<ElementCompute, FragmentSize> frg_row;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) {
frg_row[i] = tSR_rRow(epi_v * FragmentSize + i);
}
return frg_row;
}
};
template <
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
class... Args
>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m, n, k, l] = args.tile_coord_mnkl;
using ThreadCount = decltype(size(args.tiled_copy));
auto layout_N = [&] () {
auto shape_N = get<1>(args.problem_shape_mnkl);
if constexpr (IsDynamicBroadcast) {
auto stride_N = repeat_like(shape_N, int(0));
if (get<1>(params.dRow) == bool(1)) {
stride_N = transform_leaf(compact_major<LayoutLeft>(shape_N),
[] (auto const& stride) { return static_cast<int>(stride); }
);
}
return make_layout(shape_N, stride_N);
}
else {
return make_layout(shape_N);
}
}();
auto layout_M = make_layout(M, repeat_like(M, _0{}));
auto layout_L = make_layout(L, get<2>(params.dRow));
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_layout(layout_M,layout_N,layout_L));
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
Tensor sRow = make_tensor(make_smem_ptr(smem),
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
//// G2S: Gmem to Smem
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, ElementInput>{},
Layout< Shape<_1, ThreadCount>,
Stride<_0, _1>>{},
Layout<_1>{});
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
//// G2S: Coord
Tensor tGS_cRow = thr_g2s.partition_S(args.cD);
//// S2R: Smem to Reg
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tSR_rRow = make_tensor_like<ElementCompute>(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
return ConsumerStoreCallbacks(
tGS_gRow,
tGS_sRow,
tGS_cRow, tiled_g2s,
tSR_sRow,
tSR_rRow,
args.residue_cD,
ThreadCount{},
params);
}
};
// Column vector broadcast with support for grouping.
template<
int Stages,
class CtaTileShapeMNK,
class ElementInput,
class ElementCompute = ElementInput,
class StrideMNL_ = Stride<_1,_0,_0>,
int Alignment = 128 / sizeof_bits_v<ElementInput>,
bool EnableNullptr = true // Fallback scalar broadcast for nullptr params
>
struct Sm90ColBroadcastPtrArray {
using StrideMNL = StrideMNL_;
static_assert(Stages == 0, "Column broadcast doesn't support smem pipelining");
static constexpr bool IsDynamicBroadcast = is_same_v<remove_cvref_t<decltype(get<0>(StrideMNL{}))>, bool>; // Column vector or scalar broadcast
static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))> || IsDynamicBroadcast); // batch stride can be dynamic or static
static_assert(take<0,2>(StrideMNL{}) == Stride<_1,_0>{} || IsDynamicBroadcast);
// Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem
struct SharedStorage { };
struct Arguments {
ElementInput const* const* ptr_col_array = nullptr;
ElementInput null_default = ElementInput(0);
StrideMNL dCol = {};
};
struct Params {
ElementInput const* const* ptr_col_array = nullptr;
ElementCompute null_default = ElementCompute(0);
StrideMNL dCol = {};
};
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return {args.ptr_col_array, ElementCompute(args.null_default), args.dCol};
}
template <class ProblemShape>
static bool
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
return true;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
template <class ProblemShape>
static cutlass::Status
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
CudaHostAdapter* cuda_adapter = nullptr) {
return cutlass::Status::kSuccess;
}
CUTLASS_DEVICE bool
is_producer_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_C_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_zero() const {
return is_zero_;
}
CUTLASS_HOST_DEVICE
Sm90ColBroadcastPtrArray() { }
CUTLASS_HOST_DEVICE
Sm90ColBroadcastPtrArray(Params const& params, SharedStorage const& shared_storage)
: params(params), is_zero_(false) {
auto const& [stride_M, stride_N, stride_L] = params.dCol;
// Nullptr default
if (EnableNullptr && params.ptr_col_array == nullptr) {
is_zero_ = params.null_default == ElementCompute(0);
}
// Dynamic non-batched scalar broadcast
else if (IsDynamicBroadcast && stride_M == bool(0) && stride_L == repeat_like(stride_L, 0)) {
is_zero_ = params.ptr_col_array[0][0] == ElementInput(0);
}
}
Params params;
bool is_zero_;
template <class... Args>
CUTLASS_DEVICE auto
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
return EmptyProducerLoadCallbacks{};
}
template<class GTensor, class RTensor, class CTensor, class ThrResidue>
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
CUTLASS_DEVICE
ConsumerStoreCallbacks(GTensor tCgCol_, RTensor tCrCol_, CTensor tCcCol_, ThrResidue residue_tCcCol_, Params const& params_)
: tCgCol(tCgCol_),
tCrCol(tCrCol_),
tCcCol(tCcCol_),
residue_tCcCol(residue_tCcCol_),
params(params_) {
if (EnableNullptr && params.ptr_col_array == nullptr) {
fill(tCrCol, params.null_default);
}
}
GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
RTensor tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
ThrResidue residue_tCcCol;
Params const& params;
CUTLASS_DEVICE void
begin() {
if (EnableNullptr && params.ptr_col_array == nullptr) {
return;
}
// Filter so we don't issue redundant copies over stride-0 modes
// (only works if 0-strides are in same location, which is by construction)
Tensor tCgCol_flt = filter_zeros(tCgCol);
Tensor tCrCol_flt = make_tensor_like<ElementInput>(filter_zeros(tCrCol));
Tensor tCcCol_flt = filter_zeros(tCcCol, tCgCol.stride());
constexpr auto MCL = decltype(max_common_layout(tCgCol_flt, tCrCol_flt)){};
constexpr int V = cute::min(Alignment, size(MCL));
if constexpr (V > 1) {
using VecType = uint_bit_t<V * sizeof_bits_v<ElementInput>>;
Tensor tCgCol_vec = recast<VecType>(coalesce(tCgCol_flt));
Tensor tCrCol_vec = recast<VecType>(coalesce(tCrCol_flt));
Tensor tCcCol_vec = tensor<1>(zipped_divide(tCcCol_flt, MCL.compose(Int<V>{})));
auto pred_fn = [&] (auto const&... coords) { return elem_less(tCcCol_vec(coords...), residue_tCcCol); };
copy_if(pred_fn, tCgCol_vec, tCrCol_vec);
}
else {
auto pred_fn = [&] (auto const&... coords) { return elem_less(tCcCol_flt(coords...), residue_tCcCol); };
copy_if(pred_fn, tCgCol_flt, tCrCol_flt);
}
constexpr int FrgSize = size(tCrCol_flt);
using FrgInput = Array<ElementInput, FrgSize>;
using FrgCompute = Array<ElementCompute, FrgSize>;
using ConvertInput = NumericArrayConverter<ElementCompute, ElementInput, FrgSize>;
Tensor tCrCol_input_frg = recast<FrgInput>(coalesce(tCrCol_flt));
Tensor tCrCol_compute_frg = recast<FrgCompute>(filter(tCrCol));
ConvertInput convert_input{};
tCrCol_compute_frg(_0{}) = convert_input(tCrCol_input_frg(_0{}));
}
template <typename ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE Array<ElementCompute, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<ElementCompute, FragmentSize> frg_col;
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) {
frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i);
}
return frg_col;
}
};
template <
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
class... Args
>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m, n, k, l] = args.tile_coord_mnkl;
auto layout_M = [&] () {
auto shape_M = get<0>(args.problem_shape_mnkl);
if constexpr (IsDynamicBroadcast) {
auto stride_M = repeat_like(shape_M, int(0));
if (get<0>(params.dCol) == bool(1)) {
stride_M = transform_leaf(compact_major<LayoutLeft>(shape_M),
[] (auto const& stride) { return static_cast<int>(stride); }
);
}
return make_layout(shape_M, stride_M);
}
else {
return make_layout(shape_M);
}
}();
auto layout_N = make_layout(N, repeat_like(N, _0{}));
auto layout_L = make_layout(L, get<2>(params.dCol));
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_layout(layout_M,layout_N,layout_L));
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor mCol_static = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_layout(make_layout(M),layout_N,layout_L));
Tensor tCgCol_static = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol_static, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tCrCol = make_tensor_like<ElementCompute>(tCgCol_static); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
return ConsumerStoreCallbacks(tCgCol, tCrCol, args.tCcD, args.residue_tCcD, params);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Do outer product from the column and row loaded
//
template<
int Stages,
class CtaTileShapeMNK,
class ElementScalar,
class StrideColMNL_ = Stride<_1,_0,int64_t>, /// NOTE: Batched scaling untested for now
class StrideRowMNL_ = Stride<_0,_1,int64_t>,
int Alignment = 128 / sizeof_bits_v<ElementScalar>,
bool EnableNullptr = false // Fallback scalar broadcast for nullptr params
>
struct Sm90OuterProduct {
using StrideColMNL = StrideColMNL_;
using StrideRowMNL = StrideRowMNL_;
static_assert(Stages == 0, "OuterProduct doesn't support smem usage");
static_assert(Alignment * sizeof_bits_v<ElementScalar> % 128 == 0, "sub-16B alignment not supported yet");
static_assert(!EnableNullptr, "Nullptr fallback not implemented");
static_assert(is_static_v<decltype(take<0,2>(StrideColMNL{}))> &&
is_static_v<decltype(take<0,2>(StrideRowMNL{}))>, "Only batch stride can be dynamic");
static_assert(take<0,2>(StrideColMNL{}) == Stride<_1,_0>{} &&
take<0,2>(StrideRowMNL{}) == Stride<_0,_1>{}, "Row and column incorrectly formatted");
// Accumulator distributes col/row elements evenly amongst threads so we can just directly load from gmem
struct SharedStorage { };
struct Arguments {
ElementScalar const* ptr_col = nullptr;
ElementScalar const* ptr_row = nullptr;
StrideColMNL dCol = {};
StrideRowMNL dRow = {};
};
using Params = Arguments;
template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
return args;
}
template <class ProblemShape>
static bool
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
return true;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
return 0;
}
template <class ProblemShape>
static cutlass::Status
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
CudaHostAdapter* cuda_adapter = nullptr) {
return cutlass::Status::kSuccess;
}
CUTLASS_DEVICE bool
is_producer_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_C_load_needed() const {
return false;
}
CUTLASS_DEVICE bool
is_zero() const {
return false;
}
CUTLASS_HOST_DEVICE
Sm90OuterProduct() { }
CUTLASS_HOST_DEVICE
Sm90OuterProduct(Params const& params, SharedStorage const& shared_storage)
: params(params) { }
Params params;
template <class... Args>
CUTLASS_DEVICE auto
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
return EmptyProducerLoadCallbacks{};
}
template<
class GTensorCol, class RTensorCol,
class GTensorRow, class RTensorRow
>
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
CUTLASS_DEVICE
ConsumerStoreCallbacks(GTensorCol&& tCgCol, RTensorCol&& tCrCol,
GTensorRow&& tCgRow, RTensorRow&& tCrRow,
Params const& params)
: tCgCol(cute::forward<GTensorCol>(tCgCol))
, tCrCol(cute::forward<RTensorCol>(tCrCol))
, tCgRow(cute::forward<GTensorRow>(tCgRow))
, tCrRow(cute::forward<RTensorRow>(tCrRow))
, params(params) {}
GTensorCol tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
RTensorCol tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
GTensorRow tCgRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
RTensorRow tCrRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Params const& params;
CUTLASS_DEVICE void
begin() {
// Filter so we don't issue redundant copies over stride-0 modes
copy(filter(tCgCol), filter(tCrCol));
copy(filter(tCgRow), filter(tCrRow));
}
template <typename ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE Array<ElementScalar, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc, int epi_v, int epi_m, int epi_n) {
Array<ElementScalar, FragmentSize> frg_colrow;
Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n);
Tensor tCrRow_mn = tCrRow(_,_,_,epi_m,epi_n);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) {
frg_colrow[i] = static_cast<ElementScalar>(tCrCol_mn(epi_v * FragmentSize + i) * tCrRow_mn(epi_v * FragmentSize + i));
}
return frg_colrow;
}
};
template <
bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy
class... Args
>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
auto [M, N, K, L] = args.problem_shape_mnkl;
Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol);
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
Tensor tCgCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tCgRow = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
mRow, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Tensor tCrRow = make_tensor_like(tCgRow); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
return ConsumerStoreCallbacks<
decltype(tCgCol), decltype(tCrCol),
decltype(tCgRow), decltype(tCrRow)
>(
cute::move(tCgCol), cute::move(tCrCol),
cute::move(tCgRow), cute::move(tCrRow),
params
);
}
};
}

View File

@ -7073,6 +7073,12 @@
dispatch:
CUDA: _scaled_mm_out_cuda
- func: _scaled_grouped_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? offs=None, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor
variants: function
dispatch:
CUDA: _scaled_grouped_mm_cuda
# NOTE [ Sparse: autograd and API ]
#
#

View File

@ -98,6 +98,26 @@ if(INTERN_BUILD_ATEN_OPS)
endif()
list(JOIN ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS " " ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS)
set_source_files_properties(${ROWWISE_SCALED_MM_FILE} PROPERTIES COMPILE_FLAGS "${ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS}")
set(ROWWISE_SCALED_MM_FILE "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu")
# Get existing arch flags
torch_cuda_get_nvcc_gencode_flag(EXISTING_ARCH_FLAGS)
# Check NVCC version and existing arch flags
set(ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "")
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.0)
if(EXISTING_ARCH_FLAGS MATCHES ".*compute_86.*")
list(APPEND ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "-gencode;arch=compute_89,code=sm_89")
endif()
if(EXISTING_ARCH_FLAGS MATCHES ".*compute_90.*")
list(APPEND ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS "-gencode;arch=compute_90a,code=sm_90a")
endif()
endif()
list(JOIN ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS " " ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS)
set_source_files_properties(${ROWWISE_SCALED_MM_FILE} PROPERTIES COMPILE_FLAGS "${ROWWISE_SCALED_MM_FILE_COMPILE_FLAGS}")
endif()
set(GEN_ROCM_FLAG)

View File

@ -518,6 +518,7 @@ aten::_scaled_dot_product_flash_attention_backward
aten::_scaled_dot_product_flash_attention_for_cpu_backward
aten::_scaled_dot_product_fused_attention_overrideable
aten::_scaled_dot_product_fused_attention_overrideable_backward
aten::_scaled_grouped_mm
aten::_scaled_mm
aten::_scaled_mm.out
aten::_segment_reduce_backward

View File

@ -1165,6 +1165,117 @@ class TestFP8MatmulCuda(TestCase):
out_dtype=torch.bfloat16,
)
def grouped_mm_helper(self, alist, blist, ascalelist, bscalelist, outlist, use_fast_accum):
for a, b, ascale, bscale, out in zip(alist, blist, ascalelist, bscalelist, outlist):
out_ref = torch._scaled_mm(a, b.t(), ascale.view(-1, 1), bscale.view(1, -1),
out_dtype=torch.bfloat16, use_fast_accum=use_fast_accum)
self.assertEqual(out, out_ref)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
@parametrize("fast_accum", [False, True])
@parametrize("strided", [False, True])
def test_grouped_gemm_2d_2d(self, fast_accum, strided):
device = "cuda"
m, n, k, n_groups = 16, 16, 16, 4 # all sizes have to be divisible by 16
a = torch.randn(m, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups]
b = torch.randn(n, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups]
scale_a = torch.arange(m * n_groups, device=device, dtype=torch.float32) / 4
scale_b = torch.arange(n * n_groups, device=device, dtype=torch.float32) / 4
offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32)
out = torch._scaled_grouped_mm(a, b.t(), scale_a, scale_b, offs=offs,
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
offs_cpu = offs.cpu()
alist, blist, ascalelist, bscalelist = [], [], [], []
start = 0
for i in range(n_groups):
alist.append(a[:, start:offs_cpu[i]])
blist.append(b[:, start:offs_cpu[i]])
ascalelist.append(scale_a[i * m : (i + 1) * m])
bscalelist.append(scale_b[i * n : (i + 1) * n])
start = offs_cpu[i]
self.grouped_mm_helper(alist, blist, ascalelist, bscalelist, out, fast_accum)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
@parametrize("fast_accum", [False, True])
@parametrize("strided", [False, True])
def test_grouped_gemm_2d_3d(self, fast_accum, strided):
device = "cuda"
s_int = int(strided)
m, n, k, n_groups = 16, 32, 16, 4
a = torch.randn(m * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k]
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
self.assertTrue(a.is_contiguous() is not strided)
self.assertTrue(b.is_contiguous() is not strided)
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
scale_a = torch.arange(n_groups * m, device="cuda", dtype=torch.float32)
scale_b = torch.ones(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
out = torch._scaled_grouped_mm(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs,
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
offs_cpu = offs.cpu()
alist, ascalelist, outlist = [], [], []
start = 0
for i in range(n_groups):
alist.append(a[start:offs_cpu[i]])
ascalelist.append(scale_a[start:offs_cpu[i]])
outlist.append(out[start:offs_cpu[i]])
start = offs_cpu[i]
self.grouped_mm_helper(alist, b, ascalelist, scale_b, outlist, fast_accum)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
@parametrize("fast_accum", [False, True])
@parametrize("strided", [False, True])
def test_grouped_gemm_3d_3d(self, fast_accum, strided):
device = "cuda"
s_int = int(strided)
m, n, k, n_groups = 16, 32, 16, 4
a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
self.assertTrue(a.is_contiguous() is not strided)
self.assertTrue(b.is_contiguous() is not strided)
scale_a = torch.ones(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m)
scale_b = torch.ones(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
out = torch._scaled_grouped_mm(a, b.transpose(-2, -1), scale_a, scale_b,
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
self.grouped_mm_helper(a, b, scale_a, scale_b, out, fast_accum)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@unittest.skipIf(not SM90OrLater, "Grouped gemm supported on SM90")
@parametrize("fast_accum", [False, True])
@parametrize("strided", [False, True])
def test_grouped_gemm_3d_2d(self, fast_accum, strided):
device = "cuda"
s_int = int(strided)
m, n, k, n_groups = 16, 32, 16, 4
a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
b = torch.randn(n * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k]
self.assertTrue(a.is_contiguous() is not strided)
self.assertTrue(b.is_contiguous() is not strided)
scale_a = torch.arange(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m)
scale_b = torch.arange(n_groups * n, device="cuda", dtype=torch.float32)
offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32)
out = torch._scaled_grouped_mm(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs,
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
offs_cpu = offs.cpu()
blist, bscalelist, outlist = [], [], []
start = 0
for i in range(n_groups):
blist.append(b[start:offs_cpu[i]])
bscalelist.append(scale_b[start:offs_cpu[i]])
outlist.append(out[:, start:offs_cpu[i]])
start = offs_cpu[i]
self.grouped_mm_helper(a, blist, scale_a, bscalelist, outlist, fast_accum)
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")