[Sparsity] add support for H100 compute capability 9.x (#121768)

Summary: as title

Test Plan: buck test mode/opt //caffe2/test/...

Differential Revision: D54792168

@diff-train-skip-merge

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121768
Approved by: https://github.com/SherlockNoMad
This commit is contained in:
Luoshang Pan
2024-03-20 19:00:52 +00:00
committed by PyTorch MergeBot
parent d1e8b97387
commit 91fdaa1b41
2 changed files with 12 additions and 12 deletions

View File

@ -41,9 +41,9 @@ CUSPARSELT_MIXED_DTYPE_SUPPORT = [torch.float16, torch.bfloat16, torch.int32]
SEMI_STRUCTURED_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16, torch.float32, torch.int8]
SEMI_STRUCTURED_SUPPORTED_BACKENDS = []
_IS_SM8X = False
_IS_SM8X_SM9X = False
if torch.cuda.is_available():
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
_IS_SM8X_SM9X = torch.cuda.get_device_capability(0)[0] in {8, 9}
SEMI_STRUCTURED_SUPPORTED_BACKENDS.append("cutlass")
# check if cslt is available for now using this:
@ -146,8 +146,8 @@ def rand_sparse_semi_structured_all_patterns(r, c, dtype, device):
class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
def setUp(self):
if not _IS_SM8X:
self.skipTest('Only runs on SM80')
if not _IS_SM8X_SM9X:
self.skipTest('Only runs on SM80 and SM90')
super().setUp()
def tearDown(self):
@ -219,8 +219,8 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
class TestSparseSemiStructured(TestCase):
def setUp(self):
if not _IS_SM8X:
self.skipTest('Only runs on SM80')
if not _IS_SM8X_SM9X:
self.skipTest('Only runs on SM80 and SM90')
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
@ -632,8 +632,8 @@ class TestCUSPARSELT(TestCase):
"""
def setUp(self):
if not _IS_SM8X:
self.skipTest('Only runs on SM80')
if not _IS_SM8X_SM9X:
self.skipTest('Only runs on SM80 and SM90')
if "cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
self.skipTest('cuSPARSELt not enabled')
else: