From 43390d8b1339f6c438dff8798e6f9de0c5561724 Mon Sep 17 00:00:00 2001 From: "Peter Y. Yeh" Date: Sat, 31 May 2025 02:03:36 +0000 Subject: [PATCH] ROCm Sparsity through HipSparseLT (#150578) TLDR: - This pull request introduces support for hipSPARSELt in ROCm, current usage would be semi-structure sparsity. - Require **ROCm 6.4** && **gfx942/gfx950**. - The average performance uplift (compare to dense operation) is ~ 20% in ROCm 6.4 but expect further performance lift along the way. ### Dense vs. Sparse Performance Comparison #### **NT (Row-major)** **Average Uplift**: `1.20` | M | N | K | hipsparselt-bench (us) | hipblaslt-bench get all (us) | Uplift | |-------|--------|--------|-------------------------|-------------------------------|--------| | 14336 | 8 | 4096 | 20.05 | 25.3 | 1.26 | | 4096 | 8 | 14336 | 21.07 | 25.28 | 1.20 | | 3072 | 3072 | 10240 | 299.05 | 351.82 | 1.18 | | 3072 | 1536 | 768 | 18.56 | 20.05 | 1.08 | | 3072 | 17664 | 768 | 163.13 | 173.91 | 1.07 | | 3072 | 196608 | 768 | 1717.30 | 1949.63 | 1.14 | | 3072 | 24576 | 768 | 206.84 | 242.98 | 1.17 | | 3072 | 6144 | 768 | 53.90 | 56.88 | 1.06 | | 3072 | 98304 | 768 | 833.77 | 962.28 | 1.15 | | 768 | 1536 | 768 | 8.53 | 19.65 | 2.30 | | 768 | 17664 | 768 | 46.02 | 46.84 | 1.02 | | 768 | 196608 | 768 | 463.15 | 540.46 | 1.17 | | 768 | 24576 | 768 | 54.32 | 59.55 | 1.10 | | 768 | 6144 | 768 | 19.47 | 20.15 | 1.03 | | 768 | 98304 | 768 | 231.88 | 258.73 | 1.12 | --- #### **NN (Row-major)** **Average Uplift**: `1.13` | M | N | K | hipsparselt-bench (us) | hipblaslt-bench get all (us) | Uplift | |-----|--------|-------|-------------------------|-------------------------------|--------| | 768 | 1536 | 3072 | 27.50 | 28.78 | 1.05 | | 768 | 17664 | 3072 | 125.06 | 158.94 | 1.27 | | 768 | 196608 | 3072 | 1568.38 | 1767.12 | 1.13 | | 768 | 24576 | 3072 | 171.05 | 203.49 | 1.19 | | 768 | 6144 | 3072 | 58.72 | 60.39 | 1.03 | | 768 | 98304 | 3072 | 787.15 | 887.60 | 1.13 | ------------------------- This pull request introduces support for hipSPARSELt in ROCm, alongside various updates and improvements to the codebase and test suite. The changes primarily involve adding configuration flags, updating conditional checks, and ensuring compatibility with hipSPARSELt. ### ROCm and hipSPARSELt Support: * [`BUILD.bazel`](diffhunk://#diff-7fc57714ef13c3325ce2a1130202edced92fcccc0c6db34a72f7b57f60d552a3R292): Added `@AT_HIPSPARSELT_ENABLED@` substitution to enable hipSPARSELt support. * [`aten/CMakeLists.txt`](diffhunk://#diff-0604597797bb21d7c39150f9429d6b2ace10b79ab308514ad03f76153ae8249bR104-R110): Introduced a conditional flag to enable hipSPARSELt support based on ROCm version. * [`aten/src/ATen/CMakeLists.txt`](diffhunk://#diff-ce80f3115ab2f6be5142f0678a1fc92c6b2d7727766ce44f48726c99e720f777R37): Added `AT_HIPSPARSELT_ENABLED` configuration. * [`aten/src/ATen/cuda/CUDAConfig.h.in`](diffhunk://#diff-8bb82da825ca87c28233abacffa1b0566c73a54990b7a77f3f5108d3718fea15R11): Defined `AT_HIPSPARSELT_ENABLED` macro. * `caffe2/CMakeLists.txt`, `cmake/Dependencies.cmake`, `cmake/public/LoadHIP.cmake`: Included hipSPARSELt in the ROCm dependencies. [[1]](diffhunk://#diff-c5ee05f1e918772792ff6f2a3f579fc2f182e57b1709fd786ef6dc711fd68b27R1380) [[2]](diffhunk://#diff-12e8125164bbfc7556b1781a8ed516e333cc0bf058acb7197f7415be44606c72L1084-R1084) [[3]](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5R153) ### Codebase Updates: * [`aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp`](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R1-R6): Added hipSPARSELt support checks and initialization functions. Updated various methods to conditionally handle hipSPARSELt. [[1]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R1-R6) [[2]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R22-R67) [[3]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R78-R85) [[4]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R97-R109) [[5]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R183-R188) [[6]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3L134-R200) [[7]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R213-R222) [[8]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3L217-R285) ### Test Suite Updates: * [`test/test_sparse_semi_structured.py`](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR50-R65): Added checks for hipSPARSELt availability and updated test conditions to skip tests not supported on ROCm. [[1]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR50-R65) [[2]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR228) [[3]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR239) [[4]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR250) [[5]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR579) [[6]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR624) [[7]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR661) [[8]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR695) [[9]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR730) [[10]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR755) [[11]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR771) [[12]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR809) [[13]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR844) [[14]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cL840-R854) [[15]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR1005) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150578 Approved by: https://github.com/jeffdaily --- BUILD.bazel | 1 + aten/CMakeLists.txt | 7 ++ aten/src/ATen/CMakeLists.txt | 1 + aten/src/ATen/cuda/CUDAConfig.h.in | 1 + .../ATen/native/sparse/cuda/cuSPARSELtOps.cpp | 65 +++++++++++++++++-- cmake/Dependencies.cmake | 2 +- cmake/public/LoadHIP.cmake | 1 + test/test_sparse_semi_structured.py | 24 ++++++- torch/utils/hipify/cuda_to_hip_mappings.py | 38 +++++++++++ 9 files changed, 130 insertions(+), 10 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 2d3e1d7cdf72..1a12c5609c37 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -290,6 +290,7 @@ header_template_rule( substitutions = { "@AT_CUDNN_ENABLED@": "1", "@AT_CUSPARSELT_ENABLED@": "0", + "@AT_HIPSPARSELT_ENABLED@": "0", "@AT_ROCM_ENABLED@": "0", "@AT_MAGMA_ENABLED@": "0", "@NVCC_FLAGS_EXTRA@": "", diff --git a/aten/CMakeLists.txt b/aten/CMakeLists.txt index bda6aea32706..b653ab3ec210 100644 --- a/aten/CMakeLists.txt +++ b/aten/CMakeLists.txt @@ -101,6 +101,13 @@ else() set(AT_CUSPARSELT_ENABLED 1) endif() +# Add hipSPARSELt support flag +if(USE_ROCM AND ROCM_VERSION VERSION_GREATER_EQUAL "6.4.0") + set(AT_HIPSPARSELT_ENABLED 1) +else() + set(AT_HIPSPARSELT_ENABLED 0) +endif() + list(APPEND ATen_CPU_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/src) add_subdirectory(src/ATen) diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 6eea0b214759..becfcee442d1 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -34,6 +34,7 @@ set_bool(AT_MAGMA_ENABLED USE_MAGMA) set_bool(CAFFE2_STATIC_LINK_CUDA_INT CAFFE2_STATIC_LINK_CUDA) set_bool(AT_CUDNN_ENABLED CAFFE2_USE_CUDNN) set_bool(AT_CUSPARSELT_ENABLED CAFFE2_USE_CUSPARSELT) +set_bool(AT_HIPSPARSELT_ENABLED CAFFE2_USE_HIPSPARSELT) configure_file(Config.h.in "${CMAKE_CURRENT_SOURCE_DIR}/Config.h") # TODO: Do not generate CUDAConfig.h for ROCm BUILDS diff --git a/aten/src/ATen/cuda/CUDAConfig.h.in b/aten/src/ATen/cuda/CUDAConfig.h.in index 7c7f2cc7470a..6263e8455eaf 100644 --- a/aten/src/ATen/cuda/CUDAConfig.h.in +++ b/aten/src/ATen/cuda/CUDAConfig.h.in @@ -8,6 +8,7 @@ // only be included from C++ files. #define AT_CUDNN_ENABLED() @AT_CUDNN_ENABLED@ #define AT_CUSPARSELT_ENABLED() @AT_CUSPARSELT_ENABLED@ +#define AT_HIPSPARSELT_ENABLED() @AT_HIPSPARSELT_ENABLED@ #define AT_ROCM_ENABLED() @AT_ROCM_ENABLED@ #define AT_MAGMA_ENABLED() @AT_MAGMA_ENABLED@ diff --git a/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp b/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp index 5f6633593ed7..de73ce612f10 100644 --- a/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp +++ b/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp @@ -1,5 +1,7 @@ #include - +#include +#include +#include #if AT_CUSPARSELT_ENABLED() namespace at::native { @@ -15,6 +17,45 @@ namespace at::native { thread_local cusparseLtHandle_t handle; thread_local bool handle_initialized = false; +#ifdef USE_ROCM +// Single global flag for platform-wide hipSparseLt support +c10::once_flag g_hipSparseLtSupportInitFlag; +static bool g_hipSparseLtSupported = false; + +// Initialize the hipSparseLt support status once for the platform +static void initHipSparseLtSupport() { + // Default to not supported + g_hipSparseLtSupported = false; + + // Check only the first available device + try { + if (at::cuda::device_count() > 0) { + g_hipSparseLtSupported = at::detail::getCUDAHooks().isGPUArch({"gfx950", "gfx942"}, 0); + } + } catch (const std::exception&) { + // If an exception occurs during device property check, we assume hipSparseLt is not supported + // This could happen due to driver issues, device access problems, or other runtime errors + g_hipSparseLtSupported = false; + TORCH_WARN("Exception occurred while checking hipSparseLt support. Assuming not supported."); + } +} + +static bool isHipSparseLtSupported() { + // Initialize support check only once + c10::call_once(g_hipSparseLtSupportInitFlag, initHipSparseLtSupport); + + // Return cached result (platform-wide) + if (!g_hipSparseLtSupported) { + TORCH_CHECK( + false, + "hipSparseLt not supported on this device, supported architectures: " + "gfx950, gfx942. " + "required ROCM version: 6.4.0 or later."); + } + return g_hipSparseLtSupported; +} +#endif + at::Tensor _cslt_compress(const Tensor& sparse_input) { if (!handle_initialized) { TORCH_CUDASPARSE_CHECK(cusparseLtInit(&handle)); @@ -25,6 +66,10 @@ at::Tensor _cslt_compress(const Tensor& sparse_input) { cudaDataType type; auto compression_factor = 9; + #ifdef USE_ROCM + TORCH_CHECK(isHipSparseLtSupported()); + #endif + switch (sparse_input.scalar_type()) { case at::ScalarType::Char: type = CUDA_R_8I; @@ -36,17 +81,19 @@ at::Tensor _cslt_compress(const Tensor& sparse_input) { case at::ScalarType::BFloat16: type = CUDA_R_16BF; break; +#ifndef USE_ROCM case at::ScalarType::Float: type = CUDA_R_32F; break; -#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 +#endif +#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 && !defined(USE_ROCM) case at::ScalarType::Float8_e4m3fn: type = CUDA_R_8F_E4M3; compression_factor = 10; break; #endif default: - TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt compressed matrix"); + TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt/hipSparseLt compressed matrix"); break; } @@ -120,6 +167,10 @@ std::tuple _cslt_sparse_mm_impl( cusparseComputeType compute_type; auto compression_factor = 9; + #ifdef USE_ROCM + TORCH_CHECK(isHipSparseLtSupported()); + #endif + switch (compressed_A.scalar_type()) { case at::ScalarType::Char: input_type = CUDA_R_8I; @@ -131,7 +182,7 @@ std::tuple _cslt_sparse_mm_impl( // cuSPARSELt v0.5.2 onwards changes CUSPARSE_COMPUTE_TF32, CUSPARSE_COMPUT_16F // to CUSPARSE_COMPUTE_32F -#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 502 +#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 502 || defined(USE_ROCM) case at::ScalarType::Half: input_type = CUDA_R_16F; output_type = CUDA_R_16F; @@ -144,14 +195,16 @@ std::tuple _cslt_sparse_mm_impl( C_type = CUDA_R_16BF; compute_type = CUSPARSE_COMPUTE_32F; break; +#ifndef USE_ROCM case at::ScalarType::Float: input_type = CUDA_R_32F; output_type = CUDA_R_32F; C_type = CUDA_R_32F; compute_type = CUSPARSE_COMPUTE_32F; break; +#endif // if cuSPARSELt >= 6.2.3, we can add Float8 support -#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 +#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 && !defined(USE_ROCM) case at::ScalarType::Float8_e4m3fn: input_type = CUDA_R_8F_E4M3; output_type = CUDA_R_8F_E4M3; @@ -214,7 +267,7 @@ std::tuple _cslt_sparse_mm_impl( } } // cslt 0.6.2+: fp8 fp8 -> {fp8, fp16, bf16, fp32} support -#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 +#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 && !defined(USE_ROCM) else if (input_type == CUDA_R_8F_E4M3) { switch (out_dtype) { case at::ScalarType::Float8_e4m3fn: diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 15bbfaa82ddd..c8db1fe9ccbd 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1063,7 +1063,7 @@ if(USE_ROCM) # Math libraries list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS - roc::hipblas roc::rocblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsolver roc::hipblaslt) + roc::hipblas roc::rocblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsparselt roc::hipsolver roc::hipblaslt) # ---[ Kernel asserts # Kernel asserts is disabled for ROCm by default. diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index 1080b7bc2525..03389217e928 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -151,6 +151,7 @@ if(HIP_FOUND) find_package_and_print_version(miopen REQUIRED) find_package_and_print_version(hipfft REQUIRED) find_package_and_print_version(hipsparse REQUIRED) + find_package_and_print_version(hipsparselt REQUIRED) find_package_and_print_version(rocprim REQUIRED) find_package_and_print_version(hipcub REQUIRED) find_package_and_print_version(rocthrust REQUIRED) diff --git a/test/test_sparse_semi_structured.py b/test/test_sparse_semi_structured.py index 142eff2b3ae4..5078649bb006 100644 --- a/test/test_sparse_semi_structured.py +++ b/test/test_sparse_semi_structured.py @@ -47,17 +47,18 @@ SEMI_STRUCTURED_SUPPORTED_BACKENDS = dict() _IS_SM8X = False _IS_SM9X = False +_IS_HIPSPARSELT_AVAILABLE = False if torch.cuda.is_available(): _IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8 _IS_SM9X = torch.cuda.get_device_capability(0)[0] == 9 - + _IS_HIPSPARSELT_AVAILABLE = torch.version.hip is not None and tuple(int(v) for v in torch.version.hip.split('.')[:2]) > (6, 4) # CUTLASS kernels only work for Ampere if _IS_SM8X: SEMI_STRUCTURED_SUPPORTED_BACKENDS["cutlass"] = SparseSemiStructuredTensorCUTLASS # add cuSPASRELt tests if available - if torch.backends.cusparselt.is_available() and (_IS_SM8X or _IS_SM9X): + if torch.backends.cusparselt.is_available() and (_IS_SM8X or _IS_SM9X or _IS_HIPSPARSELT_AVAILABLE): SEMI_STRUCTURED_SUPPORTED_BACKENDS["cusparselt"] = SparseSemiStructuredTensorCUSPARSELT inference_dtypes = dtypes(torch.float16, torch.bfloat16, torch.int8) @@ -223,6 +224,7 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase): @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows") @unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine") + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_mlp_contiguous_relu_compile_cusparselt(self): """ test for cuSPASRELt meta registrations (_cslt_sparse_mm) + torch.compile @@ -233,6 +235,7 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase): @unittest.skipIf("cutlass" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cutlass not supported on this machine") @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows") + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_mlp_contiguous_relu_compile_cutlass(self): """ test for CUTLASS meta registrations (_sparse_semi_structured_addmm) + torch.compile @@ -243,6 +246,7 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase): @unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows") @unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine") + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_sp24_compile(self) -> None: x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True) @@ -571,6 +575,7 @@ class TestSparseSemiStructuredTraining(TestCase): @training_dtypes + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_prune_dense_static_sort(self, dtype) -> None: # Ideally we would like to clone and compare, but that won't work because the sorting order will be different # instead we pass the pruned matrix to the CUDA implementation and preserve the sparsity pattern. @@ -615,6 +620,7 @@ class TestSparseSemiStructuredTraining(TestCase): @training_dtypes @parametrize_backends + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_pruning_algo_largest_abs_values_greedy(self, dtype, backend) -> None: inp = torch.tensor( [[4, 3, 2, 1], [-1, -3, 0.6, 0.5], [1, 2, 3, 4], [10, 2, -1, 5]], @@ -651,6 +657,7 @@ class TestSparseSemiStructuredTraining(TestCase): @training_dtypes @parametrize_backends + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_pack_both_ways_meta_correctness(self, dtype, backend) -> None: M, N = 128, 256 # Construct x to make sure we always have exactly 8 elements per 4x4 tile @@ -684,6 +691,7 @@ class TestSparseSemiStructuredTraining(TestCase): torch.testing.assert_close(ref_gemm, pack_gemm, **atol_rtol_kw[dtype]) @training_dtypes + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_pack_both_ways_id(self, dtype) -> None: N = 512 torch.manual_seed(0) @@ -718,6 +726,7 @@ class TestSparseSemiStructuredTraining(TestCase): ), f"packed_t is wrong at pos: ({max_diff // N}, {max_diff % N})" @training_dtypes + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_pack_both_ways_edge_case1(self, dtype) -> None: # In this case, the heuristic will keep 7 values out of 16 # instead of 8. let's see how the kernel handles this @@ -742,6 +751,7 @@ class TestSparseSemiStructuredTraining(TestCase): assert packed_t[0, 1].item() == 0 @training_dtypes + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_sp24_apply(self, dtype) -> None: M, N = 256, 1024 x = torch.randn([M, N], dtype=dtype, device="cuda") @@ -757,6 +767,7 @@ class TestSparseSemiStructuredTraining(TestCase): torch.testing.assert_close(packed_t, packed_t2) @training_dtypes + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_sp24_apply_dense(self, dtype) -> None: M, N = 256, 1024 x = torch.randn([M, N], dtype=dtype, device="cuda") @@ -794,6 +805,7 @@ class TestSparseSemiStructuredTraining(TestCase): @training_dtypes + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_sp24_matmuls(self, dtype) -> None: M, N, K = 64, 256, 1024 a = torch.randn([M, K], device="cuda", dtype=dtype) @@ -828,6 +840,7 @@ class TestSparseSemiStructuredTraining(TestCase): a_s.t() @ a, (a * a_m).t() @ a, rtol=1e-1, atol=1e-1 ) + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_sp24_matmuls_mat_vec(self) -> None: a = torch.randn([64, 128], device="cuda", dtype=torch.float16) b = torch.randn([128], device="cuda", dtype=torch.float16) @@ -837,7 +850,7 @@ class TestSparseSemiStructuredTraining(TestCase): with pytest.raises(NotImplementedError): torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype]) - + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_sp24_matmuls_bmm(self) -> None: a = torch.randn([64, 128], device="cuda", dtype=torch.float16) b = torch.randn([5, 6, 128], device="cuda", dtype=torch.float16) @@ -988,6 +1001,7 @@ class TestSparseSemiStructuredCUTLASS(TestCase): @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @inference_dtypes + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_conversions(self, device, dtype): def run_test(r, c, device, dtype): @@ -1016,6 +1030,7 @@ class TestSparseSemiStructuredCUTLASS(TestCase): @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @inference_dtypes + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_conversions_all_patterns(self, device, dtype): r, c = 32, 128 @@ -1135,6 +1150,7 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase): @unittest.skip("cuSPARSELt v0.6.x does not support bfloat/float16 alpha scaling") @training_dtypes + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_cslt_sparse_mm_alpha(self, dtype, device): A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(dtype).cuda() B = torch.ones((256, 128), device=device).to(dtype) @@ -1151,6 +1167,7 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase): torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3) @parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.int32]) + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_cslt_sparse_mm_alpha_compile_autotune(self, device, out_dtype): A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(torch.int8).to(device) B = torch.ones((128, 256), device=device, dtype=torch.int8).t() @@ -1172,6 +1189,7 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase): torch.testing.assert_close(sparse_result.cpu(), get_dense_result(), rtol=1e-3, atol=1e-3) @parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.int32]) + @unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm") def test_cslt_sparse_mm_alpha_mixed_dtype(self, out_dtype, device): A = torch.Tensor([0, 0, 10, 10]).tile((128, 64)).to(torch.int8).cuda() B = torch.ones((128, 256), device=device).to(torch.int8).t() diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py index ae3863c8ec09..0caa4801756c 100644 --- a/torch/utils/hipify/cuda_to_hip_mappings.py +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -607,6 +607,7 @@ CUDA_INCLUDE_MAP = collections.OrderedDict( ("curand_precalc.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)), ("curand_uniform.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)), ("cusparse.h", ("hipsparse/hipsparse.h", CONV_INCLUDE, API_RAND)), + ("cusparseLt.h", ("hipsparselt/hipsparselt.h", CONV_INCLUDE, API_RAND)), ("cufft.h", ("hipfft/hipfft.h", CONV_INCLUDE, API_BLAS)), ("cufftXt.h", ("hipfft/hipfftXt.h", CONV_INCLUDE, API_BLAS)), # PyTorch also has a source file named "nccl.h", so we need to "<"">" to differentiate @@ -8256,6 +8257,43 @@ CUDA_SPECIAL_MAP = collections.OrderedDict( "CUSPARSE_MATRIX_TYPE_GENERAL", ("HIPSPARSE_MATRIX_TYPE_GENERAL", CONV_NUMERIC_LITERAL, API_SPECIAL), ), + # SparseLt + ("cuSPARSELt", ("hipSPARSELt", CONV_TYPE, API_SPECIAL)), + ("AT_CUSPARSELT_ENABLED", ("AT_HIPSPARSELT_ENABLED", CONV_TYPE, API_SPECIAL)), + ("CUSPARSE_ORDER_ROW", ("HIPSPARSE_ORDER_ROW", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_ORDER_COL", ("HIPSPARSE_ORDER_COL", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSELT_SPARSITY_50_PERCENT", ("HIPSPARSELT_SPARSITY_50_PERCENT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("cusparseComputeType", ("hipsparseLtComputetype_t", CONV_TYPE, API_SPECIAL)), + ("CUSPARSE_COMPUTE_32F", ("HIPSPARSELT_COMPUTE_32F", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_COMPUTE_16F", ("HIPSPARSELT_COMPUTE_16F", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_COMPUTE_32I", ("HIPSPARSELT_COMPUTE_32I", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSE_COMPUTE_TF32", ("HIPSPARSELT_COMPUTE_TF32", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSELT_MATMUL_BIAS_POINTER", ("HIPSPARSELT_MATMUL_BIAS_POINTER", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSELT_MATMUL_ALG_DEFAULT", ("HIPSPARSELT_MATMUL_ALG_DEFAULT", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSELT_MATMUL_ALG_CONFIG_ID", ("HIPSPARSELT_MATMUL_ALG_CONFIG_ID", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("CUSPARSELT_MATMUL_ALPHA_VECTOR_SCALING", ("HIPSPARSELT_MATMUL_ALPHA_VECTOR_SCALING", CONV_NUMERIC_LITERAL, API_SPECIAL)), + ("cusparseLtHandle_t", ("hipsparseLtHandle_t", CONV_TYPE, API_SPECIAL)), + ("cusparseLtMatDescriptor_t", ("hipsparseLtMatDescriptor_t", CONV_TYPE, API_SPECIAL)), + ("cusparseLtInit", ("hipsparseLtInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtStructuredDescriptorInit", ("hipsparseLtStructuredDescriptorInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtSpMMACompressedSize2", ("hipsparseLtSpMMACompressedSize2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtSpMMACompress2", ("hipsparseLtSpMMACompress2", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulDescriptor_t", ("hipsparseLtMatmulDescriptor_t", CONV_TYPE, API_SPECIAL)), + ("cusparseLtMatmulPlan_t", ("hipsparseLtMatmulPlan_t", CONV_TYPE, API_SPECIAL)), + ("cusparseLtMatmulAlgSelection_t", ("hipsparseLtMatmulAlgSelection_t", CONV_TYPE, API_SPECIAL)), + ("cusparseLtStructuredDescriptorInit", ("hipsparseLtStructuredDescriptorInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtDenseDescriptorInit", ("hipsparseLtDenseDescriptorInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulDescriptorInit", ("hipsparseLtMatmulDescriptorInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulDescSetAttribute", ("hipsparseLtMatmulDescSetAttribute", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulAlgSelectionInit", ("hipsparseLtMatmulAlgSelectionInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulAlgSetAttribute", ("hipsparseLtMatmulAlgSetAttribute", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulPlanInit", ("hipsparseLtMatmulPlanInit", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulGetWorkspace", ("hipsparseLtMatmulGetWorkspace", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulSearch", ("hipsparseLtMatmulSearch", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulAlgGetAttribute", ("hipsparseLtMatmulAlgGetAttribute", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmul", ("hipsparseLtMatmul", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatDescriptorDestroy", ("hipsparseLtMatDescriptorDestroy", CONV_MATH_FUNC, API_SPECIAL)), + ("cusparseLtMatmulPlanDestroy", ("hipsparseLtMatmulPlanDestroy", CONV_MATH_FUNC, API_SPECIAL)), # SOLVER ("cublasOperation_t", ("hipsolverOperation_t", CONV_TYPE, API_SPECIAL)), ("CUBLAS_OP_N", ("HIPSOLVER_OP_N", CONV_NUMERIC_LITERAL, API_SPECIAL)),