mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: This PR adds a lowering for `torch._cslt_sparse_mm` to find the optimal alg_id and cache it when running with `torch.compile` Seeing speedups on both bfloat16 and float8 dtypes: <img width="641" alt="Screenshot 2024-10-17 at 2 10 38 PM" src="https://github.com/user-attachments/assets/b928cd11-32a3-43e5-b209-8e4028896f0b"> <img width="1274" alt="Screenshot 2024-10-17 at 1 39 03 PM" src="https://github.com/user-attachments/assets/d9edd684-a8ec-46fd-b3da-2e76dbcb7bb6"> * `torch._cslt_sparse_mm_search` has been modified to return optimal split-k parameters as well as max alg_id. * max_id is now available in `torch.backends.cusparselt` via `torch.backends.cusparselt.get_max_alg_id()` * fixed meta registrations for float8 Test Plan: python test/test_sparse_semi_structured.py Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/137427 Approved by: https://github.com/cpuhrsch