mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update cutlass from 3.3.0 to 3.4.1 (#120434)
### COPY OF https://github.com/pytorch/pytorch/pull/120010 ### Update I have rolled the two blocking changes into this PR, I also imported this to fbcode to verify that nothing is breaking: D53870253 This copy was generated by merging in all the internal only changes into one merged atomic commit and re-exporting to github ### Current Status - [PR](https://github.com/pytorch/pytorch/pull/118935) aims to update the flash attention kernels to a more recent version. Pull Request resolved: https://github.com/pytorch/pytorch/pull/120434 Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
f609f2050f
commit
36c1cc962a
@ -3,31 +3,36 @@
|
|||||||
#include <ATen/cuda/CUDAUtils.h>
|
#include <ATen/cuda/CUDAUtils.h>
|
||||||
#include <ATen/Dispatch.h>
|
#include <ATen/Dispatch.h>
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
|
||||||
|
#else
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <cutlass/cutlass.h>
|
#include <cutlass/cutlass.h>
|
||||||
#include <cutlass/layout/layout.h>
|
#include <cutlass/layout/layout.h>
|
||||||
#include <cutlass/tensor_ref.h>
|
#include <cutlass/tensor_ref.h>
|
||||||
#include <cutlass/epilogue/thread/linear_combination.h>
|
#include <cutlass/gemm/device/gemm_sparse_with_visitor.h>
|
||||||
#include <cutlass/epilogue/thread/linear_combination_relu.h>
|
#include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
|
||||||
#include <cutlass/epilogue/thread/linear_combination_silu.h>
|
|
||||||
#include <cutlass/gemm/device/gemm_sparse_row_broadcast.h>
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
|
||||||
|
#else
|
||||||
#define CUTLASS_STATUS_CHECK(status) \
|
#define CUTLASS_STATUS_CHECK(status) \
|
||||||
{ \
|
{ \
|
||||||
TORCH_CHECK(status == cutlass::Status::kSuccess, \
|
TORCH_CHECK(status == cutlass::Status::kSuccess, \
|
||||||
"Got CUTLASS error: ", cutlassGetStatusString(status)); \
|
"Got CUTLASS error: ", cutlassGetStatusString(status)); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
enum class Activation{NONE, RELU, SILU};
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace at::native {
|
namespace at::native {
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
|
||||||
|
#else
|
||||||
// Wrapper function for CUTLASS sparse GEMM implementation, used
|
// Wrapper function for CUTLASS sparse GEMM implementation, used
|
||||||
// solely to simplify dispatching from
|
// solely to simplify dispatching from
|
||||||
// _sparse_semi_structured_linear() function below.
|
// _sparse_semi_structured_linear() function below.
|
||||||
@ -36,14 +41,14 @@ template <
|
|||||||
typename ElementInputB,
|
typename ElementInputB,
|
||||||
typename ElementOutput,
|
typename ElementOutput,
|
||||||
typename ElementAccumulator,
|
typename ElementAccumulator,
|
||||||
typename ElementComputeEpilogue,
|
|
||||||
typename ThreadblockShape,
|
typename ThreadblockShape,
|
||||||
typename WarpShape,
|
typename WarpShape,
|
||||||
typename InstructionShape,
|
typename InstructionShape,
|
||||||
typename EpilogueOp,
|
|
||||||
typename LayoutInputA,
|
typename LayoutInputA,
|
||||||
typename LayoutInputB>
|
typename LayoutInputB,
|
||||||
Tensor two_four_sgemm_cutlass(
|
bool use_bias,
|
||||||
|
Activation activation>
|
||||||
|
Tensor two_four_sgemm(
|
||||||
const Tensor& tensor_a,
|
const Tensor& tensor_a,
|
||||||
const at::IntArrayRef::value_type& tensor_a_stride,
|
const at::IntArrayRef::value_type& tensor_a_stride,
|
||||||
const Tensor& tensor_b,
|
const Tensor& tensor_b,
|
||||||
@ -57,22 +62,107 @@ Tensor two_four_sgemm_cutlass(
|
|||||||
using SmArch = cutlass::arch::Sm80; // Only CC 8.x devices are supported at the moment.
|
using SmArch = cutlass::arch::Sm80; // Only CC 8.x devices are supported at the moment.
|
||||||
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This choice provides good performance across wide range of operand sizes.
|
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This choice provides good performance across wide range of operand sizes.
|
||||||
constexpr int NumStages = 3; // This choice provides good performance across wide range of operand sizes.
|
constexpr int NumStages = 3; // This choice provides good performance across wide range of operand sizes.
|
||||||
using Gemm = cutlass::gemm::device::SparseGemmRowBroadcast<
|
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||||
|
constexpr int NumEVTEpilogueStages = 1;
|
||||||
|
|
||||||
|
constexpr int AlignmentInputA = 128 / cutlass::sizeof_bits<ElementInputA>::value;
|
||||||
|
constexpr int AlignmentInputB = 128 / cutlass::sizeof_bits<ElementInputB>::value;
|
||||||
|
constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;
|
||||||
|
|
||||||
|
using ElementComputeEpilogue = ElementAccumulator;
|
||||||
|
constexpr int AlignmentComputeEpilogue = 128 / cutlass::sizeof_bits<ElementComputeEpilogue>::value;
|
||||||
|
using ElementC = ElementOutput;
|
||||||
|
using LayoutC = LayoutOutput;
|
||||||
|
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||||
|
|
||||||
|
using BiasTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
|
||||||
|
ThreadblockShape,
|
||||||
|
WarpShape,
|
||||||
|
ElementC,
|
||||||
|
AlignmentC,
|
||||||
|
NumEVTEpilogueStages>;
|
||||||
|
using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
|
||||||
|
ThreadblockShape,
|
||||||
|
WarpShape,
|
||||||
|
ElementOutput,
|
||||||
|
AlignmentOutput,
|
||||||
|
NumEVTEpilogueStages>;
|
||||||
|
|
||||||
|
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
|
||||||
|
|
||||||
|
using BiasScalar =
|
||||||
|
cutlass::epilogue::threadblock::VisitorScalarBroadcast<ElementC>;
|
||||||
|
using BiasTensor =
|
||||||
|
cutlass::epilogue::threadblock::VisitorColBroadcast<
|
||||||
|
BiasTileThreadMap,
|
||||||
|
ElementC,
|
||||||
|
cute::Stride<cute::_1, cute::_0, int64_t>>;
|
||||||
|
using Bias = std::conditional_t<use_bias, BiasTensor, BiasScalar>;
|
||||||
|
using BiasArguments = typename Bias::Arguments;
|
||||||
|
|
||||||
|
using ApplyBias = cutlass::epilogue::threadblock::VisitorCompute<
|
||||||
|
cutlass::plus, ElementComputeEpilogue, ElementComputeEpilogue,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
using EVTApplyBias = cutlass::epilogue::threadblock::Sm80EVT<
|
||||||
|
ApplyBias,
|
||||||
|
Accum,
|
||||||
|
Bias>;
|
||||||
|
|
||||||
|
using ApplyActivationNone = cutlass::epilogue::threadblock::VisitorCompute<
|
||||||
|
cutlass::epilogue::thread::Identity,
|
||||||
|
ElementComputeEpilogue,
|
||||||
|
ElementComputeEpilogue,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
using ApplyActivationReLu = cutlass::epilogue::threadblock::VisitorCompute<
|
||||||
|
cutlass::epilogue::thread::ReLu,
|
||||||
|
ElementComputeEpilogue,
|
||||||
|
ElementComputeEpilogue,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
using ApplyActivationSiLu = cutlass::epilogue::threadblock::VisitorCompute<
|
||||||
|
cutlass::epilogue::thread::SiLu,
|
||||||
|
ElementComputeEpilogue,
|
||||||
|
ElementComputeEpilogue,
|
||||||
|
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||||
|
using ApplyActivation =
|
||||||
|
std::conditional_t<
|
||||||
|
activation == Activation::NONE,
|
||||||
|
ApplyActivationNone,
|
||||||
|
std::conditional_t<
|
||||||
|
activation == Activation::RELU,
|
||||||
|
ApplyActivationReLu,
|
||||||
|
ApplyActivationSiLu>>;
|
||||||
|
using EVTApplyActivation = cutlass::epilogue::threadblock::Sm80EVT<
|
||||||
|
ApplyActivation,
|
||||||
|
EVTApplyBias>;
|
||||||
|
|
||||||
|
using Output = cutlass::epilogue::threadblock::VisitorAuxStore<
|
||||||
|
OutputTileThreadMap, ElementOutput, cutlass::FloatRoundStyle::round_to_nearest,
|
||||||
|
cute::Stride<int64_t, cute::_1, int64_t>>;
|
||||||
|
|
||||||
|
using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT<
|
||||||
|
Output,
|
||||||
|
EVTApplyActivation>;
|
||||||
|
|
||||||
|
using Gemm = cutlass::gemm::device::SparseGemmWithVisitor<
|
||||||
ElementInputA,
|
ElementInputA,
|
||||||
LayoutInputA,
|
LayoutInputA,
|
||||||
ElementInputB,
|
ElementInputB,
|
||||||
LayoutInputB,
|
LayoutInputB,
|
||||||
ElementOutput,
|
ElementC,
|
||||||
LayoutOutput,
|
LayoutC,
|
||||||
ElementAccumulator,
|
ElementAccumulator,
|
||||||
MMAOp,
|
MMAOp,
|
||||||
SmArch,
|
SmArch,
|
||||||
ThreadblockShape,
|
ThreadblockShape,
|
||||||
WarpShape,
|
WarpShape,
|
||||||
InstructionShape,
|
InstructionShape,
|
||||||
EpilogueOp,
|
EVTOutput,
|
||||||
SwizzleThreadBlock,
|
SwizzleThreadBlock,
|
||||||
NumStages>;
|
NumStages,
|
||||||
|
AlignmentInputA,
|
||||||
|
AlignmentInputB,
|
||||||
|
Operator,
|
||||||
|
NumEVTEpilogueStages>;
|
||||||
|
|
||||||
// Datatype and layout of metadata matrix are inferred from sparse
|
// Datatype and layout of metadata matrix are inferred from sparse
|
||||||
// GEMM template.
|
// GEMM template.
|
||||||
@ -105,11 +195,11 @@ Tensor two_four_sgemm_cutlass(
|
|||||||
meta_dtype = at::kInt;
|
meta_dtype = at::kInt;
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
AT_ERROR("two_four_sgemm_cutlass: invalid size of meta tensor datatype "
|
AT_ERROR("two_four_sgemm: invalid size of meta tensor datatype "
|
||||||
"encountered");
|
"encountered");
|
||||||
}
|
}
|
||||||
TORCH_CHECK(meta.dtype() == meta_dtype,
|
TORCH_CHECK(meta.dtype() == meta_dtype,
|
||||||
"two_four_sgemm_cutlass: Expected meta datatype ", meta_dtype,
|
"two_four_sgemm: Expected meta datatype ", meta_dtype,
|
||||||
", but got ", meta.dtype());
|
", but got ", meta.dtype());
|
||||||
|
|
||||||
// Determine PyTorch datatype for the output matrix.
|
// Determine PyTorch datatype for the output matrix.
|
||||||
@ -125,64 +215,69 @@ Tensor two_four_sgemm_cutlass(
|
|||||||
} else if constexpr (std::is_same_v<ElementOutput, float>) {
|
} else if constexpr (std::is_same_v<ElementOutput, float>) {
|
||||||
tensor_d_dtype = at::kFloat;
|
tensor_d_dtype = at::kFloat;
|
||||||
} else {
|
} else {
|
||||||
AT_ERROR("two_four_sgemm_cutlass: invalid datatype for sparse GEMM ",
|
AT_ERROR("two_four_sgemm: invalid datatype for sparse GEMM output ",
|
||||||
" output encountered");
|
"encountered");
|
||||||
}
|
}
|
||||||
if (tensor_c.numel() != 0) {
|
if constexpr (use_bias) {
|
||||||
TORCH_CHECK(tensor_c.dtype() == tensor_d_dtype,
|
TORCH_CHECK(tensor_c.dtype() == tensor_d_dtype,
|
||||||
"two_four_sgemm_cutlass: Expected sparse GEMM bias "
|
"two_four_sgemm: Expected sparse GEMM bias datatype ",
|
||||||
"datatype ", tensor_d_dtype, ", but got ",
|
tensor_d_dtype, ", but got ", tensor_c.dtype());
|
||||||
tensor_c.dtype());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create output matrix.
|
// Create output matrix.
|
||||||
Tensor tensor_d;
|
Tensor tensor_d =
|
||||||
if (tensor_c.numel() != 0) {
|
tensor_a.new_empty({length_m, length_n},
|
||||||
tensor_d = tensor_c.new_empty({length_m, length_n});
|
at::TensorOptions().dtype(tensor_d_dtype));
|
||||||
} else {
|
|
||||||
tensor_d =
|
|
||||||
tensor_a.new_empty({length_m, length_n},
|
|
||||||
at::TensorOptions().dtype(tensor_d_dtype));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare arguments for CUTLASS sparse GEMM kernel.
|
// Prepare arguments for CUTLASS sparse GEMM kernel.
|
||||||
cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);
|
cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);
|
||||||
LayoutInputA layout_a(tensor_a_stride);
|
LayoutInputA layout_a(tensor_a_stride);
|
||||||
LayoutInputB layout_b(tensor_b_stride);
|
LayoutInputB layout_b(tensor_b_stride);
|
||||||
LayoutOutput layout_c(tensor_c.numel() != 0 ? tensor_c.stride(0) : 0);
|
|
||||||
LayoutOutput layout_d(tensor_d.stride(0));
|
|
||||||
auto tensor_a_device_ref =
|
auto tensor_a_device_ref =
|
||||||
cutlass::TensorRef<ElementInputA, LayoutInputA>(
|
cutlass::TensorRef<ElementInputA, LayoutInputA>(
|
||||||
(ElementInputA*)tensor_a.data_ptr(), layout_a);
|
(ElementInputA*)tensor_a.data_ptr(), layout_a);
|
||||||
auto tensor_b_device_ref =
|
auto tensor_b_device_ref =
|
||||||
cutlass::TensorRef<ElementInputB, LayoutInputB>(
|
cutlass::TensorRef<ElementInputB, LayoutInputB>(
|
||||||
(ElementInputB*)tensor_b.data_ptr(), layout_b);
|
(ElementInputB*)tensor_b.data_ptr(), layout_b);
|
||||||
auto tensor_c_device_ref =
|
|
||||||
cutlass::TensorRef<ElementOutput, LayoutOutput>(
|
|
||||||
(ElementOutput*)(tensor_c.numel() != 0 ?
|
|
||||||
tensor_c.data_ptr() : tensor_d.data_ptr()),
|
|
||||||
layout_c);
|
|
||||||
auto tensor_d_device_ref =
|
|
||||||
cutlass::TensorRef<ElementOutput, LayoutOutput>(
|
|
||||||
(ElementOutput*)tensor_d.data_ptr(), layout_d);
|
|
||||||
auto tensor_e_reordered_device_ref =
|
auto tensor_e_reordered_device_ref =
|
||||||
cutlass::TensorRef<ElementInputE, ReorderedLayoutInputE>(
|
cutlass::TensorRef<ElementInputE, ReorderedLayoutInputE>(
|
||||||
(ElementInputE*)meta.data_ptr(),
|
(ElementInputE*)meta.data_ptr(),
|
||||||
ReorderedLayoutInputE::packed({length_m, meta_ncols}));
|
ReorderedLayoutInputE::packed({length_m, meta_ncols}));
|
||||||
ElementComputeEpilogue alpha(1);
|
|
||||||
ElementComputeEpilogue beta(tensor_c.numel() != 0 ? 1 : 0);
|
BiasArguments bias_arguments{
|
||||||
constexpr int split_k_slices = 1;
|
[&]() -> BiasArguments {
|
||||||
|
if constexpr (use_bias) {
|
||||||
|
return {(ElementC*)tensor_c.data_ptr(),
|
||||||
|
ElementC(0),
|
||||||
|
{cute::_1{}, cute::_0{}, problem_size.m()}};
|
||||||
|
} else {
|
||||||
|
return {ElementC(0)};
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
};
|
||||||
|
typename Output::Arguments output_arguments{
|
||||||
|
(ElementOutput*)tensor_d.data_ptr(),
|
||||||
|
{problem_size.n(), cute::_1{}, problem_size.mn().product()}
|
||||||
|
};
|
||||||
|
typename EVTOutput::Arguments callback_arguments{
|
||||||
|
{
|
||||||
|
{
|
||||||
|
{}, // Accum
|
||||||
|
bias_arguments, // Bias
|
||||||
|
{} // ApplyBias
|
||||||
|
}, // EVTApplyBias
|
||||||
|
{} // ApplyActivation
|
||||||
|
}, // EVTApplyActivation
|
||||||
|
output_arguments, // Output
|
||||||
|
}; // EVTOutput
|
||||||
|
|
||||||
// Create a tuple of CUTLASS sparse GEMM kernel arguments.
|
// Create a tuple of CUTLASS sparse GEMM kernel arguments.
|
||||||
typename Gemm::Arguments arguments{
|
typename Gemm::Arguments arguments{
|
||||||
problem_size,
|
problem_size,
|
||||||
tensor_a_device_ref,
|
tensor_a_device_ref,
|
||||||
tensor_b_device_ref,
|
tensor_b_device_ref,
|
||||||
tensor_c_device_ref,
|
|
||||||
tensor_d_device_ref,
|
|
||||||
tensor_e_reordered_device_ref,
|
tensor_e_reordered_device_ref,
|
||||||
{alpha, beta},
|
callback_arguments};
|
||||||
split_k_slices};
|
|
||||||
|
|
||||||
cutlass::Status status;
|
cutlass::Status status;
|
||||||
|
|
||||||
@ -219,16 +314,16 @@ template <
|
|||||||
typename ElementInputB,
|
typename ElementInputB,
|
||||||
typename ElementOutput,
|
typename ElementOutput,
|
||||||
typename ElementAccumulator,
|
typename ElementAccumulator,
|
||||||
typename ElementComputeEpilogue,
|
|
||||||
typename ThreadblockShape,
|
typename ThreadblockShape,
|
||||||
typename WarpShape,
|
typename WarpShape,
|
||||||
typename InstructionShape,
|
typename InstructionShape,
|
||||||
typename EpilogueOp,
|
|
||||||
bool EnableRowMajorRowMajorLayouts,
|
bool EnableRowMajorRowMajorLayouts,
|
||||||
bool EnableRowMajorColumnMajorLayouts,
|
bool EnableRowMajorColumnMajorLayouts,
|
||||||
bool EnableColumnMajorRowMajorLayouts,
|
bool EnableColumnMajorRowMajorLayouts,
|
||||||
bool EnableColumnMajorColumnMajorLayouts>
|
bool EnableColumnMajorColumnMajorLayouts,
|
||||||
Tensor two_four_sgemm_cutlass_dispatch_layouts(
|
bool use_bias,
|
||||||
|
Activation activation>
|
||||||
|
Tensor two_four_sgemm_dispatch_layouts(
|
||||||
const Tensor& tensor_a, const Tensor& tensor_b, const Tensor& tensor_c,
|
const Tensor& tensor_a, const Tensor& tensor_b, const Tensor& tensor_c,
|
||||||
const Tensor& meta) {
|
const Tensor& meta) {
|
||||||
// Determine layouts (row-major or column-major) of input tensors.
|
// Determine layouts (row-major or column-major) of input tensors.
|
||||||
@ -242,18 +337,18 @@ Tensor two_four_sgemm_cutlass_dispatch_layouts(
|
|||||||
// Perform dispatching.
|
// Perform dispatching.
|
||||||
if constexpr (EnableRowMajorRowMajorLayouts) {
|
if constexpr (EnableRowMajorRowMajorLayouts) {
|
||||||
if (tensor_a_row_major && tensor_b_row_major) {
|
if (tensor_a_row_major && tensor_b_row_major) {
|
||||||
return two_four_sgemm_cutlass<
|
return two_four_sgemm<
|
||||||
ElementInputA,
|
ElementInputA,
|
||||||
ElementInputB,
|
ElementInputB,
|
||||||
ElementOutput,
|
ElementOutput,
|
||||||
ElementAccumulator,
|
ElementAccumulator,
|
||||||
ElementComputeEpilogue,
|
|
||||||
ThreadblockShape,
|
ThreadblockShape,
|
||||||
WarpShape,
|
WarpShape,
|
||||||
InstructionShape,
|
InstructionShape,
|
||||||
EpilogueOp,
|
|
||||||
cutlass::layout::RowMajor,
|
cutlass::layout::RowMajor,
|
||||||
cutlass::layout::RowMajor>(
|
cutlass::layout::RowMajor,
|
||||||
|
use_bias,
|
||||||
|
activation>(
|
||||||
tensor_a,
|
tensor_a,
|
||||||
tensor_a_stride,
|
tensor_a_stride,
|
||||||
tensor_b,
|
tensor_b,
|
||||||
@ -264,18 +359,18 @@ Tensor two_four_sgemm_cutlass_dispatch_layouts(
|
|||||||
}
|
}
|
||||||
if constexpr (EnableRowMajorColumnMajorLayouts) {
|
if constexpr (EnableRowMajorColumnMajorLayouts) {
|
||||||
if (tensor_a_row_major && !tensor_b_row_major) {
|
if (tensor_a_row_major && !tensor_b_row_major) {
|
||||||
return two_four_sgemm_cutlass<
|
return two_four_sgemm<
|
||||||
ElementInputA,
|
ElementInputA,
|
||||||
ElementInputB,
|
ElementInputB,
|
||||||
ElementOutput,
|
ElementOutput,
|
||||||
ElementAccumulator,
|
ElementAccumulator,
|
||||||
ElementComputeEpilogue,
|
|
||||||
ThreadblockShape,
|
ThreadblockShape,
|
||||||
WarpShape,
|
WarpShape,
|
||||||
InstructionShape,
|
InstructionShape,
|
||||||
EpilogueOp,
|
|
||||||
cutlass::layout::RowMajor,
|
cutlass::layout::RowMajor,
|
||||||
cutlass::layout::ColumnMajor>(
|
cutlass::layout::ColumnMajor,
|
||||||
|
use_bias,
|
||||||
|
activation>(
|
||||||
tensor_a,
|
tensor_a,
|
||||||
tensor_a_stride,
|
tensor_a_stride,
|
||||||
tensor_b,
|
tensor_b,
|
||||||
@ -286,18 +381,18 @@ Tensor two_four_sgemm_cutlass_dispatch_layouts(
|
|||||||
}
|
}
|
||||||
if constexpr (EnableColumnMajorRowMajorLayouts) {
|
if constexpr (EnableColumnMajorRowMajorLayouts) {
|
||||||
if (!tensor_a_row_major && tensor_b_row_major) {
|
if (!tensor_a_row_major && tensor_b_row_major) {
|
||||||
return two_four_sgemm_cutlass<
|
return two_four_sgemm<
|
||||||
ElementInputA,
|
ElementInputA,
|
||||||
ElementInputB,
|
ElementInputB,
|
||||||
ElementOutput,
|
ElementOutput,
|
||||||
ElementAccumulator,
|
ElementAccumulator,
|
||||||
ElementComputeEpilogue,
|
|
||||||
ThreadblockShape,
|
ThreadblockShape,
|
||||||
WarpShape,
|
WarpShape,
|
||||||
InstructionShape,
|
InstructionShape,
|
||||||
EpilogueOp,
|
|
||||||
cutlass::layout::ColumnMajor,
|
cutlass::layout::ColumnMajor,
|
||||||
cutlass::layout::RowMajor>(
|
cutlass::layout::RowMajor,
|
||||||
|
use_bias,
|
||||||
|
activation>(
|
||||||
tensor_a,
|
tensor_a,
|
||||||
tensor_a_stride,
|
tensor_a_stride,
|
||||||
tensor_b,
|
tensor_b,
|
||||||
@ -308,18 +403,18 @@ Tensor two_four_sgemm_cutlass_dispatch_layouts(
|
|||||||
}
|
}
|
||||||
if constexpr (EnableColumnMajorColumnMajorLayouts) {
|
if constexpr (EnableColumnMajorColumnMajorLayouts) {
|
||||||
if (!tensor_a_row_major && !tensor_b_row_major) {
|
if (!tensor_a_row_major && !tensor_b_row_major) {
|
||||||
return two_four_sgemm_cutlass<
|
return two_four_sgemm<
|
||||||
ElementInputA,
|
ElementInputA,
|
||||||
ElementInputB,
|
ElementInputB,
|
||||||
ElementOutput,
|
ElementOutput,
|
||||||
ElementAccumulator,
|
ElementAccumulator,
|
||||||
ElementComputeEpilogue,
|
|
||||||
ThreadblockShape,
|
ThreadblockShape,
|
||||||
WarpShape,
|
WarpShape,
|
||||||
InstructionShape,
|
InstructionShape,
|
||||||
EpilogueOp,
|
|
||||||
cutlass::layout::ColumnMajor,
|
cutlass::layout::ColumnMajor,
|
||||||
cutlass::layout::ColumnMajor>(
|
cutlass::layout::ColumnMajor,
|
||||||
|
use_bias,
|
||||||
|
activation>(
|
||||||
tensor_a,
|
tensor_a,
|
||||||
tensor_a_stride,
|
tensor_a_stride,
|
||||||
tensor_b,
|
tensor_b,
|
||||||
@ -329,20 +424,77 @@ Tensor two_four_sgemm_cutlass_dispatch_layouts(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
AT_ERROR("two_four_sgemm_cutlass_dispatch_layouts: Combination of ",
|
AT_ERROR("two_four_sgemm_dispatch_layouts: Combination of ",
|
||||||
tensor_a_row_major ? "row-major" : "column_major", " and ",
|
tensor_a_row_major ? "row-major" : "column_major", " and ",
|
||||||
tensor_b_row_major ? "row-major" : "column_major",
|
tensor_b_row_major ? "row-major" : "column_major",
|
||||||
" layouts for input tensors is not supported");
|
" layouts for input tensors is not supported");
|
||||||
return Tensor{};
|
return Tensor{};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Dispatch according to the bias tensor being provided or not.
|
||||||
|
template <
|
||||||
|
typename ElementInputA,
|
||||||
|
typename ElementInputB,
|
||||||
|
typename ElementOutput,
|
||||||
|
typename ElementAccumulator,
|
||||||
|
typename ThreadblockShape,
|
||||||
|
typename WarpShape,
|
||||||
|
typename InstructionShape,
|
||||||
|
bool EnableRowMajorRowMajorLayouts,
|
||||||
|
bool EnableRowMajorColumnMajorLayouts,
|
||||||
|
bool EnableColumnMajorRowMajorLayouts,
|
||||||
|
bool EnableColumnMajorColumnMajorLayouts,
|
||||||
|
Activation activation>
|
||||||
|
Tensor two_four_sgemm_dispatch_layouts_bias(
|
||||||
|
const Tensor& tensor_a, const Tensor& tensor_b, const Tensor& tensor_c,
|
||||||
|
const Tensor& meta) {
|
||||||
|
if (tensor_c.numel() > 0) {
|
||||||
|
return two_four_sgemm_dispatch_layouts<
|
||||||
|
ElementInputA,
|
||||||
|
ElementInputB,
|
||||||
|
ElementOutput,
|
||||||
|
ElementAccumulator,
|
||||||
|
ThreadblockShape,
|
||||||
|
WarpShape,
|
||||||
|
InstructionShape,
|
||||||
|
EnableRowMajorRowMajorLayouts,
|
||||||
|
EnableRowMajorColumnMajorLayouts,
|
||||||
|
EnableColumnMajorRowMajorLayouts,
|
||||||
|
EnableColumnMajorColumnMajorLayouts,
|
||||||
|
true,
|
||||||
|
activation>(
|
||||||
|
tensor_a,
|
||||||
|
tensor_b,
|
||||||
|
tensor_c,
|
||||||
|
meta);
|
||||||
|
} else {
|
||||||
|
return two_four_sgemm_dispatch_layouts<
|
||||||
|
ElementInputA,
|
||||||
|
ElementInputB,
|
||||||
|
ElementOutput,
|
||||||
|
ElementAccumulator,
|
||||||
|
ThreadblockShape,
|
||||||
|
WarpShape,
|
||||||
|
InstructionShape,
|
||||||
|
EnableRowMajorRowMajorLayouts,
|
||||||
|
EnableRowMajorColumnMajorLayouts,
|
||||||
|
EnableColumnMajorRowMajorLayouts,
|
||||||
|
EnableColumnMajorColumnMajorLayouts,
|
||||||
|
false,
|
||||||
|
activation>(
|
||||||
|
tensor_a,
|
||||||
|
tensor_b,
|
||||||
|
tensor_c,
|
||||||
|
meta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Dispatch according to the activation functions enabled.
|
// Dispatch according to the activation functions enabled.
|
||||||
template <
|
template <
|
||||||
typename ElementInputA,
|
typename ElementInputA,
|
||||||
typename ElementInputB,
|
typename ElementInputB,
|
||||||
typename ElementOutput,
|
typename ElementOutput,
|
||||||
typename ElementAccumulator,
|
typename ElementAccumulator,
|
||||||
typename ElementComputeEpilogue,
|
|
||||||
typename ThreadblockShape,
|
typename ThreadblockShape,
|
||||||
typename WarpShape,
|
typename WarpShape,
|
||||||
typename InstructionShape,
|
typename InstructionShape,
|
||||||
@ -353,32 +505,25 @@ template <
|
|||||||
bool EnableActivationNone,
|
bool EnableActivationNone,
|
||||||
bool EnableActivationReLU,
|
bool EnableActivationReLU,
|
||||||
bool EnableActivationSiLU>
|
bool EnableActivationSiLU>
|
||||||
Tensor two_four_sgemm_cutlass_dispatch_layouts_activation(
|
Tensor two_four_sgemm_dispatch_layouts_bias_activation(
|
||||||
const Tensor& tensor_a, const Tensor& tensor_b, const Tensor& tensor_c,
|
const Tensor& tensor_a, const Tensor& tensor_b, const Tensor& tensor_c,
|
||||||
const Tensor& meta, const c10::string_view& activation) {
|
const Tensor& meta, const c10::string_view& activation) {
|
||||||
// Perform dispatching.
|
// Perform dispatching.
|
||||||
if constexpr (EnableActivationNone) {
|
if constexpr (EnableActivationNone) {
|
||||||
if (activation == "none") {
|
if (activation == "none") {
|
||||||
using EpilogueOp =
|
return two_four_sgemm_dispatch_layouts_bias<
|
||||||
cutlass::epilogue::thread::LinearCombination<
|
|
||||||
ElementOutput,
|
|
||||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
|
||||||
ElementAccumulator,
|
|
||||||
ElementComputeEpilogue>;
|
|
||||||
return two_four_sgemm_cutlass_dispatch_layouts<
|
|
||||||
ElementInputA,
|
ElementInputA,
|
||||||
ElementInputB,
|
ElementInputB,
|
||||||
ElementOutput,
|
ElementOutput,
|
||||||
ElementAccumulator,
|
ElementAccumulator,
|
||||||
ElementComputeEpilogue,
|
|
||||||
ThreadblockShape,
|
ThreadblockShape,
|
||||||
WarpShape,
|
WarpShape,
|
||||||
InstructionShape,
|
InstructionShape,
|
||||||
EpilogueOp,
|
|
||||||
EnableRowMajorRowMajorLayouts,
|
EnableRowMajorRowMajorLayouts,
|
||||||
EnableRowMajorColumnMajorLayouts,
|
EnableRowMajorColumnMajorLayouts,
|
||||||
EnableColumnMajorRowMajorLayouts,
|
EnableColumnMajorRowMajorLayouts,
|
||||||
EnableColumnMajorColumnMajorLayouts>(
|
EnableColumnMajorColumnMajorLayouts,
|
||||||
|
Activation::NONE>(
|
||||||
tensor_a,
|
tensor_a,
|
||||||
tensor_b,
|
tensor_b,
|
||||||
tensor_c,
|
tensor_c,
|
||||||
@ -387,26 +532,19 @@ Tensor two_four_sgemm_cutlass_dispatch_layouts_activation(
|
|||||||
}
|
}
|
||||||
if constexpr (EnableActivationReLU) {
|
if constexpr (EnableActivationReLU) {
|
||||||
if (activation == "relu") {
|
if (activation == "relu") {
|
||||||
using EpilogueOp =
|
return two_four_sgemm_dispatch_layouts_bias<
|
||||||
cutlass::epilogue::thread::LinearCombinationRelu<
|
|
||||||
ElementOutput,
|
|
||||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
|
||||||
ElementAccumulator,
|
|
||||||
ElementComputeEpilogue>;
|
|
||||||
return two_four_sgemm_cutlass_dispatch_layouts<
|
|
||||||
ElementInputA,
|
ElementInputA,
|
||||||
ElementInputB,
|
ElementInputB,
|
||||||
ElementOutput,
|
ElementOutput,
|
||||||
ElementAccumulator,
|
ElementAccumulator,
|
||||||
ElementComputeEpilogue,
|
|
||||||
ThreadblockShape,
|
ThreadblockShape,
|
||||||
WarpShape,
|
WarpShape,
|
||||||
InstructionShape,
|
InstructionShape,
|
||||||
EpilogueOp,
|
|
||||||
EnableRowMajorRowMajorLayouts,
|
EnableRowMajorRowMajorLayouts,
|
||||||
EnableRowMajorColumnMajorLayouts,
|
EnableRowMajorColumnMajorLayouts,
|
||||||
EnableColumnMajorRowMajorLayouts,
|
EnableColumnMajorRowMajorLayouts,
|
||||||
EnableColumnMajorColumnMajorLayouts>(
|
EnableColumnMajorColumnMajorLayouts,
|
||||||
|
Activation::RELU>(
|
||||||
tensor_a,
|
tensor_a,
|
||||||
tensor_b,
|
tensor_b,
|
||||||
tensor_c,
|
tensor_c,
|
||||||
@ -415,26 +553,19 @@ Tensor two_four_sgemm_cutlass_dispatch_layouts_activation(
|
|||||||
}
|
}
|
||||||
if constexpr (EnableActivationSiLU) {
|
if constexpr (EnableActivationSiLU) {
|
||||||
if (activation == "silu") {
|
if (activation == "silu") {
|
||||||
using EpilogueOp =
|
return two_four_sgemm_dispatch_layouts_bias<
|
||||||
cutlass::epilogue::thread::LinearCombinationSilu<
|
|
||||||
ElementOutput,
|
|
||||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
|
||||||
ElementAccumulator,
|
|
||||||
ElementComputeEpilogue>;
|
|
||||||
return two_four_sgemm_cutlass_dispatch_layouts<
|
|
||||||
ElementInputA,
|
ElementInputA,
|
||||||
ElementInputB,
|
ElementInputB,
|
||||||
ElementOutput,
|
ElementOutput,
|
||||||
ElementAccumulator,
|
ElementAccumulator,
|
||||||
ElementComputeEpilogue,
|
|
||||||
ThreadblockShape,
|
ThreadblockShape,
|
||||||
WarpShape,
|
WarpShape,
|
||||||
InstructionShape,
|
InstructionShape,
|
||||||
EpilogueOp,
|
|
||||||
EnableRowMajorRowMajorLayouts,
|
EnableRowMajorRowMajorLayouts,
|
||||||
EnableRowMajorColumnMajorLayouts,
|
EnableRowMajorColumnMajorLayouts,
|
||||||
EnableColumnMajorRowMajorLayouts,
|
EnableColumnMajorRowMajorLayouts,
|
||||||
EnableColumnMajorColumnMajorLayouts>(
|
EnableColumnMajorColumnMajorLayouts,
|
||||||
|
Activation::SILU>(
|
||||||
tensor_a,
|
tensor_a,
|
||||||
tensor_b,
|
tensor_b,
|
||||||
tensor_c,
|
tensor_c,
|
||||||
@ -442,8 +573,8 @@ Tensor two_four_sgemm_cutlass_dispatch_layouts_activation(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
AT_ERROR("two_four_sgemm_cutlass_dispatch_layouts: Activation \"",
|
AT_ERROR("two_four_sgemm_dispatch_layouts: Activation \"", activation,
|
||||||
activation, "\" is not supported for given input tensors");
|
"\" is not supported for given input tensors");
|
||||||
return Tensor{};
|
return Tensor{};
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@ -472,7 +603,10 @@ Tensor _sparse_semi_structured_linear(
|
|||||||
const Tensor& meta, const c10::optional<Tensor>& bias_opt,
|
const Tensor& meta, const c10::optional<Tensor>& bias_opt,
|
||||||
const c10::optional<c10::string_view> activation_opt,
|
const c10::optional<c10::string_view> activation_opt,
|
||||||
const c10::optional<c10::ScalarType> out_dtype_opt) {
|
const c10::optional<c10::ScalarType> out_dtype_opt) {
|
||||||
#ifndef USE_ROCM
|
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
|
||||||
|
AT_ERROR("_sparse_semi_structured_linear: CUTLASS not supported");
|
||||||
|
return Tensor{};
|
||||||
|
#else
|
||||||
// No need to check that all tensors are on CUDA device, as this
|
// No need to check that all tensors are on CUDA device, as this
|
||||||
// is provided by dispatch.
|
// is provided by dispatch.
|
||||||
|
|
||||||
@ -575,7 +709,6 @@ Tensor _sparse_semi_structured_linear(
|
|||||||
using ElementInputA = int8_t;
|
using ElementInputA = int8_t;
|
||||||
using ElementInputB = int8_t;
|
using ElementInputB = int8_t;
|
||||||
using ElementAccumulator = int32_t;
|
using ElementAccumulator = int32_t;
|
||||||
using ElementComputeEpilogue = int32_t;
|
|
||||||
using ThreadblockShape =
|
using ThreadblockShape =
|
||||||
cutlass::gemm::GemmShape<128, 128, 128>;
|
cutlass::gemm::GemmShape<128, 128, 128>;
|
||||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>;
|
using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>;
|
||||||
@ -589,12 +722,11 @@ Tensor _sparse_semi_structured_linear(
|
|||||||
const auto EnableActivationSiLU = false;
|
const auto EnableActivationSiLU = false;
|
||||||
if (out_dtype_opt.has_value()) {
|
if (out_dtype_opt.has_value()) {
|
||||||
using ElementOutput = int32_t;
|
using ElementOutput = int32_t;
|
||||||
output = two_four_sgemm_cutlass_dispatch_layouts_activation<
|
output = two_four_sgemm_dispatch_layouts_bias_activation<
|
||||||
ElementInputA,
|
ElementInputA,
|
||||||
ElementInputB,
|
ElementInputB,
|
||||||
ElementOutput,
|
ElementOutput,
|
||||||
ElementAccumulator,
|
ElementAccumulator,
|
||||||
ElementComputeEpilogue,
|
|
||||||
ThreadblockShape,
|
ThreadblockShape,
|
||||||
WarpShape,
|
WarpShape,
|
||||||
InstructionShape,
|
InstructionShape,
|
||||||
@ -612,12 +744,11 @@ Tensor _sparse_semi_structured_linear(
|
|||||||
activation);
|
activation);
|
||||||
} else {
|
} else {
|
||||||
using ElementOutput = int8_t;
|
using ElementOutput = int8_t;
|
||||||
output = two_four_sgemm_cutlass_dispatch_layouts_activation<
|
output = two_four_sgemm_dispatch_layouts_bias_activation<
|
||||||
ElementInputA,
|
ElementInputA,
|
||||||
ElementInputB,
|
ElementInputB,
|
||||||
ElementOutput,
|
ElementOutput,
|
||||||
ElementAccumulator,
|
ElementAccumulator,
|
||||||
ElementComputeEpilogue,
|
|
||||||
ThreadblockShape,
|
ThreadblockShape,
|
||||||
WarpShape,
|
WarpShape,
|
||||||
InstructionShape,
|
InstructionShape,
|
||||||
@ -643,7 +774,6 @@ Tensor _sparse_semi_structured_linear(
|
|||||||
using ElementInputB = cutlass::half_t;
|
using ElementInputB = cutlass::half_t;
|
||||||
using ElementOutput = cutlass::half_t;
|
using ElementOutput = cutlass::half_t;
|
||||||
using ElementAccumulator = float;
|
using ElementAccumulator = float;
|
||||||
using ElementComputeEpilogue = float;
|
|
||||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
|
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||||
@ -654,12 +784,11 @@ Tensor _sparse_semi_structured_linear(
|
|||||||
const auto EnableActivationNone = true;
|
const auto EnableActivationNone = true;
|
||||||
const auto EnableActivationReLU = true;
|
const auto EnableActivationReLU = true;
|
||||||
const auto EnableActivationSiLU = true;
|
const auto EnableActivationSiLU = true;
|
||||||
output = two_four_sgemm_cutlass_dispatch_layouts_activation<
|
output = two_four_sgemm_dispatch_layouts_bias_activation<
|
||||||
ElementInputA,
|
ElementInputA,
|
||||||
ElementInputB,
|
ElementInputB,
|
||||||
ElementOutput,
|
ElementOutput,
|
||||||
ElementAccumulator,
|
ElementAccumulator,
|
||||||
ElementComputeEpilogue,
|
|
||||||
ThreadblockShape,
|
ThreadblockShape,
|
||||||
WarpShape,
|
WarpShape,
|
||||||
InstructionShape,
|
InstructionShape,
|
||||||
@ -684,7 +813,6 @@ Tensor _sparse_semi_structured_linear(
|
|||||||
using ElementInputB = cutlass::bfloat16_t;
|
using ElementInputB = cutlass::bfloat16_t;
|
||||||
using ElementOutput = cutlass::bfloat16_t;
|
using ElementOutput = cutlass::bfloat16_t;
|
||||||
using ElementAccumulator = float;
|
using ElementAccumulator = float;
|
||||||
using ElementComputeEpilogue = float;
|
|
||||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
|
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
|
||||||
@ -695,12 +823,11 @@ Tensor _sparse_semi_structured_linear(
|
|||||||
const auto EnableActivationNone = true;
|
const auto EnableActivationNone = true;
|
||||||
const auto EnableActivationReLU = true;
|
const auto EnableActivationReLU = true;
|
||||||
const auto EnableActivationSiLU = true;
|
const auto EnableActivationSiLU = true;
|
||||||
output = two_four_sgemm_cutlass_dispatch_layouts_activation<
|
output = two_four_sgemm_dispatch_layouts_bias_activation<
|
||||||
ElementInputA,
|
ElementInputA,
|
||||||
ElementInputB,
|
ElementInputB,
|
||||||
ElementOutput,
|
ElementOutput,
|
||||||
ElementAccumulator,
|
ElementAccumulator,
|
||||||
ElementComputeEpilogue,
|
|
||||||
ThreadblockShape,
|
ThreadblockShape,
|
||||||
WarpShape,
|
WarpShape,
|
||||||
InstructionShape,
|
InstructionShape,
|
||||||
@ -725,7 +852,6 @@ Tensor _sparse_semi_structured_linear(
|
|||||||
using ElementInputB = float;
|
using ElementInputB = float;
|
||||||
using ElementOutput = float;
|
using ElementOutput = float;
|
||||||
using ElementAccumulator = float;
|
using ElementAccumulator = float;
|
||||||
using ElementComputeEpilogue = float;
|
|
||||||
using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>;
|
using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>;
|
||||||
using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>;
|
using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>;
|
||||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||||
@ -736,12 +862,11 @@ Tensor _sparse_semi_structured_linear(
|
|||||||
const auto EnableActivationNone = true;
|
const auto EnableActivationNone = true;
|
||||||
const auto EnableActivationReLU = true;
|
const auto EnableActivationReLU = true;
|
||||||
const auto EnableActivationSiLU = true;
|
const auto EnableActivationSiLU = true;
|
||||||
output = two_four_sgemm_cutlass_dispatch_layouts_activation<
|
output = two_four_sgemm_dispatch_layouts_bias_activation<
|
||||||
ElementInputA,
|
ElementInputA,
|
||||||
ElementInputB,
|
ElementInputB,
|
||||||
ElementOutput,
|
ElementOutput,
|
||||||
ElementAccumulator,
|
ElementAccumulator,
|
||||||
ElementComputeEpilogue,
|
|
||||||
ThreadblockShape,
|
ThreadblockShape,
|
||||||
WarpShape,
|
WarpShape,
|
||||||
InstructionShape,
|
InstructionShape,
|
||||||
@ -764,9 +889,6 @@ Tensor _sparse_semi_structured_linear(
|
|||||||
auto output_sizes = input_sizes;
|
auto output_sizes = input_sizes;
|
||||||
output_sizes.back() = weight.size(0);
|
output_sizes.back() = weight.size(0);
|
||||||
return output.transpose(-1, -2).reshape(output_sizes);
|
return output.transpose(-1, -2).reshape(output_sizes);
|
||||||
#else
|
|
||||||
AT_ERROR("_sparse_semi_structured_linear: ROCm doesn't support CUTLASS");
|
|
||||||
return Tensor{};
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -775,7 +897,8 @@ Tensor _sparse_semi_structured_linear(
|
|||||||
// Following is just for testing purposes.
|
// Following is just for testing purposes.
|
||||||
namespace at::native {
|
namespace at::native {
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
|
||||||
|
#else
|
||||||
// Copied from tools/util/include/host_reorder.h, from CUTLASS source
|
// Copied from tools/util/include/host_reorder.h, from CUTLASS source
|
||||||
// tree. This is for simplicity - namely, this file is not under
|
// tree. This is for simplicity - namely, this file is not under
|
||||||
// include/cutlass in this tree, as other CUTLASS include files
|
// include/cutlass in this tree, as other CUTLASS include files
|
||||||
@ -812,7 +935,10 @@ static void reorder_meta(cutlass::TensorRef<Element, LayoutDest> dest,
|
|||||||
|
|
||||||
std::tuple<Tensor, Tensor>
|
std::tuple<Tensor, Tensor>
|
||||||
_to_sparse_semi_structured(const Tensor& dense) {
|
_to_sparse_semi_structured(const Tensor& dense) {
|
||||||
#ifndef USE_ROCM
|
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
|
||||||
|
AT_ERROR("_to_sparse_semi_structured: CUTLASS not supported");
|
||||||
|
return std::make_tuple(Tensor{}, Tensor{});
|
||||||
|
#else
|
||||||
// Check dimensions of the dense matrix.
|
// Check dimensions of the dense matrix.
|
||||||
TORCH_CHECK(dense.dim() == 2,
|
TORCH_CHECK(dense.dim() == 2,
|
||||||
"_to_sparse_semi_structured: Expected dense argument to be 2D "
|
"_to_sparse_semi_structured: Expected dense argument to be 2D "
|
||||||
@ -937,9 +1063,6 @@ _to_sparse_semi_structured(const Tensor& dense) {
|
|||||||
|
|
||||||
return std::make_tuple(sparse_cpu.to(dense.device()),
|
return std::make_tuple(sparse_cpu.to(dense.device()),
|
||||||
meta_reordered_cpu.to(dense.device()));
|
meta_reordered_cpu.to(dense.device()));
|
||||||
#else
|
|
||||||
AT_ERROR("_to_sparse_semi_structured: ROCm doesn't support CUTLASS");
|
|
||||||
return std::make_tuple(Tensor{}, Tensor{});
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -33,11 +33,12 @@ CUTE_HOST_DEVICE
|
|||||||
auto
|
auto
|
||||||
make_tiled_copy_B_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
|
make_tiled_copy_B_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
|
||||||
TiledMMA const& tiled_mma) {
|
TiledMMA const& tiled_mma) {
|
||||||
using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
|
constexpr int TileShape_N = decltype(tiled_mma.template tile_size_mnk<1>())::value;
|
||||||
|
constexpr int TileShape_K = decltype(tiled_mma.template tile_size_mnk<2>())::value;
|
||||||
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
|
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
|
||||||
constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value;
|
constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value;
|
||||||
// Divide by 2 because right now we always use 2 for the ValLayout
|
// Divide by 2 because right now we always use 2 for the ValLayout
|
||||||
constexpr int kNWarpsN = decltype(size<1>(TileShape_MNK{}))::value / AtomShape_N / 2;
|
constexpr int kNWarpsN = TileShape_N / AtomShape_N / 2;
|
||||||
constexpr int MMAStride_N = MMA_N * AtomShape_N * 2;
|
constexpr int MMAStride_N = MMA_N * AtomShape_N * 2;
|
||||||
// This gives the correct layout, idk why.
|
// This gives the correct layout, idk why.
|
||||||
// auto t = make_tile(Layout<Shape<Shape<_8, _2>, _2>,
|
// auto t = make_tile(Layout<Shape<Shape<_8, _2>, _2>,
|
||||||
@ -46,7 +47,7 @@ make_tiled_copy_B_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
|
|||||||
// Stride<_1, _64, _8> >{},
|
// Stride<_1, _64, _8> >{},
|
||||||
auto t = make_tile(Layout<Shape<Int<AtomShape_N>, Int<kNWarpsN>, _2>, // (8, 2, 2) or (8, 4, 2)
|
auto t = make_tile(Layout<Shape<Int<AtomShape_N>, Int<kNWarpsN>, _2>, // (8, 2, 2) or (8, 4, 2)
|
||||||
Stride<_1, Int<MMAStride_N>, _8> >{}, // (1, 64, 8) or (1, 32, 8)
|
Stride<_1, Int<MMAStride_N>, _8> >{}, // (1, 64, 8) or (1, 32, 8)
|
||||||
make_layout(size<2>(TileShape_MNK{})));
|
make_layout(Int<TileShape_K>{}));
|
||||||
// if (cute::thread0()) {printf("make_tiled_copy_B_warpcontiguousN "); print(t); printf("\n"); }
|
// if (cute::thread0()) {printf("make_tiled_copy_B_warpcontiguousN "); print(t); printf("\n"); }
|
||||||
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutB_TV(), t);
|
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutB_TV(), t);
|
||||||
}
|
}
|
||||||
@ -60,13 +61,13 @@ CUTE_HOST_DEVICE
|
|||||||
auto
|
auto
|
||||||
make_tiled_copy_C_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
|
make_tiled_copy_C_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
|
||||||
TiledMMA const& tiled_mma) {
|
TiledMMA const& tiled_mma) {
|
||||||
using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
|
constexpr int TileShape_M = decltype(tiled_mma.template tile_size_mnk<0>())::value;
|
||||||
|
constexpr int TileShape_N = decltype(tiled_mma.template tile_size_mnk<1>())::value;
|
||||||
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
|
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
|
||||||
constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value;
|
constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value;
|
||||||
// Divide by 2 because right now we always use 2 for the ValLayout
|
// Divide by 2 because right now we always use 2 for the ValLayout
|
||||||
constexpr int kNWarpsN = decltype(size<1>(TileShape_MNK{}))::value / AtomShape_N / 2;
|
constexpr int kNWarpsN = TileShape_N / AtomShape_N / 2; constexpr int MMAStride_N = MMA_N * AtomShape_N * 2;
|
||||||
constexpr int MMAStride_N = MMA_N * AtomShape_N * 2;
|
auto t = make_tile(make_layout(Int<TileShape_M>{}),
|
||||||
auto t = make_tile(make_layout(size<0>(TileShape_MNK{})),
|
|
||||||
Layout<Shape<Int<AtomShape_N>, Int<kNWarpsN>, _2>, // (8, 2, 2) or (8, 4, 2)
|
Layout<Shape<Int<AtomShape_N>, Int<kNWarpsN>, _2>, // (8, 2, 2) or (8, 4, 2)
|
||||||
Stride<_1, Int<MMAStride_N>, _8> >{}); // (1, 64, 8) or (1, 32, 8)
|
Stride<_1, Int<MMAStride_N>, _8> >{}); // (1, 64, 8) or (1, 32, 8)
|
||||||
// if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousN "); print(t); printf("\n"); }
|
// if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousN "); print(t); printf("\n"); }
|
||||||
@ -444,8 +445,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
|||||||
constexpr int kBlockN = Kernel_traits::kBlockN;
|
constexpr int kBlockN = Kernel_traits::kBlockN;
|
||||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||||
// constexpr int kNWarps = Kernel_traits::kNWarps;
|
// constexpr int kNWarps = Kernel_traits::kNWarps;
|
||||||
constexpr int MMA_N_SdP = kBlockN / decltype(size<1>(typename Kernel_traits::TiledMmaSdP::TiledShape_MNK{}))::value;
|
constexpr int MMA_N_SdP = kBlockN / decltype(typename Kernel_traits::TiledMmaSdP{}.template tile_size_mnk<1>())::value; constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP;
|
||||||
constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP;
|
|
||||||
constexpr bool Double_buffer = !Kernel_traits::No_double_buffer;
|
constexpr bool Double_buffer = !Kernel_traits::No_double_buffer;
|
||||||
|
|
||||||
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
|
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
|
||||||
|
@ -92,7 +92,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|||||||
constexpr int kBlockN = Kernel_traits::kBlockN;
|
constexpr int kBlockN = Kernel_traits::kBlockN;
|
||||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||||
constexpr int kNWarps = Kernel_traits::kNWarps;
|
constexpr int kNWarps = Kernel_traits::kNWarps;
|
||||||
constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value;
|
|
||||||
|
|
||||||
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
|
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
|
||||||
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
|
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
|
||||||
|
@ -34,10 +34,8 @@ struct Flash_kernel_traits {
|
|||||||
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
|
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
|
||||||
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
|
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
|
||||||
>;
|
>;
|
||||||
using ValLayoutMNK = Layout<Shape<_1, _2, _1>>;
|
|
||||||
#else
|
#else
|
||||||
using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
|
using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
|
||||||
using ValLayoutMNK = Layout<Shape<_1, _2, _2>>;
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
|
||||||
@ -78,7 +76,7 @@ struct Flash_fwd_kernel_traits : public Base {
|
|||||||
using TiledMma = TiledMMA<
|
using TiledMma = TiledMMA<
|
||||||
typename Base::MMA_Atom_Arch,
|
typename Base::MMA_Atom_Arch,
|
||||||
Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
|
Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
|
||||||
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
|
Tile<Int<16 * kNWarps>, _16, _16>>;
|
||||||
|
|
||||||
using SmemLayoutAtomQ = decltype(
|
using SmemLayoutAtomQ = decltype(
|
||||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||||
@ -220,18 +218,15 @@ struct Flash_bwd_kernel_traits : public Base {
|
|||||||
using TiledMmaSdP = TiledMMA<
|
using TiledMmaSdP = TiledMMA<
|
||||||
typename Base::MMA_Atom_Arch,
|
typename Base::MMA_Atom_Arch,
|
||||||
Layout<Shape<Int<AtomLayoutMSdP>, Int<kNWarps / AtomLayoutMSdP>, _1>>,
|
Layout<Shape<Int<AtomLayoutMSdP>, Int<kNWarps / AtomLayoutMSdP>, _1>>,
|
||||||
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
|
Tile<Int<16 * AtomLayoutMSdP>, Int<16 * kNWarps / AtomLayoutMSdP>, _16>>;
|
||||||
|
|
||||||
using TiledMmadKV = TiledMMA<
|
using TiledMmadKV = TiledMMA<
|
||||||
typename Base::MMA_Atom_Arch,
|
typename Base::MMA_Atom_Arch,
|
||||||
Layout<Shape<Int<AtomLayoutNdKV>, Int<kNWarps / AtomLayoutNdKV>, _1>>,
|
Layout<Shape<Int<AtomLayoutNdKV>, Int<kNWarps / AtomLayoutNdKV>, _1>>,
|
||||||
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
|
Tile<Int<16 * AtomLayoutNdKV>, Int<16 * kNWarps / AtomLayoutNdKV>, _16>>;
|
||||||
|
|
||||||
using TiledMmadQ = TiledMMA<
|
using TiledMmadQ = TiledMMA<
|
||||||
typename Base::MMA_Atom_Arch,
|
typename Base::MMA_Atom_Arch,
|
||||||
Layout<Shape<Int<AtomLayoutMdQ>, Int<kNWarps / AtomLayoutMdQ>, _1>>, // 2x4x1 or 4x2x1 thread group
|
Layout<Shape<Int<AtomLayoutMdQ>, Int<kNWarps / AtomLayoutMdQ>, _1>>, // 2x4x1 or 4x2x1 thread group
|
||||||
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
|
Tile<Int<16 * AtomLayoutMdQ>, Int<16 * kNWarps / AtomLayoutMdQ>, _16>>;
|
||||||
|
|
||||||
using SmemLayoutAtomQdO = decltype(
|
using SmemLayoutAtomQdO = decltype(
|
||||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||||
Layout<Shape<_8, Int<kBlockKSmem>>,
|
Layout<Shape<_8, Int<kBlockKSmem>>,
|
||||||
|
@ -256,7 +256,7 @@ class TestSparseSemiStructured(TestCase):
|
|||||||
if dtype is torch.int8:
|
if dtype is torch.int8:
|
||||||
# This should fail
|
# This should fail
|
||||||
if backend == "cutlass":
|
if backend == "cutlass":
|
||||||
with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_cutlass_dispatch_layouts"):
|
with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_dispatch_layouts"):
|
||||||
sparse_result = torch.mm(A_sparse, B)
|
sparse_result = torch.mm(A_sparse, B)
|
||||||
else:
|
else:
|
||||||
with self.assertRaisesRegex(RuntimeError,
|
with self.assertRaisesRegex(RuntimeError,
|
||||||
@ -287,7 +287,7 @@ class TestSparseSemiStructured(TestCase):
|
|||||||
# padding with int8 throws an error because transposing B yields a contiguous output
|
# padding with int8 throws an error because transposing B yields a contiguous output
|
||||||
# and row-row 2:4 sparse @ dense with NN is not supported by cuSPARSELt or CUTLASS.
|
# and row-row 2:4 sparse @ dense with NN is not supported by cuSPARSELt or CUTLASS.
|
||||||
if backend == "cutlass":
|
if backend == "cutlass":
|
||||||
with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_cutlass_dispatch_layouts"):
|
with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_dispatch_layouts"):
|
||||||
sparse_result = torch.mm(A_sparse, B.t())
|
sparse_result = torch.mm(A_sparse, B.t())
|
||||||
else:
|
else:
|
||||||
with self.assertRaisesRegex(RuntimeError,
|
with self.assertRaisesRegex(RuntimeError,
|
||||||
|
2
third_party/cutlass
vendored
2
third_party/cutlass
vendored
Submodule third_party/cutlass updated: a75b4ac483...bbe579a9e3
Reference in New Issue
Block a user