mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[Sparsity] add support for H100 compute capability 9.x (#121768)"
This reverts commit 91fdaa1b416ab8ac8be30f3c3428751e236657cd. Reverted https://github.com/pytorch/pytorch/pull/121768 on behalf of https://github.com/jeanschmidt due to Agreed on reverting and fixing rocm tests ([comment](https://github.com/pytorch/pytorch/pull/121768#issuecomment-2011893826))
This commit is contained in:
@ -628,12 +628,12 @@ Tensor _sparse_semi_structured_linear(
|
||||
"_sparse_semi_structured_linear: Setting out_dtype is only "
|
||||
"supported for int8 input and int32 output");
|
||||
|
||||
// For now, only CC 8.x and 9.x devices are supported.
|
||||
// For now, only CC 8.x devices are supported.
|
||||
const auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
const auto is_sm8x_sm9x = dprops->major == 8 || dprops->major == 9;
|
||||
TORCH_CHECK(is_sm8x_sm9x,
|
||||
const auto is_sm8x = dprops->major == 8;
|
||||
TORCH_CHECK(is_sm8x,
|
||||
"_sparse_semi_structured_linear: Supported only on GPUs with "
|
||||
"compute capability 8.x and 9.x");
|
||||
"compute capability 8.x");
|
||||
|
||||
// Validate datatypes of input tensors.
|
||||
TORCH_CHECK(tensor_a.dtype() == at::kChar ||
|
||||
|
@ -41,9 +41,9 @@ CUSPARSELT_MIXED_DTYPE_SUPPORT = [torch.float16, torch.bfloat16, torch.int32]
|
||||
SEMI_STRUCTURED_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16, torch.float32, torch.int8]
|
||||
SEMI_STRUCTURED_SUPPORTED_BACKENDS = []
|
||||
|
||||
_IS_SM8X_SM9X = False
|
||||
_IS_SM8X = False
|
||||
if torch.cuda.is_available():
|
||||
_IS_SM8X_SM9X = torch.cuda.get_device_capability(0)[0] in {8, 9}
|
||||
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
|
||||
SEMI_STRUCTURED_SUPPORTED_BACKENDS.append("cutlass")
|
||||
|
||||
# check if cslt is available for now using this:
|
||||
@ -146,8 +146,8 @@ def rand_sparse_semi_structured_all_patterns(r, c, dtype, device):
|
||||
class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
if not _IS_SM8X_SM9X:
|
||||
self.skipTest('Only runs on SM80 and SM90')
|
||||
if not _IS_SM8X:
|
||||
self.skipTest('Only runs on SM80')
|
||||
super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
@ -219,8 +219,8 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
|
||||
class TestSparseSemiStructured(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
if not _IS_SM8X_SM9X:
|
||||
self.skipTest('Only runs on SM80 and SM90')
|
||||
if not _IS_SM8X:
|
||||
self.skipTest('Only runs on SM80')
|
||||
|
||||
@dtypes(*SEMI_STRUCTURED_SUPPORTED_DTYPES)
|
||||
@parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
|
||||
@ -632,8 +632,8 @@ class TestCUSPARSELT(TestCase):
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
if not _IS_SM8X_SM9X:
|
||||
self.skipTest('Only runs on SM80 and SM90')
|
||||
if not _IS_SM8X:
|
||||
self.skipTest('Only runs on SM80')
|
||||
if "cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
|
||||
self.skipTest('cuSPARSELt not enabled')
|
||||
else:
|
||||
|
Reference in New Issue
Block a user