diff --git a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu index e997f49f3f43..35ae7dc102e7 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu @@ -628,12 +628,12 @@ Tensor _sparse_semi_structured_linear( "_sparse_semi_structured_linear: Setting out_dtype is only " "supported for int8 input and int32 output"); - // For now, only CC 8.x devices are supported. + // For now, only CC 8.x and 9.x devices are supported. const auto dprops = at::cuda::getCurrentDeviceProperties(); - const auto is_sm8x = dprops->major == 8; - TORCH_CHECK(is_sm8x, + const auto is_sm8x_sm9x = dprops->major == 8 || dprops->major == 9; + TORCH_CHECK(is_sm8x_sm9x, "_sparse_semi_structured_linear: Supported only on GPUs with " - "compute capability 8.x"); + "compute capability 8.x and 9.x"); // Validate datatypes of input tensors. TORCH_CHECK(tensor_a.dtype() == at::kChar || diff --git a/test/test_sparse_semi_structured.py b/test/test_sparse_semi_structured.py index aa9e0a97e168..6426273820dd 100644 --- a/test/test_sparse_semi_structured.py +++ b/test/test_sparse_semi_structured.py @@ -41,9 +41,9 @@ CUSPARSELT_MIXED_DTYPE_SUPPORT = [torch.float16, torch.bfloat16, torch.int32] SEMI_STRUCTURED_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16, torch.float32, torch.int8] SEMI_STRUCTURED_SUPPORTED_BACKENDS = [] -_IS_SM8X = False +_IS_SM8X_SM9X = False if torch.cuda.is_available(): - _IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8 + _IS_SM8X_SM9X = torch.cuda.get_device_capability(0)[0] in {8, 9} SEMI_STRUCTURED_SUPPORTED_BACKENDS.append("cutlass") # check if cslt is available for now using this: @@ -146,8 +146,8 @@ def rand_sparse_semi_structured_all_patterns(r, c, dtype, device): class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase): def setUp(self): - if not _IS_SM8X: - self.skipTest('Only runs on SM80') + if not _IS_SM8X_SM9X: + self.skipTest('Only runs on SM80 and SM90') super().setUp() def tearDown(self): @@ -219,8 +219,8 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase): class TestSparseSemiStructured(TestCase): def setUp(self): - if not _IS_SM8X: - self.skipTest('Only runs on SM80') + if not _IS_SM8X_SM9X: + self.skipTest('Only runs on SM80 and SM90') @dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES) @parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS) @@ -632,8 +632,8 @@ class TestCUSPARSELT(TestCase): """ def setUp(self): - if not _IS_SM8X: - self.skipTest('Only runs on SM80') + if not _IS_SM8X_SM9X: + self.skipTest('Only runs on SM80 and SM90') if "cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS: self.skipTest('cuSPARSELt not enabled') else: