mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[sparse] add search for optimal alg_id to torch.compile (#137427)
Summary: This PR adds a lowering for `torch._cslt_sparse_mm` to find the optimal alg_id and cache it when running with `torch.compile` Seeing speedups on both bfloat16 and float8 dtypes: <img width="641" alt="Screenshot 2024-10-17 at 2 10 38 PM" src="https://github.com/user-attachments/assets/b928cd11-32a3-43e5-b209-8e4028896f0b"> <img width="1274" alt="Screenshot 2024-10-17 at 1 39 03 PM" src="https://github.com/user-attachments/assets/d9edd684-a8ec-46fd-b3da-2e76dbcb7bb6"> * `torch._cslt_sparse_mm_search` has been modified to return optimal split-k parameters as well as max alg_id. * max_id is now available in `torch.backends.cusparselt` via `torch.backends.cusparselt.get_max_alg_id()` * fixed meta registrations for float8 Test Plan: python test/test_sparse_semi_structured.py Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/137427 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
b4cfb9c014
commit
39bfba3f56
@ -244,18 +244,17 @@ class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
|
||||
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
|
||||
def test_sp24_compile(self) -> None:
|
||||
x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True)
|
||||
e = torch.eye(x.shape[0], x.shape[0], device="cuda", dtype=torch.float16)
|
||||
|
||||
def fn(x, e):
|
||||
def fn(x):
|
||||
y = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(x)
|
||||
y = y.t()
|
||||
return x @ y
|
||||
|
||||
# Eager
|
||||
output = fn(x, e)
|
||||
output = fn(x)
|
||||
output.backward(output)
|
||||
# Torch compile
|
||||
output = torch.compile(fn)(x, e)
|
||||
output = torch.compile(fn)(x)
|
||||
output.backward(output)
|
||||
|
||||
class TestSparseSemiStructured(TestCase):
|
||||
@ -1156,8 +1155,9 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
||||
B = torch.ones((128, 128), device=device).to(dtype)
|
||||
|
||||
A_compressed = torch._cslt_compress(A)
|
||||
alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t())
|
||||
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(), alg_id=alg_id)
|
||||
alg_id, split_k, split_k_one_kernel, _ = torch._C._cusparselt.mm_search(A_compressed, B.t(), None, None, None, False)
|
||||
sparse_result = torch._cslt_sparse_mm(A_compressed, B.t(),
|
||||
alg_id=alg_id, split_k=split_k, split_k_one_kernel=split_k_one_kernel)
|
||||
|
||||
dense_result = torch.mm(A.to(torch.float32), B.to(torch.float32))
|
||||
dense_result = dense_result.to(dtype)
|
||||
@ -1174,6 +1174,16 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
||||
alg_id = torch._cslt_sparse_mm_search(A_compressed, B.t())
|
||||
assert alg_id in range(torch.backends.cusparselt.get_max_alg_id())
|
||||
|
||||
@inference_dtypes
|
||||
def test_csrc_cslt_sparse_mm_search(self, device, dtype):
|
||||
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
|
||||
A_compressed = torch._cslt_compress(A)
|
||||
B = torch.ones((128, 128), device=device).to(dtype)
|
||||
|
||||
A_compressed = torch._cslt_compress(A)
|
||||
alg_id, _, _, _ = torch._C._cusparselt.mm_search(A_compressed, B.t(), None, None, None, False)
|
||||
assert alg_id in range(torch.backends.cusparselt.get_max_alg_id())
|
||||
|
||||
def test_cusparselt_backend(self):
|
||||
version = _get_torch_cuda_version()
|
||||
assert torch.backends.cusparselt.is_available()
|
||||
@ -1181,9 +1191,11 @@ class TestSparseSemiStructuredCUSPARSELT(TestCase):
|
||||
# CUDA 11.8 has cuSPARSELt v0.4.0 support
|
||||
if version == (11, 8):
|
||||
assert torch.backends.cusparselt.version() == 400
|
||||
assert torch.backends.cusparselt.get_max_alg_id() == 4
|
||||
# CUDA 12.1 has cuSPARSELt v0.5.2 support
|
||||
elif version == (12, 1):
|
||||
assert torch.backends.cusparselt.version() == 502
|
||||
assert torch.backends.cusparselt.get_max_alg_id() == 4
|
||||
# CUDA 12.4+ has cuSPARSELt v0.6.2 support
|
||||
elif version >= (12, 4):
|
||||
assert torch.backends.cusparselt.version() == 602
|
||||
|
||||
Reference in New Issue
Block a user