mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
ROCm Sparsity through HipSparseLT (#150578)
TLDR: - This pull request introduces support for hipSPARSELt in ROCm, current usage would be semi-structure sparsity. - Require **ROCm 6.4** && **gfx942/gfx950**. - The average performance uplift (compare to dense operation) is ~ 20% in ROCm 6.4 but expect further performance lift along the way. ### Dense vs. Sparse Performance Comparison #### **NT (Row-major)** **Average Uplift**: `1.20` | M | N | K | hipsparselt-bench (us) | hipblaslt-bench get all (us) | Uplift | |-------|--------|--------|-------------------------|-------------------------------|--------| | 14336 | 8 | 4096 | 20.05 | 25.3 | 1.26 | | 4096 | 8 | 14336 | 21.07 | 25.28 | 1.20 | | 3072 | 3072 | 10240 | 299.05 | 351.82 | 1.18 | | 3072 | 1536 | 768 | 18.56 | 20.05 | 1.08 | | 3072 | 17664 | 768 | 163.13 | 173.91 | 1.07 | | 3072 | 196608 | 768 | 1717.30 | 1949.63 | 1.14 | | 3072 | 24576 | 768 | 206.84 | 242.98 | 1.17 | | 3072 | 6144 | 768 | 53.90 | 56.88 | 1.06 | | 3072 | 98304 | 768 | 833.77 | 962.28 | 1.15 | | 768 | 1536 | 768 | 8.53 | 19.65 | 2.30 | | 768 | 17664 | 768 | 46.02 | 46.84 | 1.02 | | 768 | 196608 | 768 | 463.15 | 540.46 | 1.17 | | 768 | 24576 | 768 | 54.32 | 59.55 | 1.10 | | 768 | 6144 | 768 | 19.47 | 20.15 | 1.03 | | 768 | 98304 | 768 | 231.88 | 258.73 | 1.12 | --- #### **NN (Row-major)** **Average Uplift**: `1.13` | M | N | K | hipsparselt-bench (us) | hipblaslt-bench get all (us) | Uplift | |-----|--------|-------|-------------------------|-------------------------------|--------| | 768 | 1536 | 3072 | 27.50 | 28.78 | 1.05 | | 768 | 17664 | 3072 | 125.06 | 158.94 | 1.27 | | 768 | 196608 | 3072 | 1568.38 | 1767.12 | 1.13 | | 768 | 24576 | 3072 | 171.05 | 203.49 | 1.19 | | 768 | 6144 | 3072 | 58.72 | 60.39 | 1.03 | | 768 | 98304 | 3072 | 787.15 | 887.60 | 1.13 | ------------------------- This pull request introduces support for hipSPARSELt in ROCm, alongside various updates and improvements to the codebase and test suite. The changes primarily involve adding configuration flags, updating conditional checks, and ensuring compatibility with hipSPARSELt. ### ROCm and hipSPARSELt Support: * [`BUILD.bazel`](diffhunk://#diff-7fc57714ef13c3325ce2a1130202edced92fcccc0c6db34a72f7b57f60d552a3R292): Added `@AT_HIPSPARSELT_ENABLED@` substitution to enable hipSPARSELt support. * [`aten/CMakeLists.txt`](diffhunk://#diff-0604597797bb21d7c39150f9429d6b2ace10b79ab308514ad03f76153ae8249bR104-R110): Introduced a conditional flag to enable hipSPARSELt support based on ROCm version. * [`aten/src/ATen/CMakeLists.txt`](diffhunk://#diff-ce80f3115ab2f6be5142f0678a1fc92c6b2d7727766ce44f48726c99e720f777R37): Added `AT_HIPSPARSELT_ENABLED` configuration. * [`aten/src/ATen/cuda/CUDAConfig.h.in`](diffhunk://#diff-8bb82da825ca87c28233abacffa1b0566c73a54990b7a77f3f5108d3718fea15R11): Defined `AT_HIPSPARSELT_ENABLED` macro. * `caffe2/CMakeLists.txt`, `cmake/Dependencies.cmake`, `cmake/public/LoadHIP.cmake`: Included hipSPARSELt in the ROCm dependencies. [[1]](diffhunk://#diff-c5ee05f1e918772792ff6f2a3f579fc2f182e57b1709fd786ef6dc711fd68b27R1380) [[2]](diffhunk://#diff-12e8125164bbfc7556b1781a8ed516e333cc0bf058acb7197f7415be44606c72L1084-R1084) [[3]](diffhunk://#diff-b98e27b9a5f196a6965a99ee5a7bb15b3fc633d6375b767635b1b04ccb2fd3d5R153) ### Codebase Updates: * [`aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp`](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R1-R6): Added hipSPARSELt support checks and initialization functions. Updated various methods to conditionally handle hipSPARSELt. [[1]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R1-R6) [[2]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R22-R67) [[3]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R78-R85) [[4]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R97-R109) [[5]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R183-R188) [[6]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3L134-R200) [[7]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3R213-R222) [[8]](diffhunk://#diff-ae921dd1584ab98fdd9c25a3521047795de702223f5b65fdaa45a5bd92b4d1f3L217-R285) ### Test Suite Updates: * [`test/test_sparse_semi_structured.py`](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR50-R65): Added checks for hipSPARSELt availability and updated test conditions to skip tests not supported on ROCm. [[1]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR50-R65) [[2]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR228) [[3]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR239) [[4]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR250) [[5]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR579) [[6]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR624) [[7]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR661) [[8]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR695) [[9]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR730) [[10]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR755) [[11]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR771) [[12]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR809) [[13]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR844) [[14]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cL840-R854) [[15]](diffhunk://#diff-b7b57bc1e34145ef89c7929751d5d26aeecc8edfb37da9c60e9d3f0a1335133cR1005) Pull Request resolved: https://github.com/pytorch/pytorch/pull/150578 Approved by: https://github.com/jeffdaily
This commit is contained in:
committed by
PyTorch MergeBot
parent
ad26ec6abe
commit
43390d8b13
@ -290,6 +290,7 @@ header_template_rule(
|
||||
substitutions = {
|
||||
"@AT_CUDNN_ENABLED@": "1",
|
||||
"@AT_CUSPARSELT_ENABLED@": "0",
|
||||
"@AT_HIPSPARSELT_ENABLED@": "0",
|
||||
"@AT_ROCM_ENABLED@": "0",
|
||||
"@AT_MAGMA_ENABLED@": "0",
|
||||
"@NVCC_FLAGS_EXTRA@": "",
|
||||
|
@ -101,6 +101,13 @@ else()
|
||||
set(AT_CUSPARSELT_ENABLED 1)
|
||||
endif()
|
||||
|
||||
# Add hipSPARSELt support flag
|
||||
if(USE_ROCM AND ROCM_VERSION VERSION_GREATER_EQUAL "6.4.0")
|
||||
set(AT_HIPSPARSELT_ENABLED 1)
|
||||
else()
|
||||
set(AT_HIPSPARSELT_ENABLED 0)
|
||||
endif()
|
||||
|
||||
list(APPEND ATen_CPU_INCLUDE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/src)
|
||||
add_subdirectory(src/ATen)
|
||||
|
@ -34,6 +34,7 @@ set_bool(AT_MAGMA_ENABLED USE_MAGMA)
|
||||
set_bool(CAFFE2_STATIC_LINK_CUDA_INT CAFFE2_STATIC_LINK_CUDA)
|
||||
set_bool(AT_CUDNN_ENABLED CAFFE2_USE_CUDNN)
|
||||
set_bool(AT_CUSPARSELT_ENABLED CAFFE2_USE_CUSPARSELT)
|
||||
set_bool(AT_HIPSPARSELT_ENABLED CAFFE2_USE_HIPSPARSELT)
|
||||
|
||||
configure_file(Config.h.in "${CMAKE_CURRENT_SOURCE_DIR}/Config.h")
|
||||
# TODO: Do not generate CUDAConfig.h for ROCm BUILDS
|
||||
|
@ -8,6 +8,7 @@
|
||||
// only be included from C++ files.
|
||||
#define AT_CUDNN_ENABLED() @AT_CUDNN_ENABLED@
|
||||
#define AT_CUSPARSELT_ENABLED() @AT_CUSPARSELT_ENABLED@
|
||||
#define AT_HIPSPARSELT_ENABLED() @AT_HIPSPARSELT_ENABLED@
|
||||
#define AT_ROCM_ENABLED() @AT_ROCM_ENABLED@
|
||||
#define AT_MAGMA_ENABLED() @AT_MAGMA_ENABLED@
|
||||
|
||||
|
@ -1,5 +1,7 @@
|
||||
#include <ATen/native/sparse/cuda/cuSPARSELtOps.h>
|
||||
|
||||
#include <unordered_map>
|
||||
#include <mutex>
|
||||
#include <string_view>
|
||||
#if AT_CUSPARSELT_ENABLED()
|
||||
|
||||
namespace at::native {
|
||||
@ -15,6 +17,45 @@ namespace at::native {
|
||||
thread_local cusparseLtHandle_t handle;
|
||||
thread_local bool handle_initialized = false;
|
||||
|
||||
#ifdef USE_ROCM
|
||||
// Single global flag for platform-wide hipSparseLt support
|
||||
c10::once_flag g_hipSparseLtSupportInitFlag;
|
||||
static bool g_hipSparseLtSupported = false;
|
||||
|
||||
// Initialize the hipSparseLt support status once for the platform
|
||||
static void initHipSparseLtSupport() {
|
||||
// Default to not supported
|
||||
g_hipSparseLtSupported = false;
|
||||
|
||||
// Check only the first available device
|
||||
try {
|
||||
if (at::cuda::device_count() > 0) {
|
||||
g_hipSparseLtSupported = at::detail::getCUDAHooks().isGPUArch({"gfx950", "gfx942"}, 0);
|
||||
}
|
||||
} catch (const std::exception&) {
|
||||
// If an exception occurs during device property check, we assume hipSparseLt is not supported
|
||||
// This could happen due to driver issues, device access problems, or other runtime errors
|
||||
g_hipSparseLtSupported = false;
|
||||
TORCH_WARN("Exception occurred while checking hipSparseLt support. Assuming not supported.");
|
||||
}
|
||||
}
|
||||
|
||||
static bool isHipSparseLtSupported() {
|
||||
// Initialize support check only once
|
||||
c10::call_once(g_hipSparseLtSupportInitFlag, initHipSparseLtSupport);
|
||||
|
||||
// Return cached result (platform-wide)
|
||||
if (!g_hipSparseLtSupported) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"hipSparseLt not supported on this device, supported architectures: "
|
||||
"gfx950, gfx942. "
|
||||
"required ROCM version: 6.4.0 or later.");
|
||||
}
|
||||
return g_hipSparseLtSupported;
|
||||
}
|
||||
#endif
|
||||
|
||||
at::Tensor _cslt_compress(const Tensor& sparse_input) {
|
||||
if (!handle_initialized) {
|
||||
TORCH_CUDASPARSE_CHECK(cusparseLtInit(&handle));
|
||||
@ -25,6 +66,10 @@ at::Tensor _cslt_compress(const Tensor& sparse_input) {
|
||||
cudaDataType type;
|
||||
auto compression_factor = 9;
|
||||
|
||||
#ifdef USE_ROCM
|
||||
TORCH_CHECK(isHipSparseLtSupported());
|
||||
#endif
|
||||
|
||||
switch (sparse_input.scalar_type()) {
|
||||
case at::ScalarType::Char:
|
||||
type = CUDA_R_8I;
|
||||
@ -36,17 +81,19 @@ at::Tensor _cslt_compress(const Tensor& sparse_input) {
|
||||
case at::ScalarType::BFloat16:
|
||||
type = CUDA_R_16BF;
|
||||
break;
|
||||
#ifndef USE_ROCM
|
||||
case at::ScalarType::Float:
|
||||
type = CUDA_R_32F;
|
||||
break;
|
||||
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602
|
||||
#endif
|
||||
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 && !defined(USE_ROCM)
|
||||
case at::ScalarType::Float8_e4m3fn:
|
||||
type = CUDA_R_8F_E4M3;
|
||||
compression_factor = 10;
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt compressed matrix");
|
||||
TORCH_CHECK(false, "Unsupported dtype for cuSPARSELt/hipSparseLt compressed matrix");
|
||||
break;
|
||||
}
|
||||
|
||||
@ -120,6 +167,10 @@ std::tuple<at::Tensor, int64_t, int64_t, int64_t, int64_t> _cslt_sparse_mm_impl(
|
||||
cusparseComputeType compute_type;
|
||||
auto compression_factor = 9;
|
||||
|
||||
#ifdef USE_ROCM
|
||||
TORCH_CHECK(isHipSparseLtSupported());
|
||||
#endif
|
||||
|
||||
switch (compressed_A.scalar_type()) {
|
||||
case at::ScalarType::Char:
|
||||
input_type = CUDA_R_8I;
|
||||
@ -131,7 +182,7 @@ std::tuple<at::Tensor, int64_t, int64_t, int64_t, int64_t> _cslt_sparse_mm_impl(
|
||||
|
||||
// cuSPARSELt v0.5.2 onwards changes CUSPARSE_COMPUTE_TF32, CUSPARSE_COMPUT_16F
|
||||
// to CUSPARSE_COMPUTE_32F
|
||||
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 502
|
||||
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 502 || defined(USE_ROCM)
|
||||
case at::ScalarType::Half:
|
||||
input_type = CUDA_R_16F;
|
||||
output_type = CUDA_R_16F;
|
||||
@ -144,14 +195,16 @@ std::tuple<at::Tensor, int64_t, int64_t, int64_t, int64_t> _cslt_sparse_mm_impl(
|
||||
C_type = CUDA_R_16BF;
|
||||
compute_type = CUSPARSE_COMPUTE_32F;
|
||||
break;
|
||||
#ifndef USE_ROCM
|
||||
case at::ScalarType::Float:
|
||||
input_type = CUDA_R_32F;
|
||||
output_type = CUDA_R_32F;
|
||||
C_type = CUDA_R_32F;
|
||||
compute_type = CUSPARSE_COMPUTE_32F;
|
||||
break;
|
||||
#endif
|
||||
// if cuSPARSELt >= 6.2.3, we can add Float8 support
|
||||
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602
|
||||
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 && !defined(USE_ROCM)
|
||||
case at::ScalarType::Float8_e4m3fn:
|
||||
input_type = CUDA_R_8F_E4M3;
|
||||
output_type = CUDA_R_8F_E4M3;
|
||||
@ -214,7 +267,7 @@ std::tuple<at::Tensor, int64_t, int64_t, int64_t, int64_t> _cslt_sparse_mm_impl(
|
||||
}
|
||||
}
|
||||
// cslt 0.6.2+: fp8 fp8 -> {fp8, fp16, bf16, fp32} support
|
||||
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602
|
||||
#if defined(CUSPARSELT_VERSION) && CUSPARSELT_VERSION >= 602 && !defined(USE_ROCM)
|
||||
else if (input_type == CUDA_R_8F_E4M3) {
|
||||
switch (out_dtype) {
|
||||
case at::ScalarType::Float8_e4m3fn:
|
||||
|
@ -1063,7 +1063,7 @@ if(USE_ROCM)
|
||||
|
||||
# Math libraries
|
||||
list(APPEND Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS
|
||||
roc::hipblas roc::rocblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsolver roc::hipblaslt)
|
||||
roc::hipblas roc::rocblas hip::hipfft hip::hiprand roc::hipsparse roc::hipsparselt roc::hipsolver roc::hipblaslt)
|
||||
|
||||
# ---[ Kernel asserts
|
||||
# Kernel asserts is disabled for ROCm by default.
|
||||
|
@ -151,6 +151,7 @@ if(HIP_FOUND)
|
||||
find_package_and_print_version(miopen REQUIRED)
|
||||
find_package_and_print_version(hipfft REQUIRED)
|
||||
find_package_and_print_version(hipsparse REQUIRED)
|
||||
find_package_and_print_version(hipsparselt REQUIRED)
|
||||
find_package_and_print_version(rocprim REQUIRED)
|
||||
find_package_and_print_version(hipcub REQUIRED)
|
||||
find_package_and_print_version(rocthrust REQUIRED)
|
||||
|
@ -47,17 +47,18 @@ SEMI_STRUCTURED_SUPPORTED_BACKENDS = dict()
|
||||
|
||||
_IS_SM8X = False
|
||||
_IS_SM9X = False
|
||||
_IS_HIPSPARSELT_AVAILABLE = False
|
||||
|
||||
if torch.cuda.is_available():
|
||||
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
|
||||
_IS_SM9X = torch.cuda.get_device_capability(0)[0] == 9
|
||||
|
||||
_IS_HIPSPARSELT_AVAILABLE = torch.version.hip is not None and tuple(int(v) for v in torch.version.hip.split('.')[:2]) > (6, 4)
|
||||
# CUTLASS kernels only work for Ampere
|
||||
if _IS_SM8X:
|
||||
SEMI_STRUCTURED_SUPPORTED_BACKENDS["cutlass"] = SparseSemiStructuredTensorCUTLASS
|
||||
|
||||
# add cuSPASRELt tests if available
|
||||
if torch.backends.cusparselt.is_available() and (_IS_SM8X or _IS_SM9X):
|
||||
if torch.backends.cusparselt.is_available() and (_IS_SM8X or _IS_SM9X or _IS_HIPSPARSELT_AVAILABLE):
|
||||
SEMI_STRUCTURED_SUPPORTED_BACKENDS["cusparselt"] = SparseSemiStructuredTensorCUSPARSELT
|
||||
|
||||
inference_dtypes = dtypes(torch.float16, torch.bfloat16, torch.int8)
|
||||
@ -223,6 +224,7 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
|
||||
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
def test_mlp_contiguous_relu_compile_cusparselt(self):
|
||||
"""
|
||||
test for cuSPASRELt meta registrations (_cslt_sparse_mm) + torch.compile
|
||||
@ -233,6 +235,7 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
|
||||
|
||||
@unittest.skipIf("cutlass" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cutlass not supported on this machine")
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
def test_mlp_contiguous_relu_compile_cutlass(self):
|
||||
"""
|
||||
test for CUTLASS meta registrations (_sparse_semi_structured_addmm) + torch.compile
|
||||
@ -243,6 +246,7 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
|
||||
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
def test_sp24_compile(self) -> None:
|
||||
x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True)
|
||||
|
||||
@ -571,6 +575,7 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
|
||||
|
||||
@training_dtypes
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
def test_prune_dense_static_sort(self, dtype) -> None:
|
||||
# Ideally we would like to clone and compare, but that won't work because the sorting order will be different
|
||||
# instead we pass the pruned matrix to the CUDA implementation and preserve the sparsity pattern.
|
||||
@ -615,6 +620,7 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
|
||||
@training_dtypes
|
||||
@parametrize_backends
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
def test_pruning_algo_largest_abs_values_greedy(self, dtype, backend) -> None:
|
||||
inp = torch.tensor(
|
||||
[[4, 3, 2, 1], [-1, -3, 0.6, 0.5], [1, 2, 3, 4], [10, 2, -1, 5]],
|
||||
@ -651,6 +657,7 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
|
||||
@training_dtypes
|
||||
@parametrize_backends
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
def test_pack_both_ways_meta_correctness(self, dtype, backend) -> None:
|
||||
M, N = 128, 256
|
||||
# Construct x to make sure we always have exactly 8 elements per 4x4 tile
|
||||
@ -684,6 +691,7 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
torch.testing.assert_close(ref_gemm, pack_gemm, **atol_rtol_kw[dtype])
|
||||
|
||||
@training_dtypes
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
def test_pack_both_ways_id(self, dtype) -> None:
|
||||
N = 512
|
||||
torch.manual_seed(0)
|
||||
@ -718,6 +726,7 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
), f"packed_t is wrong at pos: ({max_diff // N}, {max_diff % N})"
|
||||
|
||||
@training_dtypes
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
def test_pack_both_ways_edge_case1(self, dtype) -> None:
|
||||
# In this case, the heuristic will keep 7 values out of 16
|
||||
# instead of 8. let's see how the kernel handles this
|
||||
@ -742,6 +751,7 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
assert packed_t[0, 1].item() == 0
|
||||
|
||||
@training_dtypes
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
def test_sp24_apply(self, dtype) -> None:
|
||||
M, N = 256, 1024
|
||||
x = torch.randn([M, N], dtype=dtype, device="cuda")
|
||||
@ -757,6 +767,7 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
torch.testing.assert_close(packed_t, packed_t2)
|
||||
|
||||
@training_dtypes
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
def test_sp24_apply_dense(self, dtype) -> None:
|
||||
M, N = 256, 1024
|
||||
x = torch.randn([M, N], dtype=dtype, device="cuda")
|
||||
@ -794,6 +805,7 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
|
||||
|
||||
@training_dtypes
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
def test_sp24_matmuls(self, dtype) -> None:
|
||||
M, N, K = 64, 256, 1024
|
||||
a = torch.randn([M, K], device="cuda", dtype=dtype)
|
||||
@ -828,6 +840,7 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
a_s.t() @ a, (a * a_m).t() @ a, rtol=1e-1, atol=1e-1
|
||||
)
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
def test_sp24_matmuls_mat_vec(self) -> None:
|
||||
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
|
||||
b = torch.randn([128], device="cuda", dtype=torch.float16)
|
||||
@ -837,7 +850,7 @@ class TestSparseSemiStructuredTraining(TestCase):
|
||||
with pytest.raises(NotImplementedError):
|
||||
torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
|
||||
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
def test_sp24_matmuls_bmm(self) -> None:
|
||||
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
|
||||
b = torch.randn([5, 6, 128], device="cuda", dtype=torch.float16)
|
||||
@ -988,6 +1001,7 @@ class TestSparseSemiStructuredCUTLASS(TestCase):
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@inference_dtypes
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
def test_conversions(self, device, dtype):
|
||||
|
||||
def run_test(r, c, device, dtype):
|
||||
@ -1016,6 +1030,7 @@ class TestSparseSemiStructuredCUTLASS(TestCase):
|
||||
|
||||
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
||||
@inference_dtypes
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
def test_conversions_all_patterns(self, device, dtype):
|
||||
r, c = 32, 128
|
||||
|
||||
@ -1135,6 +1150,7 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
||||
|
||||
@unittest.skip("cuSPARSELt v0.6.x does not support bfloat/float16 alpha scaling")
|
||||
@training_dtypes
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
def test_cslt_sparse_mm_alpha(self, dtype, device):
|
||||
A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(dtype).cuda()
|
||||
B = torch.ones((256, 128), device=device).to(dtype)
|
||||
@ -1151,6 +1167,7 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
||||
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.int32])
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
def test_cslt_sparse_mm_alpha_compile_autotune(self, device, out_dtype):
|
||||
A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(torch.int8).to(device)
|
||||
B = torch.ones((128, 256), device=device, dtype=torch.int8).t()
|
||||
@ -1172,6 +1189,7 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
||||
torch.testing.assert_close(sparse_result.cpu(), get_dense_result(), rtol=1e-3, atol=1e-3)
|
||||
|
||||
@parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.int32])
|
||||
@unittest.skipIf(TEST_WITH_ROCM, "Not supported on ROCm")
|
||||
def test_cslt_sparse_mm_alpha_mixed_dtype(self, out_dtype, device):
|
||||
A = torch.Tensor([0, 0, 10, 10]).tile((128, 64)).to(torch.int8).cuda()
|
||||
B = torch.ones((128, 256), device=device).to(torch.int8).t()
|
||||
|
@ -607,6 +607,7 @@ CUDA_INCLUDE_MAP = collections.OrderedDict(
|
||||
("curand_precalc.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("curand_uniform.h", ("hiprand/hiprand_kernel.h", CONV_INCLUDE, API_RAND)),
|
||||
("cusparse.h", ("hipsparse/hipsparse.h", CONV_INCLUDE, API_RAND)),
|
||||
("cusparseLt.h", ("hipsparselt/hipsparselt.h", CONV_INCLUDE, API_RAND)),
|
||||
("cufft.h", ("hipfft/hipfft.h", CONV_INCLUDE, API_BLAS)),
|
||||
("cufftXt.h", ("hipfft/hipfftXt.h", CONV_INCLUDE, API_BLAS)),
|
||||
# PyTorch also has a source file named "nccl.h", so we need to "<"">" to differentiate
|
||||
@ -8256,6 +8257,43 @@ CUDA_SPECIAL_MAP = collections.OrderedDict(
|
||||
"CUSPARSE_MATRIX_TYPE_GENERAL",
|
||||
("HIPSPARSE_MATRIX_TYPE_GENERAL", CONV_NUMERIC_LITERAL, API_SPECIAL),
|
||||
),
|
||||
# SparseLt
|
||||
("cuSPARSELt", ("hipSPARSELt", CONV_TYPE, API_SPECIAL)),
|
||||
("AT_CUSPARSELT_ENABLED", ("AT_HIPSPARSELT_ENABLED", CONV_TYPE, API_SPECIAL)),
|
||||
("CUSPARSE_ORDER_ROW", ("HIPSPARSE_ORDER_ROW", CONV_NUMERIC_LITERAL, API_SPECIAL)),
|
||||
("CUSPARSE_ORDER_COL", ("HIPSPARSE_ORDER_COL", CONV_NUMERIC_LITERAL, API_SPECIAL)),
|
||||
("CUSPARSELT_SPARSITY_50_PERCENT", ("HIPSPARSELT_SPARSITY_50_PERCENT", CONV_NUMERIC_LITERAL, API_SPECIAL)),
|
||||
("cusparseComputeType", ("hipsparseLtComputetype_t", CONV_TYPE, API_SPECIAL)),
|
||||
("CUSPARSE_COMPUTE_32F", ("HIPSPARSELT_COMPUTE_32F", CONV_NUMERIC_LITERAL, API_SPECIAL)),
|
||||
("CUSPARSE_COMPUTE_16F", ("HIPSPARSELT_COMPUTE_16F", CONV_NUMERIC_LITERAL, API_SPECIAL)),
|
||||
("CUSPARSE_COMPUTE_32I", ("HIPSPARSELT_COMPUTE_32I", CONV_NUMERIC_LITERAL, API_SPECIAL)),
|
||||
("CUSPARSE_COMPUTE_TF32", ("HIPSPARSELT_COMPUTE_TF32", CONV_NUMERIC_LITERAL, API_SPECIAL)),
|
||||
("CUSPARSELT_MATMUL_BIAS_POINTER", ("HIPSPARSELT_MATMUL_BIAS_POINTER", CONV_NUMERIC_LITERAL, API_SPECIAL)),
|
||||
("CUSPARSELT_MATMUL_ALG_DEFAULT", ("HIPSPARSELT_MATMUL_ALG_DEFAULT", CONV_NUMERIC_LITERAL, API_SPECIAL)),
|
||||
("CUSPARSELT_MATMUL_ALG_CONFIG_ID", ("HIPSPARSELT_MATMUL_ALG_CONFIG_ID", CONV_NUMERIC_LITERAL, API_SPECIAL)),
|
||||
("CUSPARSELT_MATMUL_ALPHA_VECTOR_SCALING", ("HIPSPARSELT_MATMUL_ALPHA_VECTOR_SCALING", CONV_NUMERIC_LITERAL, API_SPECIAL)),
|
||||
("cusparseLtHandle_t", ("hipsparseLtHandle_t", CONV_TYPE, API_SPECIAL)),
|
||||
("cusparseLtMatDescriptor_t", ("hipsparseLtMatDescriptor_t", CONV_TYPE, API_SPECIAL)),
|
||||
("cusparseLtInit", ("hipsparseLtInit", CONV_MATH_FUNC, API_SPECIAL)),
|
||||
("cusparseLtStructuredDescriptorInit", ("hipsparseLtStructuredDescriptorInit", CONV_MATH_FUNC, API_SPECIAL)),
|
||||
("cusparseLtSpMMACompressedSize2", ("hipsparseLtSpMMACompressedSize2", CONV_MATH_FUNC, API_SPECIAL)),
|
||||
("cusparseLtSpMMACompress2", ("hipsparseLtSpMMACompress2", CONV_MATH_FUNC, API_SPECIAL)),
|
||||
("cusparseLtMatmulDescriptor_t", ("hipsparseLtMatmulDescriptor_t", CONV_TYPE, API_SPECIAL)),
|
||||
("cusparseLtMatmulPlan_t", ("hipsparseLtMatmulPlan_t", CONV_TYPE, API_SPECIAL)),
|
||||
("cusparseLtMatmulAlgSelection_t", ("hipsparseLtMatmulAlgSelection_t", CONV_TYPE, API_SPECIAL)),
|
||||
("cusparseLtStructuredDescriptorInit", ("hipsparseLtStructuredDescriptorInit", CONV_MATH_FUNC, API_SPECIAL)),
|
||||
("cusparseLtDenseDescriptorInit", ("hipsparseLtDenseDescriptorInit", CONV_MATH_FUNC, API_SPECIAL)),
|
||||
("cusparseLtMatmulDescriptorInit", ("hipsparseLtMatmulDescriptorInit", CONV_MATH_FUNC, API_SPECIAL)),
|
||||
("cusparseLtMatmulDescSetAttribute", ("hipsparseLtMatmulDescSetAttribute", CONV_MATH_FUNC, API_SPECIAL)),
|
||||
("cusparseLtMatmulAlgSelectionInit", ("hipsparseLtMatmulAlgSelectionInit", CONV_MATH_FUNC, API_SPECIAL)),
|
||||
("cusparseLtMatmulAlgSetAttribute", ("hipsparseLtMatmulAlgSetAttribute", CONV_MATH_FUNC, API_SPECIAL)),
|
||||
("cusparseLtMatmulPlanInit", ("hipsparseLtMatmulPlanInit", CONV_MATH_FUNC, API_SPECIAL)),
|
||||
("cusparseLtMatmulGetWorkspace", ("hipsparseLtMatmulGetWorkspace", CONV_MATH_FUNC, API_SPECIAL)),
|
||||
("cusparseLtMatmulSearch", ("hipsparseLtMatmulSearch", CONV_MATH_FUNC, API_SPECIAL)),
|
||||
("cusparseLtMatmulAlgGetAttribute", ("hipsparseLtMatmulAlgGetAttribute", CONV_MATH_FUNC, API_SPECIAL)),
|
||||
("cusparseLtMatmul", ("hipsparseLtMatmul", CONV_MATH_FUNC, API_SPECIAL)),
|
||||
("cusparseLtMatDescriptorDestroy", ("hipsparseLtMatDescriptorDestroy", CONV_MATH_FUNC, API_SPECIAL)),
|
||||
("cusparseLtMatmulPlanDestroy", ("hipsparseLtMatmulPlanDestroy", CONV_MATH_FUNC, API_SPECIAL)),
|
||||
# SOLVER
|
||||
("cublasOperation_t", ("hipsolverOperation_t", CONV_TYPE, API_SPECIAL)),
|
||||
("CUBLAS_OP_N", ("HIPSOLVER_OP_N", CONV_NUMERIC_LITERAL, API_SPECIAL)),
|
||||
|
Reference in New Issue
Block a user