Files
pytorch/torch/sparse/_semi_structured_ops.py
Maggie Moss d795fb225a [RFC] Add pyrefly to lintrunner (#165179)
This will add pyrefly to lint runner as a warning only - and allow us to collect feedback about the tool before switching to pyrefly as the main type checker.

References the steps outlined here: : https://github.com/pytorch/pytorch/issues/163283:

test plan:
`lintrunner init`
`lintrunner`
confirm when pyrefly errors are present results look like: https://gist.github.com/maggiemoss/e6cb2d015dd1ded560ae1329098cf33f

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165179
Approved by: https://github.com/ezyang
2025-10-16 20:07:09 +00:00

199 lines
6.3 KiB
Python

# mypy: allow-untyped-defs
import contextlib
import torch
__all__ = [
"fallback_dispatcher",
"semi_sparse_values",
"semi_sparse_indices",
"semi_sparse_t",
"semi_sparse_view",
"semi_sparse_detach",
"semi_sparse_mm",
"semi_sparse_addmm",
"semi_sparse_linear",
"semi_sparse_scaled_mm",
]
@contextlib.contextmanager
def no_dispatch():
guard = torch._C._DisableTorchDispatch()
try:
yield
finally:
del guard
def fallback_dispatcher(func, types, args, kwargs):
with no_dispatch():
return func(*args)
def semi_sparse_values(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) == 1
A = args[0]
assert isinstance(A, torch.sparse.SparseSemiStructuredTensor)
assert A.packed is not None
if A.meta is None:
m, k = A.shape
num_kept_elements = m * k // 2
return A.packed.ravel()[:num_kept_elements:].view(m, -1)
else:
return A.packed.detach()
def semi_sparse_indices(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) == 1
A = args[0]
assert isinstance(A, torch.sparse.SparseSemiStructuredTensor)
assert A.packed is not None
if A.meta is None:
m, k = A.shape
num_kept_elements = m * k // 2
metadata = A.packed.ravel()[num_kept_elements:].view(m, -1)
return metadata.view(torch.int32 if A.dtype == torch.int32 else torch.int16)
else:
return A.meta
def semi_sparse_t(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) == 1
self = args[0]
assert isinstance(self, torch.sparse.SparseSemiStructuredTensor)
assert len(self.shape) == 2
# Because we cannot go from the compressed representation back to the dense representation currently,
# we just keep track of how many times we have been transposed. Depending on whether the sparse matrix
# is the first or second argument, we expect an even / odd number of calls to transpose respectively.
# pyrefly: ignore # no-matching-overload
return self.__class__(
torch.Size([self.shape[-1], self.shape[0]]),
packed=self.packed_t,
meta=self.meta_t,
packed_t=self.packed,
meta_t=self.meta,
compressed_swizzled_bitmask=(
self.compressed_swizzled_bitmask.transpose(0, 1)
if self.compressed_swizzled_bitmask is not None
else None
),
fuse_transpose_cusparselt=args[0].fuse_transpose_cusparselt,
alg_id_cusparselt=args[0].alg_id_cusparselt,
)
def semi_sparse_view(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) == 2
self, shape = args
if tuple(shape) != self.shape:
raise NotImplementedError(
f"`view` is not implemented for SparseSemiStructuredTensor, except for the dummy case (shape={shape})"
)
return self
def semi_sparse_detach(func, types, args, kwargs) -> torch.Tensor:
assert len(args) == 1
self = args[0]
return self.__class__(
shape=self.shape,
packed=self.packed,
meta=self.meta,
packed_t=self.packed_t,
meta_t=self.meta_t,
compressed_swizzled_bitmask=self.compressed_swizzled_bitmask,
fuse_transpose_cusparselt=self.fuse_transpose_cusparselt,
alg_id_cusparselt=self.alg_id_cusparselt,
requires_grad=False,
)
def semi_sparse_mm(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) == 2
A, B = args
if A.ndim != 2 or B.ndim != 2:
raise NotImplementedError(
"`SparseSemiStructuredTensor` matmul: Broadcasting is not implemented"
)
if isinstance(A, torch.sparse.SparseSemiStructuredTensor):
row, col = B.shape
B_padded = A._pad_dense_input(B)
res = A._mm(B_padded)
return res[:, :col]
else:
B_t = B.t()
assert isinstance(B_t, torch.sparse.SparseSemiStructuredTensor)
row, col = A.shape
A_padded = B._pad_dense_input(A)
res = B_t._mm(A_padded.t()).t()
return res[:row, :]
def semi_sparse_addmm(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) == 3
bias, A, B = args
if A.ndim != 2 or B.ndim != 2:
raise NotImplementedError(
"`SparseSemiStructuredTensor` matmul: Broadcasting is not implemented"
)
if bias.ndim != 1:
raise NotImplementedError(
f"`SparseSemiStructuredTensor` matmul: only bias dim=1 supported. Shape={bias.shape}"
)
if isinstance(A, torch.sparse.SparseSemiStructuredTensor):
raise NotImplementedError(
"`SparseSemiStructuredTensor` matmul: only operand B of `addmm` can be sparse"
)
B_t = B.t()
assert isinstance(B_t, torch.sparse.SparseSemiStructuredTensor)
row, _col = A.shape
A_padded = B_t._pad_dense_input(A)
result = B_t._mm(A_padded.t(), bias=bias).t()
return result[:row, :]
def semi_sparse_linear(func, types, args=(), kwargs=None) -> torch.Tensor:
assert len(args) in [2, 3]
A, B = args[:2]
bias = args[2] if len(args) == 3 else None
shape = A.shape
A_2d = A.view(-1, shape[-1])
if bias is None:
res = A_2d @ B.t()
else:
res = semi_sparse_addmm(
func=None,
types=None,
args=[bias, A_2d, B.t()],
)
return res.view(*shape[:-1], -1)
def semi_sparse_scaled_mm(func, types, args=(), kwargs=None) -> torch.Tensor:
# pull all args, excluding use_fast_accum flag if set.
A, B, A_scale, B_scale, bias, scale_result, out_dtype = args[:7]
assert A.dtype == torch.float8_e4m3fn
assert B.dtype == torch.float8_e4m3fn
# only cuSPARSELt supports float8_e4m3fn currently
assert isinstance(A, torch.sparse.SparseSemiStructuredTensorCUSPARSELT)
assert A.packed is not None
# Currently we only support per-tensor scaling, with float32 scales
assert A_scale.numel() == 1 and B_scale.numel() == 1
assert A_scale.dtype == torch.float32 and B_scale.dtype == torch.float32
# cuSPARSELt lacks the A and B operand scaling support, so instead we use alpha to scale the result.
# Note that this limits us to per-tensor scalig only.
sparse_result = torch._cslt_sparse_mm(
A.packed,
B,
alpha=A_scale * B_scale,
out_dtype=out_dtype,
)
return sparse_result