mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[sparse] add extra options to _cslt_spare_mm (#137427)
Summary: Splitting this PR into two, one for the cuSPARSELt improvements, and one for the inductor lowering. This PR adds in the additional cuSPARSELt bindings into pytorch. * `torch._cslt_sparse_mm_search` will be deprecated in a future PR, so a warning has been added * Added a header file for cuSPARSELtOps.cpp * max_id is now available in `torch.backends.cusparselt` via `torch.backends.cusparselt.get_max_alg_id()` * fixed meta registrations for float8 Test Plan: python test/test_sparse_semi_structured.py Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/137427 Approved by: https://github.com/cpuhrsch, https://github.com/eqy
This commit is contained in:
committed by
PyTorch MergeBot
parent
02990fe36b
commit
f1451163ec
@ -244,18 +244,17 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
|
||||
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
|
||||
def test_sp24_compile(self) -> None:
|
||||
x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True)
|
||||
e = torch.eye(x.shape[0], x.shape[0], device="cuda", dtype=torch.float16)
|
||||
|
||||
def fn(x, e):
|
||||
def fn(x):
|
||||
y = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(x)
|
||||
y = y.t()
|
||||
return x @ y
|
||||
|
||||
# Eager
|
||||
output = fn(x, e)
|
||||
output = fn(x)
|
||||
output.backward(output)
|
||||
# Torch compile
|
||||
output = torch.compile(fn)(x, e)
|
||||
output = torch.compile(fn)(x)
|
||||
output.backward(output)
|
||||
|
||||
class TestSparseSemiStructured(TestCase):
|
||||
@ -1133,6 +1132,21 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
||||
|
||||
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
def test_cslt_sparse_mm_alpha_compile_autotune(self, device):
|
||||
A = torch.Tensor([0, 0, 1, 1]).tile((128, 64)).to(torch.int8).cuda()
|
||||
B = torch.ones((128, 256), device=device).to(torch.int8).t()
|
||||
alpha = torch.Tensor([2**(-i) for i in range(128)]).cuda()
|
||||
|
||||
A_compressed = torch._cslt_compress(A)
|
||||
compiled_sparse_mm = torch.compile(torch._cslt_sparse_mm, mode="max-autotune")
|
||||
sparse_result = compiled_sparse_mm(A_compressed, B, alpha=alpha, out_dtype=torch.int32)
|
||||
|
||||
alpha_scaled = torch.stack([alpha] * 128).t().cpu().float()
|
||||
dense_result = alpha_scaled * torch.mm(A.to(torch.int64).cpu(), B.to(torch.int64).cpu())
|
||||
dense_result = dense_result.to(torch.int32)
|
||||
|
||||
torch.testing.assert_close(sparse_result.cpu(), dense_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@parametrize("out_dtype", [torch.float16, torch.bfloat16, torch.int32])
|
||||
def test_cslt_sparse_mm_alpha_mixed_dtype(self, out_dtype, device):
|
||||
A = torch.Tensor([0, 0, 10, 10]).tile((128, 64)).to(torch.int8).cuda()
|
||||
@ -1149,21 +1163,6 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
||||
|
||||
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@inference_dtypes
|
||||
def test_cslt_sparse_mm_alg_id(self, device, dtype):
|
||||
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype)
|
||||
A_compressed = torch._cslt_compress(A)
|
||||
B = torch.ones((128, 128), device=device).to(dtype)
|
||||
|
||||
A_compressed = torch._cslt_compress(A)
|
||||
alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t())
|
||||
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), alg_id=alg_id)
|
||||
|
||||
dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32))
|
||||
dense_result = dense_result.to(dtype)
|
||||
|
||||
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@inference_dtypes
|
||||
def test_cslt_sparse_mm_search(self, device, dtype):
|
||||
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
||||
@ -1172,7 +1171,26 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
||||
|
||||
A_compressed = torch._cslt_compress(A)
|
||||
alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t())
|
||||
assert alg_id in range(torch.backends.cusparselt.get_max_alg_id())
|
||||
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), alg_id=alg_id)
|
||||
dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32))
|
||||
dense_result = dense_result.to(dtype)
|
||||
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
@inference_dtypes
|
||||
def test_csrc_cslt_sparse_mm_search(self, device, dtype):
|
||||
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
||||
A_compressed = torch._cslt_compress(A)
|
||||
B = torch.ones((128, 128), device=device).to(dtype)
|
||||
|
||||
A_compressed = torch._cslt_compress(A)
|
||||
alg_id, split_k, split_k_one_kernel, _ = torch._C._cusparselt.mm_search(A_compressed, B.t(), None, None, None, False)
|
||||
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(),
|
||||
alg_id=alg_id,
|
||||
split_k=split_k,
|
||||
split_k_one_kernel=split_k_one_kernel)
|
||||
dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32))
|
||||
dense_result = dense_result.to(dtype)
|
||||
torch.testing.assert_close(sparse_result, dense_result, rtol=1e-3, atol=1e-3)
|
||||
|
||||
def test_cusparselt_backend(self):
|
||||
version = _get_torch_cuda_version()
|
||||
|
Reference in New Issue
Block a user