[sparse] Add cuSPARSELt as a backend (#128534)

Summary:

This PR adds in cuSPARSELt as a backend to PyTorch.

It is now possible to see if cuSPARSELt is available and the version if
it is with
```
torch.backends.cusparselt.is_available()
torch.backends.cusparselt.version()
```

Test Plan:
```
python test/test_sparse_semi_structured.py -k test_cusparselt_backend
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128534
Approved by: https://github.com/cpuhrsch, https://github.com/eqy, https://github.com/syed-ahmed
This commit is contained in:
Jesse Cai
2024-08-21 10:13:14 -07:00
committed by PyTorch MergeBot
parent 0870398fa8
commit 255cd75a97
11 changed files with 118 additions and 8 deletions

View File

@ -21,7 +21,7 @@ from torch.sparse._semi_structured_conversions import (
)
from torch.testing import make_tensor
from torch.testing._internal.common_cuda import _get_torch_cuda_version
from torch.testing._internal.common_device_type import (
dtypes,
instantiate_device_type_tests,
@ -29,7 +29,6 @@ from torch.testing._internal.common_device_type import (
from torch.testing._internal.common_dtype import all_types_and_complex
import torch._dynamo.test_case
from torch.testing._internal.common_utils import (
parametrize,
run_tests,
@ -52,13 +51,9 @@ if torch.cuda.is_available():
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
SEMI_STRUCTURED_SUPPORTED_BACKENDS["cutlass"] = SparseSemiStructuredTensorCUTLASS
# check if cslt is available for now using this:
# TODO when we add cusparselt as a backend, we can update this to be use torch.cusparselt.is_available()
try:
torch._cslt_compress(torch.ones(128, 256).cuda())
# add cuSPASRELt tests if available
if torch.backends.cusparselt.is_available():
SEMI_STRUCTURED_SUPPORTED_BACKENDS["cusparselt"] = SparseSemiStructuredTensorCUSPARSELT
except Exception:
pass
inference_dtypes = dtypes(torch.float16, torch.bfloat16, torch.float32, torch.int8)
training_dtypes = dtypes(torch.float16, torch.bfloat16)
@ -1113,6 +1108,22 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
# in cuSPARSELt v0.5.0 there are only 4 alg_ids total, so we should remove the +1 here when we update.
assert alg_id in range(CUSPARSELT_NUM_ALG_IDS + 1)
def test_cusparselt_backend(self):
version = _get_torch_cuda_version()
assert torch.backends.cusparselt.is_available()
# CUDA 11.8 has cuSPARSELt v0.4.0 support
if version == (11, 8):
assert torch.backends.cusparselt.version() == 400
# CUDA 12.1+ has cuSPARSELt v0.5.2 support added here: https://github.com/pytorch/builder/pull/1672/files
elif version == (12, 1):
assert torch.backends.cusparselt.version() == 502
elif version > (12, 1):
assert torch.backends.cusparselt.version() == 502
else:
assert torch.backends.cusparselt.version() is None
instantiate_device_type_tests(TestSparseSemiStructured, globals(), only_for="cuda")
instantiate_device_type_tests(TestSparseSemiStructuredCUTLASS, globals(), only_for="cuda")
instantiate_device_type_tests(TestSparseSemiStructuredCUSPARSELT, globals(), only_for="cuda")