Update cutlass from 3.3.0 to 3.4.1 (#120434)

### COPY OF https://github.com/pytorch/pytorch/pull/120010

### Update
I have rolled the two blocking changes into this PR, I also imported this to fbcode to verify that nothing is breaking:
D53870253

This copy was generated by merging in all the internal only changes into one merged atomic commit and re-exporting to github

### Current Status
- [PR](https://github.com/pytorch/pytorch/pull/118935) aims to update the flash attention kernels to a more recent version.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120434
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
This commit is contained in:
Driss Guessous
2024-02-23 03:57:26 +00:00
committed by PyTorch MergeBot
parent f609f2050f
commit 36c1cc962a
6 changed files with 266 additions and 149 deletions

View File

@ -3,31 +3,36 @@
#include <ATen/cuda/CUDAUtils.h> #include <ATen/cuda/CUDAUtils.h>
#include <ATen/Dispatch.h> #include <ATen/Dispatch.h>
#ifndef USE_ROCM #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
#else
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cutlass/cutlass.h> #include <cutlass/cutlass.h>
#include <cutlass/layout/layout.h> #include <cutlass/layout/layout.h>
#include <cutlass/tensor_ref.h> #include <cutlass/tensor_ref.h>
#include <cutlass/epilogue/thread/linear_combination.h> #include <cutlass/gemm/device/gemm_sparse_with_visitor.h>
#include <cutlass/epilogue/thread/linear_combination_relu.h> #include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
#include <cutlass/epilogue/thread/linear_combination_silu.h>
#include <cutlass/gemm/device/gemm_sparse_row_broadcast.h>
#endif #endif
#include <type_traits> #include <type_traits>
#include <tuple> #include <tuple>
#ifndef USE_ROCM #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
#else
#define CUTLASS_STATUS_CHECK(status) \ #define CUTLASS_STATUS_CHECK(status) \
{ \ { \
TORCH_CHECK(status == cutlass::Status::kSuccess, \ TORCH_CHECK(status == cutlass::Status::kSuccess, \
"Got CUTLASS error: ", cutlassGetStatusString(status)); \ "Got CUTLASS error: ", cutlassGetStatusString(status)); \
} }
namespace {
enum class Activation{NONE, RELU, SILU};
}
#endif #endif
namespace at::native { namespace at::native {
#ifndef USE_ROCM #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
#else
// Wrapper function for CUTLASS sparse GEMM implementation, used // Wrapper function for CUTLASS sparse GEMM implementation, used
// solely to simplify dispatching from // solely to simplify dispatching from
// _sparse_semi_structured_linear() function below. // _sparse_semi_structured_linear() function below.
@ -36,14 +41,14 @@ template <
typename ElementInputB, typename ElementInputB,
typename ElementOutput, typename ElementOutput,
typename ElementAccumulator, typename ElementAccumulator,
typename ElementComputeEpilogue,
typename ThreadblockShape, typename ThreadblockShape,
typename WarpShape, typename WarpShape,
typename InstructionShape, typename InstructionShape,
typename EpilogueOp,
typename LayoutInputA, typename LayoutInputA,
typename LayoutInputB> typename LayoutInputB,
Tensor two_four_sgemm_cutlass( bool use_bias,
Activation activation>
Tensor two_four_sgemm(
const Tensor& tensor_a, const Tensor& tensor_a,
const at::IntArrayRef::value_type& tensor_a_stride, const at::IntArrayRef::value_type& tensor_a_stride,
const Tensor& tensor_b, const Tensor& tensor_b,
@ -57,22 +62,107 @@ Tensor two_four_sgemm_cutlass(
using SmArch = cutlass::arch::Sm80; // Only CC 8.x devices are supported at the moment. using SmArch = cutlass::arch::Sm80; // Only CC 8.x devices are supported at the moment.
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This choice provides good performance across wide range of operand sizes. using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This choice provides good performance across wide range of operand sizes.
constexpr int NumStages = 3; // This choice provides good performance across wide range of operand sizes. constexpr int NumStages = 3; // This choice provides good performance across wide range of operand sizes.
using Gemm = cutlass::gemm::device::SparseGemmRowBroadcast< using Operator = cutlass::arch::OpMultiplyAdd;
constexpr int NumEVTEpilogueStages = 1;
constexpr int AlignmentInputA = 128 / cutlass::sizeof_bits<ElementInputA>::value;
constexpr int AlignmentInputB = 128 / cutlass::sizeof_bits<ElementInputB>::value;
constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;
using ElementComputeEpilogue = ElementAccumulator;
constexpr int AlignmentComputeEpilogue = 128 / cutlass::sizeof_bits<ElementComputeEpilogue>::value;
using ElementC = ElementOutput;
using LayoutC = LayoutOutput;
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
using BiasTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
ThreadblockShape,
WarpShape,
ElementC,
AlignmentC,
NumEVTEpilogueStages>;
using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
ThreadblockShape,
WarpShape,
ElementOutput,
AlignmentOutput,
NumEVTEpilogueStages>;
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
using BiasScalar =
cutlass::epilogue::threadblock::VisitorScalarBroadcast<ElementC>;
using BiasTensor =
cutlass::epilogue::threadblock::VisitorColBroadcast<
BiasTileThreadMap,
ElementC,
cute::Stride<cute::_1, cute::_0, int64_t>>;
using Bias = std::conditional_t<use_bias, BiasTensor, BiasScalar>;
using BiasArguments = typename Bias::Arguments;
using ApplyBias = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::plus, ElementComputeEpilogue, ElementComputeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTApplyBias = cutlass::epilogue::threadblock::Sm80EVT<
ApplyBias,
Accum,
Bias>;
using ApplyActivationNone = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::epilogue::thread::Identity,
ElementComputeEpilogue,
ElementComputeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
using ApplyActivationReLu = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::epilogue::thread::ReLu,
ElementComputeEpilogue,
ElementComputeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
using ApplyActivationSiLu = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::epilogue::thread::SiLu,
ElementComputeEpilogue,
ElementComputeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
using ApplyActivation =
std::conditional_t<
activation == Activation::NONE,
ApplyActivationNone,
std::conditional_t<
activation == Activation::RELU,
ApplyActivationReLu,
ApplyActivationSiLu>>;
using EVTApplyActivation = cutlass::epilogue::threadblock::Sm80EVT<
ApplyActivation,
EVTApplyBias>;
using Output = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, ElementOutput, cutlass::FloatRoundStyle::round_to_nearest,
cute::Stride<int64_t, cute::_1, int64_t>>;
using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT<
Output,
EVTApplyActivation>;
using Gemm = cutlass::gemm::device::SparseGemmWithVisitor<
ElementInputA, ElementInputA,
LayoutInputA, LayoutInputA,
ElementInputB, ElementInputB,
LayoutInputB, LayoutInputB,
ElementOutput, ElementC,
LayoutOutput, LayoutC,
ElementAccumulator, ElementAccumulator,
MMAOp, MMAOp,
SmArch, SmArch,
ThreadblockShape, ThreadblockShape,
WarpShape, WarpShape,
InstructionShape, InstructionShape,
EpilogueOp, EVTOutput,
SwizzleThreadBlock, SwizzleThreadBlock,
NumStages>; NumStages,
AlignmentInputA,
AlignmentInputB,
Operator,
NumEVTEpilogueStages>;
// Datatype and layout of metadata matrix are inferred from sparse // Datatype and layout of metadata matrix are inferred from sparse
// GEMM template. // GEMM template.
@ -105,11 +195,11 @@ Tensor two_four_sgemm_cutlass(
meta_dtype = at::kInt; meta_dtype = at::kInt;
break; break;
default: default:
AT_ERROR("two_four_sgemm_cutlass: invalid size of meta tensor datatype " AT_ERROR("two_four_sgemm: invalid size of meta tensor datatype "
"encountered"); "encountered");
} }
TORCH_CHECK(meta.dtype() == meta_dtype, TORCH_CHECK(meta.dtype() == meta_dtype,
"two_four_sgemm_cutlass: Expected meta datatype ", meta_dtype, "two_four_sgemm: Expected meta datatype ", meta_dtype,
", but got ", meta.dtype()); ", but got ", meta.dtype());
// Determine PyTorch datatype for the output matrix. // Determine PyTorch datatype for the output matrix.
@ -125,64 +215,69 @@ Tensor two_four_sgemm_cutlass(
} else if constexpr (std::is_same_v<ElementOutput, float>) { } else if constexpr (std::is_same_v<ElementOutput, float>) {
tensor_d_dtype = at::kFloat; tensor_d_dtype = at::kFloat;
} else { } else {
AT_ERROR("two_four_sgemm_cutlass: invalid datatype for sparse GEMM ", AT_ERROR("two_four_sgemm: invalid datatype for sparse GEMM output ",
" output encountered"); "encountered");
} }
if (tensor_c.numel() != 0) { if constexpr (use_bias) {
TORCH_CHECK(tensor_c.dtype() == tensor_d_dtype, TORCH_CHECK(tensor_c.dtype() == tensor_d_dtype,
"two_four_sgemm_cutlass: Expected sparse GEMM bias " "two_four_sgemm: Expected sparse GEMM bias datatype ",
"datatype ", tensor_d_dtype, ", but got ", tensor_d_dtype, ", but got ", tensor_c.dtype());
tensor_c.dtype());
} }
// Create output matrix. // Create output matrix.
Tensor tensor_d; Tensor tensor_d =
if (tensor_c.numel() != 0) { tensor_a.new_empty({length_m, length_n},
tensor_d = tensor_c.new_empty({length_m, length_n}); at::TensorOptions().dtype(tensor_d_dtype));
} else {
tensor_d =
tensor_a.new_empty({length_m, length_n},
at::TensorOptions().dtype(tensor_d_dtype));
}
// Prepare arguments for CUTLASS sparse GEMM kernel. // Prepare arguments for CUTLASS sparse GEMM kernel.
cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k); cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);
LayoutInputA layout_a(tensor_a_stride); LayoutInputA layout_a(tensor_a_stride);
LayoutInputB layout_b(tensor_b_stride); LayoutInputB layout_b(tensor_b_stride);
LayoutOutput layout_c(tensor_c.numel() != 0 ? tensor_c.stride(0) : 0);
LayoutOutput layout_d(tensor_d.stride(0));
auto tensor_a_device_ref = auto tensor_a_device_ref =
cutlass::TensorRef<ElementInputA, LayoutInputA>( cutlass::TensorRef<ElementInputA, LayoutInputA>(
(ElementInputA*)tensor_a.data_ptr(), layout_a); (ElementInputA*)tensor_a.data_ptr(), layout_a);
auto tensor_b_device_ref = auto tensor_b_device_ref =
cutlass::TensorRef<ElementInputB, LayoutInputB>( cutlass::TensorRef<ElementInputB, LayoutInputB>(
(ElementInputB*)tensor_b.data_ptr(), layout_b); (ElementInputB*)tensor_b.data_ptr(), layout_b);
auto tensor_c_device_ref =
cutlass::TensorRef<ElementOutput, LayoutOutput>(
(ElementOutput*)(tensor_c.numel() != 0 ?
tensor_c.data_ptr() : tensor_d.data_ptr()),
layout_c);
auto tensor_d_device_ref =
cutlass::TensorRef<ElementOutput, LayoutOutput>(
(ElementOutput*)tensor_d.data_ptr(), layout_d);
auto tensor_e_reordered_device_ref = auto tensor_e_reordered_device_ref =
cutlass::TensorRef<ElementInputE, ReorderedLayoutInputE>( cutlass::TensorRef<ElementInputE, ReorderedLayoutInputE>(
(ElementInputE*)meta.data_ptr(), (ElementInputE*)meta.data_ptr(),
ReorderedLayoutInputE::packed({length_m, meta_ncols})); ReorderedLayoutInputE::packed({length_m, meta_ncols}));
ElementComputeEpilogue alpha(1);
ElementComputeEpilogue beta(tensor_c.numel() != 0 ? 1 : 0); BiasArguments bias_arguments{
constexpr int split_k_slices = 1; [&]() -> BiasArguments {
if constexpr (use_bias) {
return {(ElementC*)tensor_c.data_ptr(),
ElementC(0),
{cute::_1{}, cute::_0{}, problem_size.m()}};
} else {
return {ElementC(0)};
}
}()
};
typename Output::Arguments output_arguments{
(ElementOutput*)tensor_d.data_ptr(),
{problem_size.n(), cute::_1{}, problem_size.mn().product()}
};
typename EVTOutput::Arguments callback_arguments{
{
{
{}, // Accum
bias_arguments, // Bias
{} // ApplyBias
}, // EVTApplyBias
{} // ApplyActivation
}, // EVTApplyActivation
output_arguments, // Output
}; // EVTOutput
// Create a tuple of CUTLASS sparse GEMM kernel arguments. // Create a tuple of CUTLASS sparse GEMM kernel arguments.
typename Gemm::Arguments arguments{ typename Gemm::Arguments arguments{
problem_size, problem_size,
tensor_a_device_ref, tensor_a_device_ref,
tensor_b_device_ref, tensor_b_device_ref,
tensor_c_device_ref,
tensor_d_device_ref,
tensor_e_reordered_device_ref, tensor_e_reordered_device_ref,
{alpha, beta}, callback_arguments};
split_k_slices};
cutlass::Status status; cutlass::Status status;
@ -219,16 +314,16 @@ template <
typename ElementInputB, typename ElementInputB,
typename ElementOutput, typename ElementOutput,
typename ElementAccumulator, typename ElementAccumulator,
typename ElementComputeEpilogue,
typename ThreadblockShape, typename ThreadblockShape,
typename WarpShape, typename WarpShape,
typename InstructionShape, typename InstructionShape,
typename EpilogueOp,
bool EnableRowMajorRowMajorLayouts, bool EnableRowMajorRowMajorLayouts,
bool EnableRowMajorColumnMajorLayouts, bool EnableRowMajorColumnMajorLayouts,
bool EnableColumnMajorRowMajorLayouts, bool EnableColumnMajorRowMajorLayouts,
bool EnableColumnMajorColumnMajorLayouts> bool EnableColumnMajorColumnMajorLayouts,
Tensor two_four_sgemm_cutlass_dispatch_layouts( bool use_bias,
Activation activation>
Tensor two_four_sgemm_dispatch_layouts(
const Tensor& tensor_a, const Tensor& tensor_b, const Tensor& tensor_c, const Tensor& tensor_a, const Tensor& tensor_b, const Tensor& tensor_c,
const Tensor& meta) { const Tensor& meta) {
// Determine layouts (row-major or column-major) of input tensors. // Determine layouts (row-major or column-major) of input tensors.
@ -242,18 +337,18 @@ Tensor two_four_sgemm_cutlass_dispatch_layouts(
// Perform dispatching. // Perform dispatching.
if constexpr (EnableRowMajorRowMajorLayouts) { if constexpr (EnableRowMajorRowMajorLayouts) {
if (tensor_a_row_major && tensor_b_row_major) { if (tensor_a_row_major && tensor_b_row_major) {
return two_four_sgemm_cutlass< return two_four_sgemm<
ElementInputA, ElementInputA,
ElementInputB, ElementInputB,
ElementOutput, ElementOutput,
ElementAccumulator, ElementAccumulator,
ElementComputeEpilogue,
ThreadblockShape, ThreadblockShape,
WarpShape, WarpShape,
InstructionShape, InstructionShape,
EpilogueOp,
cutlass::layout::RowMajor, cutlass::layout::RowMajor,
cutlass::layout::RowMajor>( cutlass::layout::RowMajor,
use_bias,
activation>(
tensor_a, tensor_a,
tensor_a_stride, tensor_a_stride,
tensor_b, tensor_b,
@ -264,18 +359,18 @@ Tensor two_four_sgemm_cutlass_dispatch_layouts(
} }
if constexpr (EnableRowMajorColumnMajorLayouts) { if constexpr (EnableRowMajorColumnMajorLayouts) {
if (tensor_a_row_major && !tensor_b_row_major) { if (tensor_a_row_major && !tensor_b_row_major) {
return two_four_sgemm_cutlass< return two_four_sgemm<
ElementInputA, ElementInputA,
ElementInputB, ElementInputB,
ElementOutput, ElementOutput,
ElementAccumulator, ElementAccumulator,
ElementComputeEpilogue,
ThreadblockShape, ThreadblockShape,
WarpShape, WarpShape,
InstructionShape, InstructionShape,
EpilogueOp,
cutlass::layout::RowMajor, cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor>( cutlass::layout::ColumnMajor,
use_bias,
activation>(
tensor_a, tensor_a,
tensor_a_stride, tensor_a_stride,
tensor_b, tensor_b,
@ -286,18 +381,18 @@ Tensor two_four_sgemm_cutlass_dispatch_layouts(
} }
if constexpr (EnableColumnMajorRowMajorLayouts) { if constexpr (EnableColumnMajorRowMajorLayouts) {
if (!tensor_a_row_major && tensor_b_row_major) { if (!tensor_a_row_major && tensor_b_row_major) {
return two_four_sgemm_cutlass< return two_four_sgemm<
ElementInputA, ElementInputA,
ElementInputB, ElementInputB,
ElementOutput, ElementOutput,
ElementAccumulator, ElementAccumulator,
ElementComputeEpilogue,
ThreadblockShape, ThreadblockShape,
WarpShape, WarpShape,
InstructionShape, InstructionShape,
EpilogueOp,
cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>( cutlass::layout::RowMajor,
use_bias,
activation>(
tensor_a, tensor_a,
tensor_a_stride, tensor_a_stride,
tensor_b, tensor_b,
@ -308,18 +403,18 @@ Tensor two_four_sgemm_cutlass_dispatch_layouts(
} }
if constexpr (EnableColumnMajorColumnMajorLayouts) { if constexpr (EnableColumnMajorColumnMajorLayouts) {
if (!tensor_a_row_major && !tensor_b_row_major) { if (!tensor_a_row_major && !tensor_b_row_major) {
return two_four_sgemm_cutlass< return two_four_sgemm<
ElementInputA, ElementInputA,
ElementInputB, ElementInputB,
ElementOutput, ElementOutput,
ElementAccumulator, ElementAccumulator,
ElementComputeEpilogue,
ThreadblockShape, ThreadblockShape,
WarpShape, WarpShape,
InstructionShape, InstructionShape,
EpilogueOp,
cutlass::layout::ColumnMajor, cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor>( cutlass::layout::ColumnMajor,
use_bias,
activation>(
tensor_a, tensor_a,
tensor_a_stride, tensor_a_stride,
tensor_b, tensor_b,
@ -329,20 +424,77 @@ Tensor two_four_sgemm_cutlass_dispatch_layouts(
} }
} }
AT_ERROR("two_four_sgemm_cutlass_dispatch_layouts: Combination of ", AT_ERROR("two_four_sgemm_dispatch_layouts: Combination of ",
tensor_a_row_major ? "row-major" : "column_major", " and ", tensor_a_row_major ? "row-major" : "column_major", " and ",
tensor_b_row_major ? "row-major" : "column_major", tensor_b_row_major ? "row-major" : "column_major",
" layouts for input tensors is not supported"); " layouts for input tensors is not supported");
return Tensor{}; return Tensor{};
} }
// Dispatch according to the bias tensor being provided or not.
template <
typename ElementInputA,
typename ElementInputB,
typename ElementOutput,
typename ElementAccumulator,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
bool EnableRowMajorRowMajorLayouts,
bool EnableRowMajorColumnMajorLayouts,
bool EnableColumnMajorRowMajorLayouts,
bool EnableColumnMajorColumnMajorLayouts,
Activation activation>
Tensor two_four_sgemm_dispatch_layouts_bias(
const Tensor& tensor_a, const Tensor& tensor_b, const Tensor& tensor_c,
const Tensor& meta) {
if (tensor_c.numel() > 0) {
return two_four_sgemm_dispatch_layouts<
ElementInputA,
ElementInputB,
ElementOutput,
ElementAccumulator,
ThreadblockShape,
WarpShape,
InstructionShape,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
EnableColumnMajorColumnMajorLayouts,
true,
activation>(
tensor_a,
tensor_b,
tensor_c,
meta);
} else {
return two_four_sgemm_dispatch_layouts<
ElementInputA,
ElementInputB,
ElementOutput,
ElementAccumulator,
ThreadblockShape,
WarpShape,
InstructionShape,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
EnableColumnMajorColumnMajorLayouts,
false,
activation>(
tensor_a,
tensor_b,
tensor_c,
meta);
}
}
// Dispatch according to the activation functions enabled. // Dispatch according to the activation functions enabled.
template < template <
typename ElementInputA, typename ElementInputA,
typename ElementInputB, typename ElementInputB,
typename ElementOutput, typename ElementOutput,
typename ElementAccumulator, typename ElementAccumulator,
typename ElementComputeEpilogue,
typename ThreadblockShape, typename ThreadblockShape,
typename WarpShape, typename WarpShape,
typename InstructionShape, typename InstructionShape,
@ -353,32 +505,25 @@ template <
bool EnableActivationNone, bool EnableActivationNone,
bool EnableActivationReLU, bool EnableActivationReLU,
bool EnableActivationSiLU> bool EnableActivationSiLU>
Tensor two_four_sgemm_cutlass_dispatch_layouts_activation( Tensor two_four_sgemm_dispatch_layouts_bias_activation(
const Tensor& tensor_a, const Tensor& tensor_b, const Tensor& tensor_c, const Tensor& tensor_a, const Tensor& tensor_b, const Tensor& tensor_c,
const Tensor& meta, const c10::string_view& activation) { const Tensor& meta, const c10::string_view& activation) {
// Perform dispatching. // Perform dispatching.
if constexpr (EnableActivationNone) { if constexpr (EnableActivationNone) {
if (activation == "none") { if (activation == "none") {
using EpilogueOp = return two_four_sgemm_dispatch_layouts_bias<
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementComputeEpilogue>;
return two_four_sgemm_cutlass_dispatch_layouts<
ElementInputA, ElementInputA,
ElementInputB, ElementInputB,
ElementOutput, ElementOutput,
ElementAccumulator, ElementAccumulator,
ElementComputeEpilogue,
ThreadblockShape, ThreadblockShape,
WarpShape, WarpShape,
InstructionShape, InstructionShape,
EpilogueOp,
EnableRowMajorRowMajorLayouts, EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts, EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts, EnableColumnMajorRowMajorLayouts,
EnableColumnMajorColumnMajorLayouts>( EnableColumnMajorColumnMajorLayouts,
Activation::NONE>(
tensor_a, tensor_a,
tensor_b, tensor_b,
tensor_c, tensor_c,
@ -387,26 +532,19 @@ Tensor two_four_sgemm_cutlass_dispatch_layouts_activation(
} }
if constexpr (EnableActivationReLU) { if constexpr (EnableActivationReLU) {
if (activation == "relu") { if (activation == "relu") {
using EpilogueOp = return two_four_sgemm_dispatch_layouts_bias<
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementComputeEpilogue>;
return two_four_sgemm_cutlass_dispatch_layouts<
ElementInputA, ElementInputA,
ElementInputB, ElementInputB,
ElementOutput, ElementOutput,
ElementAccumulator, ElementAccumulator,
ElementComputeEpilogue,
ThreadblockShape, ThreadblockShape,
WarpShape, WarpShape,
InstructionShape, InstructionShape,
EpilogueOp,
EnableRowMajorRowMajorLayouts, EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts, EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts, EnableColumnMajorRowMajorLayouts,
EnableColumnMajorColumnMajorLayouts>( EnableColumnMajorColumnMajorLayouts,
Activation::RELU>(
tensor_a, tensor_a,
tensor_b, tensor_b,
tensor_c, tensor_c,
@ -415,26 +553,19 @@ Tensor two_four_sgemm_cutlass_dispatch_layouts_activation(
} }
if constexpr (EnableActivationSiLU) { if constexpr (EnableActivationSiLU) {
if (activation == "silu") { if (activation == "silu") {
using EpilogueOp = return two_four_sgemm_dispatch_layouts_bias<
cutlass::epilogue::thread::LinearCombinationSilu<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementComputeEpilogue>;
return two_four_sgemm_cutlass_dispatch_layouts<
ElementInputA, ElementInputA,
ElementInputB, ElementInputB,
ElementOutput, ElementOutput,
ElementAccumulator, ElementAccumulator,
ElementComputeEpilogue,
ThreadblockShape, ThreadblockShape,
WarpShape, WarpShape,
InstructionShape, InstructionShape,
EpilogueOp,
EnableRowMajorRowMajorLayouts, EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts, EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts, EnableColumnMajorRowMajorLayouts,
EnableColumnMajorColumnMajorLayouts>( EnableColumnMajorColumnMajorLayouts,
Activation::SILU>(
tensor_a, tensor_a,
tensor_b, tensor_b,
tensor_c, tensor_c,
@ -442,8 +573,8 @@ Tensor two_four_sgemm_cutlass_dispatch_layouts_activation(
} }
} }
AT_ERROR("two_four_sgemm_cutlass_dispatch_layouts: Activation \"", AT_ERROR("two_four_sgemm_dispatch_layouts: Activation \"", activation,
activation, "\" is not supported for given input tensors"); "\" is not supported for given input tensors");
return Tensor{}; return Tensor{};
} }
#endif #endif
@ -472,7 +603,10 @@ Tensor _sparse_semi_structured_linear(
const Tensor& meta, const c10::optional<Tensor>& bias_opt, const Tensor& meta, const c10::optional<Tensor>& bias_opt,
const c10::optional<c10::string_view> activation_opt, const c10::optional<c10::string_view> activation_opt,
const c10::optional<c10::ScalarType> out_dtype_opt) { const c10::optional<c10::ScalarType> out_dtype_opt) {
#ifndef USE_ROCM #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
AT_ERROR("_sparse_semi_structured_linear: CUTLASS not supported");
return Tensor{};
#else
// No need to check that all tensors are on CUDA device, as this // No need to check that all tensors are on CUDA device, as this
// is provided by dispatch. // is provided by dispatch.
@ -575,7 +709,6 @@ Tensor _sparse_semi_structured_linear(
using ElementInputA = int8_t; using ElementInputA = int8_t;
using ElementInputB = int8_t; using ElementInputB = int8_t;
using ElementAccumulator = int32_t; using ElementAccumulator = int32_t;
using ElementComputeEpilogue = int32_t;
using ThreadblockShape = using ThreadblockShape =
cutlass::gemm::GemmShape<128, 128, 128>; cutlass::gemm::GemmShape<128, 128, 128>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>;
@ -589,12 +722,11 @@ Tensor _sparse_semi_structured_linear(
const auto EnableActivationSiLU = false; const auto EnableActivationSiLU = false;
if (out_dtype_opt.has_value()) { if (out_dtype_opt.has_value()) {
using ElementOutput = int32_t; using ElementOutput = int32_t;
output = two_four_sgemm_cutlass_dispatch_layouts_activation< output = two_four_sgemm_dispatch_layouts_bias_activation<
ElementInputA, ElementInputA,
ElementInputB, ElementInputB,
ElementOutput, ElementOutput,
ElementAccumulator, ElementAccumulator,
ElementComputeEpilogue,
ThreadblockShape, ThreadblockShape,
WarpShape, WarpShape,
InstructionShape, InstructionShape,
@ -612,12 +744,11 @@ Tensor _sparse_semi_structured_linear(
activation); activation);
} else { } else {
using ElementOutput = int8_t; using ElementOutput = int8_t;
output = two_four_sgemm_cutlass_dispatch_layouts_activation< output = two_four_sgemm_dispatch_layouts_bias_activation<
ElementInputA, ElementInputA,
ElementInputB, ElementInputB,
ElementOutput, ElementOutput,
ElementAccumulator, ElementAccumulator,
ElementComputeEpilogue,
ThreadblockShape, ThreadblockShape,
WarpShape, WarpShape,
InstructionShape, InstructionShape,
@ -643,7 +774,6 @@ Tensor _sparse_semi_structured_linear(
using ElementInputB = cutlass::half_t; using ElementInputB = cutlass::half_t;
using ElementOutput = cutlass::half_t; using ElementOutput = cutlass::half_t;
using ElementAccumulator = float; using ElementAccumulator = float;
using ElementComputeEpilogue = float;
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
@ -654,12 +784,11 @@ Tensor _sparse_semi_structured_linear(
const auto EnableActivationNone = true; const auto EnableActivationNone = true;
const auto EnableActivationReLU = true; const auto EnableActivationReLU = true;
const auto EnableActivationSiLU = true; const auto EnableActivationSiLU = true;
output = two_four_sgemm_cutlass_dispatch_layouts_activation< output = two_four_sgemm_dispatch_layouts_bias_activation<
ElementInputA, ElementInputA,
ElementInputB, ElementInputB,
ElementOutput, ElementOutput,
ElementAccumulator, ElementAccumulator,
ElementComputeEpilogue,
ThreadblockShape, ThreadblockShape,
WarpShape, WarpShape,
InstructionShape, InstructionShape,
@ -684,7 +813,6 @@ Tensor _sparse_semi_structured_linear(
using ElementInputB = cutlass::bfloat16_t; using ElementInputB = cutlass::bfloat16_t;
using ElementOutput = cutlass::bfloat16_t; using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float; using ElementAccumulator = float;
using ElementComputeEpilogue = float;
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
@ -695,12 +823,11 @@ Tensor _sparse_semi_structured_linear(
const auto EnableActivationNone = true; const auto EnableActivationNone = true;
const auto EnableActivationReLU = true; const auto EnableActivationReLU = true;
const auto EnableActivationSiLU = true; const auto EnableActivationSiLU = true;
output = two_four_sgemm_cutlass_dispatch_layouts_activation< output = two_four_sgemm_dispatch_layouts_bias_activation<
ElementInputA, ElementInputA,
ElementInputB, ElementInputB,
ElementOutput, ElementOutput,
ElementAccumulator, ElementAccumulator,
ElementComputeEpilogue,
ThreadblockShape, ThreadblockShape,
WarpShape, WarpShape,
InstructionShape, InstructionShape,
@ -725,7 +852,6 @@ Tensor _sparse_semi_structured_linear(
using ElementInputB = float; using ElementInputB = float;
using ElementOutput = float; using ElementOutput = float;
using ElementAccumulator = float; using ElementAccumulator = float;
using ElementComputeEpilogue = float;
using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>;
using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
@ -736,12 +862,11 @@ Tensor _sparse_semi_structured_linear(
const auto EnableActivationNone = true; const auto EnableActivationNone = true;
const auto EnableActivationReLU = true; const auto EnableActivationReLU = true;
const auto EnableActivationSiLU = true; const auto EnableActivationSiLU = true;
output = two_four_sgemm_cutlass_dispatch_layouts_activation< output = two_four_sgemm_dispatch_layouts_bias_activation<
ElementInputA, ElementInputA,
ElementInputB, ElementInputB,
ElementOutput, ElementOutput,
ElementAccumulator, ElementAccumulator,
ElementComputeEpilogue,
ThreadblockShape, ThreadblockShape,
WarpShape, WarpShape,
InstructionShape, InstructionShape,
@ -764,9 +889,6 @@ Tensor _sparse_semi_structured_linear(
auto output_sizes = input_sizes; auto output_sizes = input_sizes;
output_sizes.back() = weight.size(0); output_sizes.back() = weight.size(0);
return output.transpose(-1, -2).reshape(output_sizes); return output.transpose(-1, -2).reshape(output_sizes);
#else
AT_ERROR("_sparse_semi_structured_linear: ROCm doesn't support CUTLASS");
return Tensor{};
#endif #endif
} }
@ -775,7 +897,8 @@ Tensor _sparse_semi_structured_linear(
// Following is just for testing purposes. // Following is just for testing purposes.
namespace at::native { namespace at::native {
#ifndef USE_ROCM #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
#else
// Copied from tools/util/include/host_reorder.h, from CUTLASS source // Copied from tools/util/include/host_reorder.h, from CUTLASS source
// tree. This is for simplicity - namely, this file is not under // tree. This is for simplicity - namely, this file is not under
// include/cutlass in this tree, as other CUTLASS include files // include/cutlass in this tree, as other CUTLASS include files
@ -812,7 +935,10 @@ static void reorder_meta(cutlass::TensorRef<Element, LayoutDest> dest,
std::tuple<Tensor, Tensor> std::tuple<Tensor, Tensor>
_to_sparse_semi_structured(const Tensor& dense) { _to_sparse_semi_structured(const Tensor& dense) {
#ifndef USE_ROCM #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
AT_ERROR("_to_sparse_semi_structured: CUTLASS not supported");
return std::make_tuple(Tensor{}, Tensor{});
#else
// Check dimensions of the dense matrix. // Check dimensions of the dense matrix.
TORCH_CHECK(dense.dim() == 2, TORCH_CHECK(dense.dim() == 2,
"_to_sparse_semi_structured: Expected dense argument to be 2D " "_to_sparse_semi_structured: Expected dense argument to be 2D "
@ -937,9 +1063,6 @@ _to_sparse_semi_structured(const Tensor& dense) {
return std::make_tuple(sparse_cpu.to(dense.device()), return std::make_tuple(sparse_cpu.to(dense.device()),
meta_reordered_cpu.to(dense.device())); meta_reordered_cpu.to(dense.device()));
#else
AT_ERROR("_to_sparse_semi_structured: ROCm doesn't support CUTLASS");
return std::make_tuple(Tensor{}, Tensor{});
#endif #endif
} }

View File

@ -33,11 +33,12 @@ CUTE_HOST_DEVICE
auto auto
make_tiled_copy_B_warpcontiguousN(Copy_Atom<Args...> const& copy_atom, make_tiled_copy_B_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
TiledMMA const& tiled_mma) { TiledMMA const& tiled_mma) {
using TileShape_MNK = typename TiledMMA::TiledShape_MNK; constexpr int TileShape_N = decltype(tiled_mma.template tile_size_mnk<1>())::value;
constexpr int TileShape_K = decltype(tiled_mma.template tile_size_mnk<2>())::value;
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value; constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value;
// Divide by 2 because right now we always use 2 for the ValLayout // Divide by 2 because right now we always use 2 for the ValLayout
constexpr int kNWarpsN = decltype(size<1>(TileShape_MNK{}))::value / AtomShape_N / 2; constexpr int kNWarpsN = TileShape_N / AtomShape_N / 2;
constexpr int MMAStride_N = MMA_N * AtomShape_N * 2; constexpr int MMAStride_N = MMA_N * AtomShape_N * 2;
// This gives the correct layout, idk why. // This gives the correct layout, idk why.
// auto t = make_tile(Layout<Shape<Shape<_8, _2>, _2>, // auto t = make_tile(Layout<Shape<Shape<_8, _2>, _2>,
@ -46,7 +47,7 @@ make_tiled_copy_B_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
// Stride<_1, _64, _8> >{}, // Stride<_1, _64, _8> >{},
auto t = make_tile(Layout<Shape<Int<AtomShape_N>, Int<kNWarpsN>, _2>, // (8, 2, 2) or (8, 4, 2) auto t = make_tile(Layout<Shape<Int<AtomShape_N>, Int<kNWarpsN>, _2>, // (8, 2, 2) or (8, 4, 2)
Stride<_1, Int<MMAStride_N>, _8> >{}, // (1, 64, 8) or (1, 32, 8) Stride<_1, Int<MMAStride_N>, _8> >{}, // (1, 64, 8) or (1, 32, 8)
make_layout(size<2>(TileShape_MNK{}))); make_layout(Int<TileShape_K>{}));
// if (cute::thread0()) {printf("make_tiled_copy_B_warpcontiguousN "); print(t); printf("\n"); } // if (cute::thread0()) {printf("make_tiled_copy_B_warpcontiguousN "); print(t); printf("\n"); }
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutB_TV(), t); return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutB_TV(), t);
} }
@ -60,13 +61,13 @@ CUTE_HOST_DEVICE
auto auto
make_tiled_copy_C_warpcontiguousN(Copy_Atom<Args...> const& copy_atom, make_tiled_copy_C_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
TiledMMA const& tiled_mma) { TiledMMA const& tiled_mma) {
using TileShape_MNK = typename TiledMMA::TiledShape_MNK; constexpr int TileShape_M = decltype(tiled_mma.template tile_size_mnk<0>())::value;
constexpr int TileShape_N = decltype(tiled_mma.template tile_size_mnk<1>())::value;
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value; constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value;
// Divide by 2 because right now we always use 2 for the ValLayout // Divide by 2 because right now we always use 2 for the ValLayout
constexpr int kNWarpsN = decltype(size<1>(TileShape_MNK{}))::value / AtomShape_N / 2; constexpr int kNWarpsN = TileShape_N / AtomShape_N / 2; constexpr int MMAStride_N = MMA_N * AtomShape_N * 2;
constexpr int MMAStride_N = MMA_N * AtomShape_N * 2; auto t = make_tile(make_layout(Int<TileShape_M>{}),
auto t = make_tile(make_layout(size<0>(TileShape_MNK{})),
Layout<Shape<Int<AtomShape_N>, Int<kNWarpsN>, _2>, // (8, 2, 2) or (8, 4, 2) Layout<Shape<Int<AtomShape_N>, Int<kNWarpsN>, _2>, // (8, 2, 2) or (8, 4, 2)
Stride<_1, Int<MMAStride_N>, _8> >{}); // (1, 64, 8) or (1, 32, 8) Stride<_1, Int<MMAStride_N>, _8> >{}); // (1, 64, 8) or (1, 32, 8)
// if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousN "); print(t); printf("\n"); } // if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousN "); print(t); printf("\n"); }
@ -444,8 +445,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kHeadDim = Kernel_traits::kHeadDim;
// constexpr int kNWarps = Kernel_traits::kNWarps; // constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int MMA_N_SdP = kBlockN / decltype(size<1>(typename Kernel_traits::TiledMmaSdP::TiledShape_MNK{}))::value; constexpr int MMA_N_SdP = kBlockN / decltype(typename Kernel_traits::TiledMmaSdP{}.template tile_size_mnk<1>())::value; constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP;
constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP;
constexpr bool Double_buffer = !Kernel_traits::No_double_buffer; constexpr bool Double_buffer = !Kernel_traits::No_double_buffer;
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb); const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);

View File

@ -92,7 +92,6 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
constexpr int kBlockN = Kernel_traits::kBlockN; constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim; constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kNWarps = Kernel_traits::kNWarps; constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value;
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb); const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return; if (m_block * kBlockM >= binfo.actual_seqlen_q) return;

View File

@ -34,10 +34,8 @@ struct Flash_kernel_traits {
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>, MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN> MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
>; >;
using ValLayoutMNK = Layout<Shape<_1, _2, _1>>;
#else #else
using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>; using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
using ValLayoutMNK = Layout<Shape<_1, _2, _2>>;
#endif #endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
@ -78,7 +76,7 @@ struct Flash_fwd_kernel_traits : public Base {
using TiledMma = TiledMMA< using TiledMma = TiledMMA<
typename Base::MMA_Atom_Arch, typename Base::MMA_Atom_Arch,
Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM Tile<Int<16 * kNWarps>, _16, _16>>;
using SmemLayoutAtomQ = decltype( using SmemLayoutAtomQ = decltype(
composition(Swizzle<kSwizzle, 3, 3>{}, composition(Swizzle<kSwizzle, 3, 3>{},
@ -220,18 +218,15 @@ struct Flash_bwd_kernel_traits : public Base {
using TiledMmaSdP = TiledMMA< using TiledMmaSdP = TiledMMA<
typename Base::MMA_Atom_Arch, typename Base::MMA_Atom_Arch,
Layout<Shape<Int<AtomLayoutMSdP>, Int<kNWarps / AtomLayoutMSdP>, _1>>, Layout<Shape<Int<AtomLayoutMSdP>, Int<kNWarps / AtomLayoutMSdP>, _1>>,
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM Tile<Int<16 * AtomLayoutMSdP>, Int<16 * kNWarps / AtomLayoutMSdP>, _16>>;
using TiledMmadKV = TiledMMA< using TiledMmadKV = TiledMMA<
typename Base::MMA_Atom_Arch, typename Base::MMA_Atom_Arch,
Layout<Shape<Int<AtomLayoutNdKV>, Int<kNWarps / AtomLayoutNdKV>, _1>>, Layout<Shape<Int<AtomLayoutNdKV>, Int<kNWarps / AtomLayoutNdKV>, _1>>,
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM Tile<Int<16 * AtomLayoutNdKV>, Int<16 * kNWarps / AtomLayoutNdKV>, _16>>;
using TiledMmadQ = TiledMMA< using TiledMmadQ = TiledMMA<
typename Base::MMA_Atom_Arch, typename Base::MMA_Atom_Arch,
Layout<Shape<Int<AtomLayoutMdQ>, Int<kNWarps / AtomLayoutMdQ>, _1>>, // 2x4x1 or 4x2x1 thread group Layout<Shape<Int<AtomLayoutMdQ>, Int<kNWarps / AtomLayoutMdQ>, _1>>, // 2x4x1 or 4x2x1 thread group
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM Tile<Int<16 * AtomLayoutMdQ>, Int<16 * kNWarps / AtomLayoutMdQ>, _16>>;
using SmemLayoutAtomQdO = decltype( using SmemLayoutAtomQdO = decltype(
composition(Swizzle<kSwizzle, 3, 3>{}, composition(Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<_8, Int<kBlockKSmem>>, Layout<Shape<_8, Int<kBlockKSmem>>,

View File

@ -256,7 +256,7 @@ class TestSparseSemiStructured(TestCase):
if dtype is torch.int8: if dtype is torch.int8:
# This should fail # This should fail
if backend == "cutlass": if backend == "cutlass":
with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_cutlass_dispatch_layouts"): with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_dispatch_layouts"):
sparse_result = torch.mm(A_sparse, B) sparse_result = torch.mm(A_sparse, B)
else: else:
with self.assertRaisesRegex(RuntimeError, with self.assertRaisesRegex(RuntimeError,
@ -287,7 +287,7 @@ class TestSparseSemiStructured(TestCase):
# padding with int8 throws an error because transposing B yields a contiguous output # padding with int8 throws an error because transposing B yields a contiguous output
# and row-row 2:4 sparse @ dense with NN is not supported by cuSPARSELt or CUTLASS. # and row-row 2:4 sparse @ dense with NN is not supported by cuSPARSELt or CUTLASS.
if backend == "cutlass": if backend == "cutlass":
with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_cutlass_dispatch_layouts"): with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_dispatch_layouts"):
sparse_result = torch.mm(A_sparse, B.t()) sparse_result = torch.mm(A_sparse, B.t())
else: else:
with self.assertRaisesRegex(RuntimeError, with self.assertRaisesRegex(RuntimeError,