Files
pytorch/torch/sparse
Jesse Cai 39bfba3f56 [sparse] add search for optimal alg_id to torch.compile (#137427)
Summary:

This PR adds a lowering for `torch._cslt_sparse_mm` to find the optimal
alg_id and cache it when running with `torch.compile`

Seeing speedups on both bfloat16 and float8 dtypes:
<img width="641" alt="Screenshot 2024-10-17 at 2 10 38 PM" src="https://github.com/user-attachments/assets/b928cd11-32a3-43e5-b209-8e4028896f0b">
<img width="1274" alt="Screenshot 2024-10-17 at 1 39 03 PM" src="https://github.com/user-attachments/assets/d9edd684-a8ec-46fd-b3da-2e76dbcb7bb6">

* `torch._cslt_sparse_mm_search` has been modified to return optimal
  split-k parameters as well as max alg_id.

* max_id is now available in `torch.backends.cusparselt` via
  `torch.backends.cusparselt.get_max_alg_id()`

* fixed meta registrations for float8

Test Plan:

python test/test_sparse_semi_structured.py

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137427
Approved by: https://github.com/cpuhrsch
2024-10-22 22:39:42 +00:00
..