Add sparse compressed fake tensor support (#120920)

As in the title.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120920
Approved by: https://github.com/ezyang
This commit is contained in:
Pearu Peterson
2024-03-04 11:32:05 +02:00
committed by PyTorch MergeBot
parent c06499981d
commit ce2903080c
11 changed files with 148 additions and 7 deletions

View File

@ -703,7 +703,7 @@ class TestEmbeddingNNDeviceType(NNTestCase):
torch._embedding_bag_forward_only torch._embedding_bag_forward_only
) )
for i, f in enumerate(funcs): for i, f in enumerate(funcs):
err_type = ValueError if i == 0 else RuntimeError err_type = (ValueError, RuntimeError) if i == 0 else RuntimeError
weight = torch.full((2, 6,), 0, dtype=torch.float64, device=device) weight = torch.full((2, 6,), 0, dtype=torch.float64, device=device)
indices = torch.full((2, 0, 0, 6, 6,), 2, dtype=torch.int64, device=device) indices = torch.full((2, 0, 0, 6, 6,), 2, dtype=torch.int64, device=device)

View File

@ -8,7 +8,7 @@ from enum import Enum
from torch.overrides import resolve_name from torch.overrides import resolve_name
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
from torch.utils import _pytree as pytree from torch.utils import _pytree as pytree
from torch._subclasses.meta_utils import MetaConverter, assert_metadata_eq from torch._subclasses.meta_utils import MetaConverter, assert_metadata_eq, is_sparse_any
import torch.utils._python_dispatch import torch.utils._python_dispatch
from torch._dispatch.python import enable_python_dispatcher from torch._dispatch.python import enable_python_dispatcher
from torch._ops import OpOverload, OpOverloadPacket from torch._ops import OpOverload, OpOverloadPacket
@ -490,7 +490,9 @@ def verbose_print(e):
return self.s return self.s
def go(t): def go(t):
if isinstance(t, torch.Tensor): if is_sparse_any(t):
return t
elif isinstance(t, torch.Tensor):
return Lit(f"{t} stride={t.stride()}") return Lit(f"{t} stride={t.stride()}")
else: else:
return t return t

View File

@ -4379,6 +4379,56 @@ class TestSparseMeta(TestCase):
self.maxDiff = orig_maxDiff self.maxDiff = orig_maxDiff
raise raise
@all_sparse_layouts('layout', include_strided=False)
@parametrize("dtype", [torch.float64])
def test_fake(self, dtype, layout):
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
fake_mode = FakeTensorMode()
index_dtype = torch.int64
device = 'cpu'
for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype):
f = FakeTensor.from_tensor(t, fake_mode)
self.assertIsInstance(f, FakeTensor)
self.assertEqual(f.layout, layout)
self.assertEqual(f.shape, t.shape)
self.assertEqual(f.device, t.device)
if layout is torch.sparse_coo:
nnz = 0
indices = f._indices()
self.assertEqual(indices.dtype, index_dtype)
self.assertEqual(indices.device, t.device)
self.assertEqual(indices.shape, (*t._indices().shape[:-1], nnz))
values = f._values()
self.assertEqual(values.dtype, dtype)
self.assertEqual(values.device, t.device)
self.assertEqual(values.shape, (nnz, *t._values().shape[1:]))
else:
nnz = 0
if layout in {torch.sparse_csr, torch.sparse_bsr}:
f_compressed_indices, f_plain_indices = f.crow_indices(), f.col_indices()
compressed_indices, plain_indices = t.crow_indices(), t.col_indices()
else:
f_compressed_indices, f_plain_indices = f.ccol_indices(), f.row_indices()
compressed_indices, plain_indices = t.ccol_indices(), t.row_indices()
f_values = f.values()
values = t.values()
batch_dims = len(compressed_indices.shape) - 1
self.assertEqual(f_compressed_indices.layout, compressed_indices.layout)
self.assertEqual(f_compressed_indices.shape, compressed_indices.shape)
self.assertEqual(f_compressed_indices.dtype, compressed_indices.dtype)
self.assertEqual(f_compressed_indices.device, compressed_indices.device)
self.assertEqual(f_plain_indices.layout, plain_indices.layout)
self.assertEqual(f_plain_indices.shape, (*plain_indices.shape[:-1], nnz))
self.assertEqual(f_plain_indices.dtype, plain_indices.dtype)
self.assertEqual(f_plain_indices.device, plain_indices.device)
batch_dim = plain_indices.ndim - 1
self.assertEqual(f_values.layout, values.layout)
self.assertEqual(f_values.shape, (*values.shape[:batch_dim], nnz, *values.shape[batch_dim + 1:]))
self.assertEqual(f_values.dtype, values.dtype)
self.assertEqual(f_values.device, values.device)
class _SparseDataset(torch.utils.data.Dataset): class _SparseDataset(torch.utils.data.Dataset):
# An utility class used in TestSparseAny.test_dataloader method. # An utility class used in TestSparseAny.test_dataloader method.

View File

@ -27,6 +27,7 @@ from torch._guards import GuardSource, TracingContext
from torch._ops import HigherOrderOperator from torch._ops import HigherOrderOperator
from torch._streambase import _EventBase, _StreamBase from torch._streambase import _EventBase, _StreamBase
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
from torch._subclasses.meta_utils import is_sparse_any
from torch.fx.experimental._backward_state import BackwardState from torch.fx.experimental._backward_state import BackwardState
from torch.fx.experimental.symbolic_shapes import ( from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size, _constrain_range_for_size,
@ -1059,6 +1060,11 @@ class VariableBuilder:
): ):
unimplemented("torch.compile does not support strided NestedTensor") unimplemented("torch.compile does not support strided NestedTensor")
if is_sparse_any(value):
unimplemented(
f"torch.compile does not support sparse Tensor with {value.layout} layout"
)
tensor_variable = wrap_fx_proxy( tensor_variable = wrap_fx_proxy(
tx=self.tx, tx=self.tx,
proxy=tensor_proxy, proxy=tensor_proxy,

View File

@ -25,6 +25,7 @@ import torch._numpy as tnp
import torch.fx import torch.fx
import torch.random import torch.random
from torch._dynamo import compiled_autograd from torch._dynamo import compiled_autograd
from torch._subclasses.meta_utils import is_sparse_any
from torch.fx.experimental.symbolic_shapes import ( from torch.fx.experimental.symbolic_shapes import (
guard_scalar, guard_scalar,
@ -149,7 +150,11 @@ class TensorVariable(VariableTracker):
"is_sparse": value.is_sparse, "is_sparse": value.is_sparse,
"class_type": type(value), "class_type": type(value),
} }
if not has_free_symbols(value): if is_sparse_any(value) and not has_free_symbols(value):
props["size"] = tuple(
[int(s) if is_symbolic(s) else s for s in value.size()]
)
elif not has_free_symbols(value):
# this is a fully static shape, and the keys on props here inform specialization. # this is a fully static shape, and the keys on props here inform specialization.
# We have to cast to int here, because these might get accessed as ConstantVariable, which has # We have to cast to int here, because these might get accessed as ConstantVariable, which has
# a strict no-symint policy. If we got here due to not having free symbols, this is a known constant # a strict no-symint policy. If we got here due to not having free symbols, this is a known constant

View File

@ -23,6 +23,7 @@ from torch._subclasses.meta_utils import (
assert_eq, assert_eq,
assert_metadata_eq, assert_metadata_eq,
is_sparse_any, is_sparse_any,
is_sparse_compressed,
MetaConverter, MetaConverter,
) )
from torch._utils import render_call from torch._utils import render_call
@ -1057,6 +1058,8 @@ class FakeTensorMode(TorchDispatchMode):
raise _BypassDispatchCache("constant attribute") raise _BypassDispatchCache("constant attribute")
if arg.is_sparse: if arg.is_sparse:
raise _BypassDispatchCache("sparse tensor") raise _BypassDispatchCache("sparse tensor")
if is_sparse_compressed(arg):
raise _BypassDispatchCache("sparse compressed tensor")
result.append(extract_tensor_metadata(arg)) result.append(extract_tensor_metadata(arg))
elif isinstance(arg, torch.Tensor): elif isinstance(arg, torch.Tensor):
raise _BypassDispatchCache("non-fake tensor") raise _BypassDispatchCache("non-fake tensor")
@ -1099,6 +1102,9 @@ class FakeTensorMode(TorchDispatchMode):
if output.is_sparse: if output.is_sparse:
raise _BypassDispatchCache("sparse output") raise _BypassDispatchCache("sparse output")
if is_sparse_compressed(output):
raise _BypassDispatchCache("sparse compressed output")
# Can an in-place op really reference a kwarg? If so, then we need # Can an in-place op really reference a kwarg? If so, then we need
# to extend the implementation to handle it. # to extend the implementation to handle it.
for kval in kwargs.values(): for kval in kwargs.values():

View File

@ -152,7 +152,7 @@ class MetaConverter:
# hold a weak ref to self, otherwise it will be kept alive # hold a weak ref to self, otherwise it will be kept alive
# by the del_ten closure # by the del_ten closure
self_weak_ref = weakref.ref(self) self_weak_ref = weakref.ref(self)
if t.is_sparse or t.is_mkldnn or is_functorch_wrapped_tensor(t): if is_sparse_any(t) or t.is_mkldnn or is_functorch_wrapped_tensor(t):
weak_st = None weak_st = None
else: else:
weak_st = StorageWeakRef(t._typed_storage()) weak_st = StorageWeakRef(t._typed_storage())
@ -298,6 +298,10 @@ class MetaConverter:
with torch.inference_mode(t.is_inference()): with torch.inference_mode(t.is_inference()):
if t.is_sparse: if t.is_sparse:
is_leaf = safe_is_leaf(t) is_leaf = safe_is_leaf(t)
# The lambda function below is similar to
# `t.to(device='meta')` except the latter
# preserves nnz value
r = callback( r = callback(
lambda: torch.ops.aten._sparse_coo_tensor_with_dims( lambda: torch.ops.aten._sparse_coo_tensor_with_dims(
t.sparse_dim(), t.sparse_dim(),
@ -321,6 +325,64 @@ class MetaConverter:
with torch.enable_grad(): with torch.enable_grad():
r = r.clone() r = r.clone()
r._coalesced_(t.is_coalesced()) r._coalesced_(t.is_coalesced())
elif is_sparse_compressed(t):
is_leaf = safe_is_leaf(t)
def mk_meta():
nnz = 0
batch_dim = t.ndim - t.sparse_dim() - t.dense_dim()
batch_size = t.shape[:batch_dim]
if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
index_dtype = t.crow_indices().dtype
compressed_indices = torch.empty(
t.crow_indices().shape, device="meta", dtype=index_dtype
)
plain_indices = torch.empty(
(*t.col_indices().shape[:-1], nnz),
device="meta",
dtype=index_dtype,
)
else:
index_dtype = t.ccol_indices().dtype
compressed_indices = torch.empty(
t.ccol_indices().shape, device="meta", dtype=index_dtype
)
plain_indices = torch.empty(
(*t.row_indices().shape[:-1], nnz),
device="meta",
dtype=index_dtype,
)
values_shape = t.values().shape
values = torch.empty(
(
*values_shape[:batch_dim],
nnz,
*values_shape[batch_dim + 1 :],
),
dtype=t.dtype,
device="meta",
)
return torch.ops.aten.sparse_compressed_tensor(
compressed_indices,
plain_indices,
values,
t.shape,
layout=t.layout,
dtype=t.dtype,
device="meta",
)
# `mk_meta()` is similar to `t.to(device='meta'))`
# except `to('meta')` preserves nnz value while
# `mk_meta` result has nnz == 0.
r = callback(mk_meta)
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
if t.requires_grad:
r.requires_grad = True
if t.requires_grad and not is_leaf:
with torch.enable_grad():
r = r.clone()
elif t.is_nested and not is_traceable_wrapper_subclass(t): elif t.is_nested and not is_traceable_wrapper_subclass(t):
# TODO: Handle this better in Dynamo? # TODO: Handle this better in Dynamo?
# There are checks there now, but this can still be triggered by a dense # There are checks there now, but this can still be triggered by a dense
@ -679,8 +741,6 @@ class MetaConverter:
if isinstance(t, torch.Tensor) or is_traceable_wrapper_subclass(t): if isinstance(t, torch.Tensor) or is_traceable_wrapper_subclass(t):
if t.device.type != "xla" and any( if t.device.type != "xla" and any(
[ [
t.is_sparse_csr,
t.layout in [torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc],
t.is_quantized, t.is_quantized,
t._is_view() and t._base is not None and t._base.is_sparse, t._is_view() and t._base is not None and t._base.is_sparse,
torch._is_functional_tensor(t), torch._is_functional_tensor(t),

View File

@ -280,6 +280,8 @@ def _iterate_exprs(val: Union[SymInt, torch.Tensor]) -> Iterable[sympy.Basic]:
yield val yield val
elif isinstance(val, (int, float, bool)): elif isinstance(val, (int, float, bool)):
pass pass
elif is_sparse_any(val):
yield from _iterate_exprs(val.size())
elif isinstance(val, torch.Tensor): elif isinstance(val, torch.Tensor):
yield from _iterate_exprs(val.size()) yield from _iterate_exprs(val.size())
yield from _iterate_exprs(val.stride()) yield from _iterate_exprs(val.stride())

View File

@ -11793,6 +11793,12 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'), DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'),
# ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ... # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ...
DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
# NotImplementedError: Could not run 'aten::sparse_sampled_addmm' with arguments from the 'SparseCsrMeta' backend.
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace'),
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace'),
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_meta_outplace'),
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'),
DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_crossref_backward_no_amp'),
)), )),
OpInfo('sparse.mm', OpInfo('sparse.mm',
dtypes=floating_types_and(torch.bfloat16), dtypes=floating_types_and(torch.bfloat16),
@ -11836,6 +11842,10 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'), DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
# ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ... # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ...
DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_fail_gradgrad'), DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_fail_gradgrad'),
# NotImplementedError: Could not run 'aten::_sparse_mm_reduce_impl' with arguments from the 'SparseCsrMeta' backend
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace'),
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace'),
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_meta_outplace'),
)), )),
UnaryUfuncInfo('i0', UnaryUfuncInfo('i0',
ref=np_unary_ufunc_integer_promotion_wrapper( ref=np_unary_ufunc_integer_promotion_wrapper(