mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: This PR adds `torch.float8e4m3fn` support to cuSPARSELt and `to_sparse_semi_structured`. This will let users to run fp8 + 2:4 sparse matmuls on Hopper GPUs with cusparselt >= 0.6.2, via to `scaled_mm` API. ``` A = rand_sparse_semi_structured_mask(256, 128, dtype=torch.float16) B = torch.rand(dense_input_shape, device=device).to(torch.float16).t() A_fp8, A_scale = to_float8(A) B_fp8, B_scale = to_float8(B) dense_result = torch._scaled_mm( A_fp8, B_fp8, scale_a=A_scale, scale_b=B_scale, out_dtype=out_dtype ) A_fp8_sparse = to_sparse_semi_structured(A_fp8) sparse_result = torch._scaled_mm( A_fp8_sparse, B_fp8, scale_a=A_scale, scale_b=B_scale, out_dtype=out_dtype ) ``` Note that to keep this consistent with normal torch behavior, calling `torch.mm(A_fp8_sparse, B_fp8)` will raise a NotImplementedError. I also turned on cuSPARSELt by default and added CUSPARSELT_MAX_ID to the backend to make the tests a bit cleaner Test Plan: ``` python test/test_sparse_semi_structured -k scaled_mm python test/test_sparse_semi_structured -k fp8 ``` Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/136397 Approved by: https://github.com/drisspg