[sparse] Add cuSPARSELt as a backend (#128534)

Summary:

This PR adds in cuSPARSELt as a backend to PyTorch.

It is now possible to see if cuSPARSELt is available and the version if
it is with
```
torch.backends.cusparselt.is_available()
torch.backends.cusparselt.version()
```

Test Plan:
```
python test/test_sparse_semi_structured.py -k test_cusparselt_backend
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128534
Approved by: https://github.com/cpuhrsch, https://github.com/eqy, https://github.com/syed-ahmed
This commit is contained in:
Jesse Cai
2024-08-21 10:13:14 -07:00
committed by PyTorch MergeBot
parent 0870398fa8
commit 255cd75a97
11 changed files with 118 additions and 8 deletions

View File

@ -780,6 +780,7 @@ libtorch_python_cuda_core_sources = [
libtorch_python_cuda_sources = libtorch_python_cuda_core_sources + [
"torch/csrc/cuda/python_nccl.cpp",
"torch/csrc/cuda/shared/cudnn.cpp",
"torch/csrc/cuda/shared/cusparselt.cpp",
"torch/csrc/cuda/Tensor.cpp",
]

View File

@ -12,6 +12,7 @@ These backends include:
- ``torch.backends.cpu``
- ``torch.backends.cuda``
- ``torch.backends.cudnn``
- ``torch.backends.cusparselt``
- ``torch.backends.mha``
- ``torch.backends.mps``
- ``torch.backends.mkl``
@ -135,6 +136,13 @@ torch.backends.cudnn
.. py:module:: torch.backends.cudnn.rnn
torch.backends.cusparselt
^^^^^^^^^^^^^^^^^^^^^^^^^
.. automodule:: torch.backends.cusparselt
.. autofunction:: torch.backends.cusparselt.version
.. autofunction:: torch.backends.cusparselt.is_available
torch.backends.mha
^^^^^^^^^^^^^^^^^^

View File

@ -21,7 +21,7 @@ from torch.sparse._semi_structured_conversions import (
)
from torch.testing import make_tensor
from torch.testing._internal.common_cuda import _get_torch_cuda_version
from torch.testing._internal.common_device_type import (
dtypes,
instantiate_device_type_tests,
@ -29,7 +29,6 @@ from torch.testing._internal.common_device_type import (
from torch.testing._internal.common_dtype import all_types_and_complex
import torch._dynamo.test_case
from torch.testing._internal.common_utils import (
parametrize,
run_tests,
@ -52,13 +51,9 @@ if torch.cuda.is_available():
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
SEMI_STRUCTURED_SUPPORTED_BACKENDS["cutlass"] = SparseSemiStructuredTensorCUTLASS
# check if cslt is available for now using this:
# TODO when we add cusparselt as a backend, we can update this to be use torch.cusparselt.is_available()
try:
torch._cslt_compress(torch.ones(128, 256).cuda())
# add cuSPASRELt tests if available
if torch.backends.cusparselt.is_available():
SEMI_STRUCTURED_SUPPORTED_BACKENDS["cusparselt"] = SparseSemiStructuredTensorCUSPARSELT
except Exception:
pass
inference_dtypes = dtypes(torch.float16, torch.bfloat16, torch.float32, torch.int8)
training_dtypes = dtypes(torch.float16, torch.bfloat16)
@ -1113,6 +1108,22 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
# in cuSPARSELt v0.5.0 there are only 4 alg_ids total, so we should remove the +1 here when we update.
assert alg_id in range(CUSPARSELT_NUM_ALG_IDS + 1)
def test_cusparselt_backend(self):
version = _get_torch_cuda_version()
assert torch.backends.cusparselt.is_available()
# CUDA 11.8 has cuSPARSELt v0.4.0 support
if version == (11, 8):
assert torch.backends.cusparselt.version() == 400
# CUDA 12.1+ has cuSPARSELt v0.5.2 support added here: https://github.com/pytorch/builder/pull/1672/files
elif version == (12, 1):
assert torch.backends.cusparselt.version() == 502
elif version > (12, 1):
assert torch.backends.cusparselt.version() == 502
else:
assert torch.backends.cusparselt.version() is None
instantiate_device_type_tests(TestSparseSemiStructured, globals(), only_for="cuda")
instantiate_device_type_tests(TestSparseSemiStructuredCUTLASS, globals(), only_for="cuda")
instantiate_device_type_tests(TestSparseSemiStructuredCUSPARSELT, globals(), only_for="cuda")

View File

@ -135,6 +135,11 @@ if(USE_CUDA)
list(APPEND TORCH_PYTHON_LINK_LIBRARIES torch::cudnn)
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_CUDNN)
endif()
if(USE_CUSPARSELT)
list(APPEND TORCH_PYTHON_LINK_LIBRARIES torch::cusparselt)
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_CUSPARSELT)
endif()
if(TARGET torch::nvtx3)
list(APPEND TORCH_PYTHON_LINK_LIBRARIES torch::nvtx3)
else()
@ -174,6 +179,10 @@ if(USE_CUDNN OR USE_ROCM)
endif()
endif()
if(USE_CUSPARSELT)
list(APPEND TORCH_PYTHON_SRCS ${TORCH_SRC_DIR}/csrc/cuda/shared/cusparselt.cpp)
endif()
if(USE_MPS)
list(APPEND TORCH_PYTHON_SRCS ${MPS_PYTHON_SRCS})
endif()

View File

@ -1299,6 +1299,7 @@ _has_magma: _bool
_has_xpu: _bool
_has_mkldnn: _bool
_has_cudnn: _bool
_has_cusparselt: _bool
has_spectral: _bool
_GLIBCXX_USE_CXX11_ABI: _bool
default_generator: Generator

1
torch/_C/_cusparselt.pyi Normal file
View File

@ -0,0 +1 @@
def getVersionInt() -> int: ...

View File

@ -61,6 +61,7 @@ from torch.backends import (
cpu as cpu,
cuda as cuda,
cudnn as cudnn,
cusparselt as cusparselt,
mha as mha,
mkl as mkl,
mkldnn as mkldnn,

View File

@ -0,0 +1,42 @@
# mypy: allow-untyped-defs
from typing import Optional
import torch
__all__ = [
"version",
"is_available",
]
try:
from torch._C import _cusparselt
except ImportError:
_cusparselt = None # type: ignore[assignment]
__cusparselt_version: Optional[int] = None
if _cusparselt is not None:
def _init():
global __cusparselt_version
if __cusparselt_version is None:
__cusparselt_version = _cusparselt.getVersionInt()
return True
else:
def _init():
return False
def version() -> Optional[int]:
"""Return the version of cuSPARSELt"""
if not _init():
return None
return __cusparselt_version
def is_available() -> bool:
r"""Return a bool indicating if cuSPARSELt is currently available."""
return torch._C._has_cusparselt

View File

@ -1740,6 +1740,13 @@ PyObject* initModule() {
#endif
ASSERT_TRUE(set_module_attr("_has_cudnn", has_cudnn));
#if defined(USE_CUSPARSELT)
PyObject* has_cusparselt = Py_True;
#else
PyObject* has_cusparselt = Py_False;
#endif
ASSERT_TRUE(set_module_attr("_has_cusparselt", has_cusparselt));
#if AT_MKL_ENABLED() || AT_POCKETFFT_ENABLED()
PyObject* has_spectral = Py_True;
#else

View File

@ -1968,6 +1968,9 @@ void initGdsBindings(PyObject* module);
#if defined(USE_CUDNN) || defined(USE_ROCM)
void initCudnnBindings(PyObject* module);
#endif
#if defined(USE_CUSPARSELT)
void initCusparseltBindings(PyObject* module);
#endif
} // namespace shared
@ -1979,6 +1982,9 @@ void initModule(PyObject* module) {
shared::initNvtxBindings(module);
#if defined(USE_CUDNN) || defined(USE_ROCM)
shared::initCudnnBindings(module);
#endif
#if defined(USE_CUSPARSELT)
shared::initCusparseltBindings(module);
#endif
shared::initGdsBindings(module);
registerCudaDeviceProperties(module);

View File

@ -0,0 +1,23 @@
#include <torch/csrc/utils/pybind.h>
#ifdef USE_CUSPARSELT
#include <cusparseLt.h>
namespace {
size_t getVersionInt() {
return CUSPARSELT_VERSION;
}
} // namespace
namespace torch::cuda::shared {
void initCusparseltBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
auto cusparselt = m.def_submodule("_cusparselt", "libcusparselt.so bindings");
cusparselt.def("getVersionInt", getVersionInt);
}
} // namespace torch::cuda::shared
#endif