[sparse] Fix semi-structured sparse shape mismatch bug (#110420)

Summary:

Currently, PyTorch incorrectly calculates the size of the returned
matrix when we pass a non-contiguous batched (>2d) input to the
semi-structured sparse subclass.

This is most common in MLP layers, where we have 2 linear layers back to back.

This will lead to an error like the following:
```
RuntimeError: shape '[20, 64, 64, 3072]' is invalid for input of size
62914560

```
Where the size of the sparse matmul result is off because we infer the
output shape with the wrong tensor shape.

This happens because of a bug where we did not update the subclass
tensor shape when doing transpose.
For semi-structured sparsity, transposing is a no-op where we just set
the boolean flag, but we forgot to also update the tensor shape.

Note that this error goes away in inference mode, since we avoid
decomposing the aten.linear op and handle shape folding ourselves,
which changes the execution path.

An alternative way to fix this issue is to set
TORCH_FLATTEN_LINEAR_3D=True, which will also fix this error.

Test Plan:
```
python test/test_sparse_semi_structured.py -k test_mlp

```

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110420
Approved by: https://github.com/alexsamardzic, https://github.com/cpuhrsch
This commit is contained in:
Jesse Cai
2023-10-09 12:53:22 -07:00
committed by PyTorch MergeBot
parent 468a73f0e3
commit f10aab03c4
2 changed files with 36 additions and 3 deletions

View File

@ -52,7 +52,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
"""
_FUSE_TRANSPOSE = False
_FORCE_CUTLASS = False
_FORCE_CUTLASS = True
_WARNING_SHOWN = False
@staticmethod
@ -268,7 +268,8 @@ class SparseSemiStructuredTensor(torch.Tensor):
if func is torch.ops.aten.t.default:
return SparseSemiStructuredTensor(
args[0].original_tensor,
original_shape=args[0].shape,
# transpose shape
original_shape=torch.Size([args[0].shape[1], args[0].shape[0]]),
compressed_tensor_cusparselt=args[0].compressed_tensor_cusparselt,
sparse_tensor_cutlass=args[0].sparse_tensor_cutlass,
meta_tensor_cutlass=args[0].meta_tensor_cutlass,
@ -438,4 +439,6 @@ def to_sparse_semi_structured(
[-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0',
dtype=torch.int16))
"""
return SparseSemiStructuredTensor(original_tensor, original_shape=original_tensor.shape, transposed=transposed)
return SparseSemiStructuredTensor(
original_tensor, original_shape=original_tensor.shape, transposed=transposed
)