mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129771 Approved by: https://github.com/justinchuby, https://github.com/janeyx99
169 lines
5.1 KiB
Python
169 lines
5.1 KiB
Python
# mypy: allow-untyped-defs
|
|
import contextlib
|
|
|
|
import torch
|
|
|
|
|
|
__all__ = [
|
|
"fallback_dispatcher",
|
|
"semi_sparse_values",
|
|
"semi_sparse_indices",
|
|
"semi_sparse_t",
|
|
"semi_sparse_view",
|
|
"semi_sparse_detach",
|
|
"semi_sparse_mm",
|
|
"semi_sparse_addmm",
|
|
"semi_sparse_linear",
|
|
]
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def no_dispatch():
|
|
guard = torch._C._DisableTorchDispatch()
|
|
try:
|
|
yield
|
|
finally:
|
|
del guard
|
|
|
|
|
|
def fallback_dispatcher(func, types, args, kwargs):
|
|
with no_dispatch():
|
|
return func(*args)
|
|
|
|
|
|
def semi_sparse_values(func, types, args=(), kwargs=None) -> torch.Tensor:
|
|
assert len(args) == 1
|
|
A = args[0]
|
|
assert isinstance(A, torch.sparse.SparseSemiStructuredTensor)
|
|
assert A.packed is not None
|
|
if A.meta is None:
|
|
m, k = A.shape
|
|
num_kept_elements = m * k // 2
|
|
return A.packed[:num_kept_elements:].view(m, -1)
|
|
else:
|
|
return A.packed.detach()
|
|
|
|
|
|
def semi_sparse_indices(func, types, args=(), kwargs=None) -> torch.Tensor:
|
|
assert len(args) == 1
|
|
A = args[0]
|
|
assert isinstance(A, torch.sparse.SparseSemiStructuredTensor)
|
|
assert A.packed is not None
|
|
if A.meta is None:
|
|
m, k = A.shape
|
|
num_kept_elements = m * k // 2
|
|
metadata = A.packed[num_kept_elements:].view(m, -1)
|
|
return metadata.view(torch.int32 if A.dtype == torch.int32 else torch.int16)
|
|
else:
|
|
return A.meta
|
|
|
|
|
|
def semi_sparse_t(func, types, args=(), kwargs=None) -> torch.Tensor:
|
|
assert len(args) == 1
|
|
self = args[0]
|
|
assert isinstance(self, torch.sparse.SparseSemiStructuredTensor)
|
|
assert len(self.shape) == 2
|
|
# Because we cannot go from the compressed representation back to the dense representation currently,
|
|
# we just keep track of how many times we have been transposed. Depending on whether the sparse matrix
|
|
# is the first or second argument, we expect an even / odd number of calls to transpose respectively.
|
|
return self.__class__(
|
|
torch.Size([self.shape[-1], self.shape[0]]),
|
|
packed=self.packed_t,
|
|
meta=self.meta_t,
|
|
packed_t=self.packed,
|
|
meta_t=self.meta,
|
|
compressed_swizzled_bitmask=self.compressed_swizzled_bitmask.transpose(0, 1)
|
|
if self.compressed_swizzled_bitmask is not None
|
|
else None,
|
|
fuse_transpose_cusparselt=args[0].fuse_transpose_cusparselt,
|
|
alg_id_cusparselt=args[0].alg_id_cusparselt,
|
|
)
|
|
|
|
|
|
def semi_sparse_view(func, types, args=(), kwargs=None) -> torch.Tensor:
|
|
assert len(args) == 2
|
|
self, shape = args
|
|
if tuple(shape) != self.shape:
|
|
raise NotImplementedError(
|
|
f"`view` is not implemented for SparseSemiStructuredTensor, except for the dummy case (shape={shape})"
|
|
)
|
|
return self
|
|
|
|
|
|
def semi_sparse_detach(func, types, args, kwargs) -> torch.Tensor:
|
|
assert len(args) == 1
|
|
self = args[0]
|
|
return self.__class__(
|
|
shape=self.shape,
|
|
packed=self.packed,
|
|
meta=self.meta,
|
|
packed_t=self.packed_t,
|
|
meta_t=self.meta_t,
|
|
compressed_swizzled_bitmask=self.compressed_swizzled_bitmask,
|
|
requires_grad=False,
|
|
)
|
|
|
|
|
|
def semi_sparse_mm(func, types, args=(), kwargs=None) -> torch.Tensor:
|
|
assert len(args) == 2
|
|
A, B = args
|
|
if A.ndim != 2 or B.ndim != 2:
|
|
raise NotImplementedError(
|
|
"`SparseSemiStructuredTensor` matmul: Broadcasting is not implemented"
|
|
)
|
|
if isinstance(A, torch.sparse.SparseSemiStructuredTensor):
|
|
row, col = B.shape
|
|
B_padded = A._pad_dense_input(B)
|
|
res = A._mm(B_padded)
|
|
return res[:, :col]
|
|
else:
|
|
B_t = B.t()
|
|
assert isinstance(B_t, torch.sparse.SparseSemiStructuredTensor)
|
|
row, col = A.shape
|
|
A_padded = B._pad_dense_input(A)
|
|
res = B_t._mm(A_padded.t()).t()
|
|
return res[:row, :]
|
|
|
|
|
|
def semi_sparse_addmm(func, types, args=(), kwargs=None) -> torch.Tensor:
|
|
assert len(args) == 3
|
|
bias, A, B = args
|
|
if A.ndim != 2 or B.ndim != 2:
|
|
raise NotImplementedError(
|
|
"`SparseSemiStructuredTensor` matmul: Broadcasting is not implemented"
|
|
)
|
|
if bias.ndim != 1:
|
|
raise NotImplementedError(
|
|
f"`SparseSemiStructuredTensor` matmul: only bias dim=1 supported. Shape={bias.shape}"
|
|
)
|
|
if isinstance(A, torch.sparse.SparseSemiStructuredTensor):
|
|
raise NotImplementedError(
|
|
"`SparseSemiStructuredTensor` matmul: only operand B of `addmm` can be sparse"
|
|
)
|
|
B_t = B.t()
|
|
assert isinstance(B_t, torch.sparse.SparseSemiStructuredTensor)
|
|
row, col = A.shape
|
|
A_padded = B_t._pad_dense_input(A)
|
|
result = B_t._mm(A_padded.t(), bias=bias).t()
|
|
return result[:row, :]
|
|
|
|
|
|
def semi_sparse_linear(func, types, args=(), kwargs=None) -> torch.Tensor:
|
|
assert len(args) in [2, 3]
|
|
A, B = args[:2]
|
|
bias = args[2] if len(args) == 3 else None
|
|
|
|
shape = A.shape
|
|
A_2d = A.view(-1, shape[-1])
|
|
|
|
if bias is None:
|
|
res = A_2d @ B.t()
|
|
else:
|
|
res = semi_sparse_addmm(
|
|
func=None,
|
|
types=None,
|
|
args=[bias, A_2d, B.t()],
|
|
)
|
|
|
|
return res.view(*shape[:-1], -1)
|