[sparse] semi-structured sparse + torch.compile support (#111049)

Summary:

This PR adds in torch.compile support for semi-structured sparsity,
using the subclass tracing @bdhirsh added.

Based on wether we are using cuSPARSELt or CUTLASS, we return a
different representation of the inner tensors.

Test Plan:
```
python test/test_sparse_semi_structured.py -k compile
```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111049
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Jesse Cai
2023-10-23 13:46:11 -07:00
committed by PyTorch MergeBot
parent 5eac44bc72
commit 702aaf8aea
3 changed files with 89 additions and 5 deletions

View File

@ -213,6 +213,37 @@ class SparseSemiStructuredTensor(torch.Tensor):
self.sparse_tensor_cutlass = sparse_tensor_cutlass
self.meta_tensor_cutlass = meta_tensor_cutlass
self.transposed = transposed
self.original_shape = original_shape
def __tensor_flatten__(self):
if self.compressed_tensor_cusparselt is not None:
return ['compressed_tensor_cusparselt'], (self.original_shape, self.transposed)
else:
return ['sparse_tensor_cutlass', 'meta_tensor_cutlass'], (self.original_shape, self.transposed)
@staticmethod
def __tensor_unflatten__(inner_tensors, meta):
original_shape, transposed = meta
if len(inner_tensors) == 2:
sparse_tensor_cutlass = inner_tensors['sparse_tensor_cutlass']
meta_tensor_cutlass = inner_tensors['meta_tensor_cutlass']
compressed_tensor_cusparselt = None
elif len(inner_tensors) == 1:
sparse_tensor_cutlass = None
meta_tensor_cutlass = None
compressed_tensor_cusparselt = inner_tensors['compressed_tensor_cusparselt']
else:
raise RuntimeError(f"Expected 1 or 2 inner tensors but got {len(inner_tensors)}")
return SparseSemiStructuredTensor(
None,
original_shape=original_shape,
compressed_tensor_cusparselt=compressed_tensor_cusparselt,
sparse_tensor_cutlass=sparse_tensor_cutlass,
meta_tensor_cutlass=meta_tensor_cutlass,
transposed=transposed,
)
def __repr__(self) -> str: # type: ignore[override]
"""Return string representation of SparseSemiStructuredTensor