Files
pytorch/torch/sparse
Jesse Cai bc21689136 [sparse][semi-structured] Add float8 dtype support to 24 sparsity (#136397)
Summary:

This PR adds `torch.float8e4m3fn` support to cuSPARSELt and `to_sparse_semi_structured`.

This will let users to run fp8 + 2:4 sparse matmuls on Hopper GPUs with
cusparselt >= 0.6.2, via to `scaled_mm` API.

```
A = rand_sparse_semi_structured_mask(256, 128, dtype=torch.float16)
B = torch.rand(dense_input_shape, device=device).to(torch.float16).t()

A_fp8, A_scale = to_float8(A)
B_fp8, B_scale = to_float8(B)

dense_result = torch._scaled_mm(
    A_fp8, B_fp8,
    scale_a=A_scale, scale_b=B_scale,
    out_dtype=out_dtype
)
A_fp8_sparse = to_sparse_semi_structured(A_fp8)
sparse_result = torch._scaled_mm(
    A_fp8_sparse, B_fp8,
    scale_a=A_scale, scale_b=B_scale,
    out_dtype=out_dtype
)
```

Note that to keep this consistent with normal torch behavior, calling
`torch.mm(A_fp8_sparse, B_fp8)` will raise a NotImplementedError.

I also turned on cuSPARSELt by default and added CUSPARSELT_MAX_ID to the
backend to make the tests a bit cleaner

Test Plan:
```
python test/test_sparse_semi_structured -k scaled_mm
python test/test_sparse_semi_structured -k fp8
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136397
Approved by: https://github.com/drisspg
2024-09-27 21:37:34 +00:00
..