ROCm Sparsity through HipSparseLT (#150578)

TLDR:

- This pull request introduces support for hipSPARSELt in ROCm, current usage would be semi-structure sparsity.
- Require **ROCm 6.4** && **gfx942/gfx950**.
- The average performance uplift (compare to dense operation) is ~ 20% in ROCm 6.4 but expect further performance lift along the way.

### Dense vs. Sparse Performance Comparison

#### **NT (Row-major)**
**Average Uplift**: `1.20`

| M     | N      | K      | hipsparselt-bench (us) | hipblaslt-bench get all (us) | Uplift |
|-------|--------|--------|-------------------------|-------------------------------|--------|
| 14336 | 8      | 4096   | 20.05                   | 25.3                          | 1.26   |
| 4096  | 8      | 14336  | 21.07                   | 25.28                         | 1.20   |
| 3072  | 3072   | 10240  | 299.05                  | 351.82                        | 1.18   |
| 3072  | 1536   | 768    | 18.56                   | 20.05                         | 1.08   |
| 3072  | 17664  | 768    | 163.13                  | 173.91                        | 1.07   |
| 3072  | 196608 | 768    | 1717.30                 | 1949.63                       | 1.14   |
| 3072  | 24576  | 768    | 206.84                  | 242.98                        | 1.17   |
| 3072  | 6144   | 768    | 53.90                   | 56.88                         | 1.06   |
| 3072  | 98304  | 768    | 833.77                  | 962.28                        | 1.15   |
| 768   | 1536   | 768    | 8.53                    | 19.65                         | 2.30   |
| 768   | 17664  | 768    | 46.02                   | 46.84                         | 1.02   |
| 768   | 196608 | 768    | 463.15                  | 540.46                        | 1.17   |
| 768   | 24576  | 768    | 54.32                   | 59.55                         | 1.10   |
| 768   | 6144   | 768    | 19.47                   | 20.15                         | 1.03   |
| 768   | 98304  | 768    | 231.88                  | 258.73                        | 1.12   |

---

#### **NN (Row-major)**
**Average Uplift**: `1.13`

| M   | N      | K     | hipsparselt-bench (us) | hipblaslt-bench get all (us) | Uplift |
|-----|--------|-------|-------------------------|-------------------------------|--------|
| 768 | 1536   | 3072  | 27.50                   | 28.78                         | 1.05   |
| 768 | 17664  | 3072  | 125.06                  | 158.94                        | 1.27   |
| 768 | 196608 | 3072  | 1568.38                 | 1767.12                       | 1.13   |
| 768 | 24576  | 3072  | 171.05                  | 203.49                        | 1.19   |
| 768 | 6144   | 3072  | 58.72                   | 60.39                         | 1.03   |
| 768 | 98304  | 3072  | 787.15                  | 887.60                        | 1.13   |

-------------------------

This pull request introduces support for hipSPARSELt in ROCm, alongside various updates and improvements to the codebase and test suite. The changes primarily involve adding configuration flags, updating conditional checks, and ensuring compatibility with hipSPARSELt.

### ROCm and hipSPARSELt Support:

* [`BUILD.bazel`](diffhunk://#diff-7fc57714ef13c3325ce2a1130202edced92fcccc0c6db34a72f7b57f60d552a3R292): Added `@AT_HIPSPARSELT_ENABLED@` substitution to enable hipSPARSELt support.
* [`aten/CMakeLists.txt`](diffhunk://#diff-0604597797bb21d7c39150f9429d6b2ace10b79ab308514ad03f76153ae8249bR104-R110): Introduced a conditional flag to enable hipSPARSELt support based on ROCm version.
* [`aten/src/ATen/CMakeLists.txt`](diffhunk://#diff-ce80f3115ab2f6be5142f0678a1fc92c6b2d7727766ce44f48726c99e720f777R37): Added `AT_HIPSPARSELT_ENABLED` configuration.
* [`aten/src/ATen/cuda/CUDAConfig.h.in`](diffhunk://#diff-8bb82da825ca87c28233abacffa1b0566c73a54990b7a77f3f5108d3718fea15R11): Defined `AT_HIPSPARSELT_ENABLED` macro.
* `caffe2/CMakeLists.txt`, `cmake/Dependencies.cmake`, `cmake/public/LoadHIP.cmake`: Included hipSPARSELt in the ROCm dependencies. [[1]](diffhunk://#diff-c5ee05f1e918772792ff6f2a3f579fc2f182e57b1709fd786ef6dc711fd68b27R1380) [[2]](diffhunk://#diff-12e8125164bbfc7556b1781a8ed516e333cc0bf058acb7197f7415be44606c72L1084-R1084) [[3]](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5R153)

### Codebase Updates:

* [`aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp`](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R1-R6): Added hipSPARSELt support checks and initialization functions. Updated various methods to conditionally handle hipSPARSELt. [[1]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R1-R6) [[2]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R22-R67) [[3]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R78-R85) [[4]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R97-R109) [[5]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R183-R188) [[6]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3L134-R200) [[7]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R213-R222) [[8]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3L217-R285)

### Test Suite Updates:

* [`test/test_sparse_semi_structured.py`](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR50-R65): Added checks for hipSPARSELt availability and updated test conditions to skip tests not supported on ROCm. [[1]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR50-R65) [[2]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR228) [[3]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR239) [[4]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR250) [[5]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR579) [[6]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR624) [[7]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR661) [[8]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR695) [[9]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR730) [[10]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR755) [[11]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR771) [[12]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR809) [[13]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR844) [[14]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cL840-R854) [[15]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR1005)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150578
Approved by: https://github.com/jeffdaily
This commit is contained in:
Peter Y. Yeh
2025-05-31 02:03:36 +00:00
committed by PyTorch MergeBot
parent ad26ec6abe
commit 43390d8b13
9 changed files with 130 additions and 10 deletions

View File

@ -290,6 +290,7 @@ header_template_rule(
substitutions = { substitutions = {
"@AT_CUDNN_ENABLED@": "1", "@AT_CUDNN_ENABLED@": "1",
"@AT_CUSPARSELT_ENABLED@": "0", "@AT_CUSPARSELT_ENABLED@": "0",
"@AT_HIPSPARSELT_ENABLED@": "0",
"@AT_ROCM_ENABLED@": "0", "@AT_ROCM_ENABLED@": "0",
"@AT_MAGMA_ENABLED@": "0", "@AT_MAGMA_ENABLED@": "0",
"@NVCC_FLAGS_EXTRA@": "", "@NVCC_FLAGS_EXTRA@": "",

View File

@ -101,6 +101,13 @@ else()
set(AT_CUSPARSELT_ENABLED 1) set(AT_CUSPARSELT_ENABLED 1)
endif() endif()
# Add hipSPARSELt support flag
if(USE_ROCM AND ROCM_VERSION VERSION_GREATER_EQUAL "6.4.0")
set(AT_HIPSPARSELT_ENABLED 1)
else()
set(AT_HIPSPARSELT_ENABLED 0)
endif()
list(APPEND ATen_CPU_INCLUDE list(APPEND ATen_CPU_INCLUDE
${CMAKE_CURRENT_SOURCE_DIR}/src) ${CMAKE_CURRENT_SOURCE_DIR}/src)
add_subdirectory(src/ATen) add_subdirectory(src/ATen)

View File

@ -34,6 +34,7 @@ set_bool(AT_MAGMA_ENABLED USE_MAGMA)
set_bool(CAFFE2_STATIC_LINK_CUDA_INT CAFFE2_STATIC_LINK_CUDA) set_bool(CAFFE2_STATIC_LINK_CUDA_INT CAFFE2_STATIC_LINK_CUDA)
set_bool(AT_CUDNN_ENABLED CAFFE2_USE_CUDNN) set_bool(AT_CUDNN_ENABLED CAFFE2_USE_CUDNN)
set_bool(AT_CUSPARSELT_ENABLED CAFFE2_USE_CUSPARSELT) set_bool(AT_CUSPARSELT_ENABLED CAFFE2_USE_CUSPARSELT)
set_bool(AT_HIPSPARSELT_ENABLED CAFFE2_USE_HIPSPARSELT)
configure_file(Config.h.in "${CMAKE_CURRENT_SOURCE_DIR}/Config.h") configure_file(Config.h.in "${CMAKE_CURRENT_SOURCE_DIR}/Config.h")
# TODO: Do not generate CUDAConfig.h for ROCm BUILDS # TODO: Do not generate CUDAConfig.h for ROCm BUILDS

View File

@ -8,6 +8,7 @@
// only be included from C++ files. // only be included from C++ files.
#define AT_CUDNN_ENABLED() @AT_CUDNN_ENABLED@ #define AT_CUDNN_ENABLED() @AT_CUDNN_ENABLED@
#define AT_CUSPARSELT_ENABLED() @AT_CUSPARSELT_ENABLED@ #define AT_CUSPARSELT_ENABLED() @AT_CUSPARSELT_ENABLED@
#define AT_HIPSPARSELT_ENABLED() @AT_HIPSPARSELT_ENABLED@
#define AT_ROCM_ENABLED() @AT_ROCM_ENABLED@ #define AT_ROCM_ENABLED() @AT_ROCM_ENABLED@
#define AT_MAGMA_ENABLED() @AT_MAGMA_ENABLED@ #define AT_MAGMA_ENABLED() @AT_MAGMA_ENABLED@

View File

@ -1,5 +1,7 @@
#include <ATen/native/sparse/cuda/cuSPARSELtOps.h> #include <ATen/native/sparse/cuda/cuSPARSELtOps.h>
#include <unordered_map>
#include <mutex>
#include <string_view>
#if AT_CUSPARSELT_ENABLED() #if AT_CUSPARSELT_ENABLED()
namespace at::native { namespace at::native {
@ -15,6 +17,45 @@ namespace at::native {
thread_local cusparseLtHandle_t handle; thread_local cusparseLtHandle_t handle;
thread_local bool handle_initialized = false; thread_local bool handle_initialized = false;
#ifdef USE_ROCM
// Single global flag for platform-wide hipSparseLt support
c10::once_flag g_hipSparseLtSupportInitFlag;
static bool g_hipSparseLtSupported = false;
// Initialize the hipSparseLt support status once for the platform
static void initHipSparseLtSupport() {
// Default to not supported
g_hipSparseLtSupported = false;
// Check only the first available device
try {
if (at::cuda::device_count() > 0) {
g_hipSparseLtSupported = at::detail::getCUDAHooks().isGPUArch({"gfx950", "gfx942"}, 0);
}
} catch (const std::exception&) {
// If an exception occurs during device property check, we assume hipSparseLt is not supported
// This could happen due to driver issues, device access problems, or other runtime errors
g_hipSparseLtSupported = false;
TORCH_WARN("Exception occurred while checking hipSparseLt support. Assuming not supported.");
}
}
static bool isHipSparseLtSupported() {
// Initialize support check only once
c10::call_once(g_hipSparseLtSupportInitFlag, initHipSparseLtSupport);
// Return cached result (platform-wide)
if (!g_hipSparseLtSupported) {
TORCH_CHECK(
false,
"hipSparseLt not supported on this device, supported architectures: "
"gfx950, gfx942. "
"required ROCM version: 6.4.0 or later.");
}
return g_hipSparseLtSupported;
}
#endif
at::Tensor _cslt_compress(const Tensor& sparse_input) { at::Tensor _cslt_compress(const Tensor& sparse_input) {
if (!handle_initialized) { if (!handle_initialized) {
TORCH_CUDASPARSE_CHECK(cusparseLtInit(&handle)); TORCH_CUDASPARSE_CHECK(cusparseLtInit(&handle));
@ -25,6 +66,10 @@ at::Tensor _cslt_compress(const Tensor& sparse_input) {
cudaDataType type; cudaDataType type;
auto compression_factor = 9; auto compression_factor = 9;
#ifdef USE_ROCM
TORCH_CHECK(isHipSparseLtSupported());
#endif
switch (sparse_input.scalar_type()) { switch (sparse_input.scalar_type()) {
case at::ScalarType::Char: case at::ScalarType::Char:
type = CUDA_R_8I; type = CUDA_R_8I;
@ -36,17 +81,19 @@ at::Tensor _cslt_compress(const Tensor& sparse_input) {
case at::ScalarType::BFloat16: case at::ScalarType::BFloat16:
type = CUDA_R_16BF; type = CUDA_R_16BF;
break; break;
#ifndef USE_ROCM
case at::ScalarType::Float: case at::ScalarType::Float:
type = CUDA_R_32F; type = CUDA_R_32F;
break; break;
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 #endif
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 && !defined(USE_ROCM)
case at::ScalarType::Float8_e4m3fn: case at::ScalarType::Float8_e4m3fn:
type = CUDA_R_8F_E4M3; type = CUDA_R_8F_E4M3;
compression_factor = 10; compression_factor = 10;
break; break;
#endif #endif
default: default:
TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt compressed matrix"); TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt/hipSparseLt compressed matrix");
break; break;
} }
@ -120,6 +167,10 @@ std::tuple<at::Tensor, int64_t, int64_t, int64_t, int64_t> _cslt_sparse_mm_impl(
cusparseComputeType compute_type; cusparseComputeType compute_type;
auto compression_factor = 9; auto compression_factor = 9;
#ifdef USE_ROCM
TORCH_CHECK(isHipSparseLtSupported());
#endif
switch (compressed_A.scalar_type()) { switch (compressed_A.scalar_type()) {
case at::ScalarType::Char: case at::ScalarType::Char:
input_type = CUDA_R_8I; input_type = CUDA_R_8I;
@ -131,7 +182,7 @@ std::tuple<at::Tensor, int64_t, int64_t, int64_t, int64_t> _cslt_sparse_mm_impl(
// cuSPARSELt v0.5.2 onwards changes CUSPARSE_COMPUTE_TF32, CUSPARSE_COMPUT_16F // cuSPARSELt v0.5.2 onwards changes CUSPARSE_COMPUTE_TF32, CUSPARSE_COMPUT_16F
// to CUSPARSE_COMPUTE_32F // to CUSPARSE_COMPUTE_32F
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 502 #if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 502 || defined(USE_ROCM)
case at::ScalarType::Half: case at::ScalarType::Half:
input_type = CUDA_R_16F; input_type = CUDA_R_16F;
output_type = CUDA_R_16F; output_type = CUDA_R_16F;
@ -144,14 +195,16 @@ std::tuple<at::Tensor, int64_t, int64_t, int64_t, int64_t> _cslt_sparse_mm_impl(
C_type = CUDA_R_16BF; C_type = CUDA_R_16BF;
compute_type = CUSPARSE_COMPUTE_32F; compute_type = CUSPARSE_COMPUTE_32F;
break; break;
#ifndef USE_ROCM
case at::ScalarType::Float: case at::ScalarType::Float:
input_type = CUDA_R_32F; input_type = CUDA_R_32F;
output_type = CUDA_R_32F; output_type = CUDA_R_32F;
C_type = CUDA_R_32F; C_type = CUDA_R_32F;
compute_type = CUSPARSE_COMPUTE_32F; compute_type = CUSPARSE_COMPUTE_32F;
break; break;
#endif
// if cuSPARSELt >= 6.2.3, we can add Float8 support // if cuSPARSELt >= 6.2.3, we can add Float8 support
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 #if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 && !defined(USE_ROCM)
case at::ScalarType::Float8_e4m3fn: case at::ScalarType::Float8_e4m3fn:
input_type = CUDA_R_8F_E4M3; input_type = CUDA_R_8F_E4M3;
output_type = CUDA_R_8F_E4M3; output_type = CUDA_R_8F_E4M3;
@ -214,7 +267,7 @@ std::tuple<at::Tensor, int64_t, int64_t, int64_t, int64_t> _cslt_sparse_mm_impl(
} }
} }
// cslt 0.6.2+: fp8 fp8 -> {fp8, fp16, bf16, fp32} support // cslt 0.6.2+: fp8 fp8 -> {fp8, fp16, bf16, fp32} support
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 #if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 && !defined(USE_ROCM)
else if (input_type == CUDA_R_8F_E4M3) { else if (input_type == CUDA_R_8F_E4M3) {
switch (out_dtype) { switch (out_dtype) {
case at::ScalarType::Float8_e4m3fn: case at::ScalarType::Float8_e4m3fn:

View File

@ -1063,7 +1063,7 @@ if(USE_ROCM)
# Math libraries # Math libraries
list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS
roc::hipblas roc::rocblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsolver roc::hipblaslt) roc::hipblas roc::rocblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsparselt roc::hipsolver roc::hipblaslt)
# ---[ Kernel asserts # ---[ Kernel asserts
# Kernel asserts is disabled for ROCm by default. # Kernel asserts is disabled for ROCm by default.

View File

@ -151,6 +151,7 @@ if(HIP_FOUND)
find_package_and_print_version(miopen REQUIRED) find_package_and_print_version(miopen REQUIRED)
find_package_and_print_version(hipfft REQUIRED) find_package_and_print_version(hipfft REQUIRED)
find_package_and_print_version(hipsparse REQUIRED) find_package_and_print_version(hipsparse REQUIRED)
find_package_and_print_version(hipsparselt REQUIRED)
find_package_and_print_version(rocprim REQUIRED) find_package_and_print_version(rocprim REQUIRED)
find_package_and_print_version(hipcub REQUIRED) find_package_and_print_version(hipcub REQUIRED)
find_package_and_print_version(rocthrust REQUIRED) find_package_and_print_version(rocthrust REQUIRED)

View File

@ -47,17 +47,18 @@ SEMI_STRUCTURED_SUPPORTED_BACKENDS = dict()
_IS_SM8X = False _IS_SM8X = False
_IS_SM9X = False _IS_SM9X = False
_IS_HIPSPARSELT_AVAILABLE = False
if torch.cuda.is_available(): if torch.cuda.is_available():
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8 _IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
_IS_SM9X = torch.cuda.get_device_capability(0)[0] == 9 _IS_SM9X = torch.cuda.get_device_capability(0)[0] == 9
_IS_HIPSPARSELT_AVAILABLE = torch.version.hip is not None and tuple(int(v) for v in torch.version.hip.split('.')[:2]) > (6, 4)
# CUTLASS kernels only work for Ampere # CUTLASS kernels only work for Ampere
if _IS_SM8X: if _IS_SM8X:
SEMI_STRUCTURED_SUPPORTED_BACKENDS["cutlass"] = SparseSemiStructuredTensorCUTLASS SEMI_STRUCTURED_SUPPORTED_BACKENDS["cutlass"] = SparseSemiStructuredTensorCUTLASS
# add cuSPASRELt tests if available # add cuSPASRELt tests if available
if torch.backends.cusparselt.is_available() and (_IS_SM8X or _IS_SM9X): if torch.backends.cusparselt.is_available() and (_IS_SM8X or _IS_SM9X or _IS_HIPSPARSELT_AVAILABLE):
SEMI_STRUCTURED_SUPPORTED_BACKENDS["cusparselt"] = SparseSemiStructuredTensorCUSPARSELT SEMI_STRUCTURED_SUPPORTED_BACKENDS["cusparselt"] = SparseSemiStructuredTensorCUSPARSELT
inference_dtypes = dtypes(torch.float16, torch.bfloat16, torch.int8) inference_dtypes = dtypes(torch.float16, torch.bfloat16, torch.int8)
@ -223,6 +224,7 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows") @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine") @unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
def test_mlp_contiguous_relu_compile_cusparselt(self): def test_mlp_contiguous_relu_compile_cusparselt(self):
""" """
test for cuSPASRELt meta registrations (_cslt_sparse_mm) + torch.compile test for cuSPASRELt meta registrations (_cslt_sparse_mm) + torch.compile
@ -233,6 +235,7 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
@unittest.skipIf("cutlass" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cutlass not supported on this machine") @unittest.skipIf("cutlass" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cutlass not supported on this machine")
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows") @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
def test_mlp_contiguous_relu_compile_cutlass(self): def test_mlp_contiguous_relu_compile_cutlass(self):
""" """
test for CUTLASS meta registrations (_sparse_semi_structured_addmm) + torch.compile test for CUTLASS meta registrations (_sparse_semi_structured_addmm) + torch.compile
@ -243,6 +246,7 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows") @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine") @unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
def test_sp24_compile(self) -> None: def test_sp24_compile(self) -> None:
x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True) x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True)
@ -571,6 +575,7 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes @training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
def test_prune_dense_static_sort(self, dtype) -> None: def test_prune_dense_static_sort(self, dtype) -> None:
# Ideally we would like to clone and compare, but that won't work because the sorting order will be different # Ideally we would like to clone and compare, but that won't work because the sorting order will be different
# instead we pass the pruned matrix to the CUDA implementation and preserve the sparsity pattern. # instead we pass the pruned matrix to the CUDA implementation and preserve the sparsity pattern.
@ -615,6 +620,7 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes @training_dtypes
@parametrize_backends @parametrize_backends
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
def test_pruning_algo_largest_abs_values_greedy(self, dtype, backend) -> None: def test_pruning_algo_largest_abs_values_greedy(self, dtype, backend) -> None:
inp = torch.tensor( inp = torch.tensor(
[[4, 3, 2, 1], [-1, -3, 0.6, 0.5], [1, 2, 3, 4], [10, 2, -1, 5]], [[4, 3, 2, 1], [-1, -3, 0.6, 0.5], [1, 2, 3, 4], [10, 2, -1, 5]],
@ -651,6 +657,7 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes @training_dtypes
@parametrize_backends @parametrize_backends
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
def test_pack_both_ways_meta_correctness(self, dtype, backend) -> None: def test_pack_both_ways_meta_correctness(self, dtype, backend) -> None:
M, N = 128, 256 M, N = 128, 256
# Construct x to make sure we always have exactly 8 elements per 4x4 tile # Construct x to make sure we always have exactly 8 elements per 4x4 tile
@ -684,6 +691,7 @@ class TestSparseSemiStructuredTraining(TestCase):
torch.testing.assert_close(ref_gemm, pack_gemm, **atol_rtol_kw[dtype]) torch.testing.assert_close(ref_gemm, pack_gemm, **atol_rtol_kw[dtype])
@training_dtypes @training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
def test_pack_both_ways_id(self, dtype) -> None: def test_pack_both_ways_id(self, dtype) -> None:
N = 512 N = 512
torch.manual_seed(0) torch.manual_seed(0)
@ -718,6 +726,7 @@ class TestSparseSemiStructuredTraining(TestCase):
), f"packed_t is wrong at pos: ({max_diff // N}, {max_diff % N})" ), f"packed_t is wrong at pos: ({max_diff // N}, {max_diff % N})"
@training_dtypes @training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
def test_pack_both_ways_edge_case1(self, dtype) -> None: def test_pack_both_ways_edge_case1(self, dtype) -> None:
# In this case, the heuristic will keep 7 values out of 16 # In this case, the heuristic will keep 7 values out of 16
# instead of 8. let's see how the kernel handles this # instead of 8. let's see how the kernel handles this
@ -742,6 +751,7 @@ class TestSparseSemiStructuredTraining(TestCase):
assert packed_t[0, 1].item() == 0 assert packed_t[0, 1].item() == 0
@training_dtypes @training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
def test_sp24_apply(self, dtype) -> None: def test_sp24_apply(self, dtype) -> None:
M, N = 256, 1024 M, N = 256, 1024
x = torch.randn([M, N], dtype=dtype, device="cuda") x = torch.randn([M, N], dtype=dtype, device="cuda")
@ -757,6 +767,7 @@ class TestSparseSemiStructuredTraining(TestCase):
torch.testing.assert_close(packed_t, packed_t2) torch.testing.assert_close(packed_t, packed_t2)
@training_dtypes @training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
def test_sp24_apply_dense(self, dtype) -> None: def test_sp24_apply_dense(self, dtype) -> None:
M, N = 256, 1024 M, N = 256, 1024
x = torch.randn([M, N], dtype=dtype, device="cuda") x = torch.randn([M, N], dtype=dtype, device="cuda")
@ -794,6 +805,7 @@ class TestSparseSemiStructuredTraining(TestCase):
@training_dtypes @training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
def test_sp24_matmuls(self, dtype) -> None: def test_sp24_matmuls(self, dtype) -> None:
M, N, K = 64, 256, 1024 M, N, K = 64, 256, 1024
a = torch.randn([M, K], device="cuda", dtype=dtype) a = torch.randn([M, K], device="cuda", dtype=dtype)
@ -828,6 +840,7 @@ class TestSparseSemiStructuredTraining(TestCase):
a_s.t() @ a, (a * a_m).t() @ a, rtol=1e-1, atol=1e-1 a_s.t() @ a, (a * a_m).t() @ a, rtol=1e-1, atol=1e-1
) )
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
def test_sp24_matmuls_mat_vec(self) -> None: def test_sp24_matmuls_mat_vec(self) -> None:
a = torch.randn([64, 128], device="cuda", dtype=torch.float16) a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
b = torch.randn([128], device="cuda", dtype=torch.float16) b = torch.randn([128], device="cuda", dtype=torch.float16)
@ -837,7 +850,7 @@ class TestSparseSemiStructuredTraining(TestCase):
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype]) torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
def test_sp24_matmuls_bmm(self) -> None: def test_sp24_matmuls_bmm(self) -> None:
a = torch.randn([64, 128], device="cuda", dtype=torch.float16) a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
b = torch.randn([5, 6, 128], device="cuda", dtype=torch.float16) b = torch.randn([5, 6, 128], device="cuda", dtype=torch.float16)
@ -988,6 +1001,7 @@ class TestSparseSemiStructuredCUTLASS(TestCase):
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@inference_dtypes @inference_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
def test_conversions(self, device, dtype): def test_conversions(self, device, dtype):
def run_test(r, c, device, dtype): def run_test(r, c, device, dtype):
@ -1016,6 +1030,7 @@ class TestSparseSemiStructuredCUTLASS(TestCase):
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@inference_dtypes @inference_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
def test_conversions_all_patterns(self, device, dtype): def test_conversions_all_patterns(self, device, dtype):
r, c = 32, 128 r, c = 32, 128
@ -1135,6 +1150,7 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
@unittest.skip("cuSPARSELt v0.6.x does not support bfloat/float16 alpha scaling") @unittest.skip("cuSPARSELt v0.6.x does not support bfloat/float16 alpha scaling")
@training_dtypes @training_dtypes
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
def test_cslt_sparse_mm_alpha(self, dtype, device): def test_cslt_sparse_mm_alpha(self, dtype, device):
A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(dtype).cuda() A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(dtype).cuda()
B = torch.ones((256, 128), device=device).to(dtype) B = torch.ones((256, 128), device=device).to(dtype)
@ -1151,6 +1167,7 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3) torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
@parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.int32]) @parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.int32])
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
def test_cslt_sparse_mm_alpha_compile_autotune(self, device, out_dtype): def test_cslt_sparse_mm_alpha_compile_autotune(self, device, out_dtype):
A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(torch.int8).to(device) A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(torch.int8).to(device)
B = torch.ones((128, 256), device=device, dtype=torch.int8).t() B = torch.ones((128, 256), device=device, dtype=torch.int8).t()
@ -1172,6 +1189,7 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
torch.testing.assert_close(sparse_result.cpu(), get_dense_result(), rtol=1e-3, atol=1e-3) torch.testing.assert_close(sparse_result.cpu(), get_dense_result(), rtol=1e-3, atol=1e-3)
@parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.int32]) @parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.int32])
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
def test_cslt_sparse_mm_alpha_mixed_dtype(self, out_dtype, device): def test_cslt_sparse_mm_alpha_mixed_dtype(self, out_dtype, device):
A = torch.Tensor([0, 0, 10, 10]).tile((128, 64)).to(torch.int8).cuda() A = torch.Tensor([0, 0, 10, 10]).tile((128, 64)).to(torch.int8).cuda()
B = torch.ones((128, 256), device=device).to(torch.int8).t() B = torch.ones((128, 256), device=device).to(torch.int8).t()

View File

@ -607,6 +607,7 @@ CUDA_INCLUDE_MAP = collections.OrderedDict(
("curand_precalc.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)), ("curand_precalc.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
("curand_uniform.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)), ("curand_uniform.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
("cusparse.h", ("hipsparse/hipsparse.h", CONV_INCLUDE, API_RAND)), ("cusparse.h", ("hipsparse/hipsparse.h", CONV_INCLUDE, API_RAND)),
("cusparseLt.h", ("hipsparselt/hipsparselt.h", CONV_INCLUDE, API_RAND)),
("cufft.h", ("hipfft/hipfft.h", CONV_INCLUDE, API_BLAS)), ("cufft.h", ("hipfft/hipfft.h", CONV_INCLUDE, API_BLAS)),
("cufftXt.h", ("hipfft/hipfftXt.h", CONV_INCLUDE, API_BLAS)), ("cufftXt.h", ("hipfft/hipfftXt.h", CONV_INCLUDE, API_BLAS)),
# PyTorch also has a source file named "nccl.h", so we need to "<"">" to differentiate # PyTorch also has a source file named "nccl.h", so we need to "<"">" to differentiate
@ -8256,6 +8257,43 @@ CUDA_SPECIAL_MAP = collections.OrderedDict(
"CUSPARSE_MATRIX_TYPE_GENERAL", "CUSPARSE_MATRIX_TYPE_GENERAL",
("HIPSPARSE_MATRIX_TYPE_GENERAL", CONV_NUMERIC_LITERAL, API_SPECIAL), ("HIPSPARSE_MATRIX_TYPE_GENERAL", CONV_NUMERIC_LITERAL, API_SPECIAL),
), ),
# SparseLt
("cuSPARSELt", ("hipSPARSELt", CONV_TYPE, API_SPECIAL)),
("AT_CUSPARSELT_ENABLED", ("AT_HIPSPARSELT_ENABLED", CONV_TYPE, API_SPECIAL)),
("CUSPARSE_ORDER_ROW", ("HIPSPARSE_ORDER_ROW", CONV_NUMERIC_LITERAL, API_SPECIAL)),
("CUSPARSE_ORDER_COL", ("HIPSPARSE_ORDER_COL", CONV_NUMERIC_LITERAL, API_SPECIAL)),
("CUSPARSELT_SPARSITY_50_PERCENT", ("HIPSPARSELT_SPARSITY_50_PERCENT", CONV_NUMERIC_LITERAL, API_SPECIAL)),
("cusparseComputeType", ("hipsparseLtComputetype_t", CONV_TYPE, API_SPECIAL)),
("CUSPARSE_COMPUTE_32F", ("HIPSPARSELT_COMPUTE_32F", CONV_NUMERIC_LITERAL, API_SPECIAL)),
("CUSPARSE_COMPUTE_16F", ("HIPSPARSELT_COMPUTE_16F", CONV_NUMERIC_LITERAL, API_SPECIAL)),
("CUSPARSE_COMPUTE_32I", ("HIPSPARSELT_COMPUTE_32I", CONV_NUMERIC_LITERAL, API_SPECIAL)),
("CUSPARSE_COMPUTE_TF32", ("HIPSPARSELT_COMPUTE_TF32", CONV_NUMERIC_LITERAL, API_SPECIAL)),
("CUSPARSELT_MATMUL_BIAS_POINTER", ("HIPSPARSELT_MATMUL_BIAS_POINTER", CONV_NUMERIC_LITERAL, API_SPECIAL)),
("CUSPARSELT_MATMUL_ALG_DEFAULT", ("HIPSPARSELT_MATMUL_ALG_DEFAULT", CONV_NUMERIC_LITERAL, API_SPECIAL)),
("CUSPARSELT_MATMUL_ALG_CONFIG_ID", ("HIPSPARSELT_MATMUL_ALG_CONFIG_ID", CONV_NUMERIC_LITERAL, API_SPECIAL)),
("CUSPARSELT_MATMUL_ALPHA_VECTOR_SCALING", ("HIPSPARSELT_MATMUL_ALPHA_VECTOR_SCALING", CONV_NUMERIC_LITERAL, API_SPECIAL)),
("cusparseLtHandle_t", ("hipsparseLtHandle_t", CONV_TYPE, API_SPECIAL)),
("cusparseLtMatDescriptor_t", ("hipsparseLtMatDescriptor_t", CONV_TYPE, API_SPECIAL)),
("cusparseLtInit", ("hipsparseLtInit", CONV_MATH_FUNC, API_SPECIAL)),
("cusparseLtStructuredDescriptorInit", ("hipsparseLtStructuredDescriptorInit", CONV_MATH_FUNC, API_SPECIAL)),
("cusparseLtSpMMACompressedSize2", ("hipsparseLtSpMMACompressedSize2", CONV_MATH_FUNC, API_SPECIAL)),
("cusparseLtSpMMACompress2", ("hipsparseLtSpMMACompress2", CONV_MATH_FUNC, API_SPECIAL)),
("cusparseLtMatmulDescriptor_t", ("hipsparseLtMatmulDescriptor_t", CONV_TYPE, API_SPECIAL)),
("cusparseLtMatmulPlan_t", ("hipsparseLtMatmulPlan_t", CONV_TYPE, API_SPECIAL)),
("cusparseLtMatmulAlgSelection_t", ("hipsparseLtMatmulAlgSelection_t", CONV_TYPE, API_SPECIAL)),
("cusparseLtStructuredDescriptorInit", ("hipsparseLtStructuredDescriptorInit", CONV_MATH_FUNC, API_SPECIAL)),
("cusparseLtDenseDescriptorInit", ("hipsparseLtDenseDescriptorInit", CONV_MATH_FUNC, API_SPECIAL)),
("cusparseLtMatmulDescriptorInit", ("hipsparseLtMatmulDescriptorInit", CONV_MATH_FUNC, API_SPECIAL)),
("cusparseLtMatmulDescSetAttribute", ("hipsparseLtMatmulDescSetAttribute", CONV_MATH_FUNC, API_SPECIAL)),
("cusparseLtMatmulAlgSelectionInit", ("hipsparseLtMatmulAlgSelectionInit", CONV_MATH_FUNC, API_SPECIAL)),
("cusparseLtMatmulAlgSetAttribute", ("hipsparseLtMatmulAlgSetAttribute", CONV_MATH_FUNC, API_SPECIAL)),
("cusparseLtMatmulPlanInit", ("hipsparseLtMatmulPlanInit", CONV_MATH_FUNC, API_SPECIAL)),
("cusparseLtMatmulGetWorkspace", ("hipsparseLtMatmulGetWorkspace", CONV_MATH_FUNC, API_SPECIAL)),
("cusparseLtMatmulSearch", ("hipsparseLtMatmulSearch", CONV_MATH_FUNC, API_SPECIAL)),
("cusparseLtMatmulAlgGetAttribute", ("hipsparseLtMatmulAlgGetAttribute", CONV_MATH_FUNC, API_SPECIAL)),
("cusparseLtMatmul", ("hipsparseLtMatmul", CONV_MATH_FUNC, API_SPECIAL)),
("cusparseLtMatDescriptorDestroy", ("hipsparseLtMatDescriptorDestroy", CONV_MATH_FUNC, API_SPECIAL)),
("cusparseLtMatmulPlanDestroy", ("hipsparseLtMatmulPlanDestroy", CONV_MATH_FUNC, API_SPECIAL)),
# SOLVER # SOLVER
("cublasOperation_t", ("hipsolverOperation_t", CONV_TYPE, API_SPECIAL)), ("cublasOperation_t", ("hipsolverOperation_t", CONV_TYPE, API_SPECIAL)),
("CUBLAS_OP_N", ("HIPSOLVER_OP_N", CONV_NUMERIC_LITERAL, API_SPECIAL)), ("CUBLAS_OP_N", ("HIPSOLVER_OP_N", CONV_NUMERIC_LITERAL, API_SPECIAL)),