mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Add sparse compressed fake tensor support (#120920)
As in the title. Pull Request resolved: https://github.com/pytorch/pytorch/pull/120920 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
c06499981d
commit
ce2903080c
@ -703,7 +703,7 @@ class TestEmbeddingNNDeviceType(NNTestCase):
|
|||||||
torch._embedding_bag_forward_only
|
torch._embedding_bag_forward_only
|
||||||
)
|
)
|
||||||
for i, f in enumerate(funcs):
|
for i, f in enumerate(funcs):
|
||||||
err_type = ValueError if i == 0 else RuntimeError
|
err_type = (ValueError, RuntimeError) if i == 0 else RuntimeError
|
||||||
|
|
||||||
weight = torch.full((2, 6,), 0, dtype=torch.float64, device=device)
|
weight = torch.full((2, 6,), 0, dtype=torch.float64, device=device)
|
||||||
indices = torch.full((2, 0, 0, 6, 6,), 2, dtype=torch.int64, device=device)
|
indices = torch.full((2, 0, 0, 6, 6,), 2, dtype=torch.int64, device=device)
|
||||||
|
@ -8,7 +8,7 @@ from enum import Enum
|
|||||||
from torch.overrides import resolve_name
|
from torch.overrides import resolve_name
|
||||||
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
||||||
from torch.utils import _pytree as pytree
|
from torch.utils import _pytree as pytree
|
||||||
from torch._subclasses.meta_utils import MetaConverter, assert_metadata_eq
|
from torch._subclasses.meta_utils import MetaConverter, assert_metadata_eq, is_sparse_any
|
||||||
import torch.utils._python_dispatch
|
import torch.utils._python_dispatch
|
||||||
from torch._dispatch.python import enable_python_dispatcher
|
from torch._dispatch.python import enable_python_dispatcher
|
||||||
from torch._ops import OpOverload, OpOverloadPacket
|
from torch._ops import OpOverload, OpOverloadPacket
|
||||||
@ -490,7 +490,9 @@ def verbose_print(e):
|
|||||||
return self.s
|
return self.s
|
||||||
|
|
||||||
def go(t):
|
def go(t):
|
||||||
if isinstance(t, torch.Tensor):
|
if is_sparse_any(t):
|
||||||
|
return t
|
||||||
|
elif isinstance(t, torch.Tensor):
|
||||||
return Lit(f"{t} stride={t.stride()}")
|
return Lit(f"{t} stride={t.stride()}")
|
||||||
else:
|
else:
|
||||||
return t
|
return t
|
||||||
|
@ -4379,6 +4379,56 @@ class TestSparseMeta(TestCase):
|
|||||||
self.maxDiff = orig_maxDiff
|
self.maxDiff = orig_maxDiff
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@all_sparse_layouts('layout', include_strided=False)
|
||||||
|
@parametrize("dtype", [torch.float64])
|
||||||
|
def test_fake(self, dtype, layout):
|
||||||
|
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
|
||||||
|
fake_mode = FakeTensorMode()
|
||||||
|
index_dtype = torch.int64
|
||||||
|
device = 'cpu'
|
||||||
|
for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype):
|
||||||
|
f = FakeTensor.from_tensor(t, fake_mode)
|
||||||
|
self.assertIsInstance(f, FakeTensor)
|
||||||
|
self.assertEqual(f.layout, layout)
|
||||||
|
self.assertEqual(f.shape, t.shape)
|
||||||
|
self.assertEqual(f.device, t.device)
|
||||||
|
if layout is torch.sparse_coo:
|
||||||
|
nnz = 0
|
||||||
|
indices = f._indices()
|
||||||
|
self.assertEqual(indices.dtype, index_dtype)
|
||||||
|
self.assertEqual(indices.device, t.device)
|
||||||
|
self.assertEqual(indices.shape, (*t._indices().shape[:-1], nnz))
|
||||||
|
values = f._values()
|
||||||
|
self.assertEqual(values.dtype, dtype)
|
||||||
|
self.assertEqual(values.device, t.device)
|
||||||
|
self.assertEqual(values.shape, (nnz, *t._values().shape[1:]))
|
||||||
|
else:
|
||||||
|
nnz = 0
|
||||||
|
if layout in {torch.sparse_csr, torch.sparse_bsr}:
|
||||||
|
f_compressed_indices, f_plain_indices = f.crow_indices(), f.col_indices()
|
||||||
|
compressed_indices, plain_indices = t.crow_indices(), t.col_indices()
|
||||||
|
else:
|
||||||
|
f_compressed_indices, f_plain_indices = f.ccol_indices(), f.row_indices()
|
||||||
|
compressed_indices, plain_indices = t.ccol_indices(), t.row_indices()
|
||||||
|
f_values = f.values()
|
||||||
|
values = t.values()
|
||||||
|
batch_dims = len(compressed_indices.shape) - 1
|
||||||
|
self.assertEqual(f_compressed_indices.layout, compressed_indices.layout)
|
||||||
|
self.assertEqual(f_compressed_indices.shape, compressed_indices.shape)
|
||||||
|
self.assertEqual(f_compressed_indices.dtype, compressed_indices.dtype)
|
||||||
|
self.assertEqual(f_compressed_indices.device, compressed_indices.device)
|
||||||
|
|
||||||
|
self.assertEqual(f_plain_indices.layout, plain_indices.layout)
|
||||||
|
self.assertEqual(f_plain_indices.shape, (*plain_indices.shape[:-1], nnz))
|
||||||
|
self.assertEqual(f_plain_indices.dtype, plain_indices.dtype)
|
||||||
|
self.assertEqual(f_plain_indices.device, plain_indices.device)
|
||||||
|
|
||||||
|
batch_dim = plain_indices.ndim - 1
|
||||||
|
self.assertEqual(f_values.layout, values.layout)
|
||||||
|
self.assertEqual(f_values.shape, (*values.shape[:batch_dim], nnz, *values.shape[batch_dim + 1:]))
|
||||||
|
self.assertEqual(f_values.dtype, values.dtype)
|
||||||
|
self.assertEqual(f_values.device, values.device)
|
||||||
|
|
||||||
|
|
||||||
class _SparseDataset(torch.utils.data.Dataset):
|
class _SparseDataset(torch.utils.data.Dataset):
|
||||||
# An utility class used in TestSparseAny.test_dataloader method.
|
# An utility class used in TestSparseAny.test_dataloader method.
|
||||||
|
@ -27,6 +27,7 @@ from torch._guards import GuardSource, TracingContext
|
|||||||
from torch._ops import HigherOrderOperator
|
from torch._ops import HigherOrderOperator
|
||||||
from torch._streambase import _EventBase, _StreamBase
|
from torch._streambase import _EventBase, _StreamBase
|
||||||
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
|
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
|
||||||
|
from torch._subclasses.meta_utils import is_sparse_any
|
||||||
from torch.fx.experimental._backward_state import BackwardState
|
from torch.fx.experimental._backward_state import BackwardState
|
||||||
from torch.fx.experimental.symbolic_shapes import (
|
from torch.fx.experimental.symbolic_shapes import (
|
||||||
_constrain_range_for_size,
|
_constrain_range_for_size,
|
||||||
@ -1059,6 +1060,11 @@ class VariableBuilder:
|
|||||||
):
|
):
|
||||||
unimplemented("torch.compile does not support strided NestedTensor")
|
unimplemented("torch.compile does not support strided NestedTensor")
|
||||||
|
|
||||||
|
if is_sparse_any(value):
|
||||||
|
unimplemented(
|
||||||
|
f"torch.compile does not support sparse Tensor with {value.layout} layout"
|
||||||
|
)
|
||||||
|
|
||||||
tensor_variable = wrap_fx_proxy(
|
tensor_variable = wrap_fx_proxy(
|
||||||
tx=self.tx,
|
tx=self.tx,
|
||||||
proxy=tensor_proxy,
|
proxy=tensor_proxy,
|
||||||
|
@ -25,6 +25,7 @@ import torch._numpy as tnp
|
|||||||
import torch.fx
|
import torch.fx
|
||||||
import torch.random
|
import torch.random
|
||||||
from torch._dynamo import compiled_autograd
|
from torch._dynamo import compiled_autograd
|
||||||
|
from torch._subclasses.meta_utils import is_sparse_any
|
||||||
|
|
||||||
from torch.fx.experimental.symbolic_shapes import (
|
from torch.fx.experimental.symbolic_shapes import (
|
||||||
guard_scalar,
|
guard_scalar,
|
||||||
@ -149,7 +150,11 @@ class TensorVariable(VariableTracker):
|
|||||||
"is_sparse": value.is_sparse,
|
"is_sparse": value.is_sparse,
|
||||||
"class_type": type(value),
|
"class_type": type(value),
|
||||||
}
|
}
|
||||||
if not has_free_symbols(value):
|
if is_sparse_any(value) and not has_free_symbols(value):
|
||||||
|
props["size"] = tuple(
|
||||||
|
[int(s) if is_symbolic(s) else s for s in value.size()]
|
||||||
|
)
|
||||||
|
elif not has_free_symbols(value):
|
||||||
# this is a fully static shape, and the keys on props here inform specialization.
|
# this is a fully static shape, and the keys on props here inform specialization.
|
||||||
# We have to cast to int here, because these might get accessed as ConstantVariable, which has
|
# We have to cast to int here, because these might get accessed as ConstantVariable, which has
|
||||||
# a strict no-symint policy. If we got here due to not having free symbols, this is a known constant
|
# a strict no-symint policy. If we got here due to not having free symbols, this is a known constant
|
||||||
|
@ -23,6 +23,7 @@ from torch._subclasses.meta_utils import (
|
|||||||
assert_eq,
|
assert_eq,
|
||||||
assert_metadata_eq,
|
assert_metadata_eq,
|
||||||
is_sparse_any,
|
is_sparse_any,
|
||||||
|
is_sparse_compressed,
|
||||||
MetaConverter,
|
MetaConverter,
|
||||||
)
|
)
|
||||||
from torch._utils import render_call
|
from torch._utils import render_call
|
||||||
@ -1057,6 +1058,8 @@ class FakeTensorMode(TorchDispatchMode):
|
|||||||
raise _BypassDispatchCache("constant attribute")
|
raise _BypassDispatchCache("constant attribute")
|
||||||
if arg.is_sparse:
|
if arg.is_sparse:
|
||||||
raise _BypassDispatchCache("sparse tensor")
|
raise _BypassDispatchCache("sparse tensor")
|
||||||
|
if is_sparse_compressed(arg):
|
||||||
|
raise _BypassDispatchCache("sparse compressed tensor")
|
||||||
result.append(extract_tensor_metadata(arg))
|
result.append(extract_tensor_metadata(arg))
|
||||||
elif isinstance(arg, torch.Tensor):
|
elif isinstance(arg, torch.Tensor):
|
||||||
raise _BypassDispatchCache("non-fake tensor")
|
raise _BypassDispatchCache("non-fake tensor")
|
||||||
@ -1099,6 +1102,9 @@ class FakeTensorMode(TorchDispatchMode):
|
|||||||
if output.is_sparse:
|
if output.is_sparse:
|
||||||
raise _BypassDispatchCache("sparse output")
|
raise _BypassDispatchCache("sparse output")
|
||||||
|
|
||||||
|
if is_sparse_compressed(output):
|
||||||
|
raise _BypassDispatchCache("sparse compressed output")
|
||||||
|
|
||||||
# Can an in-place op really reference a kwarg? If so, then we need
|
# Can an in-place op really reference a kwarg? If so, then we need
|
||||||
# to extend the implementation to handle it.
|
# to extend the implementation to handle it.
|
||||||
for kval in kwargs.values():
|
for kval in kwargs.values():
|
||||||
|
@ -152,7 +152,7 @@ class MetaConverter:
|
|||||||
# hold a weak ref to self, otherwise it will be kept alive
|
# hold a weak ref to self, otherwise it will be kept alive
|
||||||
# by the del_ten closure
|
# by the del_ten closure
|
||||||
self_weak_ref = weakref.ref(self)
|
self_weak_ref = weakref.ref(self)
|
||||||
if t.is_sparse or t.is_mkldnn or is_functorch_wrapped_tensor(t):
|
if is_sparse_any(t) or t.is_mkldnn or is_functorch_wrapped_tensor(t):
|
||||||
weak_st = None
|
weak_st = None
|
||||||
else:
|
else:
|
||||||
weak_st = StorageWeakRef(t._typed_storage())
|
weak_st = StorageWeakRef(t._typed_storage())
|
||||||
@ -298,6 +298,10 @@ class MetaConverter:
|
|||||||
with torch.inference_mode(t.is_inference()):
|
with torch.inference_mode(t.is_inference()):
|
||||||
if t.is_sparse:
|
if t.is_sparse:
|
||||||
is_leaf = safe_is_leaf(t)
|
is_leaf = safe_is_leaf(t)
|
||||||
|
|
||||||
|
# The lambda function below is similar to
|
||||||
|
# `t.to(device='meta')` except the latter
|
||||||
|
# preserves nnz value
|
||||||
r = callback(
|
r = callback(
|
||||||
lambda: torch.ops.aten._sparse_coo_tensor_with_dims(
|
lambda: torch.ops.aten._sparse_coo_tensor_with_dims(
|
||||||
t.sparse_dim(),
|
t.sparse_dim(),
|
||||||
@ -321,6 +325,64 @@ class MetaConverter:
|
|||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
r = r.clone()
|
r = r.clone()
|
||||||
r._coalesced_(t.is_coalesced())
|
r._coalesced_(t.is_coalesced())
|
||||||
|
elif is_sparse_compressed(t):
|
||||||
|
is_leaf = safe_is_leaf(t)
|
||||||
|
|
||||||
|
def mk_meta():
|
||||||
|
nnz = 0
|
||||||
|
batch_dim = t.ndim - t.sparse_dim() - t.dense_dim()
|
||||||
|
batch_size = t.shape[:batch_dim]
|
||||||
|
if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
|
||||||
|
index_dtype = t.crow_indices().dtype
|
||||||
|
compressed_indices = torch.empty(
|
||||||
|
t.crow_indices().shape, device="meta", dtype=index_dtype
|
||||||
|
)
|
||||||
|
plain_indices = torch.empty(
|
||||||
|
(*t.col_indices().shape[:-1], nnz),
|
||||||
|
device="meta",
|
||||||
|
dtype=index_dtype,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
index_dtype = t.ccol_indices().dtype
|
||||||
|
compressed_indices = torch.empty(
|
||||||
|
t.ccol_indices().shape, device="meta", dtype=index_dtype
|
||||||
|
)
|
||||||
|
plain_indices = torch.empty(
|
||||||
|
(*t.row_indices().shape[:-1], nnz),
|
||||||
|
device="meta",
|
||||||
|
dtype=index_dtype,
|
||||||
|
)
|
||||||
|
values_shape = t.values().shape
|
||||||
|
values = torch.empty(
|
||||||
|
(
|
||||||
|
*values_shape[:batch_dim],
|
||||||
|
nnz,
|
||||||
|
*values_shape[batch_dim + 1 :],
|
||||||
|
),
|
||||||
|
dtype=t.dtype,
|
||||||
|
device="meta",
|
||||||
|
)
|
||||||
|
return torch.ops.aten.sparse_compressed_tensor(
|
||||||
|
compressed_indices,
|
||||||
|
plain_indices,
|
||||||
|
values,
|
||||||
|
t.shape,
|
||||||
|
layout=t.layout,
|
||||||
|
dtype=t.dtype,
|
||||||
|
device="meta",
|
||||||
|
)
|
||||||
|
|
||||||
|
# `mk_meta()` is similar to `t.to(device='meta'))`
|
||||||
|
# except `to('meta')` preserves nnz value while
|
||||||
|
# `mk_meta` result has nnz == 0.
|
||||||
|
r = callback(mk_meta)
|
||||||
|
|
||||||
|
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
|
||||||
|
if t.requires_grad:
|
||||||
|
r.requires_grad = True
|
||||||
|
if t.requires_grad and not is_leaf:
|
||||||
|
with torch.enable_grad():
|
||||||
|
r = r.clone()
|
||||||
elif t.is_nested and not is_traceable_wrapper_subclass(t):
|
elif t.is_nested and not is_traceable_wrapper_subclass(t):
|
||||||
# TODO: Handle this better in Dynamo?
|
# TODO: Handle this better in Dynamo?
|
||||||
# There are checks there now, but this can still be triggered by a dense
|
# There are checks there now, but this can still be triggered by a dense
|
||||||
@ -679,8 +741,6 @@ class MetaConverter:
|
|||||||
if isinstance(t, torch.Tensor) or is_traceable_wrapper_subclass(t):
|
if isinstance(t, torch.Tensor) or is_traceable_wrapper_subclass(t):
|
||||||
if t.device.type != "xla" and any(
|
if t.device.type != "xla" and any(
|
||||||
[
|
[
|
||||||
t.is_sparse_csr,
|
|
||||||
t.layout in [torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc],
|
|
||||||
t.is_quantized,
|
t.is_quantized,
|
||||||
t._is_view() and t._base is not None and t._base.is_sparse,
|
t._is_view() and t._base is not None and t._base.is_sparse,
|
||||||
torch._is_functional_tensor(t),
|
torch._is_functional_tensor(t),
|
||||||
|
@ -280,6 +280,8 @@ def _iterate_exprs(val: Union[SymInt, torch.Tensor]) -> Iterable[sympy.Basic]:
|
|||||||
yield val
|
yield val
|
||||||
elif isinstance(val, (int, float, bool)):
|
elif isinstance(val, (int, float, bool)):
|
||||||
pass
|
pass
|
||||||
|
elif is_sparse_any(val):
|
||||||
|
yield from _iterate_exprs(val.size())
|
||||||
elif isinstance(val, torch.Tensor):
|
elif isinstance(val, torch.Tensor):
|
||||||
yield from _iterate_exprs(val.size())
|
yield from _iterate_exprs(val.size())
|
||||||
yield from _iterate_exprs(val.stride())
|
yield from _iterate_exprs(val.stride())
|
||||||
|
@ -11793,6 +11793,12 @@ op_db: List[OpInfo] = [
|
|||||||
DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'),
|
DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'),
|
||||||
# ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ...
|
# ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ...
|
||||||
DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
|
DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
|
||||||
|
# NotImplementedError: Could not run 'aten::sparse_sampled_addmm' with arguments from the 'SparseCsrMeta' backend.
|
||||||
|
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace'),
|
||||||
|
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace'),
|
||||||
|
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_meta_outplace'),
|
||||||
|
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'),
|
||||||
|
DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_crossref_backward_no_amp'),
|
||||||
)),
|
)),
|
||||||
OpInfo('sparse.mm',
|
OpInfo('sparse.mm',
|
||||||
dtypes=floating_types_and(torch.bfloat16),
|
dtypes=floating_types_and(torch.bfloat16),
|
||||||
@ -11836,6 +11842,10 @@ op_db: List[OpInfo] = [
|
|||||||
DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
|
DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
|
||||||
# ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ...
|
# ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ...
|
||||||
DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_fail_gradgrad'),
|
DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_fail_gradgrad'),
|
||||||
|
# NotImplementedError: Could not run 'aten::_sparse_mm_reduce_impl' with arguments from the 'SparseCsrMeta' backend
|
||||||
|
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace'),
|
||||||
|
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace'),
|
||||||
|
DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_meta_outplace'),
|
||||||
)),
|
)),
|
||||||
UnaryUfuncInfo('i0',
|
UnaryUfuncInfo('i0',
|
||||||
ref=np_unary_ufunc_integer_promotion_wrapper(
|
ref=np_unary_ufunc_integer_promotion_wrapper(
|
||||||
|
Reference in New Issue
Block a user