[sparse] Add cuSPARSELt as a backend (#128534)

Summary:

This PR adds in cuSPARSELt as a backend to PyTorch.

It is now possible to see if cuSPARSELt is available and the version if
it is with
```
torch.backends.cusparselt.is_available()
torch.backends.cusparselt.version()
```

Test Plan:
```
python test/test_sparse_semi_structured.py -k test_cusparselt_backend
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128534
Approved by: https://github.com/cpuhrsch, https://github.com/eqy, https://github.com/syed-ahmed
This commit is contained in:
Jesse Cai
2024-08-21 10:13:14 -07:00
committed by PyTorch MergeBot
parent 0870398fa8
commit 255cd75a97
11 changed files with 118 additions and 8 deletions

View File

@ -780,6 +780,7 @@ libtorch_python_cuda_core_sources = [
libtorch_python_cuda_sources = libtorch_python_cuda_core_sources + [ libtorch_python_cuda_sources = libtorch_python_cuda_core_sources + [
"torch/csrc/cuda/python_nccl.cpp", "torch/csrc/cuda/python_nccl.cpp",
"torch/csrc/cuda/shared/cudnn.cpp", "torch/csrc/cuda/shared/cudnn.cpp",
"torch/csrc/cuda/shared/cusparselt.cpp",
"torch/csrc/cuda/Tensor.cpp", "torch/csrc/cuda/Tensor.cpp",
] ]

View File

@ -12,6 +12,7 @@ These backends include:
- ``torch.backends.cpu`` - ``torch.backends.cpu``
- ``torch.backends.cuda`` - ``torch.backends.cuda``
- ``torch.backends.cudnn`` - ``torch.backends.cudnn``
- ``torch.backends.cusparselt``
- ``torch.backends.mha`` - ``torch.backends.mha``
- ``torch.backends.mps`` - ``torch.backends.mps``
- ``torch.backends.mkl`` - ``torch.backends.mkl``
@ -135,6 +136,13 @@ torch.backends.cudnn
.. py:module:: torch.backends.cudnn.rnn .. py:module:: torch.backends.cudnn.rnn
torch.backends.cusparselt
^^^^^^^^^^^^^^^^^^^^^^^^^
.. automodule:: torch.backends.cusparselt
.. autofunction:: torch.backends.cusparselt.version
.. autofunction:: torch.backends.cusparselt.is_available
torch.backends.mha torch.backends.mha
^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^

View File

@ -21,7 +21,7 @@ from torch.sparse._semi_structured_conversions import (
) )
from torch.testing import make_tensor from torch.testing import make_tensor
from torch.testing._internal.common_cuda import _get_torch_cuda_version
from torch.testing._internal.common_device_type import ( from torch.testing._internal.common_device_type import (
dtypes, dtypes,
instantiate_device_type_tests, instantiate_device_type_tests,
@ -29,7 +29,6 @@ from torch.testing._internal.common_device_type import (
from torch.testing._internal.common_dtype import all_types_and_complex from torch.testing._internal.common_dtype import all_types_and_complex
import torch._dynamo.test_case import torch._dynamo.test_case
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
parametrize, parametrize,
run_tests, run_tests,
@ -52,13 +51,9 @@ if torch.cuda.is_available():
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8 _IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
SEMI_STRUCTURED_SUPPORTED_BACKENDS["cutlass"] = SparseSemiStructuredTensorCUTLASS SEMI_STRUCTURED_SUPPORTED_BACKENDS["cutlass"] = SparseSemiStructuredTensorCUTLASS
# check if cslt is available for now using this: # add cuSPASRELt tests if available
# TODO when we add cusparselt as a backend, we can update this to be use torch.cusparselt.is_available() if torch.backends.cusparselt.is_available():
try:
torch._cslt_compress(torch.ones(128, 256).cuda())
SEMI_STRUCTURED_SUPPORTED_BACKENDS["cusparselt"] = SparseSemiStructuredTensorCUSPARSELT SEMI_STRUCTURED_SUPPORTED_BACKENDS["cusparselt"] = SparseSemiStructuredTensorCUSPARSELT
except Exception:
pass
inference_dtypes = dtypes(torch.float16, torch.bfloat16, torch.float32, torch.int8) inference_dtypes = dtypes(torch.float16, torch.bfloat16, torch.float32, torch.int8)
training_dtypes = dtypes(torch.float16, torch.bfloat16) training_dtypes = dtypes(torch.float16, torch.bfloat16)
@ -1113,6 +1108,22 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
# in cuSPARSELt v0.5.0 there are only 4 alg_ids total, so we should remove the +1 here when we update. # in cuSPARSELt v0.5.0 there are only 4 alg_ids total, so we should remove the +1 here when we update.
assert alg_id in range(CUSPARSELT_NUM_ALG_IDS + 1) assert alg_id in range(CUSPARSELT_NUM_ALG_IDS + 1)
def test_cusparselt_backend(self):
version = _get_torch_cuda_version()
assert torch.backends.cusparselt.is_available()
# CUDA 11.8 has cuSPARSELt v0.4.0 support
if version == (11, 8):
assert torch.backends.cusparselt.version() == 400
# CUDA 12.1+ has cuSPARSELt v0.5.2 support added here: https://github.com/pytorch/builder/pull/1672/files
elif version == (12, 1):
assert torch.backends.cusparselt.version() == 502
elif version > (12, 1):
assert torch.backends.cusparselt.version() == 502
else:
assert torch.backends.cusparselt.version() is None
instantiate_device_type_tests(TestSparseSemiStructured, globals(), only_for="cuda") instantiate_device_type_tests(TestSparseSemiStructured, globals(), only_for="cuda")
instantiate_device_type_tests(TestSparseSemiStructuredCUTLASS, globals(), only_for="cuda") instantiate_device_type_tests(TestSparseSemiStructuredCUTLASS, globals(), only_for="cuda")
instantiate_device_type_tests(TestSparseSemiStructuredCUSPARSELT, globals(), only_for="cuda") instantiate_device_type_tests(TestSparseSemiStructuredCUSPARSELT, globals(), only_for="cuda")

View File

@ -135,6 +135,11 @@ if(USE_CUDA)
list(APPEND TORCH_PYTHON_LINK_LIBRARIES torch::cudnn) list(APPEND TORCH_PYTHON_LINK_LIBRARIES torch::cudnn)
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_CUDNN) list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_CUDNN)
endif() endif()
if(USE_CUSPARSELT)
list(APPEND TORCH_PYTHON_LINK_LIBRARIES torch::cusparselt)
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_CUSPARSELT)
endif()
if(TARGET torch::nvtx3) if(TARGET torch::nvtx3)
list(APPEND TORCH_PYTHON_LINK_LIBRARIES torch::nvtx3) list(APPEND TORCH_PYTHON_LINK_LIBRARIES torch::nvtx3)
else() else()
@ -174,6 +179,10 @@ if(USE_CUDNN OR USE_ROCM)
endif() endif()
endif() endif()
if(USE_CUSPARSELT)
list(APPEND TORCH_PYTHON_SRCS ${TORCH_SRC_DIR}/csrc/cuda/shared/cusparselt.cpp)
endif()
if(USE_MPS) if(USE_MPS)
list(APPEND TORCH_PYTHON_SRCS ${MPS_PYTHON_SRCS}) list(APPEND TORCH_PYTHON_SRCS ${MPS_PYTHON_SRCS})
endif() endif()

View File

@ -1299,6 +1299,7 @@ _has_magma: _bool
_has_xpu: _bool _has_xpu: _bool
_has_mkldnn: _bool _has_mkldnn: _bool
_has_cudnn: _bool _has_cudnn: _bool
_has_cusparselt: _bool
has_spectral: _bool has_spectral: _bool
_GLIBCXX_USE_CXX11_ABI: _bool _GLIBCXX_USE_CXX11_ABI: _bool
default_generator: Generator default_generator: Generator

1
torch/_C/_cusparselt.pyi Normal file
View File

@ -0,0 +1 @@
def getVersionInt() -> int: ...

View File

@ -61,6 +61,7 @@ from torch.backends import (
cpu as cpu, cpu as cpu,
cuda as cuda, cuda as cuda,
cudnn as cudnn, cudnn as cudnn,
cusparselt as cusparselt,
mha as mha, mha as mha,
mkl as mkl, mkl as mkl,
mkldnn as mkldnn, mkldnn as mkldnn,

View File

@ -0,0 +1,42 @@
# mypy: allow-untyped-defs
from typing import Optional
import torch
__all__ = [
"version",
"is_available",
]
try:
from torch._C import _cusparselt
except ImportError:
_cusparselt = None # type: ignore[assignment]
__cusparselt_version: Optional[int] = None
if _cusparselt is not None:
def _init():
global __cusparselt_version
if __cusparselt_version is None:
__cusparselt_version = _cusparselt.getVersionInt()
return True
else:
def _init():
return False
def version() -> Optional[int]:
"""Return the version of cuSPARSELt"""
if not _init():
return None
return __cusparselt_version
def is_available() -> bool:
r"""Return a bool indicating if cuSPARSELt is currently available."""
return torch._C._has_cusparselt

View File

@ -1740,6 +1740,13 @@ PyObject* initModule() {
#endif #endif
ASSERT_TRUE(set_module_attr("_has_cudnn", has_cudnn)); ASSERT_TRUE(set_module_attr("_has_cudnn", has_cudnn));
#if defined(USE_CUSPARSELT)
PyObject* has_cusparselt = Py_True;
#else
PyObject* has_cusparselt = Py_False;
#endif
ASSERT_TRUE(set_module_attr("_has_cusparselt", has_cusparselt));
#if AT_MKL_ENABLED() || AT_POCKETFFT_ENABLED() #if AT_MKL_ENABLED() || AT_POCKETFFT_ENABLED()
PyObject* has_spectral = Py_True; PyObject* has_spectral = Py_True;
#else #else

View File

@ -1968,6 +1968,9 @@ void initGdsBindings(PyObject* module);
#if defined(USE_CUDNN) || defined(USE_ROCM) #if defined(USE_CUDNN) || defined(USE_ROCM)
void initCudnnBindings(PyObject* module); void initCudnnBindings(PyObject* module);
#endif #endif
#if defined(USE_CUSPARSELT)
void initCusparseltBindings(PyObject* module);
#endif
} // namespace shared } // namespace shared
@ -1979,6 +1982,9 @@ void initModule(PyObject* module) {
shared::initNvtxBindings(module); shared::initNvtxBindings(module);
#if defined(USE_CUDNN) || defined(USE_ROCM) #if defined(USE_CUDNN) || defined(USE_ROCM)
shared::initCudnnBindings(module); shared::initCudnnBindings(module);
#endif
#if defined(USE_CUSPARSELT)
shared::initCusparseltBindings(module);
#endif #endif
shared::initGdsBindings(module); shared::initGdsBindings(module);
registerCudaDeviceProperties(module); registerCudaDeviceProperties(module);

View File

@ -0,0 +1,23 @@
#include <torch/csrc/utils/pybind.h>
#ifdef USE_CUSPARSELT
#include <cusparseLt.h>
namespace {
size_t getVersionInt() {
return CUSPARSELT_VERSION;
}
} // namespace
namespace torch::cuda::shared {
void initCusparseltBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
auto cusparselt = m.def_submodule("_cusparselt", "libcusparselt.so bindings");
cusparselt.def("getVersionInt", getVersionInt);
}
} // namespace torch::cuda::shared
#endif