[sparse] add search for optimal alg_id to torch.compile (#137427)

Summary:

This PR adds a lowering for `torch._cslt_sparse_mm` to find the optimal
alg_id and cache it when running with `torch.compile`

Seeing speedups on both bfloat16 and float8 dtypes:
<img width="641" alt="Screenshot 2024-10-17 at 2 10 38 PM" src="https://github.com/user-attachments/assets/b928cd11-32a3-43e5-b209-8e4028896f0b">
<img width="1274" alt="Screenshot 2024-10-17 at 1 39 03 PM" src="https://github.com/user-attachments/assets/d9edd684-a8ec-46fd-b3da-2e76dbcb7bb6">

* `torch._cslt_sparse_mm_search` has been modified to return optimal
  split-k parameters as well as max alg_id.

* max_id is now available in `torch.backends.cusparselt` via
  `torch.backends.cusparselt.get_max_alg_id()`

* fixed meta registrations for float8

Test Plan:

python test/test_sparse_semi_structured.py

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137427
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Jesse Cai
2024-10-22 12:02:44 -07:00
committed by PyTorch MergeBot
parent b4cfb9c014
commit 39bfba3f56
10 changed files with 343 additions and 314 deletions

View File

@ -244,18 +244,17 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
def test_sp24_compile(self) -> None:
x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True)
e = torch.eye(x.shape[0], x.shape[0], device="cuda", dtype=torch.float16)
def fn(x, e):
def fn(x):
y = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(x)
y = y.t()
return x @ y
# Eager
output = fn(x, e)
output = fn(x)
output.backward(output)
# Torch compile
output = torch.compile(fn)(x, e)
output = torch.compile(fn)(x)
output.backward(output)
class TestSparseSemiStructured(TestCase):
@ -1156,8 +1155,9 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
B = torch.ones((128, 128), device=device).to(dtype)
A_compressed = torch._cslt_compress(A)
alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t())
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), alg_id=alg_id)
alg_id, split_k, split_k_one_kernel, _ = torch._C._cusparselt.mm_search(A_compressed, B.t(), None, None, None, False)
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(),
alg_id=alg_id, split_k=split_k, split_k_one_kernel=split_k_one_kernel)
dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32))
dense_result = dense_result.to(dtype)
@ -1174,6 +1174,16 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t())
assert alg_id in range(torch.backends.cusparselt.get_max_alg_id())
@inference_dtypes
def test_csrc_cslt_sparse_mm_search(self, device, dtype):
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
A_compressed = torch._cslt_compress(A)
B = torch.ones((128, 128), device=device).to(dtype)
A_compressed = torch._cslt_compress(A)
alg_id, _, _, _ = torch._C._cusparselt.mm_search(A_compressed, B.t(), None, None, None, False)
assert alg_id in range(torch.backends.cusparselt.get_max_alg_id())
def test_cusparselt_backend(self):
version = _get_torch_cuda_version()
assert torch.backends.cusparselt.is_available()
@ -1181,9 +1191,11 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
# CUDA 11.8 has cuSPARSELt v0.4.0 support
if version == (11, 8):
assert torch.backends.cusparselt.version() == 400
assert torch.backends.cusparselt.get_max_alg_id() == 4
# CUDA 12.1 has cuSPARSELt v0.5.2 support
elif version == (12, 1):
assert torch.backends.cusparselt.version() == 502
assert torch.backends.cusparselt.get_max_alg_id() == 4
# CUDA 12.4+ has cuSPARSELt v0.6.2 support
elif version >= (12, 4):
assert torch.backends.cusparselt.version() == 602