mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[sparse] Add cuSPARSELt as a backend (#128534)
Summary: This PR adds in cuSPARSELt as a backend to PyTorch. It is now possible to see if cuSPARSELt is available and the version if it is with ``` torch.backends.cusparselt.is_available() torch.backends.cusparselt.version() ``` Test Plan: ``` python test/test_sparse_semi_structured.py -k test_cusparselt_backend ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/128534 Approved by: https://github.com/cpuhrsch, https://github.com/eqy, https://github.com/syed-ahmed
This commit is contained in:
committed by
PyTorch MergeBot
parent
0870398fa8
commit
255cd75a97
@ -780,6 +780,7 @@ libtorch_python_cuda_core_sources = [
|
|||||||
libtorch_python_cuda_sources = libtorch_python_cuda_core_sources + [
|
libtorch_python_cuda_sources = libtorch_python_cuda_core_sources + [
|
||||||
"torch/csrc/cuda/python_nccl.cpp",
|
"torch/csrc/cuda/python_nccl.cpp",
|
||||||
"torch/csrc/cuda/shared/cudnn.cpp",
|
"torch/csrc/cuda/shared/cudnn.cpp",
|
||||||
|
"torch/csrc/cuda/shared/cusparselt.cpp",
|
||||||
"torch/csrc/cuda/Tensor.cpp",
|
"torch/csrc/cuda/Tensor.cpp",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ These backends include:
|
|||||||
- ``torch.backends.cpu``
|
- ``torch.backends.cpu``
|
||||||
- ``torch.backends.cuda``
|
- ``torch.backends.cuda``
|
||||||
- ``torch.backends.cudnn``
|
- ``torch.backends.cudnn``
|
||||||
|
- ``torch.backends.cusparselt``
|
||||||
- ``torch.backends.mha``
|
- ``torch.backends.mha``
|
||||||
- ``torch.backends.mps``
|
- ``torch.backends.mps``
|
||||||
- ``torch.backends.mkl``
|
- ``torch.backends.mkl``
|
||||||
@ -135,6 +136,13 @@ torch.backends.cudnn
|
|||||||
|
|
||||||
.. py:module:: torch.backends.cudnn.rnn
|
.. py:module:: torch.backends.cudnn.rnn
|
||||||
|
|
||||||
|
torch.backends.cusparselt
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
.. automodule:: torch.backends.cusparselt
|
||||||
|
|
||||||
|
.. autofunction:: torch.backends.cusparselt.version
|
||||||
|
|
||||||
|
.. autofunction:: torch.backends.cusparselt.is_available
|
||||||
|
|
||||||
torch.backends.mha
|
torch.backends.mha
|
||||||
^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^
|
||||||
|
@ -21,7 +21,7 @@ from torch.sparse._semi_structured_conversions import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from torch.testing import make_tensor
|
from torch.testing import make_tensor
|
||||||
|
from torch.testing._internal.common_cuda import _get_torch_cuda_version
|
||||||
from torch.testing._internal.common_device_type import (
|
from torch.testing._internal.common_device_type import (
|
||||||
dtypes,
|
dtypes,
|
||||||
instantiate_device_type_tests,
|
instantiate_device_type_tests,
|
||||||
@ -29,7 +29,6 @@ from torch.testing._internal.common_device_type import (
|
|||||||
|
|
||||||
from torch.testing._internal.common_dtype import all_types_and_complex
|
from torch.testing._internal.common_dtype import all_types_and_complex
|
||||||
import torch._dynamo.test_case
|
import torch._dynamo.test_case
|
||||||
|
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
parametrize,
|
parametrize,
|
||||||
run_tests,
|
run_tests,
|
||||||
@ -52,13 +51,9 @@ if torch.cuda.is_available():
|
|||||||
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
|
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
|
||||||
SEMI_STRUCTURED_SUPPORTED_BACKENDS["cutlass"] = SparseSemiStructuredTensorCUTLASS
|
SEMI_STRUCTURED_SUPPORTED_BACKENDS["cutlass"] = SparseSemiStructuredTensorCUTLASS
|
||||||
|
|
||||||
# check if cslt is available for now using this:
|
# add cuSPASRELt tests if available
|
||||||
# TODO when we add cusparselt as a backend, we can update this to be use torch.cusparselt.is_available()
|
if torch.backends.cusparselt.is_available():
|
||||||
try:
|
|
||||||
torch._cslt_compress(torch.ones(128, 256).cuda())
|
|
||||||
SEMI_STRUCTURED_SUPPORTED_BACKENDS["cusparselt"] = SparseSemiStructuredTensorCUSPARSELT
|
SEMI_STRUCTURED_SUPPORTED_BACKENDS["cusparselt"] = SparseSemiStructuredTensorCUSPARSELT
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
inference_dtypes = dtypes(torch.float16, torch.bfloat16, torch.float32, torch.int8)
|
inference_dtypes = dtypes(torch.float16, torch.bfloat16, torch.float32, torch.int8)
|
||||||
training_dtypes = dtypes(torch.float16, torch.bfloat16)
|
training_dtypes = dtypes(torch.float16, torch.bfloat16)
|
||||||
@ -1113,6 +1108,22 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
|||||||
# in cuSPARSELt v0.5.0 there are only 4 alg_ids total, so we should remove the +1 here when we update.
|
# in cuSPARSELt v0.5.0 there are only 4 alg_ids total, so we should remove the +1 here when we update.
|
||||||
assert alg_id in range(CUSPARSELT_NUM_ALG_IDS + 1)
|
assert alg_id in range(CUSPARSELT_NUM_ALG_IDS + 1)
|
||||||
|
|
||||||
|
def test_cusparselt_backend(self):
|
||||||
|
version = _get_torch_cuda_version()
|
||||||
|
assert torch.backends.cusparselt.is_available()
|
||||||
|
|
||||||
|
# CUDA 11.8 has cuSPARSELt v0.4.0 support
|
||||||
|
if version == (11, 8):
|
||||||
|
assert torch.backends.cusparselt.version() == 400
|
||||||
|
# CUDA 12.1+ has cuSPARSELt v0.5.2 support added here: https://github.com/pytorch/builder/pull/1672/files
|
||||||
|
elif version == (12, 1):
|
||||||
|
assert torch.backends.cusparselt.version() == 502
|
||||||
|
elif version > (12, 1):
|
||||||
|
assert torch.backends.cusparselt.version() == 502
|
||||||
|
else:
|
||||||
|
assert torch.backends.cusparselt.version() is None
|
||||||
|
|
||||||
|
|
||||||
instantiate_device_type_tests(TestSparseSemiStructured, globals(), only_for="cuda")
|
instantiate_device_type_tests(TestSparseSemiStructured, globals(), only_for="cuda")
|
||||||
instantiate_device_type_tests(TestSparseSemiStructuredCUTLASS, globals(), only_for="cuda")
|
instantiate_device_type_tests(TestSparseSemiStructuredCUTLASS, globals(), only_for="cuda")
|
||||||
instantiate_device_type_tests(TestSparseSemiStructuredCUSPARSELT, globals(), only_for="cuda")
|
instantiate_device_type_tests(TestSparseSemiStructuredCUSPARSELT, globals(), only_for="cuda")
|
||||||
|
@ -135,6 +135,11 @@ if(USE_CUDA)
|
|||||||
list(APPEND TORCH_PYTHON_LINK_LIBRARIES torch::cudnn)
|
list(APPEND TORCH_PYTHON_LINK_LIBRARIES torch::cudnn)
|
||||||
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_CUDNN)
|
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_CUDNN)
|
||||||
endif()
|
endif()
|
||||||
|
if(USE_CUSPARSELT)
|
||||||
|
list(APPEND TORCH_PYTHON_LINK_LIBRARIES torch::cusparselt)
|
||||||
|
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_CUSPARSELT)
|
||||||
|
endif()
|
||||||
|
|
||||||
if(TARGET torch::nvtx3)
|
if(TARGET torch::nvtx3)
|
||||||
list(APPEND TORCH_PYTHON_LINK_LIBRARIES torch::nvtx3)
|
list(APPEND TORCH_PYTHON_LINK_LIBRARIES torch::nvtx3)
|
||||||
else()
|
else()
|
||||||
@ -174,6 +179,10 @@ if(USE_CUDNN OR USE_ROCM)
|
|||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(USE_CUSPARSELT)
|
||||||
|
list(APPEND TORCH_PYTHON_SRCS ${TORCH_SRC_DIR}/csrc/cuda/shared/cusparselt.cpp)
|
||||||
|
endif()
|
||||||
|
|
||||||
if(USE_MPS)
|
if(USE_MPS)
|
||||||
list(APPEND TORCH_PYTHON_SRCS ${MPS_PYTHON_SRCS})
|
list(APPEND TORCH_PYTHON_SRCS ${MPS_PYTHON_SRCS})
|
||||||
endif()
|
endif()
|
||||||
|
@ -1299,6 +1299,7 @@ _has_magma: _bool
|
|||||||
_has_xpu: _bool
|
_has_xpu: _bool
|
||||||
_has_mkldnn: _bool
|
_has_mkldnn: _bool
|
||||||
_has_cudnn: _bool
|
_has_cudnn: _bool
|
||||||
|
_has_cusparselt: _bool
|
||||||
has_spectral: _bool
|
has_spectral: _bool
|
||||||
_GLIBCXX_USE_CXX11_ABI: _bool
|
_GLIBCXX_USE_CXX11_ABI: _bool
|
||||||
default_generator: Generator
|
default_generator: Generator
|
||||||
|
1
torch/_C/_cusparselt.pyi
Normal file
1
torch/_C/_cusparselt.pyi
Normal file
@ -0,0 +1 @@
|
|||||||
|
def getVersionInt() -> int: ...
|
@ -61,6 +61,7 @@ from torch.backends import (
|
|||||||
cpu as cpu,
|
cpu as cpu,
|
||||||
cuda as cuda,
|
cuda as cuda,
|
||||||
cudnn as cudnn,
|
cudnn as cudnn,
|
||||||
|
cusparselt as cusparselt,
|
||||||
mha as mha,
|
mha as mha,
|
||||||
mkl as mkl,
|
mkl as mkl,
|
||||||
mkldnn as mkldnn,
|
mkldnn as mkldnn,
|
||||||
|
42
torch/backends/cusparselt/__init__.py
Normal file
42
torch/backends/cusparselt/__init__.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
# mypy: allow-untyped-defs
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"version",
|
||||||
|
"is_available",
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
from torch._C import _cusparselt
|
||||||
|
except ImportError:
|
||||||
|
_cusparselt = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
__cusparselt_version: Optional[int] = None
|
||||||
|
|
||||||
|
if _cusparselt is not None:
|
||||||
|
|
||||||
|
def _init():
|
||||||
|
global __cusparselt_version
|
||||||
|
if __cusparselt_version is None:
|
||||||
|
__cusparselt_version = _cusparselt.getVersionInt()
|
||||||
|
return True
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def _init():
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def version() -> Optional[int]:
|
||||||
|
"""Return the version of cuSPARSELt"""
|
||||||
|
if not _init():
|
||||||
|
return None
|
||||||
|
return __cusparselt_version
|
||||||
|
|
||||||
|
|
||||||
|
def is_available() -> bool:
|
||||||
|
r"""Return a bool indicating if cuSPARSELt is currently available."""
|
||||||
|
return torch._C._has_cusparselt
|
@ -1740,6 +1740,13 @@ PyObject* initModule() {
|
|||||||
#endif
|
#endif
|
||||||
ASSERT_TRUE(set_module_attr("_has_cudnn", has_cudnn));
|
ASSERT_TRUE(set_module_attr("_has_cudnn", has_cudnn));
|
||||||
|
|
||||||
|
#if defined(USE_CUSPARSELT)
|
||||||
|
PyObject* has_cusparselt = Py_True;
|
||||||
|
#else
|
||||||
|
PyObject* has_cusparselt = Py_False;
|
||||||
|
#endif
|
||||||
|
ASSERT_TRUE(set_module_attr("_has_cusparselt", has_cusparselt));
|
||||||
|
|
||||||
#if AT_MKL_ENABLED() || AT_POCKETFFT_ENABLED()
|
#if AT_MKL_ENABLED() || AT_POCKETFFT_ENABLED()
|
||||||
PyObject* has_spectral = Py_True;
|
PyObject* has_spectral = Py_True;
|
||||||
#else
|
#else
|
||||||
|
@ -1968,6 +1968,9 @@ void initGdsBindings(PyObject* module);
|
|||||||
#if defined(USE_CUDNN) || defined(USE_ROCM)
|
#if defined(USE_CUDNN) || defined(USE_ROCM)
|
||||||
void initCudnnBindings(PyObject* module);
|
void initCudnnBindings(PyObject* module);
|
||||||
#endif
|
#endif
|
||||||
|
#if defined(USE_CUSPARSELT)
|
||||||
|
void initCusparseltBindings(PyObject* module);
|
||||||
|
#endif
|
||||||
|
|
||||||
} // namespace shared
|
} // namespace shared
|
||||||
|
|
||||||
@ -1979,6 +1982,9 @@ void initModule(PyObject* module) {
|
|||||||
shared::initNvtxBindings(module);
|
shared::initNvtxBindings(module);
|
||||||
#if defined(USE_CUDNN) || defined(USE_ROCM)
|
#if defined(USE_CUDNN) || defined(USE_ROCM)
|
||||||
shared::initCudnnBindings(module);
|
shared::initCudnnBindings(module);
|
||||||
|
#endif
|
||||||
|
#if defined(USE_CUSPARSELT)
|
||||||
|
shared::initCusparseltBindings(module);
|
||||||
#endif
|
#endif
|
||||||
shared::initGdsBindings(module);
|
shared::initGdsBindings(module);
|
||||||
registerCudaDeviceProperties(module);
|
registerCudaDeviceProperties(module);
|
||||||
|
23
torch/csrc/cuda/shared/cusparselt.cpp
Normal file
23
torch/csrc/cuda/shared/cusparselt.cpp
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
#include <torch/csrc/utils/pybind.h>
|
||||||
|
|
||||||
|
#ifdef USE_CUSPARSELT
|
||||||
|
#include <cusparseLt.h>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
size_t getVersionInt() {
|
||||||
|
return CUSPARSELT_VERSION;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace torch::cuda::shared {
|
||||||
|
|
||||||
|
void initCusparseltBindings(PyObject* module) {
|
||||||
|
auto m = py::handle(module).cast<py::module>();
|
||||||
|
auto cusparselt = m.def_submodule("_cusparselt", "libcusparselt.so bindings");
|
||||||
|
cusparselt.def("getVersionInt", getVersionInt);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace torch::cuda::shared
|
||||||
|
#endif
|
Reference in New Issue
Block a user