mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Adds support for SymInts in the FakeTensor cache. A couple notes: 1. When a SymInt is present in the input key for a FakeTensor operation we cache on the ShapeEnv instead of using the FakeTensorMode cache. This is necessary so we don't have to remember and check the guards. It reduces the cache hits but there's diminishing return on how much work we can do before the cache becomes more of a burden than a gain. 2. We need to be careful that when we cache an output SymInt that is a direct copy from the input that when we have a cache-hit we copy the SymNode from the input to the output. This is important because the fx-graph building code actually uses SymNode ids in the process of building the graph so constructing a same-content-but-different-id SymNode will fail. 3. In the cache key we store SymInts as a _PySymInputStub. These represent SymInt (and friends) but support `__hash__` and `__eq__` (which SymInt do not). 4. In the cache entry we store SymInts as a _SymIntOutputStub. Perf example: ``` python benchmarks/dynamo/timm_models.py --ci --accuracy --timing --explain --inductor --dynamic-shapes --dynamic-batch-only --device cuda --training --amp --total-partitions 2 --partition-id 0 --output /tmp/training_timm_models.csv --filter crossvit_9_240 ``` fake tensor cache before: ``` INFO: FakeTensor cache stats: INFO: cache_hits: 68137 INFO: cache_misses: 837 INFO: cache_bypasses: INFO: symbolic shape: 48224 INFO: CompositeImplicitAutograd: 917 INFO: non-fake tensor: 70 INFO: non-FakeTensor output: 62 INFO: non-builtin: 8 INFO: dynamic output shape: 1 ``` and after: ``` INFO: FakeTensor cache stats: INFO: cache_hits: 88187 INFO: cache_misses: 14233 INFO: cache_bypasses: INFO: CompositeImplicitAutograd: 1037 INFO: non-FakeTensor output: 602 INFO: non-fake tensor: 70 INFO: unsafe view: 36 INFO: non-builtin: 8 INFO: dynamic output shape: 1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127596 Approved by: https://github.com/eellison ghstack dependencies: #131014, #129780
1665 lines
74 KiB
Python
1665 lines
74 KiB
Python
# mypy: allow-untyped-defs
|
|
from __future__ import annotations
|
|
|
|
import contextlib
|
|
|
|
import dataclasses
|
|
import warnings
|
|
import weakref
|
|
from dataclasses import dataclass
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
ClassVar,
|
|
ContextManager,
|
|
Dict,
|
|
List,
|
|
Optional,
|
|
Tuple,
|
|
Type,
|
|
TYPE_CHECKING,
|
|
Union,
|
|
)
|
|
from typing_extensions import TypeAlias
|
|
|
|
import torch
|
|
from torch._C._autograd import CreationMeta
|
|
from torch._C._functorch import (
|
|
_add_batch_dim,
|
|
_unwrap_functional_tensor,
|
|
_wrap_functional_tensor,
|
|
get_unwrapped,
|
|
is_batchedtensor,
|
|
is_functorch_wrapped_tensor,
|
|
is_gradtrackingtensor,
|
|
is_legacy_batchedtensor,
|
|
maybe_get_bdim,
|
|
maybe_get_level,
|
|
peek_interpreter_stack,
|
|
)
|
|
from torch._logging import trace_structured
|
|
from torch.utils._mode_utils import no_dispatch
|
|
|
|
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
|
from torch.utils.weak import WeakIdKeyDictionary
|
|
|
|
if TYPE_CHECKING:
|
|
from torch._C._functorch import CInterpreter
|
|
from torch._guards import Source
|
|
|
|
# Import here to avoid cycle
|
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
|
|
# Import the following modules during type checking to enable code intelligence features,
|
|
# Do not import unconditionally, as they import sympy and importing sympy is very slow
|
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymbolicContext
|
|
|
|
DimList = List
|
|
|
|
|
|
def safe_is_leaf(t):
|
|
try:
|
|
return t.is_leaf
|
|
except RuntimeError:
|
|
# inference mode can trigger this
|
|
return False
|
|
|
|
|
|
def safe_grad(t):
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings("ignore", "The .grad attribute of a Tensor")
|
|
return t.grad
|
|
|
|
|
|
def assert_eq(a, b):
|
|
assert a == b, f"{a} != {b}"
|
|
|
|
|
|
def assert_metadata_eq(
|
|
assert_eq,
|
|
m1: Union[MetaTensorDesc, torch.Tensor],
|
|
m2: torch.Tensor,
|
|
*,
|
|
skip_symbolic=False,
|
|
skip_leaf=False,
|
|
):
|
|
if isinstance(m1, torch.Tensor):
|
|
m1 = MetaTensorDescriber().describe_tensor(m1)
|
|
|
|
def go(m1, m2):
|
|
assert_eq(m1.dtype, m2.dtype)
|
|
if not skip_symbolic:
|
|
assert_eq(m1.shape, m2.shape)
|
|
assert_eq(m1.requires_grad, m2.requires_grad)
|
|
if not skip_leaf:
|
|
assert_eq(m1.is_leaf, m2.is_leaf)
|
|
# MetaTensorDesc doesn't store grad_fn; inferred from leaf
|
|
# assert_eq(m1.grad_fn is None, m2.grad_fn is None)
|
|
assert_eq(m1.is_sparse, m2.is_sparse)
|
|
assert_eq(m1.is_inference, m2.is_inference())
|
|
assert_eq(m1.is_conj, m2.is_conj())
|
|
assert_eq(m1.is_neg, m2.is_neg())
|
|
assert_eq(m1.grad is not None, safe_grad(m2) is not None)
|
|
if m1.grad is not None:
|
|
go(m1.grad, safe_grad(m2))
|
|
if m1.is_sparse:
|
|
assert_eq(m1.dense_dim, m2.dense_dim())
|
|
assert_eq(m1.sparse_dim, m2.sparse_dim())
|
|
assert_eq(m1.is_coalesced, m2.is_coalesced())
|
|
else:
|
|
if not skip_symbolic:
|
|
assert_eq(m1.stride, m2.stride())
|
|
assert_eq(m1.storage_offset, m2.storage_offset())
|
|
assert_eq(m1.is_view, m2._is_view())
|
|
if m1.is_view:
|
|
go(m1.base, m2._base)
|
|
# TODO: test if is resizable (no direct query for this atm)
|
|
# TODO: audit AutogradMeta to see if it matches
|
|
# TODO: test forward AD
|
|
|
|
return go(m1, m2)
|
|
|
|
|
|
def is_sparse_coo(t):
|
|
return isinstance(t, torch.Tensor) and t.layout is torch.sparse_coo
|
|
|
|
|
|
def is_sparse_compressed_layout(layout):
|
|
return layout in {
|
|
torch.sparse_csr,
|
|
torch.sparse_csc,
|
|
torch.sparse_bsr,
|
|
torch.sparse_bsc,
|
|
}
|
|
|
|
|
|
def is_sparse_compressed(t):
|
|
return isinstance(t, torch.Tensor) and is_sparse_compressed_layout(t.layout)
|
|
|
|
|
|
def is_sparse_any(t):
|
|
return is_sparse_coo(t) or is_sparse_compressed(t)
|
|
|
|
|
|
# Don't use id() directly, because those can get reallocated over time.
|
|
MetaStorageId: TypeAlias = int
|
|
MetaTensorId: TypeAlias = int
|
|
|
|
|
|
DESCRIBER_NEXT_ID = 0
|
|
|
|
|
|
class MetaTensorDescriber:
|
|
"""
|
|
Given a Tensor/Storage, generate a MetaTensorDesc/MetaStorageDesc
|
|
for it, which is enough information to reconstruct a meta tensor/fake tensor
|
|
corresponding to a Tensor as faithfully as possible.
|
|
|
|
This is a stateful conversion object because we keep track of the IDs
|
|
of the tensors/storages passed to us, so we can consistently give
|
|
the same ID when we see the same tensor/storage.
|
|
"""
|
|
|
|
def __init__(self, *, copy_data=False):
|
|
global DESCRIBER_NEXT_ID
|
|
self.id = DESCRIBER_NEXT_ID
|
|
DESCRIBER_NEXT_ID += 1
|
|
self.next_tensor_id: MetaTensorId = 0
|
|
self.next_storage_id: MetaStorageId = 0
|
|
# Tensor -> int
|
|
self.lookup_tensor = WeakIdKeyDictionary()
|
|
# Storage -> int
|
|
self.lookup_storage = WeakIdKeyDictionary()
|
|
self.copy_data = copy_data
|
|
self.traced_tensors = set()
|
|
self.traced_storages = set()
|
|
|
|
def get_tensor_id(self, t: torch.Tensor):
|
|
if t not in self.lookup_tensor:
|
|
self.lookup_tensor[t] = self.next_tensor_id
|
|
self.next_tensor_id += 1
|
|
return self.lookup_tensor[t]
|
|
|
|
def get_storage_id(self, s: torch.UntypedStorage):
|
|
if s not in self.lookup_storage:
|
|
self.lookup_storage[s] = self.next_storage_id
|
|
self.next_storage_id += 1
|
|
return self.lookup_storage[s]
|
|
|
|
def describe_storage(self, s: torch.UntypedStorage, *, trace: bool = False):
|
|
r = MetaStorageDesc(
|
|
id=self.get_storage_id(s),
|
|
size=s.size(),
|
|
# NB: We don't do the copy yet; copy happens when we start
|
|
# creating the new storages
|
|
data=s if self.copy_data else None,
|
|
)
|
|
if trace and r.id not in self.traced_storages:
|
|
trace_structured(
|
|
"describe_storage",
|
|
metadata_fn=lambda: r.as_json(self.id),
|
|
)
|
|
self.traced_storages.add(r.id)
|
|
return r
|
|
|
|
def describe_tensor(
|
|
self, t: torch.Tensor, *, recurse: bool = True, trace: bool = False
|
|
):
|
|
is_leaf = safe_is_leaf(t)
|
|
is_view = t._is_view()
|
|
is_sparse = t.is_sparse
|
|
layout = t.layout
|
|
is_nested = t.is_nested
|
|
is_traceable_wrapper_subclass_v = is_traceable_wrapper_subclass(t)
|
|
is_functorch_wrapped = is_functorch_wrapped_tensor(t)
|
|
is_mkldnn = t.is_mkldnn
|
|
is_batchedtensor_v = is_batchedtensor(t)
|
|
is_legacy_batchedtensor_v = is_legacy_batchedtensor(t)
|
|
is_gradtrackingtensor_v = is_gradtrackingtensor(t)
|
|
is_functorch_batched_or_grad = is_batchedtensor_v or is_gradtrackingtensor_v
|
|
is_functional = torch._is_functional_tensor(t)
|
|
|
|
storage = None
|
|
# NB: For compatibility, I default this to zero, as sometimes people
|
|
# still have stuffed zero into storage offset even though the tensor
|
|
# doesn't meaningfully have an offset
|
|
storage_offset = 0
|
|
if not (
|
|
is_sparse
|
|
or is_sparse_compressed_layout(layout)
|
|
or (is_nested and not is_traceable_wrapper_subclass_v)
|
|
or is_mkldnn
|
|
# TODO: TBH, functorch wrapped tensors probably should have
|
|
# storage associated with them
|
|
or is_functorch_wrapped
|
|
or is_legacy_batchedtensor_v
|
|
):
|
|
# NB: We actually don't use storage to do views, but might as well
|
|
# put it in for accuracy
|
|
storage = self.describe_storage(t.untyped_storage(), trace=trace)
|
|
storage_offset = t.storage_offset() # type: ignore[assignment]
|
|
|
|
stride = None
|
|
if not (
|
|
is_sparse
|
|
or is_sparse_compressed_layout(layout)
|
|
or (is_nested and not is_traceable_wrapper_subclass_v)
|
|
):
|
|
# stride/storage_offset are called from is_functorch_wrapped,
|
|
# view_from_base, empty_create_subclass,
|
|
# sym_sizes_strides_storage_offset (empty_create)
|
|
stride = t.stride()
|
|
|
|
# NB: this technically should refer to functorch unwrapped tensor, but
|
|
# I am (perhaps abusively) using it to store both the functorch and
|
|
# non-functorch functional tensor
|
|
unwrapped = None
|
|
autograd_meta_from = None
|
|
current_level = None
|
|
if is_batchedtensor_v or is_gradtrackingtensor_v:
|
|
unwrapped = self.describe_tensor(get_unwrapped(t), trace=trace)
|
|
# xla and lazy tensors present as functional tensors, but we want them
|
|
# to be handled specially
|
|
elif is_functional and t.device.type not in ("xla", "lazy"):
|
|
if t._is_view():
|
|
raise RuntimeError(
|
|
"Cannot safely fakify a view because this process drops the view information right now."
|
|
)
|
|
if not is_functorch_wrapped:
|
|
torch._sync(t)
|
|
unwrapped = self.describe_tensor(
|
|
torch._from_functional_tensor(t), trace=trace
|
|
)
|
|
autograd_meta_from = t
|
|
else:
|
|
reapply_views = torch._C._functionalization_reapply_views_tls()
|
|
# NB: has side effects!
|
|
unwrapped = self.describe_tensor(
|
|
_unwrap_functional_tensor(t, reapply_views), trace=trace
|
|
)
|
|
# TODO: It's pretty suspicious that functional tensors don't have
|
|
# valid level and thus we just grab whatever the current level
|
|
# is
|
|
current_level = torch._C._functorch.current_level()
|
|
|
|
maybe_functorch_stack = None
|
|
if is_functorch_wrapped:
|
|
with torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack() as maybe_functorch_stack:
|
|
pass
|
|
|
|
attrs = None
|
|
ctx = None
|
|
type_v = None
|
|
if is_traceable_wrapper_subclass_v:
|
|
assert hasattr(t, "__tensor_flatten__")
|
|
raw_attrs, ctx = t.__tensor_flatten__()
|
|
attrs = {
|
|
attr: self.describe_tensor(getattr(t, attr), trace=trace)
|
|
for attr in raw_attrs
|
|
}
|
|
type_v = type(t)
|
|
|
|
# TODO: Is it important to enable torch.inference_mode before querying
|
|
# these values?
|
|
r = MetaTensorDesc(
|
|
id=self.get_tensor_id(t),
|
|
storage=storage,
|
|
is_inference=t.is_inference(),
|
|
is_leaf=is_leaf,
|
|
requires_grad=t.requires_grad,
|
|
# NB: ndim should be OK too but there is a disaster at
|
|
# python test/dynamo/test_subclasses.py -k test_user_overidden_property_unsupported
|
|
# Actually, this means that we have a little bit of a problem
|
|
# here, which is that there is some sensitivity to how exactly an
|
|
# access is done if you have a __torch_function__ subclass. Maybe
|
|
# should disable torch function before doing accesses?
|
|
ndim=t.dim(),
|
|
dtype=t.dtype,
|
|
is_sparse=is_sparse,
|
|
is_mkldnn=is_mkldnn,
|
|
is_functorch_wrapped=is_functorch_wrapped,
|
|
is_batchedtensor=is_batchedtensor_v,
|
|
is_legacy_batchedtensor=is_legacy_batchedtensor_v,
|
|
is_gradtrackingtensor=is_gradtrackingtensor_v,
|
|
is_view=is_view,
|
|
is_conj=t.is_conj(),
|
|
is_neg=t.is_neg(),
|
|
is_parameter=isinstance(t, torch.nn.Parameter),
|
|
is_traceable_wrapper_subclass=is_traceable_wrapper_subclass_v,
|
|
is_nested=is_nested,
|
|
is_functional=is_functional,
|
|
layout=layout,
|
|
device=t.device,
|
|
size=t.size(),
|
|
stride=stride,
|
|
storage_offset=storage_offset,
|
|
dynamo_dynamic_indices=list(getattr(t, "_dynamo_dynamic_indices", set())),
|
|
sparse_dim=(
|
|
t.sparse_dim() if t.is_sparse or is_sparse_compressed(t) else None
|
|
),
|
|
dense_dim=t.dense_dim() if t.is_sparse or is_sparse_compressed(t) else None,
|
|
is_coalesced=t.is_coalesced() if t.is_sparse else None,
|
|
# TODO: I actually think recursing here is correct, but we have at
|
|
# least an infinite cycle from base -> values -> base
|
|
# https://github.com/pytorch/pytorch/issues/122089
|
|
crow_indices=(
|
|
self.describe_tensor(t.crow_indices(), recurse=False, trace=trace)
|
|
if recurse and t.layout in {torch.sparse_csr, torch.sparse_bsr}
|
|
else None
|
|
),
|
|
col_indices=(
|
|
self.describe_tensor(t.col_indices(), recurse=False, trace=trace)
|
|
if recurse and t.layout in {torch.sparse_csr, torch.sparse_bsr}
|
|
else None
|
|
),
|
|
ccol_indices=(
|
|
self.describe_tensor(t.ccol_indices(), recurse=False, trace=trace)
|
|
if recurse and t.layout in {torch.sparse_csc, torch.sparse_bsc}
|
|
else None
|
|
),
|
|
row_indices=(
|
|
self.describe_tensor(t.row_indices(), recurse=False, trace=trace)
|
|
if recurse and t.layout in {torch.sparse_csc, torch.sparse_bsc}
|
|
else None
|
|
),
|
|
values=(
|
|
self.describe_tensor(t.values(), recurse=False, trace=trace)
|
|
if recurse and is_sparse_compressed(t)
|
|
else None
|
|
),
|
|
grad=(
|
|
self.describe_tensor(safe_grad(t), trace=trace)
|
|
if safe_grad(t) is not None
|
|
else None
|
|
),
|
|
creation_meta=(
|
|
torch._C._autograd._get_creation_meta(t) if t._is_view() else None
|
|
),
|
|
unwrapped=unwrapped,
|
|
level=(
|
|
maybe_get_level(t)
|
|
if is_batchedtensor_v or is_gradtrackingtensor_v
|
|
else None
|
|
),
|
|
bdim=maybe_get_bdim(t) if is_batchedtensor_v else None,
|
|
base=(
|
|
self.describe_tensor(t._base, trace=trace)
|
|
if recurse and t._is_view() and t._base is not None
|
|
else None
|
|
),
|
|
fake_mode=torch._subclasses.fake_tensor.maybe_get_fake_mode(t),
|
|
view_func=t._view_func_unsafe,
|
|
attrs=attrs,
|
|
ctx=ctx,
|
|
type=type_v,
|
|
# NB: even if functorch is enabled, don't actually save the
|
|
# interpreter stack here unless we are actually functorch wrapped;
|
|
# it's irrelevant for non-functorch stuff
|
|
functorch_stack=maybe_functorch_stack,
|
|
autograd_meta_from=autograd_meta_from,
|
|
current_level=current_level,
|
|
data=t if self.copy_data else None,
|
|
)
|
|
if trace and r.id not in self.traced_tensors:
|
|
trace_structured(
|
|
"describe_tensor",
|
|
metadata_fn=lambda: r.as_json(self.id),
|
|
)
|
|
self.traced_tensors.add(r.id)
|
|
return r
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class MetaStorageDesc:
|
|
id: MetaStorageId
|
|
size: int
|
|
# NB: this is only populated with copy_data True, it is not directly
|
|
# serializable in JSON, you want to do something special here anyway
|
|
data: Optional[torch.UntypedStorage]
|
|
|
|
def as_json(self, describer_id):
|
|
return {
|
|
"id": self.id,
|
|
"describer_id": describer_id,
|
|
"size": self.size if isinstance(self.size, int) else repr(self.size),
|
|
}
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class MetaTensorDesc:
|
|
id: MetaTensorId
|
|
ndim: int
|
|
dtype: torch.dtype
|
|
device: torch.device
|
|
|
|
# NB: Sometimes, size, stride and storage_offset contain SymInt, in which
|
|
# case this is NOT serializable. That only happens when you're
|
|
# re-fakeifying a fake tensor with an existing ShapeEnv... maybe we
|
|
# can get rid of this use case entirely. Notably, even if we are
|
|
# fakeifying a real tensor into a fake tensor with symbolic shapes, the
|
|
# size here is NOT dynamic
|
|
# NB: These also contain SymInt because wrap_meta_outputs_with_default_device_logic
|
|
# goes through this codepath. But it really should not LOL.
|
|
# NB: size could potentially be None as you can override it and make it
|
|
# throw an error, but we don't currently have any subclasses that do this
|
|
# except C++ nested tensor but we're going to have nested int to make this
|
|
# defined on NJT
|
|
size: Tuple[int, ...]
|
|
dynamo_dynamic_indices: List[int]
|
|
|
|
layout: torch.layout = torch.strided
|
|
is_inference: bool = False
|
|
is_leaf: bool = False
|
|
requires_grad: bool = False
|
|
is_sparse: bool = False
|
|
is_mkldnn: bool = False
|
|
is_functorch_wrapped: bool = False
|
|
is_batchedtensor: bool = False
|
|
is_legacy_batchedtensor: bool = False
|
|
is_gradtrackingtensor: bool = False
|
|
is_view: bool = False
|
|
is_nested: bool = False
|
|
is_traceable_wrapper_subclass: bool = False
|
|
is_functional: bool = False
|
|
is_conj: bool = False
|
|
is_neg: bool = False
|
|
is_parameter: bool = False
|
|
stride: Optional[Tuple[int, ...]] = None
|
|
storage_offset: int = 0
|
|
# NB: We have a choice whether or not to store the id or a direct pointer
|
|
# to the data structure. For ease of use, we store the data structure,
|
|
# but this means that when we serialize, we have to swizzle these pointers
|
|
# back into ids (so we have accurate aliasing relationships)
|
|
storage: Optional[MetaStorageDesc] = None
|
|
sparse_dim: Optional[int] = None # is_sparse, is_sparse_compressed
|
|
dense_dim: Optional[int] = None # is_sparse, is_sparse_compressed
|
|
is_coalesced: Optional[bool] = None # is_sparse
|
|
crow_indices: Optional[MetaTensorDesc] = None # is_sparse_compressed
|
|
col_indices: Optional[MetaTensorDesc] = None # is_sparse_compressed
|
|
ccol_indices: Optional[MetaTensorDesc] = None # is_sparse_compressed
|
|
row_indices: Optional[MetaTensorDesc] = None # is_sparse_compressed
|
|
values: Optional[MetaTensorDesc] = None # is_sparse_compressed
|
|
unwrapped: Optional[MetaTensorDesc] = None # is_functorch_wrapped
|
|
bdim: Optional[int] = None # is_functorch_wrapped
|
|
base: Optional[MetaTensorDesc] = None # is_view
|
|
attrs: Optional[Dict[str, MetaTensorDesc]] = None # is_traceable_wrapper_subclass
|
|
creation_meta: Optional[CreationMeta] = None
|
|
grad: Optional[MetaTensorDesc] = None
|
|
|
|
# Everything below is NOT serializable, need some more work
|
|
|
|
_UNSERIALIZABLE: ClassVar[List[str]] = [
|
|
"ctx",
|
|
"type",
|
|
"fake_mode",
|
|
"view_func",
|
|
"level",
|
|
"current_level",
|
|
"functorch_stack",
|
|
"autograd_meta_from",
|
|
"data",
|
|
]
|
|
|
|
ctx: Optional[object] = None # is_traceable_wrapper_subclass
|
|
type: Optional[Type] = None # is_traceable_wrapper_subclass
|
|
fake_mode: Optional[FakeTensorMode] = None
|
|
view_func: Optional[
|
|
Callable[
|
|
[
|
|
torch.Tensor,
|
|
Callable[[int], int],
|
|
Callable[[torch.Tensor], torch.Tensor],
|
|
],
|
|
torch.Tensor,
|
|
]
|
|
] = None
|
|
# level looks serializable, but actually it is meaningless without
|
|
# the functorch_stack below
|
|
level: Optional[int] = None # is_functorch_wrapped
|
|
current_level: Optional[int] = None
|
|
functorch_stack: Optional[List[CInterpreter]] = None
|
|
autograd_meta_from: Optional[torch.Tensor] = None
|
|
|
|
# This is only populated on copy_data, and typically is not used at all,
|
|
# except for some of our meta-ification paths that don't properly use
|
|
# storage (pro-tip: you should use storage)
|
|
data: Optional[torch.Tensor] = None
|
|
|
|
# Faithfully serializing functorch tensors will not be too difficult.
|
|
# We only need to consider grad/vmap interpreters, and their internal
|
|
# state is only bools (mostly what the grad enabled/disabled state
|
|
# should be in the lower layer). Beyond that, tensors just need to
|
|
# precisely indicate which particular interpreter they correspond
|
|
# to (we then replace level with a pointer to the interpreter stack.)
|
|
# However, this use of functorch is very "non-lexical" so it's not
|
|
# entirely clear how to make it all lexical again, so we haven't done
|
|
# it for now.
|
|
|
|
# NB: This will reference numeric IDs, and it is assumed that you've
|
|
# already serialized everything this recursively references
|
|
def as_json(self, describer_id):
|
|
def json(k, v):
|
|
# Some best-effort debugging serialization for unserializable
|
|
# fields (feel free to add other special cases as appropriate)
|
|
if k in ["data", "autograd_meta_from"]:
|
|
return None # never repr these
|
|
if k in set(MetaTensorDesc._UNSERIALIZABLE):
|
|
return repr(v)
|
|
if isinstance(v, (torch.device, torch.dtype, torch.layout)):
|
|
return repr(v)
|
|
if isinstance(v, torch.SymInt):
|
|
return repr(v)
|
|
if isinstance(v, (tuple, list)):
|
|
return [json(k, v1) for v1 in v]
|
|
if isinstance(v, (MetaStorageDesc, MetaTensorDesc)):
|
|
return v.id
|
|
if isinstance(v, CreationMeta):
|
|
return str(v)
|
|
if k == "attrs" and isinstance(v, dict):
|
|
return {k1: v1.id for k1, v1 in v.items()}
|
|
return v
|
|
|
|
r = {
|
|
field.name: json(field.name, getattr(self, field.name))
|
|
for field in dataclasses.fields(self)
|
|
if not (
|
|
getattr(self, field.name) is field.default
|
|
or (
|
|
field.name == "dynamo_dynamic_indices"
|
|
and not getattr(self, field.name)
|
|
)
|
|
)
|
|
}
|
|
r.update({"describer_id": describer_id})
|
|
return r
|
|
|
|
@property
|
|
def shape(self):
|
|
return self.size
|
|
|
|
|
|
# A more faithful reproduction would do a copy on the entire
|
|
# storage, but this needs to be done carefully because the
|
|
# underlying storage could have larger extent than is implied
|
|
# by size/stride. The real fix is to properly call
|
|
# meta_storage recursively here.
|
|
#
|
|
# These "safe" functions are intended to be used under no_dispatch() mode.
|
|
# The no_dispatch() here is intended to prevent ambient fake tensor mode from
|
|
# fakeifying the operation. But if we are given an honest to goodness
|
|
# FakeTensor as src, we MUST NOT run the copy/clone operation. A better way
|
|
# to do this would be to not use no_dispatch and instead just disable fake
|
|
# tensor mode only (allowing for subclass dispatch to occur)
|
|
def _safe_copy(dst, src):
|
|
if type(src) is not torch.Tensor:
|
|
return
|
|
dst.copy_(src)
|
|
|
|
|
|
def _safe_clone(src):
|
|
if type(src) is not torch.Tensor:
|
|
return None
|
|
return src.clone()
|
|
|
|
|
|
# This is a class for converting multiple tensors into meta tensors which
|
|
# share the same view/storage structure. The operation model is you allocate
|
|
# one of these, and then call it repeatedly on all the tensors you want to
|
|
# convert. It's important to use the same object for tensors you want to
|
|
# share storage because this is how we correlate shared storages to the same
|
|
# meta storages. This class will hold weak references to cached tenosrs
|
|
# and tensor storages.
|
|
class MetaConverter:
|
|
def __init__(self, *, copy_data: bool = False):
|
|
# Maps MetaStorageId to UntypedStorage
|
|
self.storage_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
|
|
# Maps MetaTensorId to torch.Tensor (typically a meta tensor or
|
|
# FakeTensor)
|
|
self.tensor_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
|
|
self.hit = 0
|
|
self.miss = 0
|
|
self.del_hook = None
|
|
self.arg_cnt = 0
|
|
# Ensures real_storage/real_tensor are populated on the resulting
|
|
# metaified storage/tensor. The naming of this attribute is load
|
|
# bearing: FakeTensor relies on real tensor being set to exactly this
|
|
# value
|
|
self.copy_data = copy_data
|
|
self.describer = MetaTensorDescriber(copy_data=copy_data)
|
|
|
|
def successful(self):
|
|
return self.hit > 0 and self.miss == 0
|
|
|
|
def get_tensor_memo(self, t: MetaTensorDesc):
|
|
return self.tensor_memo.get(t.id, None)
|
|
|
|
def set_tensor_memo(self, t: MetaTensorDesc, v):
|
|
self.tensor_memo[t.id] = v
|
|
|
|
def get_storage_memo(self, s: MetaStorageDesc):
|
|
return self.storage_memo.get(s.id, None)
|
|
|
|
def set_storage_memo(self, s: MetaStorageDesc, v):
|
|
self.storage_memo[s.id] = v
|
|
|
|
def meta_storage(self, s: MetaStorageDesc, callback):
|
|
# If we are fakeifying a tensor that has a secretly-zero-sized storage,
|
|
# Need to make sure to resize the meta storage too.
|
|
if self.get_storage_memo(s) is None:
|
|
r_s = callback(
|
|
lambda: torch.empty(s.size, dtype=torch.uint8, device="meta"),
|
|
).untyped_storage()
|
|
if self.copy_data:
|
|
# NB: no_dispatch is needed because internally storage copy is
|
|
# implemented as Tensor operations
|
|
with torch.no_grad(), no_dispatch():
|
|
assert s.data is not None
|
|
r_s.real_storage = s.data.clone()
|
|
self.set_storage_memo(s, r_s)
|
|
return r_s
|
|
else:
|
|
return self.get_storage_memo(s)
|
|
|
|
# This function assumes that it's possible to do the conversion
|
|
# NB: name here is used in a conventional way by Dynamo; it corresponds
|
|
# precisely to the Source.name() of the tensor we're fakeifying and
|
|
# corresponds to a valid Python expression. When we construct sub-names
|
|
# as part of this process, we will maintain this invariant! (Even though
|
|
# other users of this may not need it this property to be upheld.)
|
|
def meta_tensor(
|
|
self,
|
|
t: MetaTensorDesc,
|
|
shape_env: Optional[ShapeEnv] = None,
|
|
callback=lambda t: t(),
|
|
source: Optional[Source] = None,
|
|
symbolic_context: Optional[SymbolicContext] = None,
|
|
):
|
|
if source is None:
|
|
from torch._dynamo.source import ConstantSource
|
|
|
|
# TODO: make a dedicated UnknownSource for this?
|
|
source = ConstantSource(
|
|
f"__meta_utils_unknown_tensor{len(self.tensor_memo)}"
|
|
)
|
|
|
|
# This indicates you set no_dispatch() before calling into this
|
|
# function. This is an error: we may be creating fake tensors and
|
|
# will perform operations on them which need fake tensor mode to
|
|
# be active. You will segfault if you are in a no_dispatch() block.
|
|
assert not torch._C._dispatch_tls_local_exclude_set().has(
|
|
torch._C.DispatchKey.Python
|
|
)
|
|
arg_cnt = self.arg_cnt
|
|
self.arg_cnt += 1
|
|
|
|
# When we make as_strided calls, we end up generating a guard
|
|
# that the new as_strided tensor is in bounds for the old storage
|
|
# for the base (since as_strided calls can "bust" out of their
|
|
# bounding box.) This guard is unnecessary: if a user is able
|
|
# to provide us a tensor with the view base setup this way, we
|
|
# don't need to produce a guard, because the fact that they
|
|
# were able to produce the view base means its in bounds.
|
|
#
|
|
# Now, ordinarily, this guard would be harmless. However, the
|
|
# generated guard refers to variables bound on the base variable.
|
|
# At the moment, Dynamo doesn't actually guard on x._base, because
|
|
# according to Voz this results in a lot of spurious invalidations,
|
|
# and also if the user doesn't directly make use of _base, its
|
|
# pointless anyway (because programs should be parametric over
|
|
# whether or not the input tensor is a view or not--unless you're
|
|
# mutating the input, but that's a whole 'nother ballgame). So
|
|
# for expediency, we suppress these guards so we don't have to
|
|
# deal with this (yet, anyway.)
|
|
#
|
|
# NB: An old version of this code suppressed guards for ALL operations
|
|
# happening during meta conversion, not just as_strided calls.
|
|
# This is too aggressive: we do duck sizing and 0/1 simplification
|
|
# as we allocate variables, and we do need to register guards for
|
|
# these cases.
|
|
maybe_suppress: Callable[[], Any] = contextlib.nullcontext
|
|
if shape_env is not None:
|
|
maybe_suppress = shape_env.suppress_guards
|
|
|
|
def sym_sizes_strides_storage_offset(
|
|
t: MetaTensorDesc, src, symbolic_context=symbolic_context
|
|
) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]:
|
|
assert t.stride is not None
|
|
if shape_env is not None:
|
|
fake_mode = t.fake_mode
|
|
if fake_mode is not None and fake_mode.shape_env is shape_env:
|
|
# Don't reallocate the sizes; the shape envs are the same,
|
|
# so reuse the old sizes/strides/etc
|
|
return (t.size, t.stride, t.storage_offset)
|
|
else:
|
|
# TODO: deduplicate this
|
|
t_size = tuple(
|
|
shape_env._maybe_specialize_sym_int_with_hint(sz)
|
|
for sz in t.size
|
|
)
|
|
t_stride = tuple(
|
|
shape_env._maybe_specialize_sym_int_with_hint(sd)
|
|
for sd in t.stride
|
|
)
|
|
t_storage_offset = shape_env._maybe_specialize_sym_int_with_hint(
|
|
t.storage_offset
|
|
)
|
|
return shape_env._create_symbolic_sizes_strides_storage_offset(
|
|
t_size,
|
|
t_stride,
|
|
t_storage_offset,
|
|
[d in t.dynamo_dynamic_indices for d in range(t.ndim)],
|
|
src,
|
|
symbolic_context=symbolic_context,
|
|
)
|
|
else:
|
|
return (t.size, t.stride, t.storage_offset)
|
|
|
|
def empty_create(
|
|
inner_t: MetaTensorDesc, inner_src, symbolic_context=symbolic_context
|
|
):
|
|
(
|
|
inner_sizes,
|
|
inner_strides,
|
|
inner_storage_offset,
|
|
) = sym_sizes_strides_storage_offset(inner_t, inner_src, symbolic_context)
|
|
return torch.empty_strided(
|
|
inner_sizes,
|
|
inner_strides,
|
|
dtype=inner_t.dtype,
|
|
device="meta",
|
|
)
|
|
|
|
# Creates a subclass instance with empty inner tensors according to the specified
|
|
# symbolic context.
|
|
def empty_create_subclass(
|
|
t: MetaTensorDesc,
|
|
outer_size,
|
|
outer_stride,
|
|
symbolic_context=symbolic_context,
|
|
callback=callback,
|
|
source=source,
|
|
):
|
|
from torch._dynamo.source import AttrSource
|
|
from torch.fx.experimental.symbolic_shapes import SubclassSymbolicContext
|
|
|
|
assert t.attrs is not None
|
|
assert t.type is not None
|
|
# NB: t.ctx could be None if the subclass in question has no
|
|
# meaningful context
|
|
|
|
# Note: transform_subclass will use __tensor_unflatten__ to generate
|
|
# a fresh subclass wrapper with outer sizes / strides according to the
|
|
# outer symbolic context (passed in to this function). Inner size / stride
|
|
# / storage offset symbols are allocated according to the appropriate inner
|
|
# symbolic contexts, after which the checks in transform_subclass() will
|
|
# relate them to the outer metadata as possible.
|
|
#
|
|
# Morally, the code here is same as transform_subclass, but we've
|
|
# written it from scratch to read EmptyCreateSubclass
|
|
outer_size = outer_size if outer_size is not None else t.size
|
|
outer_stride = outer_stride if outer_stride is not None else t.stride
|
|
|
|
assert symbolic_context is None or isinstance(
|
|
symbolic_context, SubclassSymbolicContext
|
|
)
|
|
|
|
def _empty_create_subclass(
|
|
t, outer_size, outer_stride, symbolic_context, callback, source
|
|
):
|
|
# We are hitting plain meta_desc tensor so actually
|
|
# create a tensor here.
|
|
if t.attrs is None:
|
|
r = callback(
|
|
lambda: empty_create(
|
|
t,
|
|
source,
|
|
symbolic_context,
|
|
)
|
|
)
|
|
if self.copy_data:
|
|
with torch.no_grad(), no_dispatch():
|
|
r.real_tensor = torch.empty_strided(
|
|
t.size,
|
|
t.stride,
|
|
dtype=t.dtype,
|
|
device=t.device,
|
|
)
|
|
assert t.data is not None
|
|
_safe_copy(r.real_tensor, t.data)
|
|
return r
|
|
|
|
inner_tensors = {}
|
|
for attr, meta_tensor_desc in t.attrs.items():
|
|
current_context = None
|
|
if symbolic_context is not None:
|
|
current_context = symbolic_context.inner_contexts[attr]
|
|
|
|
current_source = AttrSource(source, attr)
|
|
new_empty_tensor = _empty_create_subclass(
|
|
meta_tensor_desc,
|
|
meta_tensor_desc.size,
|
|
meta_tensor_desc.stride,
|
|
current_context,
|
|
callback,
|
|
current_source,
|
|
)
|
|
inner_tensors[attr] = new_empty_tensor
|
|
|
|
return t.type.__tensor_unflatten__(
|
|
inner_tensors, t.ctx, outer_size, outer_stride
|
|
)
|
|
|
|
sub = _empty_create_subclass(
|
|
t, outer_size, outer_stride, symbolic_context, callback, source
|
|
)
|
|
|
|
# NB: Purposefully guard here to simplify the inner / outer symbols.
|
|
# Using sym_eq() for symbolic comparison can result in an expression that's too
|
|
# difficult to guard on, so we use == here.
|
|
assert sub.shape == outer_size, (
|
|
f"Expected return value from {t.type}__tensor_unflatten__() to have "
|
|
f"shape equal to {outer_size}, but got: {sub.shape}"
|
|
)
|
|
assert sub.stride() == outer_stride, (
|
|
f"Expected return value from {t.type}__tensor_unflatten__() to have "
|
|
f"stride equal to {outer_stride}, but got: {sub.stride()}"
|
|
)
|
|
|
|
return sub
|
|
|
|
# Returns an all-dynamic symbolic context used for metafying the given tensor with
|
|
# fully dynamic dims. This is useful when fake-ifying intermediate tensors in
|
|
# closed-over ViewFunc state, as we don't have symbolic contexts for them, but we
|
|
# don't want to over-specialize during view replay.
|
|
def all_dynamic_symbolic_context(
|
|
t: MetaTensorDesc, source, shape_env, callback
|
|
):
|
|
from torch._dynamo.source import AttrSource
|
|
from torch.fx.experimental.symbolic_shapes import (
|
|
DimDynamic,
|
|
StatelessSymbolicContext,
|
|
SubclassSymbolicContext,
|
|
)
|
|
|
|
view_base_context: Optional[SymbolicContext] = None
|
|
if t.is_view:
|
|
assert t.base is not None
|
|
view_base_context = all_dynamic_symbolic_context(
|
|
t.base, AttrSource(source, "_base"), shape_env, callback
|
|
)
|
|
|
|
t_symbolic_context: SymbolicContext
|
|
t_dynamic_sizes = [DimDynamic.DYNAMIC] * t.ndim
|
|
if t.is_traceable_wrapper_subclass:
|
|
assert t.attrs is not None
|
|
inner_contexts: Dict[str, SymbolicContext] = {}
|
|
for attr, inner in t.attrs.items():
|
|
assert isinstance(attr, str)
|
|
inner_contexts[attr] = all_dynamic_symbolic_context(
|
|
inner, AttrSource(source, attr), shape_env, callback
|
|
)
|
|
t_symbolic_context = SubclassSymbolicContext(
|
|
dynamic_sizes=t_dynamic_sizes,
|
|
constraint_sizes=[None] * t.ndim,
|
|
inner_contexts=inner_contexts, # type: ignore[arg-type]
|
|
tensor_source=source,
|
|
view_base_context=view_base_context,
|
|
)
|
|
else:
|
|
t_symbolic_context = StatelessSymbolicContext(
|
|
dynamic_sizes=t_dynamic_sizes,
|
|
constraint_sizes=[None] * t.ndim,
|
|
view_base_context=view_base_context,
|
|
)
|
|
|
|
return t_symbolic_context
|
|
|
|
# Returns a fake-ified version of an input view tensor t, given an already fake-ified
|
|
# base. At a high level, we want two things:
|
|
# 1. fake_t should have the same view relationship to the given fake base as the
|
|
# input t has to its _base.
|
|
# 2. fake_t should have symbolic sizes / strides / storage offset according to the
|
|
# appropriate symbolic context (i.e. from the automatic dynamic algorithm).
|
|
#
|
|
# We currently take different strategies across view types:
|
|
# * For dense -> dense views, accomplish both (1) and (2) simultaneously via an
|
|
# as_strided() call on the fake-ified base, passing symbolic metadata.
|
|
# * For views involving subclasses, perform view replay using view funcs to
|
|
# achieve (1). It's necessary for (2) to swap out any closed-over state in
|
|
# the view funcs with symbolicized SymInts and fake-ified tensors. Doing this
|
|
# avoids specialization (and thus over-eager simplification of symbols) that
|
|
# could occur during view replay on the fake-ified base.
|
|
#
|
|
# Examples:
|
|
# * t.unsqueeze(-1) with dense t is a dense -> dense view. It can be modeled
|
|
# with an as_strided() call on the fake base passing symbolic metadata.
|
|
# * sub.select(dim=0, index=3) is a subclass -> subclass view. The index arg
|
|
# is made symbolic to avoid invalid specialization and view replay is then
|
|
# done to reconstruct the view.
|
|
# * _nested_from_jagged(values, offsets) is a dense -> subclass view
|
|
# that returns a subclass instance from a dense values tensor. The offsets
|
|
# tensor is closed over in the view func, as it can be considered view metadata.
|
|
# First, the offsets tensor is fake-ified according to the inner symbolic
|
|
# context and with the correct relationship to the outer size / stride metadata.
|
|
# Then view replay is done, swapping in the fake offsets so the view replay output
|
|
# is fully fake with no invalid specialization.
|
|
def view_from_base(
|
|
base: torch.Tensor, t: MetaTensorDesc, source=source, shape_env=shape_env
|
|
):
|
|
# fake-ify t's metadata according to the outer symbolic context
|
|
(sizes, strides, storage_offset) = sym_sizes_strides_storage_offset(
|
|
t, source
|
|
)
|
|
if (
|
|
not t.is_traceable_wrapper_subclass
|
|
and not is_traceable_wrapper_subclass(base)
|
|
):
|
|
# Dense -> Dense view case uses as_strided() to construct view relationship.
|
|
# TODO: Change this logic to use view replay for consistency?
|
|
# It's likely there is no view func available.
|
|
with maybe_suppress():
|
|
return base.as_strided(sizes, strides, storage_offset)
|
|
|
|
from torch._dynamo.source import EphemeralSource
|
|
from torch.fx.experimental.symbolic_shapes import (
|
|
StatelessSymbolicContext,
|
|
sym_eq,
|
|
)
|
|
|
|
def symint_visitor_fn(s):
|
|
nonlocal symbolic_context
|
|
from torch.fx.experimental.symbolic_shapes import DimDynamic
|
|
|
|
all_static_sizes = (
|
|
symbolic_context is not None
|
|
and isinstance(symbolic_context, StatelessSymbolicContext)
|
|
and all(
|
|
x is DimDynamic.STATIC for x in symbolic_context.dynamic_sizes
|
|
)
|
|
)
|
|
# Can't just rely on shape env being None - dynamo always initializes it
|
|
if all_static_sizes or shape_env is None:
|
|
return s
|
|
|
|
# NB: The symbol here is expected to be simplified out because we a priori
|
|
# allocate inner and outer symbols according to the appropriate symbolic
|
|
# contexts and prefer those over this symbol during symbol simplification
|
|
# (via usage of EphemeralSource below). This -shouldn't- happen, but if
|
|
# this symbol somehow leaks out beyond the view tensor's shape metadata, our
|
|
# assumption of it being simplified out will fail and it may be guarded on,
|
|
# which will hard error.
|
|
sym_source = EphemeralSource("symint_visitor_fn")
|
|
|
|
symbol = shape_env.create_symbol(s, sym_source, positive=None)
|
|
return shape_env.create_symintnode(symbol, hint=s, source=sym_source)
|
|
|
|
real_to_fake_mapping = {}
|
|
if t.is_traceable_wrapper_subclass:
|
|
assert t.attrs is not None
|
|
# NB: t.ctx could be None if the subclass in question has no
|
|
# meaningful context
|
|
assert t.type is not None
|
|
|
|
# Fake-ify t naively here; this is only done so we can get fake-ified inner
|
|
# tensors with the correct relationships to the outer sizes / strides for use
|
|
# in view replay. It's done beforehand here because it's not easy to do when
|
|
# visiting tensors one-by-one during view replay.
|
|
#
|
|
# Example:
|
|
# Consider a Dense -> NJT view. NJT has (values, offsets) components and we
|
|
# want a view of values with the offsets closed over. As the offsets component
|
|
# is needed to describe the output view, it's important that it's fakeified
|
|
# correctly.
|
|
fake_t = empty_create_subclass(
|
|
t, outer_size=sizes, outer_stride=strides
|
|
)
|
|
attrs, _ = fake_t.__tensor_flatten__()
|
|
for attr in attrs:
|
|
real_to_fake_mapping[t.attrs[attr].id] = getattr(fake_t, attr)
|
|
|
|
def tensor_visitor_fn(
|
|
visited_t: torch.Tensor,
|
|
# These arguments are never passed, we just use them to close
|
|
# over these relevant values
|
|
shape_env=shape_env,
|
|
callback=callback,
|
|
):
|
|
# It's possible to close over an undefined tensor (e.g. NJT's lengths).
|
|
if visited_t is None:
|
|
return None
|
|
|
|
# NB: visited_t being a Tensor here is very naughty! Should
|
|
# have already been described
|
|
|
|
# Fake inner tensors of view subclasses will come from the mapping built above.
|
|
visited_id = self.describer.get_tensor_id(visited_t)
|
|
fake_visited_t = real_to_fake_mapping.get(visited_id, None)
|
|
if fake_visited_t is not None:
|
|
return fake_visited_t
|
|
|
|
visited_desc = self.describer.describe_tensor(visited_t)
|
|
|
|
# For other closed-over tensor state, fake-ify it as all dynamic with an
|
|
# ephemeral source. This avoids invalid specialization during view replay.
|
|
# If we find that in practice the usage of ephemeral sources isn't enough
|
|
# to guarantee that we don't have guards on these symbols, we may need to
|
|
# explicitly suppress guards (as is done for _base in the dense -> dense
|
|
# view case).
|
|
temp_source = EphemeralSource("tensor_visitor_fn")
|
|
return self.meta_tensor(
|
|
visited_desc,
|
|
shape_env,
|
|
callback,
|
|
source=temp_source,
|
|
symbolic_context=all_dynamic_symbolic_context(
|
|
visited_desc, temp_source, shape_env, callback
|
|
),
|
|
)
|
|
|
|
# Replay the view, swapping out any non-symbolic SymInts or real tensors
|
|
# for symbolic SymInts or fake tensors.
|
|
assert t.view_func is not None
|
|
# NB: we do NOT suppress guards here, we need to remove ephemeral
|
|
# sources
|
|
fake_t = t.view_func(base, symint_visitor_fn, tensor_visitor_fn)
|
|
|
|
# Ensure the output has symbolic shapes according to the outer symbolic context.
|
|
# These checks should simplify out any symbols created for closed-over view func
|
|
# SymInts.
|
|
torch._check(sym_eq(fake_t.size(), sizes))
|
|
torch._check(sym_eq(fake_t.stride(), strides))
|
|
torch._check(sym_eq(fake_t.storage_offset(), storage_offset))
|
|
return fake_t
|
|
|
|
if self.get_tensor_memo(t) is None:
|
|
GRAD_TENSOR_SENTINEL_VALUE = -2
|
|
|
|
with torch.inference_mode(t.is_inference):
|
|
if t.is_sparse:
|
|
is_leaf = t.is_leaf
|
|
|
|
# The lambda function below is similar to
|
|
# `t.to(device='meta')` except the latter
|
|
# preserves nnz value
|
|
r = callback(
|
|
lambda: torch.ops.aten._sparse_coo_tensor_with_dims(
|
|
t.sparse_dim,
|
|
t.dense_dim,
|
|
t.size,
|
|
dtype=t.dtype,
|
|
layout=torch.sparse_coo,
|
|
device="meta",
|
|
)
|
|
)
|
|
if self.copy_data:
|
|
# Pray that sparse clone doesn't lose information
|
|
assert t.data is not None
|
|
with torch.no_grad(), no_dispatch():
|
|
r.real_tensor = _safe_clone(t.data)
|
|
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
|
|
# Note [is_coalesced is dispatched]
|
|
# Strangely enough, is_coalesced() is a dispatched operator,
|
|
# which means that it will get caught by fake tensor mode.
|
|
# Ordinarily this would error, but there's some logic in
|
|
# fake tensor ensure this doesn't happen.
|
|
r._coalesced_(t.is_coalesced)
|
|
if t.requires_grad:
|
|
r.requires_grad = True
|
|
if t.requires_grad and not is_leaf:
|
|
# This should probably use DelayedError,
|
|
# but clone is fine for now for sparse tensors.
|
|
# (DelayedError does not work for sparse because it causes
|
|
# the Fake sparse tensor to "lose" its fakeness)
|
|
r = r.clone()
|
|
with torch.enable_grad():
|
|
r._coalesced_(t.is_coalesced)
|
|
elif is_sparse_compressed_layout(t.layout):
|
|
is_leaf = t.is_leaf
|
|
|
|
if t.layout in {torch.sparse_bsr, torch.sparse_bsc}:
|
|
assert t.sparse_dim is not None
|
|
assert t.dense_dim is not None
|
|
assert t.values is not None
|
|
batch_dim = t.ndim - t.sparse_dim - t.dense_dim
|
|
blocksize = t.values.shape[batch_dim + 1 : batch_dim + 3]
|
|
else:
|
|
blocksize = ()
|
|
if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
|
|
assert t.crow_indices is not None
|
|
index_dtype = t.crow_indices.dtype
|
|
else:
|
|
assert t.ccol_indices is not None
|
|
index_dtype = t.ccol_indices.dtype
|
|
|
|
r = callback(
|
|
lambda: torch.ops.aten._sparse_compressed_tensor_with_dims(
|
|
0,
|
|
t.dense_dim,
|
|
t.shape,
|
|
blocksize,
|
|
index_dtype,
|
|
layout=t.layout,
|
|
dtype=t.dtype,
|
|
device="meta",
|
|
)
|
|
)
|
|
if self.copy_data:
|
|
# Pray sparse clone doesn't lose information
|
|
assert t.data is not None
|
|
with torch.no_grad(), no_dispatch():
|
|
r.real_tensor = _safe_clone(t.data)
|
|
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
|
|
if t.requires_grad:
|
|
r.requires_grad = True
|
|
if t.requires_grad and not is_leaf:
|
|
r = torch._C._functions.DelayedError(
|
|
"Internal error: Tried to backward() through example input",
|
|
1,
|
|
)(r)
|
|
elif t.is_nested and not t.is_traceable_wrapper_subclass:
|
|
# TODO: Handle this better in Dynamo?
|
|
# There are checks there now, but this can still be triggered by a dense
|
|
# tensor graph input that is a view of a strided NT.
|
|
from torch._dynamo.exc import unimplemented
|
|
|
|
unimplemented(
|
|
"strided nested tensors are not supported by meta conversion"
|
|
)
|
|
elif t.is_mkldnn:
|
|
is_leaf = t.is_leaf
|
|
sizes, strides, _storage_offset = sym_sizes_strides_storage_offset(
|
|
t, source
|
|
)
|
|
# TODO: This doesn't seem right, where's the MKLDNN'ness
|
|
# lol
|
|
r = callback(
|
|
lambda: torch.empty_strided(
|
|
sizes, strides, dtype=t.dtype, device="meta"
|
|
)
|
|
)
|
|
if self.copy_data:
|
|
with torch.no_grad(), no_dispatch():
|
|
assert t.size is not None
|
|
assert t.stride is not None
|
|
r.real_tensor = torch.empty_strided(
|
|
t.size, t.stride, dtype=t.dtype, device=t.device
|
|
)
|
|
assert t.data is not None
|
|
_safe_copy(r.real_tensor, t.data)
|
|
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
|
|
if t.requires_grad:
|
|
r.requires_grad = True
|
|
if t.requires_grad and not is_leaf:
|
|
r = torch._C._functions.DelayedError(
|
|
"Internal error: Tried to backward() through example input",
|
|
1,
|
|
)(r)
|
|
elif t.is_functorch_wrapped:
|
|
if t.is_view:
|
|
from torch._dynamo.exc import unimplemented
|
|
|
|
unimplemented(
|
|
"view functorch tensors are not supported by meta conversion"
|
|
)
|
|
|
|
# Wraps a functorch tensor class (BatchedTensor, GradTrackingTensor)
|
|
# in a FakeTensor
|
|
def _to_fake_tensor(t: MetaTensorDesc):
|
|
# TODO: why aren't the recursive calls going to
|
|
# meta_tensor
|
|
if t.is_batchedtensor:
|
|
assert t.unwrapped is not None
|
|
assert t.level is not None
|
|
assert t.bdim is not None
|
|
ft = _to_fake_tensor(t.unwrapped)
|
|
lvl = t.level
|
|
bdim = t.bdim
|
|
# You cannot create functorch tensors without
|
|
# having the ambient funtorch interpreter stack
|
|
# available, as the level refers to things in the
|
|
# stack
|
|
with torch._functorch.pyfunctorch.temporarily_restore_interpreter_stack(
|
|
t.functorch_stack
|
|
):
|
|
r = _add_batch_dim(ft, bdim, lvl)
|
|
elif t.is_gradtrackingtensor:
|
|
assert t.unwrapped is not None
|
|
assert t.level is not None
|
|
disable_functorch = torch._C._DisableFuncTorch
|
|
with disable_functorch():
|
|
ft = _to_fake_tensor(t.unwrapped)
|
|
lvl = t.level
|
|
if lvl == GRAD_TENSOR_SENTINEL_VALUE:
|
|
r = ft
|
|
else:
|
|
with torch._functorch.pyfunctorch.temporarily_restore_interpreter_stack(
|
|
t.functorch_stack
|
|
):
|
|
r = torch._C._functorch._wrap_for_grad(ft, lvl)
|
|
|
|
is_leaf = t.is_leaf
|
|
if t.requires_grad and safe_is_leaf(r):
|
|
r.requires_grad = True
|
|
elif t.requires_grad and not is_leaf:
|
|
r = torch._C._functions.DelayedError( # type: ignore[assignment]
|
|
"Internal error: Tried to backward() through example input",
|
|
1,
|
|
)(
|
|
r # type: ignore[arg-type]
|
|
)
|
|
elif t.is_functional:
|
|
assert t.unwrapped is not None
|
|
assert t.current_level is not None
|
|
ft = self.meta_tensor(
|
|
t.unwrapped,
|
|
shape_env=shape_env,
|
|
callback=callback,
|
|
# NB: reuse these exactly, we treat the
|
|
# functional tensor as "invisible".
|
|
# TODO: Actually this all probably doesn't
|
|
# work, take a closer look.
|
|
source=source,
|
|
symbolic_context=symbolic_context,
|
|
)
|
|
r = _wrap_functional_tensor(ft, t.current_level)
|
|
# TODO: is_leaf/requires_grad?
|
|
else:
|
|
assert t.stride is not None
|
|
|
|
sizes = t.size
|
|
strides = t.stride
|
|
r = callback(
|
|
lambda: torch.empty_strided(
|
|
sizes,
|
|
strides,
|
|
dtype=t.dtype,
|
|
device="meta",
|
|
)
|
|
)
|
|
if self.copy_data:
|
|
with torch.no_grad(), no_dispatch():
|
|
r.real_tensor = torch.empty_strided( # type: ignore[attr-defined]
|
|
t.size,
|
|
t.stride,
|
|
dtype=t.dtype,
|
|
device=t.device,
|
|
)
|
|
assert t.data is not None
|
|
_safe_copy(r.real_tensor, t.data) # type: ignore[attr-defined]
|
|
return r
|
|
|
|
r = _to_fake_tensor(t)
|
|
|
|
elif t.is_functional and t.device.type not in ["xla", "lazy"]:
|
|
assert t.unwrapped is not None
|
|
assert not t.is_functorch_wrapped # handled above
|
|
unwrapped = self.meta_tensor(
|
|
t.unwrapped,
|
|
shape_env=shape_env,
|
|
callback=callback,
|
|
source=source,
|
|
symbolic_context=symbolic_context,
|
|
)
|
|
r = torch._to_functional_tensor(unwrapped)
|
|
torch._mirror_autograd_meta_to(t.autograd_meta_from, r) # type: ignore[attr-defined]
|
|
|
|
elif t.is_view:
|
|
# Construct views in two steps: recursively meta-fy their
|
|
# base, and then create view(s) off that. NB: doing it
|
|
# directly from storage is WRONG because this won't cause
|
|
# version counters to get shared.
|
|
|
|
assert t.base is not None
|
|
|
|
base_symbolic_context = None
|
|
if shape_env and symbolic_context is not None:
|
|
from torch.fx.experimental.symbolic_shapes import (
|
|
StatelessSymbolicContext,
|
|
)
|
|
|
|
assert isinstance(symbolic_context, StatelessSymbolicContext)
|
|
# NB: This should generally be set when the input is a view,
|
|
# but the exception right now is for fake-ifying grads, which is
|
|
# a work in progress.
|
|
if symbolic_context.view_base_context is not None:
|
|
base_symbolic_context = symbolic_context.view_base_context
|
|
|
|
base = self.meta_tensor(
|
|
t.base,
|
|
shape_env,
|
|
callback,
|
|
source=torch._dynamo.source.AttrSource(source, "_base"),
|
|
symbolic_context=base_symbolic_context,
|
|
)
|
|
|
|
def is_c_of_r(complex_dtype, real_dtype):
|
|
return (
|
|
utils.is_complex_dtype(complex_dtype)
|
|
and utils.corresponding_real_dtype(complex_dtype)
|
|
== real_dtype
|
|
)
|
|
|
|
# In some situations, MetaConverter may be called in a
|
|
# context where autograd is disabled. For the _is_view
|
|
# assert to pass, we have to setup the autograd view
|
|
# metadata anyway. Do this by reenabling the
|
|
# ADInplaceOrView key. This is kind of a hack.
|
|
old_exclude = torch._C._dispatch_tls_is_dispatch_key_excluded(
|
|
torch._C.DispatchKey.ADInplaceOrView
|
|
)
|
|
torch._C._dispatch_tls_set_dispatch_key_excluded(
|
|
torch._C.DispatchKey.ADInplaceOrView, False
|
|
)
|
|
try:
|
|
if base.dtype == t.dtype:
|
|
pass
|
|
elif is_c_of_r(base.dtype, t.dtype):
|
|
base = torch.view_as_real(base)
|
|
elif is_c_of_r(t.dtype, base.dtype):
|
|
base = torch.view_as_complex(base)
|
|
else:
|
|
# This is not guaranteed to succeed. If it fails, it
|
|
# means there is another dtype-converting view function
|
|
# that hasn't been handled here
|
|
base = base.view(t.dtype)
|
|
|
|
# This is very tricky. Naively, you might expect this
|
|
# to hold:
|
|
#
|
|
# if t.requires_grad and not safe_is_leaf(t)
|
|
# assert t._base.requires_grad
|
|
#
|
|
# But it's not true! As you can see in the following
|
|
# program:
|
|
#
|
|
# x = torch.zeros(4)
|
|
# y = x.view(1, 4)
|
|
# y.requires_grad = True
|
|
# z = y.view(1, 1, 4)
|
|
# assert z._base is x
|
|
#
|
|
# So we may have to do *two* views out of the base to
|
|
# recreate this situation.
|
|
if t.is_leaf:
|
|
# Leaf views that track view metadata are created by
|
|
# creating a view inside a no_grad block
|
|
with torch.no_grad():
|
|
r = view_from_base(base, t)
|
|
# As it's a leaf, we can directly assign requires_grad
|
|
r.requires_grad = t.requires_grad
|
|
else:
|
|
if t.base.requires_grad == t.requires_grad:
|
|
# Easy case, just run the view op
|
|
with torch.enable_grad():
|
|
r = view_from_base(base, t)
|
|
|
|
# NB: We don't actaully faithfully replicate
|
|
# autograd connectivity, but that doesn't matter
|
|
# today. See following for more info:
|
|
# https://gist.github.com/soulitzer/e03f015b314c3f5fcf80888c69390913
|
|
else:
|
|
# Obscure case. Create a leaf view and give it the
|
|
# correct requires_grad, then do the final view.
|
|
# NB: Can't have a non-leaf without requiring grad!
|
|
assert t.requires_grad
|
|
with torch.no_grad():
|
|
mid = base.view(base.shape)
|
|
mid.requires_grad = t.requires_grad
|
|
with torch.enable_grad():
|
|
r = view_from_base(mid, t)
|
|
# The CreationMeta influences whether or not inplace
|
|
# mutation is an error or not. So we need to make
|
|
# sure we properly propagate this as well.
|
|
assert t.creation_meta is not None
|
|
torch._C._autograd._set_creation_meta(r, t.creation_meta)
|
|
finally:
|
|
torch._C._dispatch_tls_set_dispatch_key_excluded(
|
|
torch._C.DispatchKey.ADInplaceOrView, old_exclude
|
|
)
|
|
|
|
else:
|
|
is_leaf = t.is_leaf
|
|
|
|
# Graph-Break for wrapped tensors
|
|
if (
|
|
not (t.is_batchedtensor or t.is_gradtrackingtensor)
|
|
and t.is_functorch_wrapped
|
|
) or t.is_legacy_batchedtensor:
|
|
return NotImplemented
|
|
|
|
(
|
|
sizes,
|
|
strides,
|
|
storage_offset,
|
|
) = sym_sizes_strides_storage_offset(t, source, symbolic_context)
|
|
|
|
# If we have a subclass that desugars into dense tensors,
|
|
# perform our callback on each inner tensor.
|
|
if t.is_traceable_wrapper_subclass:
|
|
r = empty_create_subclass(
|
|
t, outer_size=sizes, outer_stride=strides
|
|
)
|
|
else:
|
|
r = callback(
|
|
lambda: torch.empty_strided(
|
|
sizes,
|
|
strides,
|
|
dtype=t.dtype,
|
|
device="meta",
|
|
)
|
|
)
|
|
if self.copy_data:
|
|
with torch.no_grad(), no_dispatch():
|
|
assert t.size is not None
|
|
assert t.stride is not None
|
|
r.real_tensor = torch.empty_strided(
|
|
t.size, t.stride, dtype=t.dtype, device=t.device
|
|
)
|
|
_safe_copy(r.real_tensor, t.data)
|
|
|
|
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
|
|
if t.requires_grad:
|
|
r.requires_grad = t.requires_grad
|
|
if not is_leaf:
|
|
# Fake up some autograd history.
|
|
# Note: we *used* to call .clone() here to mock up some autograd history.
|
|
# This is bad for subclasses.
|
|
# Consider the case where you have a wrapper subclass that is contiguous,
|
|
# but its inner tensor is noncontiguous().
|
|
# .clone() (or other ops) will have the side effect of changing
|
|
# the metadata of the inner tensor.
|
|
# So instead, we now have a dedicated fn to set autograd history,
|
|
# without inadvertently changing other metadata.
|
|
r = torch._C._functions.DelayedError(
|
|
"Internal error: Tried to backward() through example input",
|
|
1,
|
|
)(r)
|
|
|
|
s = t.storage
|
|
assert s is not None
|
|
if s.id not in self.storage_memo and (
|
|
r.is_nested
|
|
or (
|
|
r.stride() == strides
|
|
and r.storage_offset() == storage_offset
|
|
)
|
|
):
|
|
# You're normal and happy, install the fresh storage into the memo
|
|
self.set_storage_memo(s, r.untyped_storage())
|
|
if self.copy_data:
|
|
r.untyped_storage().real_storage = (
|
|
r.real_tensor.untyped_storage()
|
|
)
|
|
else:
|
|
# You're in crazy town; somehow you gave us a tensor
|
|
# that wasn't a view, but had nonzero storage offset,
|
|
# nontrivial strides (such that clone() couldn't
|
|
# preserve them), or already aliases with another
|
|
# tensor's storage. The most typical way to end
|
|
# up here is with set_. So use set_ to bludgeon this
|
|
# in.
|
|
r_s = self.meta_storage(s, callback=callback)
|
|
# NB: In principle, this should always work, but there
|
|
# is some subtle difference in the autograd metadata
|
|
# that means we will backprop the set_ call, even if
|
|
# r is declared as an input to grad.
|
|
# See https://github.com/pytorch/pytorch/issues/87956
|
|
# for the reproducer.
|
|
# NB: The in_kernel_invocation_manager here is necessary
|
|
# for fake tensor. If we run the set_ call with fake
|
|
# tensor on, r will improperly report that it is NOT a
|
|
# meta tensor but a cpu tensor, and then the set_ call
|
|
# will fail due to device mismatch. no_dispatch() is
|
|
# not enough, because the fake tensor will still claim
|
|
# to be a CPU tensor and you'll end up in the CPU
|
|
# kernel. Arguably this is a hack; a cleaner way to
|
|
# solve this is to have a FakeStorage concept which
|
|
# would report it's CPU device--no problem now! But
|
|
# this is difficult to do because we don't have storage
|
|
# subclasses. Relevant test is
|
|
# DynamicShapesFunctionTests::test_add_dynamic_shapes in
|
|
# test/dynamo/test_dynamic_shapes.py
|
|
maybe_fake_mgr: ContextManager[None] = contextlib.nullcontext()
|
|
from torch._subclasses.fake_tensor import (
|
|
in_kernel_invocation_manager,
|
|
maybe_get_fake_mode,
|
|
)
|
|
|
|
mb_fake_mode = maybe_get_fake_mode(r)
|
|
if mb_fake_mode is not None:
|
|
maybe_fake_mgr = in_kernel_invocation_manager(mb_fake_mode)
|
|
with torch.no_grad(), maybe_suppress():
|
|
with maybe_fake_mgr:
|
|
r.set_(r_s, storage_offset, sizes, strides)
|
|
if self.copy_data:
|
|
with torch.no_grad(), no_dispatch():
|
|
r.real_tensor.set_(
|
|
r_s.real_storage,
|
|
t.storage_offset,
|
|
t.size,
|
|
t.stride,
|
|
)
|
|
|
|
if t.grad is not None:
|
|
from torch._dynamo.source import AttrSource
|
|
|
|
# TODO: Use a valid grad-specific symbolic context instead of recycling
|
|
# the one from t. This isn't correct if e.g. t._is_view() != t.grad._is_view().
|
|
r.grad = self.meta_tensor(
|
|
t.grad,
|
|
shape_env,
|
|
callback,
|
|
source=AttrSource(source, "grad"),
|
|
symbolic_context=symbolic_context,
|
|
)
|
|
torch._C._set_conj(r, t.is_conj)
|
|
torch._C._set_neg(r, t.is_neg)
|
|
# This can be skipped if necessary for performance reasons
|
|
skip_leaf = (
|
|
t.is_gradtrackingtensor and t.level == GRAD_TENSOR_SENTINEL_VALUE
|
|
)
|
|
assert_metadata_eq(assert_eq, t, r, skip_symbolic=True, skip_leaf=skip_leaf)
|
|
# Thanks to storage resizing, it's possible to end up with a tensor
|
|
# that advertises a real size, but has a storage that actually has zero bytes.
|
|
# Need to reflect this in the generated FakeTensor.
|
|
if t.storage is not None and t.storage.size == 0:
|
|
r.untyped_storage().resize_(0)
|
|
|
|
if t.is_parameter:
|
|
r._is_param = True
|
|
|
|
self.set_tensor_memo(t, r)
|
|
|
|
return self.get_tensor_memo(t)
|
|
|
|
def __call__(
|
|
self,
|
|
t,
|
|
shape_env=None,
|
|
*,
|
|
callback=lambda t: t(),
|
|
source=None,
|
|
symbolic_context=None,
|
|
# Controls whether or not we should dump the tensor metadata to structured logs
|
|
# when source is not None. Because we refakify after Dynamo is done,
|
|
# we don't want to dump info again from AOTAutograd, it is redundant.
|
|
trace=True,
|
|
):
|
|
# TODO: zero tensors? We appear to have eliminated them by
|
|
# excluding complex for now
|
|
|
|
# Filter out cases we don't support
|
|
# TODO: This can probably be simplified quite a bit
|
|
if isinstance(t, torch.Tensor):
|
|
if (
|
|
# Lazy tensors are not supported. Note that XLA is
|
|
# implemented on top of lazy tensor, not excluded here; we
|
|
# have some special handling for it; this is for XLA Dynamo
|
|
# integration
|
|
t.device.type == "lazy"
|
|
or
|
|
# Quantization is not supported
|
|
t.is_quantized
|
|
or
|
|
# Views out of sparse tensors not currently supported (plain
|
|
# sparse is supported htough)
|
|
(t._is_view() and t._base is not None and t._base.is_sparse)
|
|
):
|
|
self.miss += 1
|
|
return NotImplemented
|
|
else:
|
|
self.hit += 1
|
|
elif torch.overrides.is_tensor_like(t):
|
|
self.miss += 1
|
|
return NotImplemented
|
|
else:
|
|
# non-Tensor types don't count as hit or miss
|
|
return t
|
|
|
|
if source is None:
|
|
trace = False
|
|
|
|
# Describe the tensor. NB: do NOT disable ambient modes, we may need
|
|
# to query them when figuring out what to put in here
|
|
t_desc = self.describer.describe_tensor(t, trace=trace)
|
|
|
|
if trace:
|
|
trace_structured(
|
|
"describe_source",
|
|
metadata_fn=lambda: {
|
|
"describer_id": self.describer.id,
|
|
"id": t_desc.id,
|
|
"source": source.name(),
|
|
},
|
|
)
|
|
|
|
# Do the meta-fication. Here, we disable all the ambient modes, to
|
|
# better simulate what would be like to re-fakeify from a fresh
|
|
# process
|
|
with contextlib.ExitStack() as exit_stack:
|
|
exit_stack.enter_context(torch._dispatch.python.suspend_functionalization())
|
|
st = peek_interpreter_stack()
|
|
if st is not None:
|
|
exit_stack.enter_context(
|
|
torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack()
|
|
)
|
|
|
|
r = self.meta_tensor(
|
|
t_desc,
|
|
shape_env=shape_env,
|
|
callback=callback,
|
|
source=source,
|
|
symbolic_context=symbolic_context,
|
|
)
|
|
|
|
if type(t) is torch.nn.Parameter:
|
|
# NB: Cannot directly use Parameter constructor
|
|
# because that would force a detach, not desirable
|
|
r._is_param = True
|
|
|
|
# TODO: return the description for later
|
|
return r
|
|
|
|
|
|
import torch._prims_common as utils
|