From f5331aade57725b03c36d5cc6c683f6a6bc0692d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aleksandar=20Samard=C5=BEi=C4=87?= Date: Sat, 13 Apr 2024 18:35:02 +0000 Subject: [PATCH] Simplify ATen sparse semi-structured operators based on CUTLASS (#123473) Pull Request resolved: https://github.com/pytorch/pytorch/pull/123473 Approved by: https://github.com/cpuhrsch --- aten/src/ATen/native/native_functions.yaml | 9 + .../sparse/cuda/SparseSemiStructuredLinear.cu | 178 +--- .../sparse/cuda/SparseSemiStructuredOps.cu | 979 ++++++++++++++++++ ...asDecompTest.test_has_decomposition.expect | 2 + .../check_forward_backward_compatibility.py | 1 - test/test_sparse_semi_structured.py | 75 +- torch/_dynamo/trace_rules.py | 2 + torch/_meta_registrations.py | 60 ++ torch/sparse/semi_structured.py | 18 +- 9 files changed, 1139 insertions(+), 185 deletions(-) create mode 100644 aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 02e7c5caa251..6e96a8a6aabc 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -3342,10 +3342,19 @@ dispatch: CUDA: _cslt_sparse_mm_search +# DEPRECATED: Use torch.__sparse_semi_structured_mm/torch._sparse_semi_structured_addmm instead - func: _sparse_semi_structured_linear(Tensor input, Tensor weight, Tensor meta, *, Tensor? bias=None, str? activation=None, ScalarType? out_dtype=None) -> Tensor dispatch: CUDA: _sparse_semi_structured_linear +- func: _sparse_semi_structured_mm(Tensor mat1, Tensor mat1_meta, Tensor mat2, *, ScalarType? out_dtype=None) -> Tensor + dispatch: + CUDA: _sparse_semi_structured_mm + +- func: _sparse_semi_structured_addmm(Tensor input, Tensor mat1, Tensor mat1_meta, Tensor mat2, *, Scalar alpha=1, Scalar beta=1, ScalarType? out_dtype=None) -> Tensor + dispatch: + CUDA: _sparse_semi_structured_addmm + - func: _mixed_dtypes_linear(Tensor input, Tensor weight, Tensor scale, *, Tensor? bias=None, str? activation=None) -> Tensor dispatch: CUDA: _mixed_dtypes_linear diff --git a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu index e997f49f3f43..47ee1568beb1 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu @@ -603,6 +603,10 @@ Tensor _sparse_semi_structured_linear( const Tensor& meta, const c10::optional& bias_opt, const c10::optional activation_opt, const c10::optional out_dtype_opt) { + TORCH_WARN_ONCE("_sparse_semi_structured_linear is deprecated and will be " + "removed in a future PyTorch release. Please use " + "_sparse_semi_structured_mm/_sparse_semi_structured_addmm " + "instead."); #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) AT_ERROR("_sparse_semi_structured_linear: CUTLASS not supported"); return Tensor{}; @@ -893,177 +897,3 @@ Tensor _sparse_semi_structured_linear( } } // namespace at::native - -// Following is just for testing purposes. -namespace at::native { - -#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) -#else -// Copied from tools/util/include/host_reorder.h, from CUTLASS source -// tree. This is for simplicity - namely, this file is not under -// include/cutlass in this tree, as other CUTLASS include files -// needed, so it would require changing PyTorch CMake configuration; -// furthermore, including this file produces build errors in PyTorch -// at the moment. -template -static void reorder_meta(cutlass::TensorRef dest, - cutlass::TensorRef src, - const int problem_size_m, const int problem_size_k) { - for (int m = 0; m < problem_size_m; m++) { - for (int k = 0; k < problem_size_k; k++) { - // First reorder the rows. - int group = (sizeof(Element) == 2) ? 32 : 16; - int interweave = (sizeof(Element) == 2) ? 4 : 2; - - int dest_row = m / group * group + (m % 8) * interweave + (m % group) / 8; - int dest_col = k; - - // Next swizzle the 2x2 blocks from Z to N. - if (((dest_row % 2) == 0) && ((dest_col % 2) == 1)) { - ++dest_row; - --dest_col; - } else if (((dest_row % 2) == 1) && ((dest_col % 2) == 0)) { - --dest_row; - ++dest_col; - } - - dest.at({dest_row, dest_col}) = src.at({m, k}); - } - } -} -#endif - -std::tuple -_to_sparse_semi_structured(const Tensor& dense) { -#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) - AT_ERROR("_to_sparse_semi_structured: CUTLASS not supported"); - return std::make_tuple(Tensor{}, Tensor{}); -#else - // Check dimensions of the dense matrix. - TORCH_CHECK(dense.dim() == 2, - "_to_sparse_semi_structured: Expected dense argument to be 2D " - "tensor, got ", dense.dim(), " dims"); - - // Determine PyTorch datatype for the metadata matrix. - auto meta_dtype = at::kChar; - auto ksparse = 0; - auto dense_elems_per_meta_elem = 0; - if (dense.dtype() == at::kChar) { - meta_dtype = at::kInt; - ksparse = 4; - dense_elems_per_meta_elem = 32; - } else if (dense.dtype() == at::kHalf || dense.dtype() == at::kBFloat16) { - meta_dtype = at::kShort; - ksparse = 4; - dense_elems_per_meta_elem = 16; - } else if (dense.dtype() == at::kFloat) { - meta_dtype = at::kShort; - ksparse = 2; - dense_elems_per_meta_elem = 8; - } else { - AT_ERROR("_to_sparse_semi_structured: Invalid dense argument datatype ", - dense.dtype(), " encountered"); - } - - const auto dense_nrows = dense.size(0); - const auto dense_ncols = dense.size(1); - - if (dense_nrows % (meta_dtype == at::kShort ? 32 : 16) != 0) { - AT_ERROR("_to_sparse_semi_structured: Number of rows of dense matrix must " - "be divisible by ", (meta_dtype == at::kShort ? 32 : 16), - ", but it is ", dense_nrows); - } - if (dense_ncols % dense_elems_per_meta_elem != 0) { - AT_ERROR("_to_sparse_semi_structured: Number of columns of dense matrix " - "must be divisible by ", dense_elems_per_meta_elem, ", but it is ", - dense_ncols); - } - - const auto dense_cpu = dense.to("cpu"); - - const auto mask_cpu = dense_cpu != at::zeros({1}, dense_cpu.options()); - - const auto sparse_cpu = - dense_cpu.masked_select(mask_cpu).view({dense_nrows, dense_ncols / 2}); - - const auto meta_nrows = dense_nrows; - const auto meta_ncols = dense_ncols / dense_elems_per_meta_elem; - auto meta_cpu = dense_cpu.new_empty({meta_nrows, meta_ncols}, - at::TensorOptions().dtype(meta_dtype)); - - auto* mask_cpu_ptr = mask_cpu.data_ptr(); - for (auto i = 0; i < meta_nrows; ++i) { - for (auto j = 0; j < meta_ncols; ++j) { - uint64_t meta_val = 0; - for (auto k = 0; k < dense_elems_per_meta_elem / ksparse; ++k, mask_cpu_ptr += ksparse) { - const auto mask_elems = - (ksparse == 4) ? std::make_tuple(mask_cpu_ptr[0], mask_cpu_ptr[1], - mask_cpu_ptr[2], mask_cpu_ptr[3]) - : std::make_tuple(mask_cpu_ptr[0], mask_cpu_ptr[0], - mask_cpu_ptr[1], mask_cpu_ptr[1]); - auto meta_quadruple = 0; - if (mask_elems == std::make_tuple(1, 1, 0, 0)) { - meta_quadruple = 4; // 0100 - } else if (mask_elems == std::make_tuple(1, 0, 1, 0)) { - meta_quadruple = 8; // 1000 - } else if (mask_elems == std::make_tuple(0, 1, 1, 0)) { - meta_quadruple = 9; // 1001 - } else if (mask_elems == std::make_tuple(1, 0, 0, 1)) { - meta_quadruple = 12; // 1100 - } else if (mask_elems == std::make_tuple(0, 1, 0, 1)) { - meta_quadruple = 13; // 1101 - } else if (mask_elems == std::make_tuple(0, 0, 1, 1)) { - meta_quadruple = 14; // 1110 - } else { - AT_ERROR("_to_sparse_semi_structured: dense argument does not match ", - (dense.dtype() != at::kFloat) ? "2:4" : "1:2", - "sparsity pattern"); - } - meta_val = meta_val | (meta_quadruple << (4 * k)); - } - const auto idx = i * meta_ncols + j; - if (meta_dtype == at::kShort) { - using MetaElement = int16_t; - const auto meta_cpu_ptr = meta_cpu.data_ptr(); - meta_cpu_ptr[idx] = (MetaElement)meta_val; - } else if (meta_dtype == at::kInt) { - using MetaElement = int32_t; - const auto meta_cpu_ptr = meta_cpu.data_ptr(); - meta_cpu_ptr[idx] = (MetaElement)meta_val; - } - } - } - - auto meta_reordered_cpu = meta_cpu.new_empty({meta_nrows, meta_ncols}); - using MetaLayout = cutlass::layout::RowMajor; - using MetaReorderedLayout = cutlass::layout::ColumnMajorInterleaved<2>; - if (meta_dtype == at::kShort) { - using MetaElement = int16_t; - auto meta_cpu_ref = - cutlass::TensorRef( - meta_cpu.data_ptr(), - MetaLayout::packed({meta_nrows, meta_ncols})); - auto meta_reordered_cpu_ref = - cutlass::TensorRef( - meta_reordered_cpu.data_ptr(), - MetaReorderedLayout::packed({meta_nrows, meta_ncols})); - reorder_meta(meta_reordered_cpu_ref, meta_cpu_ref, meta_nrows, meta_ncols); - } else if (meta_dtype == at::kInt) { - using MetaElement = int32_t; - auto meta_cpu_ref = - cutlass::TensorRef( - meta_cpu.data_ptr(), - MetaLayout::packed({meta_nrows, meta_ncols})); - auto meta_reordered_cpu_ref = - cutlass::TensorRef( - meta_reordered_cpu.data_ptr(), - MetaReorderedLayout::packed({meta_nrows, meta_ncols})); - reorder_meta(meta_reordered_cpu_ref, meta_cpu_ref, meta_nrows, meta_ncols); - } - - return std::make_tuple(sparse_cpu.to(dense.device()), - meta_reordered_cpu.to(dense.device())); -#endif -} - -} // namespace at::native diff --git a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu new file mode 100644 index 000000000000..8c05acc66bc9 --- /dev/null +++ b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu @@ -0,0 +1,979 @@ +#include +#include +#include +#include + +#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) +#else +#include +#include +#include +#include +#include +#include +#endif + +#include +#include + +#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) +#else +#define CUTLASS_STATUS_CHECK(status) \ + { \ + TORCH_CHECK(status == cutlass::Status::kSuccess, \ + __func__, " : CUTLASS error: ", \ + cutlassGetStatusString(status)); \ + } +#endif + +namespace at::native { + +#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) +#else +// Wrapper function for CUTLASS sparse GEMM implementation, used +// solely to simplify dispatching from +// sparse_semi_structured_mad_op() function below. +template < + typename ElementInputA, + typename ElementInputB, + typename ElementOutput, + typename ElementAccumulator, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename LayoutInputA, + typename LayoutInputB, + bool use_tensor_c> +void spgemm_cutlass( + const Tensor& tensor_a, const at::IntArrayRef::value_type& tensor_a_stride, + const Tensor& tensor_b, const at::IntArrayRef::value_type& tensor_b_stride, + const Tensor& tensor_c, const Tensor& tensor_e, const Scalar& alpha, + const Scalar& beta, Tensor& tensor_d) { + // Fix CUTLASS sparse GEMM template arguments that are not + // provided as template argument of this function, and create an + // alias for particular instantiation of this template. + using LayoutOutput = cutlass::layout::RowMajor; // Result of the operation will be provided in row-major format. + using MMAOp = cutlass::arch::OpClassTensorOp; // Tensor cores are to be used for maximum performance. + using SmArch = cutlass::arch::Sm80; // Only CC 8.x devices are supported at the moment. + using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This choice provides good performance across wide range of operand sizes. + constexpr int NumStages = 3; // This choice provides good performance across wide range of operand sizes. + using Operator = cutlass::arch::OpMultiplyAdd; + constexpr int NumEVTEpilogueStages = 1; + + constexpr int AlignmentInputA = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentInputB = 128 / cutlass::sizeof_bits::value; + constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; + + using ElementComputeEpilogue = ElementAccumulator; // Typically slightly slower, but more precise than if ElementOutput used. + constexpr int AlignmentComputeEpilogue = 128 / cutlass::sizeof_bits::value; + using ElementC = ElementOutput; + using LayoutC = LayoutOutput; + constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using TensorCTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< + ThreadblockShape, + WarpShape, + ElementC, + AlignmentC, + NumEVTEpilogueStages>; + using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< + ThreadblockShape, + WarpShape, + ElementOutput, + AlignmentOutput, + NumEVTEpilogueStages>; + + using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; + + using Alpha = + cutlass::epilogue::threadblock::VisitorScalarBroadcast; + using AlphaArguments = typename Alpha::Arguments; + + using ApplyAlpha = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementComputeEpilogue, ElementComputeEpilogue, + cutlass::FloatRoundStyle::round_to_nearest>; + using EVTApplyAlpha = cutlass::epilogue::threadblock::Sm80EVT< + ApplyAlpha, + Alpha, + Accum>; + + using Beta = + cutlass::epilogue::threadblock::VisitorScalarBroadcast; + using BetaArguments = typename Beta::Arguments; + + using TensorCScalar = + cutlass::epilogue::threadblock::VisitorScalarBroadcast; + using TensorCTensor = + cutlass::epilogue::threadblock::VisitorColBroadcast< + TensorCTileThreadMap, + ElementC, + cute::Stride>; + using TensorC = std::conditional_t; + using TensorCArguments = typename TensorC::Arguments; + + using ApplyBeta = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementComputeEpilogue, ElementComputeEpilogue, + cutlass::FloatRoundStyle::round_to_nearest>; + using EVTApplyBeta = cutlass::epilogue::threadblock::Sm80EVT< + ApplyBeta, + Beta, + TensorC>; + + using ApplySum = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::plus, ElementComputeEpilogue, ElementComputeEpilogue, + cutlass::FloatRoundStyle::round_to_nearest>; + using EVTApplySum = cutlass::epilogue::threadblock::Sm80EVT< + ApplySum, + EVTApplyAlpha, + EVTApplyBeta>; + + using Output = cutlass::epilogue::threadblock::VisitorAuxStore< + OutputTileThreadMap, ElementOutput, cutlass::FloatRoundStyle::round_to_nearest, + cute::Stride>; + + using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT< + Output, + EVTApplySum>; + + using Gemm = cutlass::gemm::device::SparseGemmWithVisitor< + ElementInputA, + LayoutInputA, + ElementInputB, + LayoutInputB, + ElementC, + LayoutC, + ElementAccumulator, + MMAOp, + SmArch, + ThreadblockShape, + WarpShape, + InstructionShape, + EVTOutput, + SwizzleThreadBlock, + NumStages, + AlignmentInputA, + AlignmentInputB, + Operator, + NumEVTEpilogueStages>; + + // Datatype and layout of metadata matrix are inferred from sparse + // GEMM template. + using ElementInputE = typename Gemm::ElementE; + using LayoutInputE = cutlass::layout::RowMajor; + using ReorderedLayoutInputE = typename Gemm::LayoutE; + static_assert( + std::is_same>::value, + "Matrix layout used by CUTLASS for reordered metadata for sparse GEMM " + "change, thus code doing conversions from/to dense matrix has to be " + "updated."); + + constexpr auto kSparse = Gemm::kSparse; + constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; + + // Operand sizes. + const int length_m = tensor_a.size(0); + const int length_k = tensor_b.size(0); + const int length_n = tensor_b.size(1); + const auto tensor_e_ncols = length_k / kSparse / kElementsPerElementE; + + // Determine PyTorch datatype for the metadata matrix. + auto tensor_e_dtype = at::kChar; + switch (sizeof(ElementInputE)) { + case 2: + tensor_e_dtype = at::kShort; + break; + case 4: + tensor_e_dtype = at::kInt; + break; + default: + AT_ERROR(__func__, ": invalid size of meta tensor datatype " + "encountered"); + } + TORCH_CHECK(tensor_e.dtype() == tensor_e_dtype, + __func__, " : Expected meta datatype ", tensor_e_dtype, + ", but got ", tensor_e.dtype()); + + // Prepare arguments for CUTLASS sparse GEMM kernel. + cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k); + LayoutInputA layout_a(tensor_a_stride); + LayoutInputB layout_b(tensor_b_stride); + auto tensor_a_device_ref = + cutlass::TensorRef( + (ElementInputA*)tensor_a.data_ptr(), layout_a); + auto tensor_b_device_ref = + cutlass::TensorRef( + (ElementInputB*)tensor_b.data_ptr(), layout_b); + auto tensor_e_reordered_device_ref = + cutlass::TensorRef( + (ElementInputE*)tensor_e.data_ptr(), + ReorderedLayoutInputE::packed({length_m, tensor_e_ncols})); + + AlphaArguments alpha_arguments{ + [&]() -> AlphaArguments { + if constexpr (std::is_same::value || + std::is_same::value) { + return {ElementComputeEpilogue{alpha.to()}}; + } else { + return {alpha.to()}; + } + }() + }; + BetaArguments beta_arguments{ + [&]() -> BetaArguments { + if constexpr (std::is_same::value || + std::is_same::value) { + return {ElementComputeEpilogue{beta.to()}}; + } else { + return {beta.to()}; + } + }() + }; + TensorCArguments tensor_c_arguments{ + [&]() -> TensorCArguments { + if constexpr (use_tensor_c) { + return {(ElementC*)tensor_c.data_ptr(), + ElementC(0), + {cute::_1{}, cute::_0{}, problem_size.m()}}; + } else { + return {ElementC(0)}; + } + }() + }; + typename Output::Arguments output_arguments{ + (ElementOutput*)tensor_d.data_ptr(), + {problem_size.n(), cute::_1{}, problem_size.mn().product()} + }; + typename EVTOutput::Arguments callback_arguments{ + { + { + alpha_arguments, // Alpha + {}, // Accum + {} // ApplyAlpha + }, // EVTApplyAlpha + { + beta_arguments, // Beta + tensor_c_arguments, // TensorC + {} // ApplyBeta + }, // EVTApplyBeta + {} // ApplySum + }, // EVTApplySum + output_arguments // Output + }; // EVTOutput + + // Create a tuple of CUTLASS sparse GEMM kernel arguments. + typename Gemm::Arguments arguments{ + problem_size, + tensor_a_device_ref, + tensor_b_device_ref, + tensor_e_reordered_device_ref, + callback_arguments}; + + cutlass::Status status; + + // Create CUTLASS sparse GEMM kernel object. + Gemm gemm_op; + + // Verify that sparse GEMM operation with given arguments can be + // performed by CUTLASS. + status = gemm_op.can_implement(arguments); + CUTLASS_STATUS_CHECK(status); + + // Allocate workspace for CUTLASS sparse GEMM kernel. + const auto workspace_size = Gemm::get_workspace_size(arguments); + auto workspace = tensor_a.new_empty({(int64_t)workspace_size}, + at::TensorOptions().dtype(at::kByte)); + + // Initialize CUTLASS sparse GEMM object. + status = gemm_op.initialize(arguments, workspace.data_ptr(), + at::cuda::getCurrentCUDAStream()); + CUTLASS_STATUS_CHECK(status); + + // Perform sparse GEMM operation. + status = gemm_op.run(at::cuda::getCurrentCUDAStream()); + CUTLASS_STATUS_CHECK(status); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +// Dispatch according to the input tensors layouts combination. +template < + typename ElementInputA, + typename ElementInputB, + typename ElementOutput, + typename ElementAccumulator, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + bool EnableRowMajorRowMajorLayouts, + bool EnableRowMajorColumnMajorLayouts, + bool EnableColumnMajorRowMajorLayouts, + bool EnableColumnMajorColumnMajorLayouts, + bool use_tensor_c> +void spgemm_cutlass_dispatch_layouts( + const Tensor& tensor_a, const Tensor& tensor_b, const Tensor& tensor_c, + const Tensor& tensor_e, const Scalar& alpha, const Scalar& beta, + Tensor& tensor_d) { + // Determine layouts (row-major or column-major) of input tensors. + const auto strides_a = tensor_a.strides(); + auto tensor_a_row_major = strides_a[1] == 1; + auto tensor_a_stride = tensor_a_row_major ? strides_a[0] : strides_a[1]; + const auto strides_b = tensor_b.strides(); + auto tensor_b_row_major = strides_b[1] == 1; + auto tensor_b_stride = tensor_b_row_major ? strides_b[0] : strides_b[1]; + + // Perform dispatching. + if constexpr (EnableRowMajorRowMajorLayouts) { + if (tensor_a_row_major && tensor_b_row_major) { + spgemm_cutlass< + ElementInputA, + ElementInputB, + ElementOutput, + ElementAccumulator, + ThreadblockShape, + WarpShape, + InstructionShape, + cutlass::layout::RowMajor, + cutlass::layout::RowMajor, + use_tensor_c>( + tensor_a, + tensor_a_stride, + tensor_b, + tensor_b_stride, + tensor_c, + tensor_e, + alpha, + beta, + tensor_d); + return; + } + } + if constexpr (EnableRowMajorColumnMajorLayouts) { + if (tensor_a_row_major && !tensor_b_row_major) { + spgemm_cutlass< + ElementInputA, + ElementInputB, + ElementOutput, + ElementAccumulator, + ThreadblockShape, + WarpShape, + InstructionShape, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + use_tensor_c>( + tensor_a, + tensor_a_stride, + tensor_b, + tensor_b_stride, + tensor_c, + tensor_e, + alpha, + beta, + tensor_d); + return; + } + } + if constexpr (EnableColumnMajorRowMajorLayouts) { + if (!tensor_a_row_major && tensor_b_row_major) { + spgemm_cutlass< + ElementInputA, + ElementInputB, + ElementOutput, + ElementAccumulator, + ThreadblockShape, + WarpShape, + InstructionShape, + cutlass::layout::ColumnMajor, + cutlass::layout::RowMajor, + use_tensor_c>( + tensor_a, + tensor_a_stride, + tensor_b, + tensor_b_stride, + tensor_c, + tensor_e, + alpha, + beta, + tensor_d); + return; + } + } + if constexpr (EnableColumnMajorColumnMajorLayouts) { + if (!tensor_a_row_major && !tensor_b_row_major) { + spgemm_cutlass< + ElementInputA, + ElementInputB, + ElementOutput, + ElementAccumulator, + ThreadblockShape, + WarpShape, + InstructionShape, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + use_tensor_c>( + tensor_a, + tensor_a_stride, + tensor_b, + tensor_b_stride, + tensor_c, + tensor_e, + alpha, + beta, + tensor_d); + return; + } + } + + AT_ERROR(__func__, "_dispatch_layouts: Combination of ", + tensor_a_row_major ? "row-major" : "column_major", " and ", + tensor_b_row_major ? "row-major" : "column_major", + " layouts for input tensors is not supported"); +} + +// Dispatch according to the tensor_c tensor being provided or not. +template < + typename ElementInputA, + typename ElementInputB, + typename ElementOutput, + typename ElementAccumulator, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + bool EnableRowMajorRowMajorLayouts, + bool EnableRowMajorColumnMajorLayouts, + bool EnableColumnMajorRowMajorLayouts, + bool EnableColumnMajorColumnMajorLayouts> +void spgemm_cutlass_dispatch_layouts_tensor_c( + const Tensor& tensor_a, const Tensor& tensor_b, const Tensor& tensor_c, + const Tensor& tensor_e, const Scalar& alpha, const Scalar& beta, + Tensor& tensor_d) { + if (tensor_c.numel() > 0) { + spgemm_cutlass_dispatch_layouts< + ElementInputA, + ElementInputB, + ElementOutput, + ElementAccumulator, + ThreadblockShape, + WarpShape, + InstructionShape, + EnableRowMajorRowMajorLayouts, + EnableRowMajorColumnMajorLayouts, + EnableColumnMajorRowMajorLayouts, + EnableColumnMajorColumnMajorLayouts, + true>( + tensor_a, + tensor_b, + tensor_c, + tensor_e, + alpha, + beta, + tensor_d); + } else { + spgemm_cutlass_dispatch_layouts< + ElementInputA, + ElementInputB, + ElementOutput, + ElementAccumulator, + ThreadblockShape, + WarpShape, + InstructionShape, + EnableRowMajorRowMajorLayouts, + EnableRowMajorColumnMajorLayouts, + EnableColumnMajorRowMajorLayouts, + EnableColumnMajorColumnMajorLayouts, + false>( + tensor_a, + tensor_b, + tensor_c, + tensor_e, + alpha, + beta, + tensor_d); + } +} +#endif + +// Perform multiply-add operation, using corresponding CUTLASS +// sparse GEMM kernel, to given arguments: +// result = alpha * mat1 @ mat2 + beta * input +// The "mat2" tensor is a dense tensor, while the "mat1" tensor is a +// sparse semi-structured matrix. The "input" tensor is optional; if +// provided, it should be a vector, with the number of elements equal +// to the number of rows of "mat1" matrix. It is assumed that "mat1" +// and "mat2" are 2D tensors, supplied either in row-major or +// column-major layouts (different layouts between these two tensors +// are OK, but not all combinations of formats are supported for some +// datatypes of these matrices). The "mat1_meta" argument contains +// sparse semi-strucutred metadata. +// +// There exists numerous limitations of CUTLASS sparse GEMM kernel, +// with regards to sizes and alignments of input tensors, their +// layouts and datatypes, and so on; this is the reason for large +// number of checks throughout the code. +// +// TODO: The "input" tensor has to be a vector, such that it could be +// broadcasted to columns of mat1 * mat2. The case of broadcasting to +// rows of mat1 * mat2 could be also supported, if "input" tensor is a +// vector of corresponding length; and same for the case when "input" +// tensor is a matrix of same size as mat1 * mat2 product. If these +// updates made here, then remember to update corresponding bits in +// the Inductor code that are handling meta registrations and +// lowerings of aten._sparse_semi_structured_mm and +// aten._sparse_semi_structured_addmm operators. +Tensor sparse_semi_structured_mad_op( + const Tensor& mat1, const Tensor& mat1_meta, const Tensor& mat2, + const c10::optional& input_opt, const Scalar& alpha, + const Scalar& beta, const c10::optional out_dtype_opt) { +#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) + AT_ERROR(__func__, " : CUTLASS not supported"); + return Tensor{}; +#else + // No need to check that all tensors are on CUDA device, as this + // is provided by dispatch. + + const auto& input = input_opt.value_or(Tensor{}); + const auto out_dtype = out_dtype_opt.value_or(mat2.scalar_type()); + + // For now, only CC 8.x devices are supported. + const auto dprops = at::cuda::getCurrentDeviceProperties(); + const auto is_sm8x = dprops->major == 8; + TORCH_CHECK(is_sm8x, + __func__, " : Supported only on GPUs with compute capability " + "8.x"); + + // Validate datatypes of input tensors. + TORCH_CHECK(mat2.dtype() == at::kChar || + mat2.dtype() == at::kHalf || + mat2.dtype() == at::kBFloat16 || + mat2.dtype() == at::kFloat, + __func__, " : The mat2 datatype ", mat2.dtype(), + " is not supported"); + TORCH_CHECK(mat1.dtype() == mat2.dtype(), + __func__, " : Expected mat1 datatype ", mat2.dtype(), + ", but got ", mat1.dtype()); + if (input.numel() != 0) { + TORCH_CHECK(input.dtype() == out_dtype, + __func__, " : Expected input datatype ", out_dtype, + ", but got ", input.dtype()); + } + + // Validate layouts of input tensors. + TORCH_CHECK(mat1.layout() == Layout::Strided, + __func__, " : Expected mat1 argument to be strided, but got " + "layout ", mat1.layout()); + TORCH_CHECK(mat1.dim() == 2, + __func__, " : Expected mat1 argument to be 2D tensor, got ", + mat1.dim(), " dims"); + const auto strides_a = mat1.strides(); + TORCH_CHECK(strides_a[0] == 1 || strides_a[1] == 1, + __func__, " : Invalid strides for mat1 argument: row stride = ", + strides_a[0], ", column stride = ", strides_a[1]); + TORCH_CHECK(mat2.layout() == Layout::Strided, + __func__, " : Expected mat2 argument to be " + "strided, but got layout ", mat2.layout()); + TORCH_CHECK(mat2.dim() == 2, + __func__, " : Expected mat2 argument to be 2D tensor, got ", + mat2.dim(), " dims"); + const auto strides_b = mat2.strides(); + TORCH_CHECK(strides_b[0] == 1 || strides_b[1] == 1, + __func__, " : Invalid strides for mat2 argument: row stride = ", + strides_b[0], ", column stride = ", strides_b[1]); + if (input.numel() != 0) { + TORCH_CHECK(input.layout() == Layout::Strided, + __func__, " : Expected input argument to be strided, but " + "got layout ", input.layout()); + TORCH_CHECK(input.dim() == 1, + __func__, " : Expected input argument to be 1D tensor, " + "got ", input.dim(), " dims"); + } + + // Validate sizes of input tensors. + TORCH_CHECK(mat1.size(1) == mat2.size(0) / 2, + __func__, " : Expected mat1 argument to have ", + mat2.size(0) / 2, " columns, but got ", mat1.size(1)); + if (input.numel() != 0) { + TORCH_CHECK(input.size(0) == mat1.size(0), + __func__, " : Expected input argument to have ", + mat1.size(0), " elements, but got ", input.size(0)); + } + + // Introduce alias names for arguments, according to the CUTLASS + // naming conventions. + const auto& tensor_a = mat1; + const auto& tensor_b = mat2; + const auto& tensor_c = input; + const auto& tensor_e = mat1_meta; + + // Create output tensor. + Tensor tensor_d = + tensor_b.new_empty({tensor_a.size(0), tensor_b.size(1)}, + at::TensorOptions().dtype(out_dtype)); + + // Call wrapper function for CUTLASS sparse GEMM, dispatching on + // the input datatype, and then on input tensors layouts. + // According to the input tensors datatypes and layouts, + // corresponding template arguments are supplied for instantiating + // the wrapper function. The tile sizes template arguments are + // selected according to the CUTLASS profiler results, for number + // of runs. + AT_DISPATCH_SWITCH( + tensor_a.scalar_type(), + "sparse_semi_structured_mad_op", + AT_DISPATCH_CASE( + at::ScalarType::Char, + [&]() { + using ElementInputA = int8_t; + using ElementInputB = int8_t; + using ElementAccumulator = int32_t; + using ThreadblockShape = + cutlass::gemm::GemmShape<128, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + const auto EnableRowMajorRowMajorLayouts = false; + const auto EnableRowMajorColumnMajorLayouts = true; + const auto EnableColumnMajorRowMajorLayouts = false; + const auto EnableColumnMajorColumnMajorLayouts = false; + if (out_dtype == at::kInt) { + using ElementOutput = int32_t; + spgemm_cutlass_dispatch_layouts_tensor_c< + ElementInputA, + ElementInputB, + ElementOutput, + ElementAccumulator, + ThreadblockShape, + WarpShape, + InstructionShape, + EnableRowMajorRowMajorLayouts, + EnableRowMajorColumnMajorLayouts, + EnableColumnMajorRowMajorLayouts, + EnableColumnMajorColumnMajorLayouts>( + tensor_a, + tensor_b, + tensor_c, + tensor_e, + alpha, + beta, + tensor_d); + } else if (out_dtype == at::kChar) { + using ElementOutput = int8_t; + spgemm_cutlass_dispatch_layouts_tensor_c< + ElementInputA, + ElementInputB, + ElementOutput, + ElementAccumulator, + ThreadblockShape, + WarpShape, + InstructionShape, + EnableRowMajorRowMajorLayouts, + EnableRowMajorColumnMajorLayouts, + EnableColumnMajorRowMajorLayouts, + EnableColumnMajorColumnMajorLayouts>( + tensor_a, + tensor_b, + tensor_c, + tensor_e, + alpha, + beta, + tensor_d); + } + }) + AT_DISPATCH_CASE( + at::ScalarType::Half, + [&]() { + using ElementInputA = cutlass::half_t; + using ElementInputB = cutlass::half_t; + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + const auto EnableRowMajorRowMajorLayouts = true; + const auto EnableRowMajorColumnMajorLayouts = true; + const auto EnableColumnMajorRowMajorLayouts = true; + const auto EnableColumnMajorColumnMajorLayouts = true; + spgemm_cutlass_dispatch_layouts_tensor_c< + ElementInputA, + ElementInputB, + ElementOutput, + ElementAccumulator, + ThreadblockShape, + WarpShape, + InstructionShape, + EnableRowMajorRowMajorLayouts, + EnableRowMajorColumnMajorLayouts, + EnableColumnMajorRowMajorLayouts, + EnableColumnMajorColumnMajorLayouts>( + tensor_a, + tensor_b, + tensor_c, + tensor_e, + alpha, + beta, + tensor_d); + }) + AT_DISPATCH_CASE( + at::ScalarType::BFloat16, + [&]() { + using ElementInputA = cutlass::bfloat16_t; + using ElementInputB = cutlass::bfloat16_t; + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + const auto EnableRowMajorRowMajorLayouts = true; + const auto EnableRowMajorColumnMajorLayouts = true; + const auto EnableColumnMajorRowMajorLayouts = true; + const auto EnableColumnMajorColumnMajorLayouts = true; + spgemm_cutlass_dispatch_layouts_tensor_c< + ElementInputA, + ElementInputB, + ElementOutput, + ElementAccumulator, + ThreadblockShape, + WarpShape, + InstructionShape, + EnableRowMajorRowMajorLayouts, + EnableRowMajorColumnMajorLayouts, + EnableColumnMajorRowMajorLayouts, + EnableColumnMajorColumnMajorLayouts>( + tensor_a, + tensor_b, + tensor_c, + tensor_e, + alpha, + beta, + tensor_d); + }) + AT_DISPATCH_CASE( + at::ScalarType::Float, + [&]() { + using ElementInputA = float; + using ElementInputB = float; + using ElementOutput = float; + using ElementAccumulator = float; + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + const auto EnableRowMajorRowMajorLayouts = true; + const auto EnableRowMajorColumnMajorLayouts = true; + const auto EnableColumnMajorRowMajorLayouts = true; + const auto EnableColumnMajorColumnMajorLayouts = true; + spgemm_cutlass_dispatch_layouts_tensor_c< + ElementInputA, + ElementInputB, + ElementOutput, + ElementAccumulator, + ThreadblockShape, + WarpShape, + InstructionShape, + EnableRowMajorRowMajorLayouts, + EnableRowMajorColumnMajorLayouts, + EnableColumnMajorRowMajorLayouts, + EnableColumnMajorColumnMajorLayouts>( + tensor_a, + tensor_b, + tensor_c, + tensor_e, + alpha, + beta, + tensor_d); + })); + + return tensor_d; +#endif +} + +// Implementation of aten._sparse_semi_structured_mm operator. +Tensor _sparse_semi_structured_mm( + const Tensor& mat1, const Tensor& mat1_meta, const Tensor& mat2, + const c10::optional out_dtype_opt) { + return sparse_semi_structured_mad_op(mat1, mat1_meta, mat2, + c10::optional(), 1, 0, + out_dtype_opt); +} + +// Implementation of aten._sparse_semi_structured_addmm operator. +Tensor _sparse_semi_structured_addmm( + const Tensor& input, const Tensor& mat1, const Tensor& mat1_meta, + const Tensor& mat2, const Scalar& alpha, const Scalar& beta, + const c10::optional out_dtype_opt) { + return sparse_semi_structured_mad_op(mat1, mat1_meta, mat2, input, alpha, + beta, out_dtype_opt); +} + +} // namespace at::native + +// Following is just for testing purposes. +namespace at::native { + +#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) +#else +// Copied from tools/util/include/host_reorder.h, from CUTLASS source +// tree. This is for simplicity - namely, this file is not under +// include/cutlass in this tree, as other CUTLASS include files +// needed, so it would require changing PyTorch CMake configuration; +// furthermore, including this file produces build errors in PyTorch +// at the moment. +template +static void reorder_meta(cutlass::TensorRef dest, + cutlass::TensorRef src, + const int problem_size_m, const int problem_size_k) { + for (int m = 0; m < problem_size_m; m++) { + for (int k = 0; k < problem_size_k; k++) { + // First reorder the rows. + int group = (sizeof(Element) == 2) ? 32 : 16; + int interweave = (sizeof(Element) == 2) ? 4 : 2; + + int dest_row = m / group * group + (m % 8) * interweave + (m % group) / 8; + int dest_col = k; + + // Next swizzle the 2x2 blocks from Z to N. + if (((dest_row % 2) == 0) && ((dest_col % 2) == 1)) { + ++dest_row; + --dest_col; + } else if (((dest_row % 2) == 1) && ((dest_col % 2) == 0)) { + --dest_row; + ++dest_col; + } + + dest.at({dest_row, dest_col}) = src.at({m, k}); + } + } +} +#endif + +std::tuple +_to_sparse_semi_structured(const Tensor& dense) { +#if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) + AT_ERROR(__func__, " : CUTLASS not supported"); + return std::make_tuple(Tensor{}, Tensor{}); +#else + // Check dimensions of the dense matrix. + TORCH_CHECK(dense.dim() == 2, + __func__, " : Expected dense argument to be 2D tensor, got ", + dense.dim(), " dims"); + + // Determine PyTorch datatype for the metadata matrix. + auto meta_dtype = at::kChar; + auto ksparse = 0; + auto dense_elems_per_meta_elem = 0; + if (dense.dtype() == at::kChar) { + meta_dtype = at::kInt; + ksparse = 4; + dense_elems_per_meta_elem = 32; + } else if (dense.dtype() == at::kHalf || dense.dtype() == at::kBFloat16) { + meta_dtype = at::kShort; + ksparse = 4; + dense_elems_per_meta_elem = 16; + } else if (dense.dtype() == at::kFloat) { + meta_dtype = at::kShort; + ksparse = 2; + dense_elems_per_meta_elem = 8; + } else { + AT_ERROR("_to_sparse_semi_structured: Invalid dense argument datatype ", + dense.dtype(), " encountered"); + } + + const auto dense_nrows = dense.size(0); + const auto dense_ncols = dense.size(1); + + if (dense_nrows % (meta_dtype == at::kShort ? 32 : 16) != 0) { + AT_ERROR("_to_sparse_semi_structured: Number of rows of dense matrix must " + "be divisible by ", (meta_dtype == at::kShort ? 32 : 16), + ", but it is ", dense_nrows); + } + if (dense_ncols % dense_elems_per_meta_elem != 0) { + AT_ERROR("_to_sparse_semi_structured: Number of columns of dense matrix " + "must be divisible by ", dense_elems_per_meta_elem, ", but it is ", + dense_ncols); + } + + const auto dense_cpu = dense.to("cpu"); + + const auto mask_cpu = dense_cpu != at::zeros({1}, dense_cpu.options()); + + const auto sparse_cpu = + dense_cpu.masked_select(mask_cpu).view({dense_nrows, dense_ncols / 2}); + + const auto meta_nrows = dense_nrows; + const auto meta_ncols = dense_ncols / dense_elems_per_meta_elem; + auto meta_cpu = dense_cpu.new_empty({meta_nrows, meta_ncols}, + at::TensorOptions().dtype(meta_dtype)); + + auto* mask_cpu_ptr = mask_cpu.data_ptr(); + for (auto i = 0; i < meta_nrows; ++i) { + for (auto j = 0; j < meta_ncols; ++j) { + uint64_t meta_val = 0; + for (auto k = 0; k < dense_elems_per_meta_elem / ksparse; ++k, mask_cpu_ptr += ksparse) { + const auto mask_elems = + (ksparse == 4) ? std::make_tuple(mask_cpu_ptr[0], mask_cpu_ptr[1], + mask_cpu_ptr[2], mask_cpu_ptr[3]) + : std::make_tuple(mask_cpu_ptr[0], mask_cpu_ptr[0], + mask_cpu_ptr[1], mask_cpu_ptr[1]); + auto meta_quadruple = 0; + if (mask_elems == std::make_tuple(1, 1, 0, 0)) { + meta_quadruple = 4; // 0100 + } else if (mask_elems == std::make_tuple(1, 0, 1, 0)) { + meta_quadruple = 8; // 1000 + } else if (mask_elems == std::make_tuple(0, 1, 1, 0)) { + meta_quadruple = 9; // 1001 + } else if (mask_elems == std::make_tuple(1, 0, 0, 1)) { + meta_quadruple = 12; // 1100 + } else if (mask_elems == std::make_tuple(0, 1, 0, 1)) { + meta_quadruple = 13; // 1101 + } else if (mask_elems == std::make_tuple(0, 0, 1, 1)) { + meta_quadruple = 14; // 1110 + } else { + AT_ERROR("_to_sparse_semi_structured: dense argument does not match ", + (dense.dtype() != at::kFloat) ? "2:4" : "1:2", + "sparsity pattern"); + } + meta_val = meta_val | (meta_quadruple << (4 * k)); + } + const auto idx = i * meta_ncols + j; + if (meta_dtype == at::kShort) { + using MetaElement = int16_t; + const auto meta_cpu_ptr = meta_cpu.data_ptr(); + meta_cpu_ptr[idx] = (MetaElement)meta_val; + } else if (meta_dtype == at::kInt) { + using MetaElement = int32_t; + const auto meta_cpu_ptr = meta_cpu.data_ptr(); + meta_cpu_ptr[idx] = (MetaElement)meta_val; + } + } + } + + auto meta_reordered_cpu = meta_cpu.new_empty({meta_nrows, meta_ncols}); + using MetaLayout = cutlass::layout::RowMajor; + using MetaReorderedLayout = cutlass::layout::ColumnMajorInterleaved<2>; + if (meta_dtype == at::kShort) { + using MetaElement = int16_t; + auto meta_cpu_ref = + cutlass::TensorRef( + meta_cpu.data_ptr(), + MetaLayout::packed({meta_nrows, meta_ncols})); + auto meta_reordered_cpu_ref = + cutlass::TensorRef( + meta_reordered_cpu.data_ptr(), + MetaReorderedLayout::packed({meta_nrows, meta_ncols})); + reorder_meta(meta_reordered_cpu_ref, meta_cpu_ref, meta_nrows, meta_ncols); + } else if (meta_dtype == at::kInt) { + using MetaElement = int32_t; + auto meta_cpu_ref = + cutlass::TensorRef( + meta_cpu.data_ptr(), + MetaLayout::packed({meta_nrows, meta_ncols})); + auto meta_reordered_cpu_ref = + cutlass::TensorRef( + meta_reordered_cpu.data_ptr(), + MetaReorderedLayout::packed({meta_nrows, meta_ncols})); + reorder_meta(meta_reordered_cpu_ref, meta_cpu_ref, meta_nrows, meta_ncols); + } + + return std::make_tuple(sparse_cpu.to(dense.device()), + meta_reordered_cpu.to(dense.device())); +#endif +} + +} // namespace at::native diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 0be956c17f1d..79a345571349 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -523,7 +523,9 @@ aten::_sparse_mask_projection aten::_sparse_mask_projection.out aten::_sparse_mm_reduce_impl aten::_sparse_mm_reduce_impl_backward +aten::_sparse_semi_structured_addmm aten::_sparse_semi_structured_linear +aten::_sparse_semi_structured_mm aten::_sparse_softmax aten::_sparse_softmax.out aten::_sparse_softmax_backward_data diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 65c4a1196e19..5a4aac572c17 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -134,7 +134,6 @@ ALLOW_LIST = [ ("aten::batch_norm_backward_elemt", datetime.date(2023, 12, 31)), ("aten::sym_constrain_range", datetime.date(2023, 12, 31)), ("aten::_efficient_attention_forward", datetime.date(2024, 1, 15)), - ("aten::_sparse_semi_structured_linear", datetime.date(2024, 1, 15)), ("onednn::qconv1d_pointwise", datetime.date(2023, 12, 31)), ("onednn::qconv2d_pointwise", datetime.date(2023, 12, 31)), ("onednn::qconv3d_pointwise", datetime.date(2023, 12, 31)), diff --git a/test/test_sparse_semi_structured.py b/test/test_sparse_semi_structured.py index fcb316ee3019..a09e2647eb7f 100644 --- a/test/test_sparse_semi_structured.py +++ b/test/test_sparse_semi_structured.py @@ -157,7 +157,7 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase): """ Test nn.Linear + .contiguous() + nn.ReLU with SparseSemiStructuredTensor + torch.compile We expect: - (1) The sparse tensor subclass should turn nn.Linear into `aten._structured_sparse_linear` + `aten.contiguous()` + (1) The sparse tensor subclass should turn nn.Linear into `aten._structured_sparse_addmm` + `aten.contiguous()` (2) Inductor should fuse the .contiguous() call into the relu """ @@ -207,7 +207,7 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase): @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows") def test_mlp_contiguous_relu_compile_cutlass(self): """ - test for CUTLASS meta registrations (_sparse_semi_structured_linear) + torch.compile + test for CUTLASS meta registrations (_sparse_semi_structured_addmm) + torch.compile """ for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]: SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cutlass", dense_input_shape) @@ -258,7 +258,7 @@ class TestSparseSemiStructured(TestCase): if dtype is torch.int8: # This should fail if backend == "cutlass": - with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_dispatch_layouts"): + with self.assertRaisesRegex(RuntimeError, "spgemm_cutlass_dispatch_layouts"): sparse_result = torch.mm(A_sparse, B) else: with self.assertRaisesRegex(RuntimeError, @@ -291,7 +291,7 @@ class TestSparseSemiStructured(TestCase): # padding with int8 throws an error because transposing B yields a contiguous output # and row-row 2:4 sparse @ dense with NN is not supported by cuSPARSELt or CUTLASS. if backend == "cutlass": - with self.assertRaisesRegex(RuntimeError, "two_four_sgemm_dispatch_layouts"): + with self.assertRaisesRegex(RuntimeError, "spgemm_cutlass_dispatch_layouts"): sparse_result = torch.mm(A_sparse, B.t()) else: with self.assertRaisesRegex(RuntimeError, @@ -575,6 +575,73 @@ class TestSparseSemiStructured(TestCase): torch.backends.cuda.matmul.allow_tf32 = orig + @unittest.skipIf(TEST_WITH_ROCM or IS_WINDOWS, "ROCm and Windows doesn't support CUTLASS") + @parametrize("backend", ["cutlass"]) + @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES) + def test_sparse_semi_structured_ops_cutlass(self, device, dtype, backend): + SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass") + if backend == "cutlass" and IS_WINDOWS: + self.skipTest("CUTLASS not supported on Windows") + + def run_test(m, n, k, device, dtype, dtype_out, use_input, rtol, atol): + mat1 = rand_sparse_semi_structured(m, k, dtype, device) + # mat2 transposed as int8 case supports only row-major/column-major combination + mat2 = make_tensor((n, k), dtype=dtype, device=device).t() + input = make_tensor((m,), dtype=dtype_out, device=device) if use_input else None + + if use_input: + if dtype.is_floating_point: + alpha = 1.3 + beta = -0.7 + else: + alpha = 2 + beta = -3 + + dtype_dense = torch.float32 + mat1_dense = mat1.to(dtype_dense) + mat2_dense = mat2.to(dtype_dense) + if not use_input: + output0 = torch.mm(mat1_dense, mat2_dense) + else: + input_dense = input.to(dtype_dense)[:, None] + output0 = torch.addmm(input_dense, mat1_dense, mat2_dense, alpha=alpha, beta=beta) + + compressed = to_sparse_semi_structured(mat1) + + mat1_sparse = compressed.values() + mat1_meta = compressed.indices() + + if not use_input: + output1 = torch._sparse_semi_structured_mm(mat1_sparse, mat1_meta, mat2, out_dtype=dtype_out) + else: + output1 = torch._sparse_semi_structured_addmm( + input, mat1_sparse, mat1_meta, mat2, alpha=alpha, beta=beta, out_dtype=dtype_out + ) + torch.testing.assert_close(output1.to(dtype_dense), output0, rtol=rtol, atol=atol) + + if dtype == torch.float32: + # Inputs are converted to TF32 internally for sparse GEMM, + # so make dense GEMM to do the same for matching results. + orig = torch.backends.cuda.matmul.allow_tf32 + torch.backends.cuda.matmul.allow_tf32 = True + + dtype_out = {torch.int8: torch.int32, torch.half: torch.half, torch.bfloat16: torch.bfloat16, torch.float32: torch.float32} + rtol, atol = 1e-3, 1e-3 + if dtype == torch.bfloat16: + rtol, atol = 5e-3, 5e-3 + elif dtype == torch.float32: + rtol, atol = 1e-3, 75e-2 + for m, n, k, use_input in \ + itertools.product(range(3), range(3), range(3), (False, True)): + m = 2 ** m * 32 + n = 2 ** n * 32 + k = 2 ** k * 128 + run_test(m, n, k, device, dtype, dtype_out[dtype], use_input, rtol, atol) + + if dtype == torch.float32: + torch.backends.cuda.matmul.allow_tf32 = orig + + @unittest.skipIf(not has_triton(), "Test needs triton and recent GPU arch") @parametrize("backend", ["cutlass"]) @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index a16333cd7f48..393c649133df 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -1542,7 +1542,9 @@ torch_c_binding_in_graph_functions = dict.fromkeys( "torch._sparse_csr_prod", "torch._sparse_csr_sum", "torch._sparse_log_softmax_backward_data", + "torch._sparse_semi_structured_addmm", "torch._sparse_semi_structured_linear", + "torch._sparse_semi_structured_mm", "torch._sparse_softmax_backward_data", "torch._sparse_sparse_matmul", "torch._sparse_sum", diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index fd69525b6deb..70e91dc3f2b7 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -431,6 +431,66 @@ def meta_sparse_structured_linear( return output +@register_meta(aten._sparse_semi_structured_mm) +def meta_sparse_structured_mm( + mat1: Tensor, + mat1_meta: Tensor, + mat2: Tensor, + out_dtype: Optional[torch.dtype] = None, +): + assert len(mat1.shape) == 2 + assert len(mat1_meta.shape) == 2 + assert len(mat2.shape) == 2 + assert mat1.size(1) == mat2.size(0) / 2 + output_sizes = [mat1.size(0), mat2.size(1)] + + if out_dtype is not None: + assert ( + mat2.dtype == torch.int8 and out_dtype == torch.int32 + ), "out_dtype is only supported for i8i8->i32 linear operator" + output = mat2.new_empty( + output_sizes, + dtype=mat2.dtype if out_dtype is None else out_dtype, + ) + + return output + + +@register_meta(aten._sparse_semi_structured_addmm) +def meta_sparse_structured_addmm( + input: Tensor, + mat1: Tensor, + mat1_meta: Tensor, + mat2: Tensor, + *, + alpha=1, + beta=1, + out_dtype: Optional[torch.dtype] = None, +): + assert ( + len(input.shape) == 1 + ), "only input broadcasted to columns of mat1 * mat2 product is supported" + assert len(mat1.shape) == 2 + assert len(mat1_meta.shape) == 2 + assert len(mat2.shape) == 2 + assert input.size(0) == mat1.size( + 0 + ), "only input broadcasted to columns of mat1 * mat2 product is supported" + assert mat1.size(1) == mat2.size(0) / 2 + output_sizes = [mat1.size(0), mat2.size(1)] + + if out_dtype is not None: + assert ( + mat2.dtype == torch.int8 and out_dtype == torch.int32 + ), "out_dtype is only supported for i8i8->i32 linear operator" + output = mat2.new_empty( + output_sizes, + dtype=mat2.dtype if out_dtype is None else out_dtype, + ) + + return output + + @register_meta(aten._cslt_sparse_mm) def meta__cslt_sparse_mm( compressed_A: torch.Tensor, diff --git a/torch/sparse/semi_structured.py b/torch/sparse/semi_structured.py index 03c15c0eee51..7c86b0d43b51 100644 --- a/torch/sparse/semi_structured.py +++ b/torch/sparse/semi_structured.py @@ -47,7 +47,7 @@ class SparseSemiStructuredTensor(torch.Tensor): -`_DTYPE_SHAPE_CONSTRAINTS` - A dictionary holding backend specific dense/sparse min shape constraints - `def from_dense()` - backend specific compression routines - - `def _mm()` - backend specifc mm op (either torch._cslt_sparse_mm or torch._sparse_semi_structured_linear) + - `def _mm()` - backend specifc mm op (either torch._cslt_sparse_mm or torch._sparse_semi_structured_(mm|addmm)) """ _DEFAULT_ALG_ID: int = 0 @@ -371,11 +371,12 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor): """ This class implements semi-structured sparsity for the CUTLASS backend. + In this implementation, the specified elements and metadata are stored seprately, in packed and meta respectively. - When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_linear - and sparse_semi_structured_from_dense for conversion to the compressed format. + When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_(mm|addmm) and + sparse_semi_structured_from_dense for conversion to the compressed format. """ _DTYPE_SHAPE_CONSTRAINTS = { @@ -436,9 +437,14 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor): f"`{cls_name}` matmul: operation is not supported" ) else: - res = torch._sparse_semi_structured_linear( - B.t(), self.packed, self.meta, bias=bias - ).t() + if bias is None: + res = torch._sparse_semi_structured_mm( + self.packed, self.meta, B + ) + else: + res = torch._sparse_semi_structured_addmm( + bias, self.packed, self.meta, B + ) return res[: self.shape[0]]