mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[sparse] Add cuSPARSELt as a backend (#128534)
Summary: This PR adds in cuSPARSELt as a backend to PyTorch. It is now possible to see if cuSPARSELt is available and the version if it is with ``` torch.backends.cusparselt.is_available() torch.backends.cusparselt.version() ``` Test Plan: ``` python test/test_sparse_semi_structured.py -k test_cusparselt_backend ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/128534 Approved by: https://github.com/cpuhrsch, https://github.com/eqy, https://github.com/syed-ahmed
This commit is contained in:
committed by
PyTorch MergeBot
parent
0870398fa8
commit
255cd75a97
@ -21,7 +21,7 @@ from torch.sparse._semi_structured_conversions import (
|
||||
)
|
||||
|
||||
from torch.testing import make_tensor
|
||||
|
||||
from torch.testing._internal.common_cuda import _get_torch_cuda_version
|
||||
from torch.testing._internal.common_device_type import (
|
||||
dtypes,
|
||||
instantiate_device_type_tests,
|
||||
@ -29,7 +29,6 @@ from torch.testing._internal.common_device_type import (
|
||||
|
||||
from torch.testing._internal.common_dtype import all_types_and_complex
|
||||
import torch._dynamo.test_case
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
parametrize,
|
||||
run_tests,
|
||||
@ -52,13 +51,9 @@ if torch.cuda.is_available():
|
||||
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
|
||||
SEMI_STRUCTURED_SUPPORTED_BACKENDS["cutlass"] = SparseSemiStructuredTensorCUTLASS
|
||||
|
||||
# check if cslt is available for now using this:
|
||||
# TODO when we add cusparselt as a backend, we can update this to be use torch.cusparselt.is_available()
|
||||
try:
|
||||
torch._cslt_compress(torch.ones(128, 256).cuda())
|
||||
# add cuSPASRELt tests if available
|
||||
if torch.backends.cusparselt.is_available():
|
||||
SEMI_STRUCTURED_SUPPORTED_BACKENDS["cusparselt"] = SparseSemiStructuredTensorCUSPARSELT
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
inference_dtypes = dtypes(torch.float16, torch.bfloat16, torch.float32, torch.int8)
|
||||
training_dtypes = dtypes(torch.float16, torch.bfloat16)
|
||||
@ -1113,6 +1108,22 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
||||
# in cuSPARSELt v0.5.0 there are only 4 alg_ids total, so we should remove the +1 here when we update.
|
||||
assert alg_id in range(CUSPARSELT_NUM_ALG_IDS + 1)
|
||||
|
||||
def test_cusparselt_backend(self):
|
||||
version = _get_torch_cuda_version()
|
||||
assert torch.backends.cusparselt.is_available()
|
||||
|
||||
# CUDA 11.8 has cuSPARSELt v0.4.0 support
|
||||
if version == (11, 8):
|
||||
assert torch.backends.cusparselt.version() == 400
|
||||
# CUDA 12.1+ has cuSPARSELt v0.5.2 support added here: https://github.com/pytorch/builder/pull/1672/files
|
||||
elif version == (12, 1):
|
||||
assert torch.backends.cusparselt.version() == 502
|
||||
elif version > (12, 1):
|
||||
assert torch.backends.cusparselt.version() == 502
|
||||
else:
|
||||
assert torch.backends.cusparselt.version() is None
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestSparseSemiStructured, globals(), only_for="cuda")
|
||||
instantiate_device_type_tests(TestSparseSemiStructuredCUTLASS, globals(), only_for="cuda")
|
||||
instantiate_device_type_tests(TestSparseSemiStructuredCUSPARSELT, globals(), only_for="cuda")
|
||||
|
Reference in New Issue
Block a user