Simplify ATen sparse semi-structured operators based on CUTLASS (#123473)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123473
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Aleksandar Samardžić
2024-04-13 18:35:02 +00:00
committed by PyTorch MergeBot
parent 635c238bad
commit f5331aade5
9 changed files with 1139 additions and 185 deletions

View File

@ -3342,10 +3342,19 @@
dispatch:
CUDA: _cslt_sparse_mm_search
# DEPRECATED: Use torch.__sparse_semi_structured_mm/torch._sparse_semi_structured_addmm instead
- func: _sparse_semi_structured_linear(Tensor input, Tensor weight, Tensor meta, *, Tensor? bias=None, str? activation=None, ScalarType? out_dtype=None) -> Tensor
dispatch:
CUDA: _sparse_semi_structured_linear
- func: _sparse_semi_structured_mm(Tensor mat1, Tensor mat1_meta, Tensor mat2, *, ScalarType? out_dtype=None) -> Tensor
dispatch:
CUDA: _sparse_semi_structured_mm
- func: _sparse_semi_structured_addmm(Tensor input, Tensor mat1, Tensor mat1_meta, Tensor mat2, *, Scalar alpha=1, Scalar beta=1, ScalarType? out_dtype=None) -> Tensor
dispatch:
CUDA: _sparse_semi_structured_addmm
- func: _mixed_dtypes_linear(Tensor input, Tensor weight, Tensor scale, *, Tensor? bias=None, str? activation=None) -> Tensor
dispatch:
CUDA: _mixed_dtypes_linear

View File

@ -603,6 +603,10 @@ Tensor _sparse_semi_structured_linear(
const Tensor& meta, const c10::optional<Tensor>& bias_opt,
const c10::optional<c10::string_view> activation_opt,
const c10::optional<c10::ScalarType> out_dtype_opt) {
TORCH_WARN_ONCE("_sparse_semi_structured_linear is deprecated and will be "
"removed in a future PyTorch release. Please use "
"_sparse_semi_structured_mm/_sparse_semi_structured_addmm "
"instead.");
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
AT_ERROR("_sparse_semi_structured_linear: CUTLASS not supported");
return Tensor{};
@ -893,177 +897,3 @@ Tensor _sparse_semi_structured_linear(
}
} // namespace at::native
// Following is just for testing purposes.
namespace at::native {
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
#else
// Copied from tools/util/include/host_reorder.h, from CUTLASS source
// tree. This is for simplicity - namely, this file is not under
// include/cutlass in this tree, as other CUTLASS include files
// needed, so it would require changing PyTorch CMake configuration;
// furthermore, including this file produces build errors in PyTorch
// at the moment.
template <typename Element, typename LayoutDest, typename LayoutSrc>
static void reorder_meta(cutlass::TensorRef<Element, LayoutDest> dest,
cutlass::TensorRef<Element, LayoutSrc> src,
const int problem_size_m, const int problem_size_k) {
for (int m = 0; m < problem_size_m; m++) {
for (int k = 0; k < problem_size_k; k++) {
// First reorder the rows.
int group = (sizeof(Element) == 2) ? 32 : 16;
int interweave = (sizeof(Element) == 2) ? 4 : 2;
int dest_row = m / group * group + (m % 8) * interweave + (m % group) / 8;
int dest_col = k;
// Next swizzle the 2x2 blocks from Z to N.
if (((dest_row % 2) == 0) && ((dest_col % 2) == 1)) {
++dest_row;
--dest_col;
} else if (((dest_row % 2) == 1) && ((dest_col % 2) == 0)) {
--dest_row;
++dest_col;
}
dest.at({dest_row, dest_col}) = src.at({m, k});
}
}
}
#endif
std::tuple<Tensor, Tensor>
_to_sparse_semi_structured(const Tensor& dense) {
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
AT_ERROR("_to_sparse_semi_structured: CUTLASS not supported");
return std::make_tuple(Tensor{}, Tensor{});
#else
// Check dimensions of the dense matrix.
TORCH_CHECK(dense.dim() == 2,
"_to_sparse_semi_structured: Expected dense argument to be 2D "
"tensor, got ", dense.dim(), " dims");
// Determine PyTorch datatype for the metadata matrix.
auto meta_dtype = at::kChar;
auto ksparse = 0;
auto dense_elems_per_meta_elem = 0;
if (dense.dtype() == at::kChar) {
meta_dtype = at::kInt;
ksparse = 4;
dense_elems_per_meta_elem = 32;
} else if (dense.dtype() == at::kHalf || dense.dtype() == at::kBFloat16) {
meta_dtype = at::kShort;
ksparse = 4;
dense_elems_per_meta_elem = 16;
} else if (dense.dtype() == at::kFloat) {
meta_dtype = at::kShort;
ksparse = 2;
dense_elems_per_meta_elem = 8;
} else {
AT_ERROR("_to_sparse_semi_structured: Invalid dense argument datatype ",
dense.dtype(), " encountered");
}
const auto dense_nrows = dense.size(0);
const auto dense_ncols = dense.size(1);
if (dense_nrows % (meta_dtype == at::kShort ? 32 : 16) != 0) {
AT_ERROR("_to_sparse_semi_structured: Number of rows of dense matrix must "
"be divisible by ", (meta_dtype == at::kShort ? 32 : 16),
", but it is ", dense_nrows);
}
if (dense_ncols % dense_elems_per_meta_elem != 0) {
AT_ERROR("_to_sparse_semi_structured: Number of columns of dense matrix "
"must be divisible by ", dense_elems_per_meta_elem, ", but it is ",
dense_ncols);
}
const auto dense_cpu = dense.to("cpu");
const auto mask_cpu = dense_cpu != at::zeros({1}, dense_cpu.options());
const auto sparse_cpu =
dense_cpu.masked_select(mask_cpu).view({dense_nrows, dense_ncols / 2});
const auto meta_nrows = dense_nrows;
const auto meta_ncols = dense_ncols / dense_elems_per_meta_elem;
auto meta_cpu = dense_cpu.new_empty({meta_nrows, meta_ncols},
at::TensorOptions().dtype(meta_dtype));
auto* mask_cpu_ptr = mask_cpu.data_ptr<bool>();
for (auto i = 0; i < meta_nrows; ++i) {
for (auto j = 0; j < meta_ncols; ++j) {
uint64_t meta_val = 0;
for (auto k = 0; k < dense_elems_per_meta_elem / ksparse; ++k, mask_cpu_ptr += ksparse) {
const auto mask_elems =
(ksparse == 4) ? std::make_tuple(mask_cpu_ptr[0], mask_cpu_ptr[1],
mask_cpu_ptr[2], mask_cpu_ptr[3])
: std::make_tuple(mask_cpu_ptr[0], mask_cpu_ptr[0],
mask_cpu_ptr[1], mask_cpu_ptr[1]);
auto meta_quadruple = 0;
if (mask_elems == std::make_tuple(1, 1, 0, 0)) {
meta_quadruple = 4; // 0100
} else if (mask_elems == std::make_tuple(1, 0, 1, 0)) {
meta_quadruple = 8; // 1000
} else if (mask_elems == std::make_tuple(0, 1, 1, 0)) {
meta_quadruple = 9; // 1001
} else if (mask_elems == std::make_tuple(1, 0, 0, 1)) {
meta_quadruple = 12; // 1100
} else if (mask_elems == std::make_tuple(0, 1, 0, 1)) {
meta_quadruple = 13; // 1101
} else if (mask_elems == std::make_tuple(0, 0, 1, 1)) {
meta_quadruple = 14; // 1110
} else {
AT_ERROR("_to_sparse_semi_structured: dense argument does not match ",
(dense.dtype() != at::kFloat) ? "2:4" : "1:2",
"sparsity pattern");
}
meta_val = meta_val | (meta_quadruple << (4 * k));
}
const auto idx = i * meta_ncols + j;
if (meta_dtype == at::kShort) {
using MetaElement = int16_t;
const auto meta_cpu_ptr = meta_cpu.data_ptr<MetaElement>();
meta_cpu_ptr[idx] = (MetaElement)meta_val;
} else if (meta_dtype == at::kInt) {
using MetaElement = int32_t;
const auto meta_cpu_ptr = meta_cpu.data_ptr<MetaElement>();
meta_cpu_ptr[idx] = (MetaElement)meta_val;
}
}
}
auto meta_reordered_cpu = meta_cpu.new_empty({meta_nrows, meta_ncols});
using MetaLayout = cutlass::layout::RowMajor;
using MetaReorderedLayout = cutlass::layout::ColumnMajorInterleaved<2>;
if (meta_dtype == at::kShort) {
using MetaElement = int16_t;
auto meta_cpu_ref =
cutlass::TensorRef<MetaElement, MetaLayout>(
meta_cpu.data_ptr<MetaElement>(),
MetaLayout::packed({meta_nrows, meta_ncols}));
auto meta_reordered_cpu_ref =
cutlass::TensorRef<MetaElement, MetaReorderedLayout>(
meta_reordered_cpu.data_ptr<MetaElement>(),
MetaReorderedLayout::packed({meta_nrows, meta_ncols}));
reorder_meta(meta_reordered_cpu_ref, meta_cpu_ref, meta_nrows, meta_ncols);
} else if (meta_dtype == at::kInt) {
using MetaElement = int32_t;
auto meta_cpu_ref =
cutlass::TensorRef<MetaElement, MetaLayout>(
meta_cpu.data_ptr<MetaElement>(),
MetaLayout::packed({meta_nrows, meta_ncols}));
auto meta_reordered_cpu_ref =
cutlass::TensorRef<MetaElement, MetaReorderedLayout>(
meta_reordered_cpu.data_ptr<MetaElement>(),
MetaReorderedLayout::packed({meta_nrows, meta_ncols}));
reorder_meta(meta_reordered_cpu_ref, meta_cpu_ref, meta_nrows, meta_ncols);
}
return std::make_tuple(sparse_cpu.to(dense.device()),
meta_reordered_cpu.to(dense.device()));
#endif
}
} // namespace at::native

View File

@ -0,0 +1,979 @@
#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
#include <ATen/cuda/CUDAUtils.h>
#include <ATen/Dispatch.h>
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
#else
#include <cuda_runtime.h>
#include <cutlass/cutlass.h>
#include <cutlass/layout/layout.h>
#include <cutlass/tensor_ref.h>
#include <cutlass/gemm/device/gemm_sparse_with_visitor.h>
#include <cutlass/epilogue/threadblock/fusion/visitors.hpp>
#endif
#include <type_traits>
#include <tuple>
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
#else
#define CUTLASS_STATUS_CHECK(status) \
{ \
TORCH_CHECK(status == cutlass::Status::kSuccess, \
__func__, " : CUTLASS error: ", \
cutlassGetStatusString(status)); \
}
#endif
namespace at::native {
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
#else
// Wrapper function for CUTLASS sparse GEMM implementation, used
// solely to simplify dispatching from
// sparse_semi_structured_mad_op() function below.
template <
typename ElementInputA,
typename ElementInputB,
typename ElementOutput,
typename ElementAccumulator,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
typename LayoutInputA,
typename LayoutInputB,
bool use_tensor_c>
void spgemm_cutlass(
const Tensor& tensor_a, const at::IntArrayRef::value_type& tensor_a_stride,
const Tensor& tensor_b, const at::IntArrayRef::value_type& tensor_b_stride,
const Tensor& tensor_c, const Tensor& tensor_e, const Scalar& alpha,
const Scalar& beta, Tensor& tensor_d) {
// Fix CUTLASS sparse GEMM template arguments that are not
// provided as template argument of this function, and create an
// alias for particular instantiation of this template.
using LayoutOutput = cutlass::layout::RowMajor; // Result of the operation will be provided in row-major format.
using MMAOp = cutlass::arch::OpClassTensorOp; // Tensor cores are to be used for maximum performance.
using SmArch = cutlass::arch::Sm80; // Only CC 8.x devices are supported at the moment.
using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This choice provides good performance across wide range of operand sizes.
constexpr int NumStages = 3; // This choice provides good performance across wide range of operand sizes.
using Operator = cutlass::arch::OpMultiplyAdd;
constexpr int NumEVTEpilogueStages = 1;
constexpr int AlignmentInputA = 128 / cutlass::sizeof_bits<ElementInputA>::value;
constexpr int AlignmentInputB = 128 / cutlass::sizeof_bits<ElementInputB>::value;
constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<ElementOutput>::value;
using ElementComputeEpilogue = ElementAccumulator; // Typically slightly slower, but more precise than if ElementOutput used.
constexpr int AlignmentComputeEpilogue = 128 / cutlass::sizeof_bits<ElementComputeEpilogue>::value;
using ElementC = ElementOutput;
using LayoutC = LayoutOutput;
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
using TensorCTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
ThreadblockShape,
WarpShape,
ElementC,
AlignmentC,
NumEVTEpilogueStages>;
using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout<
ThreadblockShape,
WarpShape,
ElementOutput,
AlignmentOutput,
NumEVTEpilogueStages>;
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
using Alpha =
cutlass::epilogue::threadblock::VisitorScalarBroadcast<ElementComputeEpilogue>;
using AlphaArguments = typename Alpha::Arguments;
using ApplyAlpha = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, ElementComputeEpilogue, ElementComputeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTApplyAlpha = cutlass::epilogue::threadblock::Sm80EVT<
ApplyAlpha,
Alpha,
Accum>;
using Beta =
cutlass::epilogue::threadblock::VisitorScalarBroadcast<ElementComputeEpilogue>;
using BetaArguments = typename Beta::Arguments;
using TensorCScalar =
cutlass::epilogue::threadblock::VisitorScalarBroadcast<ElementC>;
using TensorCTensor =
cutlass::epilogue::threadblock::VisitorColBroadcast<
TensorCTileThreadMap,
ElementC,
cute::Stride<cute::_1, cute::_0, int64_t>>;
using TensorC = std::conditional_t<use_tensor_c, TensorCTensor, TensorCScalar>;
using TensorCArguments = typename TensorC::Arguments;
using ApplyBeta = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, ElementComputeEpilogue, ElementComputeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTApplyBeta = cutlass::epilogue::threadblock::Sm80EVT<
ApplyBeta,
Beta,
TensorC>;
using ApplySum = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::plus, ElementComputeEpilogue, ElementComputeEpilogue,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTApplySum = cutlass::epilogue::threadblock::Sm80EVT<
ApplySum,
EVTApplyAlpha,
EVTApplyBeta>;
using Output = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, ElementOutput, cutlass::FloatRoundStyle::round_to_nearest,
cute::Stride<int64_t, cute::_1, int64_t>>;
using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT<
Output,
EVTApplySum>;
using Gemm = cutlass::gemm::device::SparseGemmWithVisitor<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementC,
LayoutC,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EVTOutput,
SwizzleThreadBlock,
NumStages,
AlignmentInputA,
AlignmentInputB,
Operator,
NumEVTEpilogueStages>;
// Datatype and layout of metadata matrix are inferred from sparse
// GEMM template.
using ElementInputE = typename Gemm::ElementE;
using LayoutInputE = cutlass::layout::RowMajor;
using ReorderedLayoutInputE = typename Gemm::LayoutE;
static_assert(
std::is_same<ReorderedLayoutInputE,
cutlass::layout::ColumnMajorInterleaved<2>>::value,
"Matrix layout used by CUTLASS for reordered metadata for sparse GEMM "
"change, thus code doing conversions from/to dense matrix has to be "
"updated.");
constexpr auto kSparse = Gemm::kSparse;
constexpr int kElementsPerElementE = Gemm::kElementsPerElementE;
// Operand sizes.
const int length_m = tensor_a.size(0);
const int length_k = tensor_b.size(0);
const int length_n = tensor_b.size(1);
const auto tensor_e_ncols = length_k / kSparse / kElementsPerElementE;
// Determine PyTorch datatype for the metadata matrix.
auto tensor_e_dtype = at::kChar;
switch (sizeof(ElementInputE)) {
case 2:
tensor_e_dtype = at::kShort;
break;
case 4:
tensor_e_dtype = at::kInt;
break;
default:
AT_ERROR(__func__, ": invalid size of meta tensor datatype "
"encountered");
}
TORCH_CHECK(tensor_e.dtype() == tensor_e_dtype,
__func__, " : Expected meta datatype ", tensor_e_dtype,
", but got ", tensor_e.dtype());
// Prepare arguments for CUTLASS sparse GEMM kernel.
cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k);
LayoutInputA layout_a(tensor_a_stride);
LayoutInputB layout_b(tensor_b_stride);
auto tensor_a_device_ref =
cutlass::TensorRef<ElementInputA, LayoutInputA>(
(ElementInputA*)tensor_a.data_ptr(), layout_a);
auto tensor_b_device_ref =
cutlass::TensorRef<ElementInputB, LayoutInputB>(
(ElementInputB*)tensor_b.data_ptr(), layout_b);
auto tensor_e_reordered_device_ref =
cutlass::TensorRef<ElementInputE, ReorderedLayoutInputE>(
(ElementInputE*)tensor_e.data_ptr(),
ReorderedLayoutInputE::packed({length_m, tensor_e_ncols}));
AlphaArguments alpha_arguments{
[&]() -> AlphaArguments {
if constexpr (std::is_same<ElementComputeEpilogue, cutlass::half_t>::value ||
std::is_same<ElementComputeEpilogue, cutlass::bfloat16_t>::value) {
return {ElementComputeEpilogue{alpha.to<float>()}};
} else {
return {alpha.to<ElementComputeEpilogue>()};
}
}()
};
BetaArguments beta_arguments{
[&]() -> BetaArguments {
if constexpr (std::is_same<ElementComputeEpilogue, cutlass::half_t>::value ||
std::is_same<ElementComputeEpilogue, cutlass::bfloat16_t>::value) {
return {ElementComputeEpilogue{beta.to<float>()}};
} else {
return {beta.to<ElementComputeEpilogue>()};
}
}()
};
TensorCArguments tensor_c_arguments{
[&]() -> TensorCArguments {
if constexpr (use_tensor_c) {
return {(ElementC*)tensor_c.data_ptr(),
ElementC(0),
{cute::_1{}, cute::_0{}, problem_size.m()}};
} else {
return {ElementC(0)};
}
}()
};
typename Output::Arguments output_arguments{
(ElementOutput*)tensor_d.data_ptr(),
{problem_size.n(), cute::_1{}, problem_size.mn().product()}
};
typename EVTOutput::Arguments callback_arguments{
{
{
alpha_arguments, // Alpha
{}, // Accum
{} // ApplyAlpha
}, // EVTApplyAlpha
{
beta_arguments, // Beta
tensor_c_arguments, // TensorC
{} // ApplyBeta
}, // EVTApplyBeta
{} // ApplySum
}, // EVTApplySum
output_arguments // Output
}; // EVTOutput
// Create a tuple of CUTLASS sparse GEMM kernel arguments.
typename Gemm::Arguments arguments{
problem_size,
tensor_a_device_ref,
tensor_b_device_ref,
tensor_e_reordered_device_ref,
callback_arguments};
cutlass::Status status;
// Create CUTLASS sparse GEMM kernel object.
Gemm gemm_op;
// Verify that sparse GEMM operation with given arguments can be
// performed by CUTLASS.
status = gemm_op.can_implement(arguments);
CUTLASS_STATUS_CHECK(status);
// Allocate workspace for CUTLASS sparse GEMM kernel.
const auto workspace_size = Gemm::get_workspace_size(arguments);
auto workspace = tensor_a.new_empty({(int64_t)workspace_size},
at::TensorOptions().dtype(at::kByte));
// Initialize CUTLASS sparse GEMM object.
status = gemm_op.initialize(arguments, workspace.data_ptr(),
at::cuda::getCurrentCUDAStream());
CUTLASS_STATUS_CHECK(status);
// Perform sparse GEMM operation.
status = gemm_op.run(at::cuda::getCurrentCUDAStream());
CUTLASS_STATUS_CHECK(status);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
// Dispatch according to the input tensors layouts combination.
template <
typename ElementInputA,
typename ElementInputB,
typename ElementOutput,
typename ElementAccumulator,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
bool EnableRowMajorRowMajorLayouts,
bool EnableRowMajorColumnMajorLayouts,
bool EnableColumnMajorRowMajorLayouts,
bool EnableColumnMajorColumnMajorLayouts,
bool use_tensor_c>
void spgemm_cutlass_dispatch_layouts(
const Tensor& tensor_a, const Tensor& tensor_b, const Tensor& tensor_c,
const Tensor& tensor_e, const Scalar& alpha, const Scalar& beta,
Tensor& tensor_d) {
// Determine layouts (row-major or column-major) of input tensors.
const auto strides_a = tensor_a.strides();
auto tensor_a_row_major = strides_a[1] == 1;
auto tensor_a_stride = tensor_a_row_major ? strides_a[0] : strides_a[1];
const auto strides_b = tensor_b.strides();
auto tensor_b_row_major = strides_b[1] == 1;
auto tensor_b_stride = tensor_b_row_major ? strides_b[0] : strides_b[1];
// Perform dispatching.
if constexpr (EnableRowMajorRowMajorLayouts) {
if (tensor_a_row_major && tensor_b_row_major) {
spgemm_cutlass<
ElementInputA,
ElementInputB,
ElementOutput,
ElementAccumulator,
ThreadblockShape,
WarpShape,
InstructionShape,
cutlass::layout::RowMajor,
cutlass::layout::RowMajor,
use_tensor_c>(
tensor_a,
tensor_a_stride,
tensor_b,
tensor_b_stride,
tensor_c,
tensor_e,
alpha,
beta,
tensor_d);
return;
}
}
if constexpr (EnableRowMajorColumnMajorLayouts) {
if (tensor_a_row_major && !tensor_b_row_major) {
spgemm_cutlass<
ElementInputA,
ElementInputB,
ElementOutput,
ElementAccumulator,
ThreadblockShape,
WarpShape,
InstructionShape,
cutlass::layout::RowMajor,
cutlass::layout::ColumnMajor,
use_tensor_c>(
tensor_a,
tensor_a_stride,
tensor_b,
tensor_b_stride,
tensor_c,
tensor_e,
alpha,
beta,
tensor_d);
return;
}
}
if constexpr (EnableColumnMajorRowMajorLayouts) {
if (!tensor_a_row_major && tensor_b_row_major) {
spgemm_cutlass<
ElementInputA,
ElementInputB,
ElementOutput,
ElementAccumulator,
ThreadblockShape,
WarpShape,
InstructionShape,
cutlass::layout::ColumnMajor,
cutlass::layout::RowMajor,
use_tensor_c>(
tensor_a,
tensor_a_stride,
tensor_b,
tensor_b_stride,
tensor_c,
tensor_e,
alpha,
beta,
tensor_d);
return;
}
}
if constexpr (EnableColumnMajorColumnMajorLayouts) {
if (!tensor_a_row_major && !tensor_b_row_major) {
spgemm_cutlass<
ElementInputA,
ElementInputB,
ElementOutput,
ElementAccumulator,
ThreadblockShape,
WarpShape,
InstructionShape,
cutlass::layout::ColumnMajor,
cutlass::layout::ColumnMajor,
use_tensor_c>(
tensor_a,
tensor_a_stride,
tensor_b,
tensor_b_stride,
tensor_c,
tensor_e,
alpha,
beta,
tensor_d);
return;
}
}
AT_ERROR(__func__, "_dispatch_layouts: Combination of ",
tensor_a_row_major ? "row-major" : "column_major", " and ",
tensor_b_row_major ? "row-major" : "column_major",
" layouts for input tensors is not supported");
}
// Dispatch according to the tensor_c tensor being provided or not.
template <
typename ElementInputA,
typename ElementInputB,
typename ElementOutput,
typename ElementAccumulator,
typename ThreadblockShape,
typename WarpShape,
typename InstructionShape,
bool EnableRowMajorRowMajorLayouts,
bool EnableRowMajorColumnMajorLayouts,
bool EnableColumnMajorRowMajorLayouts,
bool EnableColumnMajorColumnMajorLayouts>
void spgemm_cutlass_dispatch_layouts_tensor_c(
const Tensor& tensor_a, const Tensor& tensor_b, const Tensor& tensor_c,
const Tensor& tensor_e, const Scalar& alpha, const Scalar& beta,
Tensor& tensor_d) {
if (tensor_c.numel() > 0) {
spgemm_cutlass_dispatch_layouts<
ElementInputA,
ElementInputB,
ElementOutput,
ElementAccumulator,
ThreadblockShape,
WarpShape,
InstructionShape,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
EnableColumnMajorColumnMajorLayouts,
true>(
tensor_a,
tensor_b,
tensor_c,
tensor_e,
alpha,
beta,
tensor_d);
} else {
spgemm_cutlass_dispatch_layouts<
ElementInputA,
ElementInputB,
ElementOutput,
ElementAccumulator,
ThreadblockShape,
WarpShape,
InstructionShape,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
EnableColumnMajorColumnMajorLayouts,
false>(
tensor_a,
tensor_b,
tensor_c,
tensor_e,
alpha,
beta,
tensor_d);
}
}
#endif
// Perform multiply-add operation, using corresponding CUTLASS
// sparse GEMM kernel, to given arguments:
// result = alpha * mat1 @ mat2 + beta * input
// The "mat2" tensor is a dense tensor, while the "mat1" tensor is a
// sparse semi-structured matrix. The "input" tensor is optional; if
// provided, it should be a vector, with the number of elements equal
// to the number of rows of "mat1" matrix. It is assumed that "mat1"
// and "mat2" are 2D tensors, supplied either in row-major or
// column-major layouts (different layouts between these two tensors
// are OK, but not all combinations of formats are supported for some
// datatypes of these matrices). The "mat1_meta" argument contains
// sparse semi-strucutred metadata.
//
// There exists numerous limitations of CUTLASS sparse GEMM kernel,
// with regards to sizes and alignments of input tensors, their
// layouts and datatypes, and so on; this is the reason for large
// number of checks throughout the code.
//
// TODO: The "input" tensor has to be a vector, such that it could be
// broadcasted to columns of mat1 * mat2. The case of broadcasting to
// rows of mat1 * mat2 could be also supported, if "input" tensor is a
// vector of corresponding length; and same for the case when "input"
// tensor is a matrix of same size as mat1 * mat2 product. If these
// updates made here, then remember to update corresponding bits in
// the Inductor code that are handling meta registrations and
// lowerings of aten._sparse_semi_structured_mm and
// aten._sparse_semi_structured_addmm operators.
Tensor sparse_semi_structured_mad_op(
const Tensor& mat1, const Tensor& mat1_meta, const Tensor& mat2,
const c10::optional<Tensor>& input_opt, const Scalar& alpha,
const Scalar& beta, const c10::optional<c10::ScalarType> out_dtype_opt) {
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
AT_ERROR(__func__, " : CUTLASS not supported");
return Tensor{};
#else
// No need to check that all tensors are on CUDA device, as this
// is provided by dispatch.
const auto& input = input_opt.value_or(Tensor{});
const auto out_dtype = out_dtype_opt.value_or(mat2.scalar_type());
// For now, only CC 8.x devices are supported.
const auto dprops = at::cuda::getCurrentDeviceProperties();
const auto is_sm8x = dprops->major == 8;
TORCH_CHECK(is_sm8x,
__func__, " : Supported only on GPUs with compute capability "
"8.x");
// Validate datatypes of input tensors.
TORCH_CHECK(mat2.dtype() == at::kChar ||
mat2.dtype() == at::kHalf ||
mat2.dtype() == at::kBFloat16 ||
mat2.dtype() == at::kFloat,
__func__, " : The mat2 datatype ", mat2.dtype(),
" is not supported");
TORCH_CHECK(mat1.dtype() == mat2.dtype(),
__func__, " : Expected mat1 datatype ", mat2.dtype(),
", but got ", mat1.dtype());
if (input.numel() != 0) {
TORCH_CHECK(input.dtype() == out_dtype,
__func__, " : Expected input datatype ", out_dtype,
", but got ", input.dtype());
}
// Validate layouts of input tensors.
TORCH_CHECK(mat1.layout() == Layout::Strided,
__func__, " : Expected mat1 argument to be strided, but got "
"layout ", mat1.layout());
TORCH_CHECK(mat1.dim() == 2,
__func__, " : Expected mat1 argument to be 2D tensor, got ",
mat1.dim(), " dims");
const auto strides_a = mat1.strides();
TORCH_CHECK(strides_a[0] == 1 || strides_a[1] == 1,
__func__, " : Invalid strides for mat1 argument: row stride = ",
strides_a[0], ", column stride = ", strides_a[1]);
TORCH_CHECK(mat2.layout() == Layout::Strided,
__func__, " : Expected mat2 argument to be "
"strided, but got layout ", mat2.layout());
TORCH_CHECK(mat2.dim() == 2,
__func__, " : Expected mat2 argument to be 2D tensor, got ",
mat2.dim(), " dims");
const auto strides_b = mat2.strides();
TORCH_CHECK(strides_b[0] == 1 || strides_b[1] == 1,
__func__, " : Invalid strides for mat2 argument: row stride = ",
strides_b[0], ", column stride = ", strides_b[1]);
if (input.numel() != 0) {
TORCH_CHECK(input.layout() == Layout::Strided,
__func__, " : Expected input argument to be strided, but "
"got layout ", input.layout());
TORCH_CHECK(input.dim() == 1,
__func__, " : Expected input argument to be 1D tensor, "
"got ", input.dim(), " dims");
}
// Validate sizes of input tensors.
TORCH_CHECK(mat1.size(1) == mat2.size(0) / 2,
__func__, " : Expected mat1 argument to have ",
mat2.size(0) / 2, " columns, but got ", mat1.size(1));
if (input.numel() != 0) {
TORCH_CHECK(input.size(0) == mat1.size(0),
__func__, " : Expected input argument to have ",
mat1.size(0), " elements, but got ", input.size(0));
}
// Introduce alias names for arguments, according to the CUTLASS
// naming conventions.
const auto& tensor_a = mat1;
const auto& tensor_b = mat2;
const auto& tensor_c = input;
const auto& tensor_e = mat1_meta;
// Create output tensor.
Tensor tensor_d =
tensor_b.new_empty({tensor_a.size(0), tensor_b.size(1)},
at::TensorOptions().dtype(out_dtype));
// Call wrapper function for CUTLASS sparse GEMM, dispatching on
// the input datatype, and then on input tensors layouts.
// According to the input tensors datatypes and layouts,
// corresponding template arguments are supplied for instantiating
// the wrapper function. The tile sizes template arguments are
// selected according to the CUTLASS profiler results, for number
// of runs.
AT_DISPATCH_SWITCH(
tensor_a.scalar_type(),
"sparse_semi_structured_mad_op",
AT_DISPATCH_CASE(
at::ScalarType::Char,
[&]() {
using ElementInputA = int8_t;
using ElementInputB = int8_t;
using ElementAccumulator = int32_t;
using ThreadblockShape =
cutlass::gemm::GemmShape<128, 128, 128>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>;
const auto EnableRowMajorRowMajorLayouts = false;
const auto EnableRowMajorColumnMajorLayouts = true;
const auto EnableColumnMajorRowMajorLayouts = false;
const auto EnableColumnMajorColumnMajorLayouts = false;
if (out_dtype == at::kInt) {
using ElementOutput = int32_t;
spgemm_cutlass_dispatch_layouts_tensor_c<
ElementInputA,
ElementInputB,
ElementOutput,
ElementAccumulator,
ThreadblockShape,
WarpShape,
InstructionShape,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
EnableColumnMajorColumnMajorLayouts>(
tensor_a,
tensor_b,
tensor_c,
tensor_e,
alpha,
beta,
tensor_d);
} else if (out_dtype == at::kChar) {
using ElementOutput = int8_t;
spgemm_cutlass_dispatch_layouts_tensor_c<
ElementInputA,
ElementInputB,
ElementOutput,
ElementAccumulator,
ThreadblockShape,
WarpShape,
InstructionShape,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
EnableColumnMajorColumnMajorLayouts>(
tensor_a,
tensor_b,
tensor_c,
tensor_e,
alpha,
beta,
tensor_d);
}
})
AT_DISPATCH_CASE(
at::ScalarType::Half,
[&]() {
using ElementInputA = cutlass::half_t;
using ElementInputB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using ElementAccumulator = float;
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
const auto EnableRowMajorRowMajorLayouts = true;
const auto EnableRowMajorColumnMajorLayouts = true;
const auto EnableColumnMajorRowMajorLayouts = true;
const auto EnableColumnMajorColumnMajorLayouts = true;
spgemm_cutlass_dispatch_layouts_tensor_c<
ElementInputA,
ElementInputB,
ElementOutput,
ElementAccumulator,
ThreadblockShape,
WarpShape,
InstructionShape,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
EnableColumnMajorColumnMajorLayouts>(
tensor_a,
tensor_b,
tensor_c,
tensor_e,
alpha,
beta,
tensor_d);
})
AT_DISPATCH_CASE(
at::ScalarType::BFloat16,
[&]() {
using ElementInputA = cutlass::bfloat16_t;
using ElementInputB = cutlass::bfloat16_t;
using ElementOutput = cutlass::bfloat16_t;
using ElementAccumulator = float;
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>;
const auto EnableRowMajorRowMajorLayouts = true;
const auto EnableRowMajorColumnMajorLayouts = true;
const auto EnableColumnMajorRowMajorLayouts = true;
const auto EnableColumnMajorColumnMajorLayouts = true;
spgemm_cutlass_dispatch_layouts_tensor_c<
ElementInputA,
ElementInputB,
ElementOutput,
ElementAccumulator,
ThreadblockShape,
WarpShape,
InstructionShape,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
EnableColumnMajorColumnMajorLayouts>(
tensor_a,
tensor_b,
tensor_c,
tensor_e,
alpha,
beta,
tensor_d);
})
AT_DISPATCH_CASE(
at::ScalarType::Float,
[&]() {
using ElementInputA = float;
using ElementInputB = float;
using ElementOutput = float;
using ElementAccumulator = float;
using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>;
using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
const auto EnableRowMajorRowMajorLayouts = true;
const auto EnableRowMajorColumnMajorLayouts = true;
const auto EnableColumnMajorRowMajorLayouts = true;
const auto EnableColumnMajorColumnMajorLayouts = true;
spgemm_cutlass_dispatch_layouts_tensor_c<
ElementInputA,
ElementInputB,
ElementOutput,
ElementAccumulator,
ThreadblockShape,
WarpShape,
InstructionShape,
EnableRowMajorRowMajorLayouts,
EnableRowMajorColumnMajorLayouts,
EnableColumnMajorRowMajorLayouts,
EnableColumnMajorColumnMajorLayouts>(
tensor_a,
tensor_b,
tensor_c,
tensor_e,
alpha,
beta,
tensor_d);
}));
return tensor_d;
#endif
}
// Implementation of aten._sparse_semi_structured_mm operator.
Tensor _sparse_semi_structured_mm(
const Tensor& mat1, const Tensor& mat1_meta, const Tensor& mat2,
const c10::optional<c10::ScalarType> out_dtype_opt) {
return sparse_semi_structured_mad_op(mat1, mat1_meta, mat2,
c10::optional<Tensor>(), 1, 0,
out_dtype_opt);
}
// Implementation of aten._sparse_semi_structured_addmm operator.
Tensor _sparse_semi_structured_addmm(
const Tensor& input, const Tensor& mat1, const Tensor& mat1_meta,
const Tensor& mat2, const Scalar& alpha, const Scalar& beta,
const c10::optional<c10::ScalarType> out_dtype_opt) {
return sparse_semi_structured_mad_op(mat1, mat1_meta, mat2, input, alpha,
beta, out_dtype_opt);
}
} // namespace at::native
// Following is just for testing purposes.
namespace at::native {
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
#else
// Copied from tools/util/include/host_reorder.h, from CUTLASS source
// tree. This is for simplicity - namely, this file is not under
// include/cutlass in this tree, as other CUTLASS include files
// needed, so it would require changing PyTorch CMake configuration;
// furthermore, including this file produces build errors in PyTorch
// at the moment.
template <typename Element, typename LayoutDest, typename LayoutSrc>
static void reorder_meta(cutlass::TensorRef<Element, LayoutDest> dest,
cutlass::TensorRef<Element, LayoutSrc> src,
const int problem_size_m, const int problem_size_k) {
for (int m = 0; m < problem_size_m; m++) {
for (int k = 0; k < problem_size_k; k++) {
// First reorder the rows.
int group = (sizeof(Element) == 2) ? 32 : 16;
int interweave = (sizeof(Element) == 2) ? 4 : 2;
int dest_row = m / group * group + (m % 8) * interweave + (m % group) / 8;
int dest_col = k;
// Next swizzle the 2x2 blocks from Z to N.
if (((dest_row % 2) == 0) && ((dest_col % 2) == 1)) {
++dest_row;
--dest_col;
} else if (((dest_row % 2) == 1) && ((dest_col % 2) == 0)) {
--dest_row;
++dest_col;
}
dest.at({dest_row, dest_col}) = src.at({m, k});
}
}
}
#endif
std::tuple<Tensor, Tensor>
_to_sparse_semi_structured(const Tensor& dense) {
#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080)
AT_ERROR(__func__, " : CUTLASS not supported");
return std::make_tuple(Tensor{}, Tensor{});
#else
// Check dimensions of the dense matrix.
TORCH_CHECK(dense.dim() == 2,
__func__, " : Expected dense argument to be 2D tensor, got ",
dense.dim(), " dims");
// Determine PyTorch datatype for the metadata matrix.
auto meta_dtype = at::kChar;
auto ksparse = 0;
auto dense_elems_per_meta_elem = 0;
if (dense.dtype() == at::kChar) {
meta_dtype = at::kInt;
ksparse = 4;
dense_elems_per_meta_elem = 32;
} else if (dense.dtype() == at::kHalf || dense.dtype() == at::kBFloat16) {
meta_dtype = at::kShort;
ksparse = 4;
dense_elems_per_meta_elem = 16;
} else if (dense.dtype() == at::kFloat) {
meta_dtype = at::kShort;
ksparse = 2;
dense_elems_per_meta_elem = 8;
} else {
AT_ERROR("_to_sparse_semi_structured: Invalid dense argument datatype ",
dense.dtype(), " encountered");
}
const auto dense_nrows = dense.size(0);
const auto dense_ncols = dense.size(1);
if (dense_nrows % (meta_dtype == at::kShort ? 32 : 16) != 0) {
AT_ERROR("_to_sparse_semi_structured: Number of rows of dense matrix must "
"be divisible by ", (meta_dtype == at::kShort ? 32 : 16),
", but it is ", dense_nrows);
}
if (dense_ncols % dense_elems_per_meta_elem != 0) {
AT_ERROR("_to_sparse_semi_structured: Number of columns of dense matrix "
"must be divisible by ", dense_elems_per_meta_elem, ", but it is ",
dense_ncols);
}
const auto dense_cpu = dense.to("cpu");
const auto mask_cpu = dense_cpu != at::zeros({1}, dense_cpu.options());
const auto sparse_cpu =
dense_cpu.masked_select(mask_cpu).view({dense_nrows, dense_ncols / 2});
const auto meta_nrows = dense_nrows;
const auto meta_ncols = dense_ncols / dense_elems_per_meta_elem;
auto meta_cpu = dense_cpu.new_empty({meta_nrows, meta_ncols},
at::TensorOptions().dtype(meta_dtype));
auto* mask_cpu_ptr = mask_cpu.data_ptr<bool>();
for (auto i = 0; i < meta_nrows; ++i) {
for (auto j = 0; j < meta_ncols; ++j) {
uint64_t meta_val = 0;
for (auto k = 0; k < dense_elems_per_meta_elem / ksparse; ++k, mask_cpu_ptr += ksparse) {
const auto mask_elems =
(ksparse == 4) ? std::make_tuple(mask_cpu_ptr[0], mask_cpu_ptr[1],
mask_cpu_ptr[2], mask_cpu_ptr[3])
: std::make_tuple(mask_cpu_ptr[0], mask_cpu_ptr[0],
mask_cpu_ptr[1], mask_cpu_ptr[1]);
auto meta_quadruple = 0;
if (mask_elems == std::make_tuple(1, 1, 0, 0)) {
meta_quadruple = 4; // 0100
} else if (mask_elems == std::make_tuple(1, 0, 1, 0)) {
meta_quadruple = 8; // 1000
} else if (mask_elems == std::make_tuple(0, 1, 1, 0)) {
meta_quadruple = 9; // 1001
} else if (mask_elems == std::make_tuple(1, 0, 0, 1)) {
meta_quadruple = 12; // 1100
} else if (mask_elems == std::make_tuple(0, 1, 0, 1)) {
meta_quadruple = 13; // 1101
} else if (mask_elems == std::make_tuple(0, 0, 1, 1)) {
meta_quadruple = 14; // 1110
} else {
AT_ERROR("_to_sparse_semi_structured: dense argument does not match ",
(dense.dtype() != at::kFloat) ? "2:4" : "1:2",
"sparsity pattern");
}
meta_val = meta_val | (meta_quadruple << (4 * k));
}
const auto idx = i * meta_ncols + j;
if (meta_dtype == at::kShort) {
using MetaElement = int16_t;
const auto meta_cpu_ptr = meta_cpu.data_ptr<MetaElement>();
meta_cpu_ptr[idx] = (MetaElement)meta_val;
} else if (meta_dtype == at::kInt) {
using MetaElement = int32_t;
const auto meta_cpu_ptr = meta_cpu.data_ptr<MetaElement>();
meta_cpu_ptr[idx] = (MetaElement)meta_val;
}
}
}
auto meta_reordered_cpu = meta_cpu.new_empty({meta_nrows, meta_ncols});
using MetaLayout = cutlass::layout::RowMajor;
using MetaReorderedLayout = cutlass::layout::ColumnMajorInterleaved<2>;
if (meta_dtype == at::kShort) {
using MetaElement = int16_t;
auto meta_cpu_ref =
cutlass::TensorRef<MetaElement, MetaLayout>(
meta_cpu.data_ptr<MetaElement>(),
MetaLayout::packed({meta_nrows, meta_ncols}));
auto meta_reordered_cpu_ref =
cutlass::TensorRef<MetaElement, MetaReorderedLayout>(
meta_reordered_cpu.data_ptr<MetaElement>(),
MetaReorderedLayout::packed({meta_nrows, meta_ncols}));
reorder_meta(meta_reordered_cpu_ref, meta_cpu_ref, meta_nrows, meta_ncols);
} else if (meta_dtype == at::kInt) {
using MetaElement = int32_t;
auto meta_cpu_ref =
cutlass::TensorRef<MetaElement, MetaLayout>(
meta_cpu.data_ptr<MetaElement>(),
MetaLayout::packed({meta_nrows, meta_ncols}));
auto meta_reordered_cpu_ref =
cutlass::TensorRef<MetaElement, MetaReorderedLayout>(
meta_reordered_cpu.data_ptr<MetaElement>(),
MetaReorderedLayout::packed({meta_nrows, meta_ncols}));
reorder_meta(meta_reordered_cpu_ref, meta_cpu_ref, meta_nrows, meta_ncols);
}
return std::make_tuple(sparse_cpu.to(dense.device()),
meta_reordered_cpu.to(dense.device()));
#endif
}
} // namespace at::native

View File

@ -523,7 +523,9 @@ aten::_sparse_mask_projection
aten::_sparse_mask_projection.out
aten::_sparse_mm_reduce_impl
aten::_sparse_mm_reduce_impl_backward
aten::_sparse_semi_structured_addmm
aten::_sparse_semi_structured_linear
aten::_sparse_semi_structured_mm
aten::_sparse_softmax
aten::_sparse_softmax.out
aten::_sparse_softmax_backward_data

View File

@ -134,7 +134,6 @@ ALLOW_LIST = [
("aten::batch_norm_backward_elemt", datetime.date(2023, 12, 31)),
("aten::sym_constrain_range", datetime.date(2023, 12, 31)),
("aten::_efficient_attention_forward", datetime.date(2024, 1, 15)),
("aten::_sparse_semi_structured_linear", datetime.date(2024, 1, 15)),
("onednn::qconv1d_pointwise", datetime.date(2023, 12, 31)),
("onednn::qconv2d_pointwise", datetime.date(2023, 12, 31)),
("onednn::qconv3d_pointwise", datetime.date(2023, 12, 31)),

View File

@ -157,7 +157,7 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
"""
Test nn.Linear + .contiguous() + nn.ReLU with SparseSemiStructuredTensor + torch.compile
We expect:
(1) The sparse tensor subclass should turn nn.Linear into `aten._structured_sparse_linear` + `aten.contiguous()`
(1) The sparse tensor subclass should turn nn.Linear into `aten._structured_sparse_addmm` + `aten.contiguous()`
(2) Inductor should fuse the .contiguous() call into the relu
"""
@ -207,7 +207,7 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
def test_mlp_contiguous_relu_compile_cutlass(self):
"""
test for CUTLASS meta registrations (_sparse_semi_structured_linear) + torch.compile
test for CUTLASS meta registrations (_sparse_semi_structured_addmm) + torch.compile
"""
for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]:
SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cutlass", dense_input_shape)
@ -258,7 +258,7 @@ class TestSparseSemiStructured(TestCase):
if dtype is torch.int8:
# This should fail
if backend == "cutlass":
with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_dispatch_layouts"):
with self.assertRaisesRegex(RuntimeError, "spgemm_cutlass_dispatch_layouts"):
sparse_result = torch.mm(A_sparse, B)
else:
with self.assertRaisesRegex(RuntimeError,
@ -291,7 +291,7 @@ class TestSparseSemiStructured(TestCase):
# padding with int8 throws an error because transposing B yields a contiguous output
# and row-row 2:4 sparse @ dense with NN is not supported by cuSPARSELt or CUTLASS.
if backend == "cutlass":
with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_dispatch_layouts"):
with self.assertRaisesRegex(RuntimeError, "spgemm_cutlass_dispatch_layouts"):
sparse_result = torch.mm(A_sparse, B.t())
else:
with self.assertRaisesRegex(RuntimeError,
@ -575,6 +575,73 @@ class TestSparseSemiStructured(TestCase):
torch.backends.cuda.matmul.allow_tf32 = orig
@unittest.skipIf(TEST_WITH_ROCM or IS_WINDOWS, "ROCm and Windows doesn't support CUTLASS")
@parametrize("backend", ["cutlass"])
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
def test_sparse_semi_structured_ops_cutlass(self, device, dtype, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
if backend == "cutlass" and IS_WINDOWS:
self.skipTest("CUTLASS not supported on Windows")
def run_test(m, n, k, device, dtype, dtype_out, use_input, rtol, atol):
mat1 = rand_sparse_semi_structured(m, k, dtype, device)
# mat2 transposed as int8 case supports only row-major/column-major combination
mat2 = make_tensor((n, k), dtype=dtype, device=device).t()
input = make_tensor((m,), dtype=dtype_out, device=device) if use_input else None
if use_input:
if dtype.is_floating_point:
alpha = 1.3
beta = -0.7
else:
alpha = 2
beta = -3
dtype_dense = torch.float32
mat1_dense = mat1.to(dtype_dense)
mat2_dense = mat2.to(dtype_dense)
if not use_input:
output0 = torch.mm(mat1_dense, mat2_dense)
else:
input_dense = input.to(dtype_dense)[:, None]
output0 = torch.addmm(input_dense, mat1_dense, mat2_dense, alpha=alpha, beta=beta)
compressed = to_sparse_semi_structured(mat1)
mat1_sparse = compressed.values()
mat1_meta = compressed.indices()
if not use_input:
output1 = torch._sparse_semi_structured_mm(mat1_sparse, mat1_meta, mat2, out_dtype=dtype_out)
else:
output1 = torch._sparse_semi_structured_addmm(
input, mat1_sparse, mat1_meta, mat2, alpha=alpha, beta=beta, out_dtype=dtype_out
)
torch.testing.assert_close(output1.to(dtype_dense), output0, rtol=rtol, atol=atol)
if dtype == torch.float32:
# Inputs are converted to TF32 internally for sparse GEMM,
# so make dense GEMM to do the same for matching results.
orig = torch.backends.cuda.matmul.allow_tf32
torch.backends.cuda.matmul.allow_tf32 = True
dtype_out = {torch.int8: torch.int32, torch.half: torch.half, torch.bfloat16: torch.bfloat16, torch.float32: torch.float32}
rtol, atol = 1e-3, 1e-3
if dtype == torch.bfloat16:
rtol, atol = 5e-3, 5e-3
elif dtype == torch.float32:
rtol, atol = 1e-3, 75e-2
for m, n, k, use_input in \
itertools.product(range(3), range(3), range(3), (False, True)):
m = 2 ** m * 32
n = 2 ** n * 32
k = 2 ** k * 128
run_test(m, n, k, device, dtype, dtype_out[dtype], use_input, rtol, atol)
if dtype == torch.float32:
torch.backends.cuda.matmul.allow_tf32 = orig
@unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch")
@parametrize("backend", ["cutlass"])
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)

View File

@ -1542,7 +1542,9 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch._sparse_csr_prod",
"torch._sparse_csr_sum",
"torch._sparse_log_softmax_backward_data",
"torch._sparse_semi_structured_addmm",
"torch._sparse_semi_structured_linear",
"torch._sparse_semi_structured_mm",
"torch._sparse_softmax_backward_data",
"torch._sparse_sparse_matmul",
"torch._sparse_sum",

View File

@ -431,6 +431,66 @@ def meta_sparse_structured_linear(
return output
@register_meta(aten._sparse_semi_structured_mm)
def meta_sparse_structured_mm(
mat1: Tensor,
mat1_meta: Tensor,
mat2: Tensor,
out_dtype: Optional[torch.dtype] = None,
):
assert len(mat1.shape) == 2
assert len(mat1_meta.shape) == 2
assert len(mat2.shape) == 2
assert mat1.size(1) == mat2.size(0) / 2
output_sizes = [mat1.size(0), mat2.size(1)]
if out_dtype is not None:
assert (
mat2.dtype == torch.int8 and out_dtype == torch.int32
), "out_dtype is only supported for i8i8->i32 linear operator"
output = mat2.new_empty(
output_sizes,
dtype=mat2.dtype if out_dtype is None else out_dtype,
)
return output
@register_meta(aten._sparse_semi_structured_addmm)
def meta_sparse_structured_addmm(
input: Tensor,
mat1: Tensor,
mat1_meta: Tensor,
mat2: Tensor,
*,
alpha=1,
beta=1,
out_dtype: Optional[torch.dtype] = None,
):
assert (
len(input.shape) == 1
), "only input broadcasted to columns of mat1 * mat2 product is supported"
assert len(mat1.shape) == 2
assert len(mat1_meta.shape) == 2
assert len(mat2.shape) == 2
assert input.size(0) == mat1.size(
0
), "only input broadcasted to columns of mat1 * mat2 product is supported"
assert mat1.size(1) == mat2.size(0) / 2
output_sizes = [mat1.size(0), mat2.size(1)]
if out_dtype is not None:
assert (
mat2.dtype == torch.int8 and out_dtype == torch.int32
), "out_dtype is only supported for i8i8->i32 linear operator"
output = mat2.new_empty(
output_sizes,
dtype=mat2.dtype if out_dtype is None else out_dtype,
)
return output
@register_meta(aten._cslt_sparse_mm)
def meta__cslt_sparse_mm(
compressed_A: torch.Tensor,

View File

@ -47,7 +47,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
-`_DTYPE_SHAPE_CONSTRAINTS` - A dictionary holding backend specific dense/sparse min shape constraints
- `def from_dense()` - backend specific compression routines
- `def _mm()` - backend specifc mm op (either torch._cslt_sparse_mm or torch._sparse_semi_structured_linear)
- `def _mm()` - backend specifc mm op (either torch._cslt_sparse_mm or torch._sparse_semi_structured_(mm|addmm))
"""
_DEFAULT_ALG_ID: int = 0
@ -371,11 +371,12 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
"""
This class implements semi-structured sparsity for the CUTLASS backend.
In this implementation, the specified elements and metadata are stored seprately,
in packed and meta respectively.
When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_linear
and sparse_semi_structured_from_dense for conversion to the compressed format.
When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_(mm|addmm) and
sparse_semi_structured_from_dense for conversion to the compressed format.
"""
_DTYPE_SHAPE_CONSTRAINTS = {
@ -436,9 +437,14 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
f"`{cls_name}` matmul: operation is not supported"
)
else:
res = torch._sparse_semi_structured_linear(
B.t(), self.packed, self.meta, bias=bias
).t()
if bias is None:
res = torch._sparse_semi_structured_mm(
self.packed, self.meta, B
)
else:
res = torch._sparse_semi_structured_addmm(
bias, self.packed, self.meta, B
)
return res[: self.shape[0]]