80 Commits

Author SHA1 Message Date
341c4227a8 Update F32 sparse semi-structured support for CUTLASS back-end (#116017)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116017
Approved by: https://github.com/jcaip
2023-12-22 16:53:04 +00:00
a8e354a9a0 [sparse][semi-structured] enable fp32 support, separate sparse and dense constraints (#115550)
Summary:

Both cuSPASRELt and CUTLASS support 1:2 semi-structured sparsity for
fp32, which this PR enables.(thanks @alexsamardzic).

Furthermore, this PR also updates the sparse_config to take into account
the different shape constraints for sparse and dense matrices.

Technically, cuSPARSELt supports smaller sparse matrix constraints as it
seens to pad to the CUTLASS constraints under the hood. However, in
practice small sparse matrices are not commonly used and we care more
about the dense constraints for LLM inference.

For now, we keep the CUTLASS constraints in place for both cuSPARSELt
and CUTLASS tensors

This PR also reconnects the _FUSE_TRANSPOSE flag for cuSPARSELt tensors.

Test Plan:
```
python test/test_sparse_semi_structured.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115550
Approved by: https://github.com/cpuhrsch
2023-12-15 02:28:17 +00:00
4471fe6c39 [sparse][semi-structured] add alg_id to _cslt_sparse_mm and _cslt_sparse_mm_search (#115178)
Summary:

cuSPARSELt has support for different alg_id, which are set via

`cusparseLTMatmulAlgSetAttribute`, in total there are 4 different
alg_ids, 0 - 3.

Previously we were just using the default alg_id, as from our initial
experiments we found that for most shapes the default alg_id is the
fastest and that they made no difference on numerical correctness, just
performance. From our previous experiments the fastest alg_id seemed to
differ only on small matmul shapes.

danthe3rd found a performance regression when running with
cuSPARSELt v0.4.0 vs v0.5.0, on LLM shapes, which match these
characteristics (activations are small, weights are large).

However it's likely that this is due to the alg_id ordering changing, as
mentioned in the release notes for v0.5.0.
```
cusparseLtMatmulAlgSelectionInit() does not ensure the same ordering of
algorithm id alg as in v0.4.0.
```

This PR adds in the following:
- support for passing in alg_id to _cslt_sparse_mm
- a new op, _cslt_sparse_mm_search, which returns the optimal alg_id for
  a given matmul

_cslt_sparse_mm_search has the same function signature as
_cslt_sparse_mm, minus the alg_id parameter.
We are able to achieve v0.4.0 performance with alg_id=1 on the shapes
that daniel provided.

We will address autoselecting the best alg_id in a future PR, possibly
with torch.compile.

Test Plan:
```
python test/test_sparse_semi_structured -k cslt
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115178
Approved by: https://github.com/cpuhrsch
2023-12-11 23:08:51 +00:00
40a14e07ef Revert "[sparse][semi-structured] add alg_id to _cslt_sparse_mm and _cslt_spasre_mm_search (#115178)"
This reverts commit 1e5636f7915035b09dce22ad1d2170a65f344214.

Reverted https://github.com/pytorch/pytorch/pull/115178 on behalf of https://github.com/huydhn due to Sorry for reverting your change, but the Window build failure looks legit 1e5636f791 ([comment](https://github.com/pytorch/pytorch/pull/115178#issuecomment-1850605711))
2023-12-11 18:07:17 +00:00
1e5636f791 [sparse][semi-structured] add alg_id to _cslt_sparse_mm and _cslt_spasre_mm_search (#115178)
Summary:

cuSPARSELt has support for different alg_id, which are set via

`cusparseLTMatmulAlgSetAttribute`, in total there are 4 different
alg_ids, 0 - 3.

Previously we were just using the default alg_id, as from our initial
experiments we found that for most shapes the default alg_id is the
fastest and that they made no difference on numerical correctness, just
performance. From our previous experiments the fastest alg_id seemed to
differ only on small matmul shapes.

danthe3rd found a performance regression when running with
cuSPARSELt v0.4.0 vs v0.5.0, on LLM shapes, which match these
characteristics (activations are small, weights are large).

However it's likely that this is due to the alg_id ordering changing, as
mentioned in the release notes for v0.5.0.
```
cusparseLtMatmulAlgSelectionInit() does not ensure the same ordering of
algorithm id alg as in v0.4.0.
```

This PR adds in the following:
- support for passing in alg_id to _cslt_sparse_mm
- a new op, _cslt_sparse_mm_search, which returns the optimal alg_id for
  a given matmul

_cslt_sparse_mm_search has the same function signature as
_cslt_sparse_mm, minus the alg_id parameter.
We are able to achieve v0.4.0 performance with alg_id=1 on the shapes
that daniel provided.

We will address autoselecting the best alg_id in a future PR, possibly
with torch.compile.

Test Plan:
```
python test/test_sparse_semi_structured -k cslt
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115178
Approved by: https://github.com/cpuhrsch
2023-12-11 15:47:28 +00:00
4cb7dd0fc9 [sparse][quant] Add support for vector alpha in cusparselt mm (#112056)
Summary:

This PR adds in support for passing in a alpha Tensor, which represents
a tensor of alpha values to fuse into the matmul.

```
cusparselt_sparse_mm = alpha A @ B + bias
```

This operation is necessary for quantization, where we would like to
fuse one of the dequant matmuls into the sparse op.

Test Plan:

```
python test/test_sparse_semi_structured -k alpha
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112056
Approved by: https://github.com/cpuhrsch
2023-12-04 16:56:06 +00:00
ae593d0393 [sparse][semi-structured][inductor] meta registrations for _cslt_sparse_mm + additional stride checking in test. (#114685)
_cslt_sparse_mm + additional stride checking in test.

Summary:

This PR adds in meta registrations for _cslt_sparse_mm.

Based on the work @drisspg did
in #114370.

Additionally, it updates the tests by checking that the strides of the
spare result and the result returned by sparse+compile are the same, to
avoid errors like those found in

https://github.com/pytorch/pytorch/pull/114477.

Test Plan:
```
python test/test_sparse_semi_structred -k compile_cusparselt
python test/test_sparse_semi_structred -k compile_cutlass
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114685
Approved by: https://github.com/alexsamardzic, https://github.com/drisspg
2023-11-29 00:31:52 +00:00
cef79c0df4 [inductor] _sparse_semi_structured_linear fallback - no meta registration; not on testing path (#114477)
Test was wrong in original PR and merged changes were never tested. Further, the sparse op was never actually compiled due to missing `fullgraph=True` and missing meta registration.

When meta is added as per this PR, it gives wrong answers when input needs to be padded and when input needs to be reshaped.

Is this something to do with the generated inductor code for:
```
 constant_pad_nd: "f16[32, 128]" = torch.ops.aten.constant_pad_nd.default(primals_3, [0, 0, 0, 31], 0.0)
...
slice_1: "f16[1, 128]" = torch.ops.aten.slice.Tensor(_sparse_semi_structured_linear, 0, 0, 1);  _sparse_semi_structured_linear = None
```
and

```
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         mul: "Sym(s0*s1)" = primals_4 * primals_5
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         view: "f16[s0*s1, 128]" = torch.ops.aten.view.default(primals_6, [mul, 128]);  primals_6 = mul = None
...
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         view_1: "f16[s0, s1, 128]" = torch.ops.aten.view.default(slice_1, [primals_4, primals_5, 128]);  slice_1 = None
```

Failing graphs:
Padded:
```
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO] TRACED GRAPH
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]  ===== Forward graph 5 =====
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]  <eval_with_key>.66 class GraphModule(torch.nn.Module):
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]     def forward(self, primals_1: "f16[128, 64]", primals_2: "i16[128, 8]", primals_3: "f16[1, 128]"):
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:145, code: x = self.linear(x)
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         constant_pad_nd: "f16[32, 128]" = torch.ops.aten.constant_pad_nd.default(primals_3, [0, 0, 0, 31], 0.0)
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         _sparse_semi_structured_linear: "f16[32, 128]" = torch.ops.aten._sparse_semi_structured_linear.default(constant_pad_nd, primals_1, primals_2);  constant_pad_nd = primals_1 = primals_2 = None
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         slice_1: "f16[1, 128]" = torch.ops.aten.slice.Tensor(_sparse_semi_structured_linear, 0, 0, 1);  _sparse_semi_structured_linear = None
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         slice_2: "f16[1, 128]" = torch.ops.aten.slice.Tensor(slice_1, 1, 0, 9223372036854775807);  slice_1 = None
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:147, code: return torch.nn.functional.relu(x)
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         relu: "f16[1, 128]" = torch.ops.aten.relu.default(slice_2);  slice_2 = None
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         alias: "f16[1, 128]" = torch.ops.aten.alias.default(relu)
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         alias_1: "f16[1, 128]" = torch.ops.aten.alias.default(alias);  alias = None
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         le: "b8[1, 128]" = torch.ops.aten.le.Scalar(alias_1, 0);  alias_1 = None
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:145, code: x = self.linear(x)
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         permute: "f16[128, 1]" = torch.ops.aten.permute.default(primals_3, [1, 0]);  primals_3 = None
[2023-11-23 13:59:51,102] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         return [relu, le, permute]

```

Reshape:

```
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]  <eval_with_key>.69 class GraphModule(torch.nn.Module):
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]     def forward(self, primals_1: "f16[128, 64]", primals_2: "i16[128, 8]", primals_3: "f16[128]", primals_4: "Sym(s0)", primals_5: "Sym(s1)", primals_6: "f16[s0, s1, 128]"):
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:145, code: x = self.linear(x)
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         mul: "Sym(s0*s1)" = primals_4 * primals_5
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         view: "f16[s0*s1, 128]" = torch.ops.aten.view.default(primals_6, [mul, 128]);  primals_6 = mul = None
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         _sparse_semi_structured_linear: "f16[s0*s1, 128]" = torch.ops.aten._sparse_semi_structured_linear.default(view, primals_1, primals_2, bias = primals_3);  primals_1 = primals_2 = primals_3 = None
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         slice_1: "f16[s0*s1, 128]" = torch.ops.aten.slice.Tensor(_sparse_semi_structured_linear, 1, 0, 9223372036854775807);  _sparse_semi_structured_linear = None
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         view_1: "f16[s0, s1, 128]" = torch.ops.aten.view.default(slice_1, [primals_4, primals_5, 128]);  slice_1 = None
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         # File: /home/jonch/Desktop/Programming/mlsys/pytorch/test/test_sparse_semi_structured.py:147, code: return torch.nn.functional.relu(x)
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         relu: "f16[s0, s1, 128]" = torch.ops.aten.relu.default(view_1);  view_1 = None
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         alias: "f16[s0, s1, 128]" = torch.ops.aten.alias.default(relu)
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         alias_1: "f16[s0, s1, 128]" = torch.ops.aten.alias.default(alias);  alias = None
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         le: "b8[s0, s1, 128]" = torch.ops.aten.le.Scalar(alias_1, 0);  alias_1 = None
[2023-11-23 14:01:03,463] [0/2] torch._functorch.aot_autograd.__aot_graphs: [INFO]         return [relu, view, le, primals_4, primals_5]

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114477
Approved by: https://github.com/jcaip
2023-11-28 19:35:05 +00:00
702aaf8aea [sparse] semi-structured sparse + torch.compile support (#111049)
Summary:

This PR adds in torch.compile support for semi-structured sparsity,
using the subclass tracing @bdhirsh added.

Based on wether we are using cuSPARSELt or CUTLASS, we return a
different representation of the inner tensors.

Test Plan:
```
python test/test_sparse_semi_structured.py -k compile
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111049
Approved by: https://github.com/cpuhrsch
2023-10-24 02:23:20 +00:00
41490119f2 Revert "[sparse] semi-structured sparse + torch.compile support (#111049)"
This reverts commit 408f210938176870133a3dde5e8fbc4926cafbc0.

Reverted https://github.com/pytorch/pytorch/pull/111049 on behalf of https://github.com/clee2000 due to Sorry I'm pretty sure this caused a memory leak 408f210938 https://github.com/pytorch/pytorch/actions/runs/6550388354/job/17790615103 `test_sparse_semi_structured.py::TestSparseSemiStructuredCUDA::test_mlp_contiguous_relu_compile_backend_cutlass_dense_input_shape_(1, 128)_cuda - RuntimeError: CUDA driver API confirmed a leak in __main__.TestSparseSemiStructuredCUDA.test_mlp_contiguous_relu_compile_backend_cutlass_dense_input_shape_(1, 128)_cuda! Caching allocator allocated memory was 235008 and is now reported as 352256 on device 0. CUDA driver allocated memory was 359333888 and is now 361431040.` ([comment](https://github.com/pytorch/pytorch/pull/111049#issuecomment-1767186569))
2023-10-17 21:11:09 +00:00
408f210938 [sparse] semi-structured sparse + torch.compile support (#111049)
Summary:

This PR adds in torch.compile support for semi-structured sparsity,
using the subclass tracing @bdhirsh added.

Based on wether we are using cuSPARSELt or CUTLASS, we return a
different representation of the inner tensors.

Test Plan:
```
python test/test_sparse_semi_structured.py -k compile
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111049
Approved by: https://github.com/cpuhrsch
ghstack dependencies: #110583
2023-10-16 23:07:26 +00:00
b4745d476c Revert "[sparse] semi-structured sparse + torch.compile support (#111049)"
This reverts commit ac02531babab028cb260d2225ff9e91e92df063b.

Reverted https://github.com/pytorch/pytorch/pull/111049 on behalf of https://github.com/DanilBaibak due to Broken trunk ([comment](https://github.com/pytorch/pytorch/pull/111049#issuecomment-1763795957))
2023-10-16 06:16:59 +00:00
ac02531bab [sparse] semi-structured sparse + torch.compile support (#111049)
Summary:

This PR adds in torch.compile support for semi-structured sparsity,
using the subclass tracing @bdhirsh added.

Based on wether we are using cuSPARSELt or CUTLASS, we return a
different representation of the inner tensors.

Test Plan:
```
python test/test_sparse_semi_structured.py -k compile
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111049
Approved by: https://github.com/cpuhrsch
ghstack dependencies: #110583
2023-10-14 01:13:01 +00:00
8db72a430d [sparse] Add padding for dense matrices in semi-structured sparse (#110583)
Summary:

Currently we have shape constraints in semi-structured sparsity for both
CUTLASS and cuSPARSELt

These shape constraints unfortunately apply to both the dense and sparse
matrices in sparsedense matmul.

This PR adds in support for calling `F.pad` in order to pad dense
matrices to the right size with zeros and then pull out the
corresponding rows from the resultant result matrix.

We also throw a warning in this case.
The tests have also been updated to take in a dense_input_shape
parameter.

Test Plan:
```
python test/test_sparse_semi_structured.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110583
Approved by: https://github.com/alexsamardzic, https://github.com/cpuhrsch
2023-10-13 20:04:23 +00:00
cae3a2e6eb Revert "[sparse] Add i8i8->i32 support for cuSPARSELt (#110499)"
This reverts commit 33da6c89516d9d9067f7181826826224a4cf5afe.

Reverted https://github.com/pytorch/pytorch/pull/110499 on behalf of https://github.com/jcaip due to cslt v0.5.0 requires a newer linker and we will be using v0.4.0 as the base version ([comment](https://github.com/pytorch/pytorch/pull/110499#issuecomment-1758039953))
2023-10-11 16:14:59 +00:00
f10aab03c4 [sparse] Fix semi-structured sparse shape mismatch bug (#110420)
Summary:

Currently, PyTorch incorrectly calculates the size of the returned
matrix when we pass a non-contiguous batched (>2d) input to the
semi-structured sparse subclass.

This is most common in MLP layers, where we have 2 linear layers back to back.

This will lead to an error like the following:
```
RuntimeError: shape '[20, 64, 64, 3072]' is invalid for input of size
62914560

```
Where the size of the sparse matmul result is off because we infer the
output shape with the wrong tensor shape.

This happens because of a bug where we did not update the subclass
tensor shape when doing transpose.
For semi-structured sparsity, transposing is a no-op where we just set
the boolean flag, but we forgot to also update the tensor shape.

Note that this error goes away in inference mode, since we avoid
decomposing the aten.linear op and handle shape folding ourselves,
which changes the execution path.

An alternative way to fix this issue is to set
TORCH_FLATTEN_LINEAR_3D=True, which will also fix this error.

Test Plan:
```
python test/test_sparse_semi_structured.py -k test_mlp

```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110420
Approved by: https://github.com/alexsamardzic, https://github.com/cpuhrsch
2023-10-10 03:07:31 +00:00
33da6c8951 [sparse] Add i8i8->i32 support for cuSPARSELt (#110499)
Summary:

With the release of cuSPARSELt v0.5.0, we now have support for
int8 int8 -> int32 matmul.

This PR adds support for this via out_dtype.

Test Plan:
```
python test/test_sparse_semi_structured.py -k int32
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110499
Approved by: https://github.com/cpuhrsch
2023-10-06 18:32:47 +00:00
6a202c36af Minor fixes in semi-structured sparse code (#105595)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105595
Approved by: https://github.com/jcaip
2023-09-25 14:06:08 +00:00
1df14f1bf8 Move has_triton to top level triton utils so that dynamo can also access (#109832)
it without creating cyclic dependencies

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109832
Approved by: https://github.com/zou3519
2023-09-22 19:33:41 +00:00
369a84e5c4 [core][sparse][pruning] Add (i8i8)-> fp16 support to cuSPARSELt matmul (#109214)
Summary:

This PR adds in support for sparse matmul using cuSPASRELt with int8
inputs and fp16 outputs.

It does so by adding a out_dtype flag to `torch_cslt_sparse_mm`.
Because the only mixed_dtype support present in cuSPARSELt is for int8
input and fp16 output, we error out if:

* out_dtype is set and the input tensors are not int8.
* out_dtype is set to any value other than fp16

Test Plan:

python test/test_sparse_semi_structured -k int8_in_fp16_out

Reviewers:

@cphursh

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109214
Approved by: https://github.com/cpuhrsch
2023-09-15 18:14:40 +00:00
925d71e72e [core][sparse][pruning] cuSPARSELt Kernels and ops. (#107398)
Summary:
This is a duplicate PR of 102133, which was reverted because it was
failing internal tests.

It seems like that internal builds did not like my guard to check if
cuSPARSELt was available or not.

Test Plan: python test/test_sparse_semi_structured.py

Differential Revision: D48440330

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107398
Approved by: https://github.com/cpuhrsch
2023-08-25 07:04:15 +00:00
fe594ab323 Revert "[core][pruning][feature] cuSPARSELt kernels and ops (#102133)"
This reverts commit ad22f0ffb456fc3f967ad32e09376f7c9cf94a56.

Reverted https://github.com/pytorch/pytorch/pull/102133 on behalf of https://github.com/jcaip due to breaking lots of internal builds, see D48144534 ([comment](https://github.com/pytorch/pytorch/pull/102133#issuecomment-1671707821))
2023-08-09 16:03:14 +00:00
ad22f0ffb4 [core][pruning][feature] cuSPARSELt kernels and ops (#102133)
This PR contains two new private ops, added for cuSPARSELt support.

These ops call into the cuSPASRELt kernels using the bindings they
provide. For more information, see the documentation
[here](https://docs.nvidia.com/cuda/cusparselt/index.html).

The two new private ops added are:
```
_cslt_compress()
_cslt_sparse_mm()
```

_cslt_compress is an op that reuturns the compressesed matrix given a
sparse matrix that is passed in.

_cslt_sparse_mm is an op that expects a compressed matrix (the result of
_cslt_compress) and a dense matrix and performs sparse-dense matmul

These ops will throw runtime errors if they cusparselt is not present.

This PR also modifies the test and tensor sublass to reflect the new
cuSPARSELt support.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102133
Approved by: https://github.com/cpuhrsch
2023-08-08 06:59:22 +00:00
d7e6040efa Update sparse semi-structured linear operator (#104608)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104608
Approved by: https://github.com/cpuhrsch
2023-07-13 23:52:39 +00:00
fc2f87b281 Add semi-structured sparse conversions (#103830)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103830
Approved by: https://github.com/amjames, https://github.com/jcaip, https://github.com/cpuhrsch
2023-07-13 21:09:09 +00:00
2da6cae43c [core][pruning][sparse][feature] SparseSemiStructured tensor subclass (#102135)
This PR adds in support for semi-structured sparsity via a tensor
subclass. It currently uses the CUTLASS kernels merged in PR #100881.

In the future we plan to add in cuSPARSELt support (see the other PRs in
the stack), which will give us larger performance gains.

This PR adds in 2 things:
- a Tensor subclass, `SparseSemiStructuredTensor` to store the
  sparse tensor in copmressed form and override `__torch_dispatch__`.
- a conversion function that takes in a dense tensor and a
  semi-structured sparse bool mask and creates an instance of the
  subclass.

**SparseSemiStructuredTensor**

The subclass stores the dense tensor in a contiguous flattened tensor
for future compatability with cuSPARSELt, which expects this format.
Note that the CUTLASS kernels do not have this limitation, as the
specified values and the metadata are passed separately in
`_structured_sparse_linear`. In the future we can use the cuSPARSELT bindings
[here](https://github.com/pytorch/pytorch/pull/103700) for faster matmul, better dtype converage, and relaxed shape
constraints.

Since we currently don't have a way to go back from the sparse
representation to the dense representation, and we store the weights in
compressed form, we don't have a great way to handle .t().

Instead, we keep track of how often we've called transpose on our
tensor, and if it's an unexpected number we throw an error. When the first
argument is sparse, we expect an even number of calls to transpose,
while when the second argument is sparse, we expect an odd number of
calls. This is because we support second argument sparse matrix
multiplications by using transpose properties.

**to_sparse_semi_structured**

This is a conversion function to convert a dense tensor and a
semi-structured sparse bool mask into a subclass. Currently, we must
pass in a bool mask, since we can't infer it becuase there may be
additional zero elements in the dense tensor, so `tensor !=0` is not 2:4
sparse.

Once we add either a method to derive the mask from the dense tensor or
cuSPARSELt, we no longer need to pass in the mask. cuSPARSELt has it's
own helper functions to create the metadata mask.

**User Details**

We have implemented support for the following ops for `torch.float16`
and `torch.int8`:
```
torch.addmm(bias, dense, sparse.t())
torch.mm(dense, sparse)
torch.mm(sparse, dense)
aten.linear.default
aten.t.default
aten.t.detach
```

The end user interface to accelerate a nn.Linaer module with the
subclass would look like this:

```
from torch.sparse import to_sparse_semi_structured

mask = torch.Tensor([0, 0, 1, 1]).tile(128, 32).cuda().bool()
linear = Model(128, 128).half().cuda()

linear.weight = nn.Parameter(to_sparse_semi_structured(linear.weight,
                                                       mask=linear.weight.bool())

```

This also updates tests and the `torch.sparse` module docstring to
reflect these changes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102135
Approved by: https://github.com/albanD
2023-06-27 19:21:06 +00:00
b76a040b18 Revert "[core][pruning][sparse][feature] SparseSemiStructured tensor subclass (#102135)"
This reverts commit aea771de30427998e83010459b69da1ab66f0879.

Reverted https://github.com/pytorch/pytorch/pull/102135 on behalf of https://github.com/huydhn due to test_sparse_semi_structured.py::TestSparseSemiStructuredCUDA::test_mm_sparse_first_NT_cuda_int8 is still failing CUDA trunk jobs aea771de30 ([comment](https://github.com/pytorch/pytorch/pull/102135#issuecomment-1608744110))
2023-06-27 03:49:31 +00:00
aea771de30 [core][pruning][sparse][feature] SparseSemiStructured tensor subclass (#102135)
This PR adds in support for semi-structured sparsity via a tensor
subclass. It currently uses the CUTLASS kernels merged in PR #100881.

In the future we plan to add in cuSPARSELt support (see the other PRs in
the stack), which will give us larger performance gains.

This PR adds in 2 things:
- a Tensor subclass, `SparseSemiStructuredTensor` to store the
  sparse tensor in copmressed form and override `__torch_dispatch__`.
- a conversion function that takes in a dense tensor and a
  semi-structured sparse bool mask and creates an instance of the
  subclass.

**SparseSemiStructuredTensor**

The subclass stores the dense tensor in a contiguous flattened tensor
for future compatability with cuSPARSELt, which expects this format.
Note that the CUTLASS kernels do not have this limitation, as the
specified values and the metadata are passed separately in
`_structured_sparse_linear`. In the future we can use the cuSPARSELT bindings
[here](https://github.com/pytorch/pytorch/pull/103700) for faster matmul, better dtype converage, and relaxed shape
constraints.

Since we currently don't have a way to go back from the sparse
representation to the dense representation, and we store the weights in
compressed form, we don't have a great way to handle .t().

Instead, we keep track of how often we've called transpose on our
tensor, and if it's an unexpected number we throw an error. When the first
argument is sparse, we expect an even number of calls to transpose,
while when the second argument is sparse, we expect an odd number of
calls. This is because we support second argument sparse matrix
multiplications by using transpose properties.

**to_sparse_semi_structured**

This is a conversion function to convert a dense tensor and a
semi-structured sparse bool mask into a subclass. Currently, we must
pass in a bool mask, since we can't infer it becuase there may be
additional zero elements in the dense tensor, so `tensor !=0` is not 2:4
sparse.

Once we add either a method to derive the mask from the dense tensor or
cuSPARSELt, we no longer need to pass in the mask. cuSPARSELt has it's
own helper functions to create the metadata mask.

**User Details**

We have implemented support for the following ops for `torch.float16`
and `torch.int8`:
```
torch.addmm(bias, dense, sparse.t())
torch.mm(dense, sparse)
torch.mm(sparse, dense)
aten.linear.default
aten.t.default
aten.t.detach
```

The end user interface to accelerate a nn.Linaer module with the
subclass would look like this:

```
from torch.sparse import to_sparse_semi_structured

mask = torch.Tensor([0, 0, 1, 1]).tile(128, 32).cuda().bool()
linear = Model(128, 128).half().cuda()

linear.weight = nn.Parameter(to_sparse_semi_structured(linear.weight,
                                                       mask=linear.weight.bool())

```

This also updates tests and the `torch.sparse` module docstring to
reflect these changes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102135
Approved by: https://github.com/albanD
2023-06-27 02:37:00 +00:00
bfa08a1c67 Revert "[core][pruning][sparse][feature] SparseSemiStructured tensor subclass (#102135)"
This reverts commit cf5262a84f815c1e574883bc244333d0d211c7a2.

Reverted https://github.com/pytorch/pytorch/pull/102135 on behalf of https://github.com/huydhn due to Sorry for reverting your PR but test_sparse_semi_structured.py::TestSparseSemiStructuredCUDA::test_mm_sparse_first_NT_cuda_int8 is failing CUDA trunk jobs cf5262a84f. This looks like a landrace ([comment](https://github.com/pytorch/pytorch/pull/102135#issuecomment-1608423849))
2023-06-26 22:54:16 +00:00
cf5262a84f [core][pruning][sparse][feature] SparseSemiStructured tensor subclass (#102135)
This PR adds in support for semi-structured sparsity via a tensor
subclass. It currently uses the CUTLASS kernels merged in PR #100881.

In the future we plan to add in cuSPARSELt support (see the other PRs in
the stack), which will give us larger performance gains.

This PR adds in 2 things:
- a Tensor subclass, `SparseSemiStructuredTensor` to store the
  sparse tensor in copmressed form and override `__torch_dispatch__`.
- a conversion function that takes in a dense tensor and a
  semi-structured sparse bool mask and creates an instance of the
  subclass.

**SparseSemiStructuredTensor**

The subclass stores the dense tensor in a contiguous flattened tensor
for future compatability with cuSPARSELt, which expects this format.
Note that the CUTLASS kernels do not have this limitation, as the
specified values and the metadata are passed separately in
`_structured_sparse_linear`. In the future we can use the cuSPARSELT bindings
[here](https://github.com/pytorch/pytorch/pull/103700) for faster matmul, better dtype converage, and relaxed shape
constraints.

Since we currently don't have a way to go back from the sparse
representation to the dense representation, and we store the weights in
compressed form, we don't have a great way to handle .t().

Instead, we keep track of how often we've called transpose on our
tensor, and if it's an unexpected number we throw an error. When the first
argument is sparse, we expect an even number of calls to transpose,
while when the second argument is sparse, we expect an odd number of
calls. This is because we support second argument sparse matrix
multiplications by using transpose properties.

**to_sparse_semi_structured**

This is a conversion function to convert a dense tensor and a
semi-structured sparse bool mask into a subclass. Currently, we must
pass in a bool mask, since we can't infer it becuase there may be
additional zero elements in the dense tensor, so `tensor !=0` is not 2:4
sparse.

Once we add either a method to derive the mask from the dense tensor or
cuSPARSELt, we no longer need to pass in the mask. cuSPARSELt has it's
own helper functions to create the metadata mask.

**User Details**

We have implemented support for the following ops for `torch.float16`
and `torch.int8`:
```
torch.addmm(bias, dense, sparse.t())
torch.mm(dense, sparse)
torch.mm(sparse, dense)
aten.linear.default
aten.t.default
aten.t.detach
```

The end user interface to accelerate a nn.Linaer module with the
subclass would look like this:

```
from torch.sparse import to_sparse_semi_structured

mask = torch.Tensor([0, 0, 1, 1]).tile(128, 32).cuda().bool()
linear = Model(128, 128).half().cuda()

linear.weight = nn.Parameter(to_sparse_semi_structured(linear.weight,
                                                       mask=linear.weight.bool())

```

This also updates tests and the `torch.sparse` module docstring to
reflect these changes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102135
Approved by: https://github.com/albanD
2023-06-26 21:30:43 +00:00