Update cutlass from 3.3.0 to 3.4.1 (#120434)

### COPY OF https://github.com/pytorch/pytorch/pull/120010

### Update
I have rolled the two blocking changes into this PR, I also imported this to fbcode to verify that nothing is breaking:
D53870253

This copy was generated by merging in all the internal only changes into one merged atomic commit and re-exporting to github

### Current Status
- [PR](https://github.com/pytorch/pytorch/pull/118935) aims to update the flash attention kernels to a more recent version.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120434
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
This commit is contained in:
Driss Guessous
2024-02-23 03:57:26 +00:00
committed by PyTorch MergeBot
parent f609f2050f
commit 36c1cc962a
6 changed files with 266 additions and 149 deletions

View File

@ -3,31 +3,36 @@
#include <ATen/cuda/CUDAUtils.h>
#include <ATen/Dispatch.h>
#ifndef USE_ROCM
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
#else
#include <cuda_runtime.h>
#include <cutlass/cutlass.h>
#include <cutlass/layout/layout.h>
#include <cutlass/tensor_ref.h>
#include <cutlass/epilogue/thread/linear_combination.h>
#include <cutlass/epilogue/thread/linear_combination_relu.h>
#include <cutlass/epilogue/thread/linear_combination_silu.h>
#include <cutlass/gemm/device/gemm_sparse_row_broadcast.h>
#include <cutlass/gemm/device/gemm_sparse_with_visitor.h>
#include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
#endif
#include <type_traits>
#include <tuple>
#ifndef USE_ROCM
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
#else
#define CUTLASS_STATUS_CHECK(status) \
{ \
TORCH_CHECK(status == cutlass::Status::kSuccess, \
"Got CUTLASS error: ", cutlassGetStatusString(status)); \
}
namespace {
enum class Activation{NONE, RELU, SILU};
}
#endif
namespace at::native {
#ifndef USE_ROCM
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
#else
// Wrapper function for CUTLASS sparse GEMM implementation, used
// solely to simplify dispatching from
// _sparse_semi_structured_linear() function below.
@ -36,14 +41,14 @@ template <
typename ElementInputB,
typename ElementOutput,
typename ElementAccumulator,
typename ElementComputeEpilogue,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename EpilogueOp,
typename LayoutInputA,
typename LayoutInputB>
Tensor two_four_sgemm_cutlass(
typename LayoutInputB,
bool use_bias,
Activation activation>
Tensor two_four_sgemm(
const Tensor& tensor_a,
const at::IntArrayRef::value_type& tensor_a_stride,
const Tensor& tensor_b,
@ -57,22 +62,107 @@ Tensor two_four_sgemm_cutlass(
using SmArch = cutlass::arch::Sm80; // Only CC 8.x devices are supported at the moment.
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This choice provides good performance across wide range of operand sizes.
constexpr int NumStages = 3; // This choice provides good performance across wide range of operand sizes.
using Gemm = cutlass::gemm::device::SparseGemmRowBroadcast<
using Operator = cutlass::arch::OpMultiplyAdd;
constexpr int NumEVTEpilogueStages = 1;
constexpr int AlignmentInputA = 128 / cutlass::sizeof_bits<ElementInputA>::value;
constexpr int AlignmentInputB = 128 / cutlass::sizeof_bits<ElementInputB>::value;
constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;
using ElementComputeEpilogue = ElementAccumulator;
constexpr int AlignmentComputeEpilogue = 128 / cutlass::sizeof_bits<ElementComputeEpilogue>::value;
using ElementC = ElementOutput;
using LayoutC = LayoutOutput;
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
using BiasTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
ThreadblockShape,
WarpShape,
ElementC,
AlignmentC,
NumEVTEpilogueStages>;
using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
ThreadblockShape,
WarpShape,
ElementOutput,
AlignmentOutput,
NumEVTEpilogueStages>;
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
using BiasScalar =
cutlass::epilogue::threadblock::VisitorScalarBroadcast<ElementC>;
using BiasTensor =
cutlass::epilogue::threadblock::VisitorColBroadcast<
BiasTileThreadMap,
ElementC,
cute::Stride<cute::_1, cute::_0, int64_t>>;
using Bias = std::conditional_t<use_bias, BiasTensor, BiasScalar>;
using BiasArguments = typename Bias::Arguments;
using ApplyBias = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::plus, ElementComputeEpilogue, ElementComputeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTApplyBias = cutlass::epilogue::threadblock::Sm80EVT<
ApplyBias,
Accum,
Bias>;
using ApplyActivationNone = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::epilogue::thread::Identity,
ElementComputeEpilogue,
ElementComputeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
using ApplyActivationReLu = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::epilogue::thread::ReLu,
ElementComputeEpilogue,
ElementComputeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
using ApplyActivationSiLu = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::epilogue::thread::SiLu,
ElementComputeEpilogue,
ElementComputeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
using ApplyActivation =
std::conditional_t<
activation == Activation::NONE,
ApplyActivationNone,
std::conditional_t<
activation == Activation::RELU,
ApplyActivationReLu,
ApplyActivationSiLu>>;
using EVTApplyActivation = cutlass::epilogue::threadblock::Sm80EVT<
ApplyActivation,
EVTApplyBias>;
using Output = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, ElementOutput, cutlass::FloatRoundStyle::round_to_nearest,
cute::Stride<int64_t, cute::_1, int64_t>>;
using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT<
Output,
EVTApplyActivation>;
using Gemm = cutlass::gemm::device::SparseGemmWithVisitor<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementC,
LayoutC,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
EVTOutput,
SwizzleThreadBlock,
NumStages>;
NumStages,
AlignmentInputA,
AlignmentInputB,
Operator,
NumEVTEpilogueStages>;
// Datatype and layout of metadata matrix are inferred from sparse
// GEMM template.
@ -105,11 +195,11 @@ Tensor two_four_sgemm_cutlass(
meta_dtype = at::kInt;
break;
default:
AT_ERROR("two_four_sgemm_cutlass: invalid size of meta tensor datatype "
AT_ERROR("two_four_sgemm: invalid size of meta tensor datatype "
"encountered");
}
TORCH_CHECK(meta.dtype() == meta_dtype,
"two_four_sgemm_cutlass: Expected meta datatype ", meta_dtype,
"two_four_sgemm: Expected meta datatype ", meta_dtype,
", but got ", meta.dtype());
// Determine PyTorch datatype for the output matrix.
@ -125,64 +215,69 @@ Tensor two_four_sgemm_cutlass(
} else if constexpr (std::is_same_v<ElementOutput, float>) {
tensor_d_dtype = at::kFloat;
} else {
AT_ERROR("two_four_sgemm_cutlass: invalid datatype for sparse GEMM ",
" output encountered");
AT_ERROR("two_four_sgemm: invalid datatype for sparse GEMM output ",
"encountered");
}
if (tensor_c.numel() != 0) {
if constexpr (use_bias) {
TORCH_CHECK(tensor_c.dtype() == tensor_d_dtype,
"two_four_sgemm_cutlass: Expected sparse GEMM bias "
"datatype ", tensor_d_dtype, ", but got ",
tensor_c.dtype());
"two_four_sgemm: Expected sparse GEMM bias datatype ",
tensor_d_dtype, ", but got ", tensor_c.dtype());
}
// Create output matrix.
Tensor tensor_d;
if (tensor_c.numel() != 0) {
tensor_d = tensor_c.new_empty({length_m, length_n});
} else {
tensor_d =
tensor_a.new_empty({length_m, length_n},
at::TensorOptions().dtype(tensor_d_dtype));
}
Tensor tensor_d =
tensor_a.new_empty({length_m, length_n},
at::TensorOptions().dtype(tensor_d_dtype));
// Prepare arguments for CUTLASS sparse GEMM kernel.
cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);
LayoutInputA layout_a(tensor_a_stride);
LayoutInputB layout_b(tensor_b_stride);
LayoutOutput layout_c(tensor_c.numel() != 0 ? tensor_c.stride(0) : 0);
LayoutOutput layout_d(tensor_d.stride(0));
auto tensor_a_device_ref =
cutlass::TensorRef<ElementInputA, LayoutInputA>(
(ElementInputA*)tensor_a.data_ptr(), layout_a);
auto tensor_b_device_ref =
cutlass::TensorRef<ElementInputB, LayoutInputB>(
(ElementInputB*)tensor_b.data_ptr(), layout_b);
auto tensor_c_device_ref =
cutlass::TensorRef<ElementOutput, LayoutOutput>(
(ElementOutput*)(tensor_c.numel() != 0 ?
tensor_c.data_ptr() : tensor_d.data_ptr()),
layout_c);
auto tensor_d_device_ref =
cutlass::TensorRef<ElementOutput, LayoutOutput>(
(ElementOutput*)tensor_d.data_ptr(), layout_d);
auto tensor_e_reordered_device_ref =
cutlass::TensorRef<ElementInputE, ReorderedLayoutInputE>(
(ElementInputE*)meta.data_ptr(),
ReorderedLayoutInputE::packed({length_m, meta_ncols}));
ElementComputeEpilogue alpha(1);
ElementComputeEpilogue beta(tensor_c.numel() != 0 ? 1 : 0);
constexpr int split_k_slices = 1;
BiasArguments bias_arguments{
[&]() -> BiasArguments {
if constexpr (use_bias) {
return {(ElementC*)tensor_c.data_ptr(),
ElementC(0),
{cute::_1{}, cute::_0{}, problem_size.m()}};
} else {
return {ElementC(0)};
}
}()
};
typename Output::Arguments output_arguments{
(ElementOutput*)tensor_d.data_ptr(),
{problem_size.n(), cute::_1{}, problem_size.mn().product()}
};
typename EVTOutput::Arguments callback_arguments{
{
{
{}, // Accum
bias_arguments, // Bias
{} // ApplyBias
}, // EVTApplyBias
{} // ApplyActivation
}, // EVTApplyActivation
output_arguments, // Output
}; // EVTOutput
// Create a tuple of CUTLASS sparse GEMM kernel arguments.
typename Gemm::Arguments arguments{
problem_size,
tensor_a_device_ref,
tensor_b_device_ref,
tensor_c_device_ref,
tensor_d_device_ref,
tensor_e_reordered_device_ref,
{alpha, beta},
split_k_slices};
callback_arguments};
cutlass::Status status;
@ -219,16 +314,16 @@ template <
typename ElementInputB,
typename ElementOutput,
typename ElementAccumulator,
typename ElementComputeEpilogue,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename EpilogueOp,
bool EnableRowMajorRowMajorLayouts,
bool EnableRowMajorColumnMajorLayouts,
bool EnableColumnMajorRowMajorLayouts,
bool EnableColumnMajorColumnMajorLayouts>
Tensor two_four_sgemm_cutlass_dispatch_layouts(
bool EnableColumnMajorColumnMajorLayouts,
bool use_bias,
Activation activation>
Tensor two_four_sgemm_dispatch_layouts(
const Tensor& tensor_a, const Tensor& tensor_b, const Tensor& tensor_c,
const Tensor& meta) {
// Determine layouts (row-major or column-major) of input tensors.
@ -242,18 +337,18 @@ Tensor two_four_sgemm_cutlass_dispatch_layouts(
// Perform dispatching.
if constexpr (EnableRowMajorRowMajorLayouts) {
if (tensor_a_row_major && tensor_b_row_major) {
return two_four_sgemm_cutlass<
return two_four_sgemm<
ElementInputA,
ElementInputB,
ElementOutput,
ElementAccumulator,
ElementComputeEpilogue,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
cutlass::layout::RowMajor,
cutlass::layout::RowMajor>(
cutlass::layout::RowMajor,
use_bias,
activation>(
tensor_a,
tensor_a_stride,
tensor_b,
@ -264,18 +359,18 @@ Tensor two_four_sgemm_cutlass_dispatch_layouts(
}
if constexpr (EnableRowMajorColumnMajorLayouts) {
if (tensor_a_row_major && !tensor_b_row_major) {
return two_four_sgemm_cutlass<
return two_four_sgemm<
ElementInputA,
ElementInputB,
ElementOutput,
ElementAccumulator,
ElementComputeEpilogue,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor>(
cutlass::layout::ColumnMajor,
use_bias,
activation>(
tensor_a,
tensor_a_stride,
tensor_b,
@ -286,18 +381,18 @@ Tensor two_four_sgemm_cutlass_dispatch_layouts(
}
if constexpr (EnableColumnMajorRowMajorLayouts) {
if (!tensor_a_row_major && tensor_b_row_major) {
return two_four_sgemm_cutlass<
return two_four_sgemm<
ElementInputA,
ElementInputB,
ElementOutput,
ElementAccumulator,
ElementComputeEpilogue,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor>(
cutlass::layout::RowMajor,
use_bias,
activation>(
tensor_a,
tensor_a_stride,
tensor_b,
@ -308,18 +403,18 @@ Tensor two_four_sgemm_cutlass_dispatch_layouts(
}
if constexpr (EnableColumnMajorColumnMajorLayouts) {
if (!tensor_a_row_major && !tensor_b_row_major) {
return two_four_sgemm_cutlass<
return two_four_sgemm<
ElementInputA,
ElementInputB,
ElementOutput,
ElementAccumulator,
ElementComputeEpilogue,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor>(
cutlass::layout::ColumnMajor,
use_bias,
activation>(
tensor_a,
tensor_a_stride,
tensor_b,
@ -329,20 +424,77 @@ Tensor two_four_sgemm_cutlass_dispatch_layouts(
}
}
AT_ERROR("two_four_sgemm_cutlass_dispatch_layouts: Combination of ",
AT_ERROR("two_four_sgemm_dispatch_layouts: Combination of ",
tensor_a_row_major ? "row-major" : "column_major", " and ",
tensor_b_row_major ? "row-major" : "column_major",
" layouts for input tensors is not supported");
return Tensor{};
}
// Dispatch according to the bias tensor being provided or not.
template <
typename ElementInputA,
typename ElementInputB,
typename ElementOutput,
typename ElementAccumulator,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
bool EnableRowMajorRowMajorLayouts,
bool EnableRowMajorColumnMajorLayouts,
bool EnableColumnMajorRowMajorLayouts,
bool EnableColumnMajorColumnMajorLayouts,
Activation activation>
Tensor two_four_sgemm_dispatch_layouts_bias(
const Tensor& tensor_a, const Tensor& tensor_b, const Tensor& tensor_c,
const Tensor& meta) {
if (tensor_c.numel() > 0) {
return two_four_sgemm_dispatch_layouts<
ElementInputA,
ElementInputB,
ElementOutput,
ElementAccumulator,
ThreadblockShape,
WarpShape,
InstructionShape,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
EnableColumnMajorColumnMajorLayouts,
true,
activation>(
tensor_a,
tensor_b,
tensor_c,
meta);
} else {
return two_four_sgemm_dispatch_layouts<
ElementInputA,
ElementInputB,
ElementOutput,
ElementAccumulator,
ThreadblockShape,
WarpShape,
InstructionShape,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
EnableColumnMajorColumnMajorLayouts,
false,
activation>(
tensor_a,
tensor_b,
tensor_c,
meta);
}
}
// Dispatch according to the activation functions enabled.
template <
typename ElementInputA,
typename ElementInputB,
typename ElementOutput,
typename ElementAccumulator,
typename ElementComputeEpilogue,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
@ -353,32 +505,25 @@ template <
bool EnableActivationNone,
bool EnableActivationReLU,
bool EnableActivationSiLU>
Tensor two_four_sgemm_cutlass_dispatch_layouts_activation(
Tensor two_four_sgemm_dispatch_layouts_bias_activation(
const Tensor& tensor_a, const Tensor& tensor_b, const Tensor& tensor_c,
const Tensor& meta, const c10::string_view& activation) {
// Perform dispatching.
if constexpr (EnableActivationNone) {
if (activation == "none") {
using EpilogueOp =
cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementComputeEpilogue>;
return two_four_sgemm_cutlass_dispatch_layouts<
return two_four_sgemm_dispatch_layouts_bias<
ElementInputA,
ElementInputB,
ElementOutput,
ElementAccumulator,
ElementComputeEpilogue,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
EnableColumnMajorColumnMajorLayouts>(
EnableColumnMajorColumnMajorLayouts,
Activation::NONE>(
tensor_a,
tensor_b,
tensor_c,
@ -387,26 +532,19 @@ Tensor two_four_sgemm_cutlass_dispatch_layouts_activation(
}
if constexpr (EnableActivationReLU) {
if (activation == "relu") {
using EpilogueOp =
cutlass::epilogue::thread::LinearCombinationRelu<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementComputeEpilogue>;
return two_four_sgemm_cutlass_dispatch_layouts<
return two_four_sgemm_dispatch_layouts_bias<
ElementInputA,
ElementInputB,
ElementOutput,
ElementAccumulator,
ElementComputeEpilogue,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
EnableColumnMajorColumnMajorLayouts>(
EnableColumnMajorColumnMajorLayouts,
Activation::RELU>(
tensor_a,
tensor_b,
tensor_c,
@ -415,26 +553,19 @@ Tensor two_four_sgemm_cutlass_dispatch_layouts_activation(
}
if constexpr (EnableActivationSiLU) {
if (activation == "silu") {
using EpilogueOp =
cutlass::epilogue::thread::LinearCombinationSilu<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator,
ElementComputeEpilogue>;
return two_four_sgemm_cutlass_dispatch_layouts<
return two_four_sgemm_dispatch_layouts_bias<
ElementInputA,
ElementInputB,
ElementOutput,
ElementAccumulator,
ElementComputeEpilogue,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
EnableColumnMajorColumnMajorLayouts>(
EnableColumnMajorColumnMajorLayouts,
Activation::SILU>(
tensor_a,
tensor_b,
tensor_c,
@ -442,8 +573,8 @@ Tensor two_four_sgemm_cutlass_dispatch_layouts_activation(
}
}
AT_ERROR("two_four_sgemm_cutlass_dispatch_layouts: Activation \"",
activation, "\" is not supported for given input tensors");
AT_ERROR("two_four_sgemm_dispatch_layouts: Activation \"", activation,
"\" is not supported for given input tensors");
return Tensor{};
}
#endif
@ -472,7 +603,10 @@ Tensor _sparse_semi_structured_linear(
const Tensor& meta, const c10::optional<Tensor>& bias_opt,
const c10::optional<c10::string_view> activation_opt,
const c10::optional<c10::ScalarType> out_dtype_opt) {
#ifndef USE_ROCM
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
AT_ERROR("_sparse_semi_structured_linear: CUTLASS not supported");
return Tensor{};
#else
// No need to check that all tensors are on CUDA device, as this
// is provided by dispatch.
@ -575,7 +709,6 @@ Tensor _sparse_semi_structured_linear(
using ElementInputA = int8_t;
using ElementInputB = int8_t;
using ElementAccumulator = int32_t;
using ElementComputeEpilogue = int32_t;
using ThreadblockShape =
cutlass::gemm::GemmShape<128, 128, 128>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>;
@ -589,12 +722,11 @@ Tensor _sparse_semi_structured_linear(
const auto EnableActivationSiLU = false;
if (out_dtype_opt.has_value()) {
using ElementOutput = int32_t;
output = two_four_sgemm_cutlass_dispatch_layouts_activation<
output = two_four_sgemm_dispatch_layouts_bias_activation<
ElementInputA,
ElementInputB,
ElementOutput,
ElementAccumulator,
ElementComputeEpilogue,
ThreadblockShape,
WarpShape,
InstructionShape,
@ -612,12 +744,11 @@ Tensor _sparse_semi_structured_linear(
activation);
} else {
using ElementOutput = int8_t;
output = two_four_sgemm_cutlass_dispatch_layouts_activation<
output = two_four_sgemm_dispatch_layouts_bias_activation<
ElementInputA,
ElementInputB,
ElementOutput,
ElementAccumulator,
ElementComputeEpilogue,
ThreadblockShape,
WarpShape,
InstructionShape,
@ -643,7 +774,6 @@ Tensor _sparse_semi_structured_linear(
using ElementInputB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using ElementAccumulator = float;
using ElementComputeEpilogue = float;
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
@ -654,12 +784,11 @@ Tensor _sparse_semi_structured_linear(
const auto EnableActivationNone = true;
const auto EnableActivationReLU = true;
const auto EnableActivationSiLU = true;
output = two_four_sgemm_cutlass_dispatch_layouts_activation<
output = two_four_sgemm_dispatch_layouts_bias_activation<
ElementInputA,
ElementInputB,
ElementOutput,
ElementAccumulator,
ElementComputeEpilogue,
ThreadblockShape,
WarpShape,
InstructionShape,
@ -684,7 +813,6 @@ Tensor _sparse_semi_structured_linear(
using ElementInputB = cutlass::bfloat16_t;
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using ElementComputeEpilogue = float;
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
@ -695,12 +823,11 @@ Tensor _sparse_semi_structured_linear(
const auto EnableActivationNone = true;
const auto EnableActivationReLU = true;
const auto EnableActivationSiLU = true;
output = two_four_sgemm_cutlass_dispatch_layouts_activation<
output = two_four_sgemm_dispatch_layouts_bias_activation<
ElementInputA,
ElementInputB,
ElementOutput,
ElementAccumulator,
ElementComputeEpilogue,
ThreadblockShape,
WarpShape,
InstructionShape,
@ -725,7 +852,6 @@ Tensor _sparse_semi_structured_linear(
using ElementInputB = float;
using ElementOutput = float;
using ElementAccumulator = float;
using ElementComputeEpilogue = float;
using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>;
using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
@ -736,12 +862,11 @@ Tensor _sparse_semi_structured_linear(
const auto EnableActivationNone = true;
const auto EnableActivationReLU = true;
const auto EnableActivationSiLU = true;
output = two_four_sgemm_cutlass_dispatch_layouts_activation<
output = two_four_sgemm_dispatch_layouts_bias_activation<
ElementInputA,
ElementInputB,
ElementOutput,
ElementAccumulator,
ElementComputeEpilogue,
ThreadblockShape,
WarpShape,
InstructionShape,
@ -764,9 +889,6 @@ Tensor _sparse_semi_structured_linear(
auto output_sizes = input_sizes;
output_sizes.back() = weight.size(0);
return output.transpose(-1, -2).reshape(output_sizes);
#else
AT_ERROR("_sparse_semi_structured_linear: ROCm doesn't support CUTLASS");
return Tensor{};
#endif
}
@ -775,7 +897,8 @@ Tensor _sparse_semi_structured_linear(
// Following is just for testing purposes.
namespace at::native {
#ifndef USE_ROCM
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
#else
// Copied from tools/util/include/host_reorder.h, from CUTLASS source
// tree. This is for simplicity - namely, this file is not under
// include/cutlass in this tree, as other CUTLASS include files
@ -812,7 +935,10 @@ static void reorder_meta(cutlass::TensorRef<Element, LayoutDest> dest,
std::tuple<Tensor, Tensor>
_to_sparse_semi_structured(const Tensor& dense) {
#ifndef USE_ROCM
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
AT_ERROR("_to_sparse_semi_structured: CUTLASS not supported");
return std::make_tuple(Tensor{}, Tensor{});
#else
// Check dimensions of the dense matrix.
TORCH_CHECK(dense.dim() == 2,
"_to_sparse_semi_structured: Expected dense argument to be 2D "
@ -937,9 +1063,6 @@ _to_sparse_semi_structured(const Tensor& dense) {
return std::make_tuple(sparse_cpu.to(dense.device()),
meta_reordered_cpu.to(dense.device()));
#else
AT_ERROR("_to_sparse_semi_structured: ROCm doesn't support CUTLASS");
return std::make_tuple(Tensor{}, Tensor{});
#endif
}

View File

@ -33,11 +33,12 @@ CUTE_HOST_DEVICE
auto
make_tiled_copy_B_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
TiledMMA const& tiled_mma) {
using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
constexpr int TileShape_N = decltype(tiled_mma.template tile_size_mnk<1>())::value;
constexpr int TileShape_K = decltype(tiled_mma.template tile_size_mnk<2>())::value;
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value;
// Divide by 2 because right now we always use 2 for the ValLayout
constexpr int kNWarpsN = decltype(size<1>(TileShape_MNK{}))::value / AtomShape_N / 2;
constexpr int kNWarpsN = TileShape_N / AtomShape_N / 2;
constexpr int MMAStride_N = MMA_N * AtomShape_N * 2;
// This gives the correct layout, idk why.
// auto t = make_tile(Layout<Shape<Shape<_8, _2>, _2>,
@ -46,7 +47,7 @@ make_tiled_copy_B_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
// Stride<_1, _64, _8> >{},
auto t = make_tile(Layout<Shape<Int<AtomShape_N>, Int<kNWarpsN>, _2>, // (8, 2, 2) or (8, 4, 2)
Stride<_1, Int<MMAStride_N>, _8> >{}, // (1, 64, 8) or (1, 32, 8)
make_layout(size<2>(TileShape_MNK{})));
make_layout(Int<TileShape_K>{}));
// if (cute::thread0()) {printf("make_tiled_copy_B_warpcontiguousN "); print(t); printf("\n"); }
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutB_TV(), t);
}
@ -60,13 +61,13 @@ CUTE_HOST_DEVICE
auto
make_tiled_copy_C_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
TiledMMA const& tiled_mma) {
using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
constexpr int TileShape_M = decltype(tiled_mma.template tile_size_mnk<0>())::value;
constexpr int TileShape_N = decltype(tiled_mma.template tile_size_mnk<1>())::value;
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value;
// Divide by 2 because right now we always use 2 for the ValLayout
constexpr int kNWarpsN = decltype(size<1>(TileShape_MNK{}))::value / AtomShape_N / 2;
constexpr int MMAStride_N = MMA_N * AtomShape_N * 2;
auto t = make_tile(make_layout(size<0>(TileShape_MNK{})),
constexpr int kNWarpsN = TileShape_N / AtomShape_N / 2; constexpr int MMAStride_N = MMA_N * AtomShape_N * 2;
auto t = make_tile(make_layout(Int<TileShape_M>{}),
Layout<Shape<Int<AtomShape_N>, Int<kNWarpsN>, _2>, // (8, 2, 2) or (8, 4, 2)
Stride<_1, Int<MMAStride_N>, _8> >{}); // (1, 64, 8) or (1, 32, 8)
// if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousN "); print(t); printf("\n"); }
@ -444,8 +445,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
// constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int MMA_N_SdP = kBlockN / decltype(size<1>(typename Kernel_traits::TiledMmaSdP::TiledShape_MNK{}))::value;
constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP;
constexpr int MMA_N_SdP = kBlockN / decltype(typename Kernel_traits::TiledMmaSdP{}.template tile_size_mnk<1>())::value; constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP;
constexpr bool Double_buffer = !Kernel_traits::No_double_buffer;
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);

View File

@ -92,7 +92,6 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value;
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;

View File

@ -34,10 +34,8 @@ struct Flash_kernel_traits {
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
>;
using ValLayoutMNK = Layout<Shape<_1, _2, _1>>;
#else
using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
using ValLayoutMNK = Layout<Shape<_1, _2, _2>>;
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
@ -78,7 +76,7 @@ struct Flash_fwd_kernel_traits : public Base {
using TiledMma = TiledMMA<
typename Base::MMA_Atom_Arch,
Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
Tile<Int<16 * kNWarps>, _16, _16>>;
using SmemLayoutAtomQ = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
@ -220,18 +218,15 @@ struct Flash_bwd_kernel_traits : public Base {
using TiledMmaSdP = TiledMMA<
typename Base::MMA_Atom_Arch,
Layout<Shape<Int<AtomLayoutMSdP>, Int<kNWarps / AtomLayoutMSdP>, _1>>,
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
Tile<Int<16 * AtomLayoutMSdP>, Int<16 * kNWarps / AtomLayoutMSdP>, _16>>;
using TiledMmadKV = TiledMMA<
typename Base::MMA_Atom_Arch,
Layout<Shape<Int<AtomLayoutNdKV>, Int<kNWarps / AtomLayoutNdKV>, _1>>,
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
Tile<Int<16 * AtomLayoutNdKV>, Int<16 * kNWarps / AtomLayoutNdKV>, _16>>;
using TiledMmadQ = TiledMMA<
typename Base::MMA_Atom_Arch,
Layout<Shape<Int<AtomLayoutMdQ>, Int<kNWarps / AtomLayoutMdQ>, _1>>, // 2x4x1 or 4x2x1 thread group
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
Tile<Int<16 * AtomLayoutMdQ>, Int<16 * kNWarps / AtomLayoutMdQ>, _16>>;
using SmemLayoutAtomQdO = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<_8, Int<kBlockKSmem>>,

View File

@ -256,7 +256,7 @@ class TestSparseSemiStructured(TestCase):
if dtype is torch.int8:
# This should fail
if backend == "cutlass":
with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_cutlass_dispatch_layouts"):
with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_dispatch_layouts"):
sparse_result = torch.mm(A_sparse, B)
else:
with self.assertRaisesRegex(RuntimeError,
@ -287,7 +287,7 @@ class TestSparseSemiStructured(TestCase):
# padding with int8 throws an error because transposing B yields a contiguous output
# and row-row 2:4 sparse @ dense with NN is not supported by cuSPARSELt or CUTLASS.
if backend == "cutlass":
with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_cutlass_dispatch_layouts"):
with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_dispatch_layouts"):
sparse_result = torch.mm(A_sparse, B.t())
else:
with self.assertRaisesRegex(RuntimeError,