type annotations for meta_utils (#140203)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140203
Approved by: https://github.com/ezyang
This commit is contained in:
Aaron Orenstein
2024-11-12 21:03:55 -08:00
committed by PyTorch MergeBot
parent c25999bdc0
commit 82597d07aa
4 changed files with 309 additions and 160 deletions

View File

@ -1,8 +1,8 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import contextlib
import dataclasses
import typing
import warnings
import weakref
from dataclasses import dataclass
@ -12,14 +12,18 @@ from typing import (
ClassVar,
ContextManager,
Dict,
Generic,
List,
NewType,
Optional,
Set,
Tuple,
Type,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing_extensions import TypeAlias, TypeGuard
from typing_extensions import TypeGuard
import torch
from torch._C._autograd import CreationMeta
@ -47,16 +51,17 @@ if TYPE_CHECKING:
from torch._guards import Source
# Import here to avoid cycle
from torch._subclasses.fake_tensor import FakeTensorMode
# Import the following modules during type checking to enable code intelligence features,
# Do not import unconditionally, as they import sympy and importing sympy is very slow
from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymbolicContext
DimList = List
_TensorLikeT = TypeVar("_TensorLikeT", "MetaTensorDesc", torch.Tensor)
_T = TypeVar("_T")
_TensorT = TypeVar("_TensorT", bound=torch.Tensor)
def safe_is_leaf(t):
def safe_is_leaf(t: Union[MetaTensorDesc, torch.Tensor]) -> bool:
try:
return t.is_leaf
except RuntimeError:
@ -64,28 +69,37 @@ def safe_is_leaf(t):
return False
def safe_grad(t):
def safe_grad(t: _TensorLikeT) -> Optional[_TensorLikeT]:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "The .grad attribute of a Tensor")
return t.grad
def assert_eq(a, b):
def _expect_safe_grad(t: _TensorLikeT) -> _TensorLikeT:
grad = safe_grad(t)
assert grad is not None
return grad
def assert_eq(a: _T, b: _T) -> None:
assert a == b, f"{a} != {b}"
def assert_metadata_eq(
assert_eq,
assert_eq: Callable[[object, object], None],
m1: Union[MetaTensorDesc, torch.Tensor],
m2: torch.Tensor,
*,
skip_symbolic=False,
skip_leaf=False,
):
if isinstance(m1, torch.Tensor):
m1 = MetaTensorDescriber().describe_tensor(m1)
skip_symbolic: bool = False,
skip_leaf: bool = False,
) -> None:
m1 = (
MetaTensorDescriber().describe_tensor(m1)
if isinstance(m1, torch.Tensor)
else m1
)
def go(m1, m2):
def go(m1: MetaTensorDesc, m2: torch.Tensor) -> None:
assert_eq(m1.dtype, m2.dtype)
if not skip_symbolic:
assert_eq(m1.shape, m2.shape)
@ -100,7 +114,7 @@ def assert_metadata_eq(
assert_eq(m1.is_neg, m2.is_neg())
assert_eq(m1.grad is not None, safe_grad(m2) is not None)
if m1.grad is not None:
go(m1.grad, safe_grad(m2))
go(m1.grad, _expect_safe_grad(m2))
# TODO: move "assert_eq(m1.layout, m2.layout)" out of sparse
# branches (but not ready for prime time yet)...
if m1.is_sparse:
@ -118,6 +132,8 @@ def assert_metadata_eq(
assert_eq(m1.storage_offset, m2.storage_offset())
assert_eq(m1.is_view, m2._is_view())
if m1.is_view:
assert m1.base is not None
assert m2._base is not None
go(m1.base, m2._base)
# TODO: test if is resizable (no direct query for this atm)
# TODO: audit AutogradMeta to see if it matches
@ -126,11 +142,12 @@ def assert_metadata_eq(
return go(m1, m2)
def is_sparse_coo(t):
# TypeGuard (not TypeIs): False does not imply !torch.Tensor
def is_sparse_coo(t: object) -> TypeGuard[torch.Tensor]:
return isinstance(t, torch.Tensor) and t.layout is torch.sparse_coo
def is_sparse_compressed_layout(layout):
def is_sparse_compressed_layout(layout: torch.layout) -> bool:
return layout in {
torch.sparse_csr,
torch.sparse_csc,
@ -139,20 +156,38 @@ def is_sparse_compressed_layout(layout):
}
def is_sparse_compressed(t):
# TypeGuard (not TypeIs): False does not imply !torch.Tensor
def is_sparse_compressed(t: object) -> TypeGuard[torch.Tensor]:
return isinstance(t, torch.Tensor) and is_sparse_compressed_layout(t.layout)
# TypeGuard (not TypeIs): False does not imply !torch.Tensor
def is_sparse_any(t: object) -> TypeGuard[torch.Tensor]:
return is_sparse_coo(t) or is_sparse_compressed(t)
def _checked_cast(ty: Type[_T], obj: object) -> _T:
assert isinstance(obj, ty), f"expected {ty} but got {type(obj)}"
return obj
def _get_real_storage(base: torch.UntypedStorage) -> torch.UntypedStorage:
return base.real_storage # type: ignore[attr-defined]
def _set_real_storage(
base: torch.UntypedStorage, real_storage: torch.UntypedStorage
) -> None:
base.real_storage = real_storage # type: ignore[attr-defined]
# Don't use id() directly, because those can get reallocated over time.
MetaStorageId: TypeAlias = int
MetaTensorId: TypeAlias = int
MetaStorageId = NewType("MetaStorageId", int)
MetaTensorId = NewType("MetaTensorId", int)
DESCRIBER_NEXT_ID = 0
_DescriberId = NewType("_DescriberId", int)
DESCRIBER_NEXT_ID = _DescriberId(0)
class MetaTensorDescriber:
@ -166,33 +201,35 @@ class MetaTensorDescriber:
the same ID when we see the same tensor/storage.
"""
def __init__(self, *, copy_data=False):
def __init__(self, *, copy_data: bool = False) -> None:
global DESCRIBER_NEXT_ID
self.id = DESCRIBER_NEXT_ID
DESCRIBER_NEXT_ID += 1
self.next_tensor_id: MetaTensorId = 0
self.next_storage_id: MetaStorageId = 0
DESCRIBER_NEXT_ID = _DescriberId(DESCRIBER_NEXT_ID + 1)
self.next_tensor_id: MetaTensorId = MetaTensorId(0)
self.next_storage_id: MetaStorageId = MetaStorageId(0)
# Tensor -> int
self.lookup_tensor = WeakIdKeyDictionary()
# Storage -> int
self.lookup_storage = WeakIdKeyDictionary()
self.copy_data = copy_data
self.traced_tensors = set()
self.traced_storages = set()
self.traced_tensors: Set[int] = set()
self.traced_storages: Set[int] = set()
def get_tensor_id(self, t: torch.Tensor):
def get_tensor_id(self, t: torch.Tensor) -> MetaTensorId:
if t not in self.lookup_tensor:
self.lookup_tensor[t] = self.next_tensor_id
self.next_tensor_id += 1
self.next_tensor_id = MetaTensorId(self.next_tensor_id + 1)
return self.lookup_tensor[t]
def get_storage_id(self, s: torch.UntypedStorage):
def get_storage_id(self, s: torch.UntypedStorage) -> MetaStorageId:
if s not in self.lookup_storage:
self.lookup_storage[s] = self.next_storage_id
self.next_storage_id += 1
self.next_storage_id = MetaStorageId(self.next_storage_id + 1)
return self.lookup_storage[s]
def describe_storage(self, s: torch.UntypedStorage, *, trace: bool = False):
def describe_storage(
self, s: torch.UntypedStorage, *, trace: bool = False
) -> MetaStorageDesc:
r = MetaStorageDesc(
id=self.get_storage_id(s),
size=s.size(),
@ -210,7 +247,7 @@ class MetaTensorDescriber:
def describe_tensor(
self, t: torch.Tensor, *, recurse: bool = True, trace: bool = False
):
) -> MetaTensorDesc:
is_leaf = safe_is_leaf(t)
is_view = t._is_view()
is_sparse = t.is_sparse
@ -381,8 +418,8 @@ class MetaTensorDescriber:
else None
),
grad=(
self.describe_tensor(safe_grad(t), trace=trace)
if safe_grad(t) is not None
self.describe_tensor(grad, trace=trace)
if (grad := safe_grad(t)) is not None
else None
),
creation_meta=(
@ -430,7 +467,7 @@ class MetaStorageDesc:
# serializable in JSON, you want to do something special here anyway
data: Optional[torch.UntypedStorage]
def as_json(self, describer_id):
def as_json(self, describer_id: _DescriberId) -> Dict[str, object]:
return {
"id": self.id,
"describer_id": describer_id,
@ -439,7 +476,7 @@ class MetaStorageDesc:
@dataclass(frozen=True)
class MetaTensorDesc:
class MetaTensorDesc(Generic[_TensorT]):
id: MetaTensorId
ndim: int
dtype: torch.dtype
@ -520,15 +557,15 @@ class MetaTensorDesc:
ctx: Optional[object] = None # is_traceable_wrapper_subclass
type: Optional[Type] = None # is_traceable_wrapper_subclass
fake_mode: Optional[FakeTensorMode] = None
fake_mode: Optional[torch._subclasses.fake_tensor.FakeTensorMode] = None
view_func: Optional[
Callable[
[
torch.Tensor,
Callable[[int], int],
Callable[[torch.Tensor], torch.Tensor],
Callable[[torch.Tensor], _TensorT],
],
torch.Tensor,
_TensorT,
]
] = None
# level looks serializable, but actually it is meaningless without
@ -555,8 +592,8 @@ class MetaTensorDesc:
# NB: This will reference numeric IDs, and it is assumed that you've
# already serialized everything this recursively references
def as_json(self, describer_id):
def json(k, v):
def as_json(self, describer_id: _DescriberId) -> Dict[str, object]:
def json(k: str, v: object) -> object:
# Some best-effort debugging serialization for unserializable
# fields (feel free to add other special cases as appropriate)
if k in ["data", "autograd_meta_from"]:
@ -592,7 +629,7 @@ class MetaTensorDesc:
return r
@property
def shape(self):
def shape(self) -> Tuple[int, ...]:
return self.size
@ -608,13 +645,13 @@ class MetaTensorDesc:
# FakeTensor as src, we MUST NOT run the copy/clone operation. A better way
# to do this would be to not use no_dispatch and instead just disable fake
# tensor mode only (allowing for subclass dispatch to occur)
def _safe_copy(dst, src):
def _safe_copy(dst: torch.Tensor, src: Optional[torch.Tensor]) -> None:
if type(src) is not torch.Tensor:
return
dst.copy_(src)
def _safe_clone(src):
def _safe_clone(src: torch.Tensor) -> Optional[torch.Tensor]:
if type(src) is not torch.Tensor:
return None
return src.clone()
@ -627,13 +664,17 @@ def _safe_clone(src):
# share storage because this is how we correlate shared storages to the same
# meta storages. This class will hold weak references to cached tenosrs
# and tensor storages.
class MetaConverter:
def __init__(self, *, copy_data: bool = False):
class MetaConverter(Generic[_TensorT]):
def __init__(self, *, copy_data: bool = False) -> None:
# Maps MetaStorageId to UntypedStorage
self.storage_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
self.storage_memo: weakref.WeakValueDictionary[
MetaStorageId, torch.UntypedStorage
] = weakref.WeakValueDictionary()
# Maps MetaTensorId to torch.Tensor (typically a meta tensor or
# FakeTensor)
self.tensor_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
self.tensor_memo: weakref.WeakValueDictionary[
MetaTensorId, _TensorT
] = weakref.WeakValueDictionary()
self.hit = 0
self.miss = 0
self.del_hook = None
@ -645,25 +686,34 @@ class MetaConverter:
self.copy_data = copy_data
self.describer = MetaTensorDescriber(copy_data=copy_data)
def successful(self):
def successful(self) -> bool:
return self.hit > 0 and self.miss == 0
def get_tensor_memo(self, t: MetaTensorDesc):
def get_tensor_memo(self, t: MetaTensorDesc) -> Optional[torch.Tensor]:
return self.tensor_memo.get(t.id, None)
def set_tensor_memo(self, t: MetaTensorDesc, v):
def _checked_get_tensor_memo(self, t: MetaTensorDesc) -> _TensorT:
r = self.tensor_memo.get(t.id, None)
assert r is not None
return r
def set_tensor_memo(self, t: MetaTensorDesc, v: _TensorT) -> None:
self.tensor_memo[t.id] = v
def get_storage_memo(self, s: MetaStorageDesc):
def get_storage_memo(self, s: MetaStorageDesc) -> Optional[torch.UntypedStorage]:
return self.storage_memo.get(s.id, None)
def set_storage_memo(self, s: MetaStorageDesc, v):
def set_storage_memo(self, s: MetaStorageDesc, v: torch.UntypedStorage) -> None:
self.storage_memo[s.id] = v
def meta_storage(self, s: MetaStorageDesc, callback):
def meta_storage(
self,
s: MetaStorageDesc,
callback: Callable[[Callable[[], torch.Tensor]], _TensorT],
) -> torch.UntypedStorage:
# If we are fakeifying a tensor that has a secretly-zero-sized storage,
# Need to make sure to resize the meta storage too.
if self.get_storage_memo(s) is None:
if (memo := self.get_storage_memo(s)) is None:
r_s = callback(
lambda: torch.empty(s.size, dtype=torch.uint8, device="meta"),
).untyped_storage()
@ -672,11 +722,29 @@ class MetaConverter:
# implemented as Tensor operations
with torch.no_grad(), no_dispatch():
assert s.data is not None
r_s.real_storage = s.data.clone()
_set_real_storage(r_s, s.data.clone())
self.set_storage_memo(s, r_s)
return r_s
else:
return self.get_storage_memo(s)
return memo
@classmethod
def _checked_cast_tensor_t(cls, t: torch.Tensor) -> _TensorT:
# TODO: how to check _TensorT?
return typing.cast(_TensorT, t)
@classmethod
def _identity_callable(cls, t: Callable[[], torch.Tensor]) -> _TensorT:
return cls._checked_cast_tensor_t(t())
@classmethod
def _backward_error(cls, t: _TensorT) -> _TensorT:
errfn = torch._C._functions.DelayedError(
"Internal error: Tried to backward() through example input",
1,
)
err = errfn(t)
return typing.cast(_TensorT, err)
# This function assumes that it's possible to do the conversion
# NB: name here is used in a conventional way by Dynamo; it corresponds
@ -687,11 +755,11 @@ class MetaConverter:
def meta_tensor(
self,
t: MetaTensorDesc,
shape_env: Optional[ShapeEnv] = None,
callback=lambda t: t(),
source: Optional[Source] = None,
symbolic_context: Optional[SymbolicContext] = None,
):
shape_env: Optional[ShapeEnv],
callback: Callable[[Callable[[], torch.Tensor]], _TensorT],
source: Optional[Source],
symbolic_context: Optional[SymbolicContext],
) -> _TensorT:
if source is None:
from torch._dynamo.source import ConstantSource
@ -739,7 +807,11 @@ class MetaConverter:
maybe_suppress = shape_env.suppress_guards
def sym_sizes_strides_storage_offset(
t: MetaTensorDesc, src, symbolic_context=symbolic_context
t: MetaTensorDesc,
src: torch._guards.Source,
symbolic_context: Optional[
torch.fx.experimental.symbolic_shapes.SymbolicContext
] = symbolic_context,
) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]:
assert t.stride is not None
if shape_env is not None:
@ -773,8 +845,12 @@ class MetaConverter:
return (t.size, t.stride, t.storage_offset)
def empty_create(
inner_t: MetaTensorDesc, inner_src, symbolic_context=symbolic_context
):
inner_t: MetaTensorDesc,
inner_src: torch._guards.Source,
symbolic_context: Optional[
torch.fx.experimental.symbolic_shapes.SymbolicContext
] = symbolic_context,
) -> torch.Tensor:
(
inner_sizes,
inner_strides,
@ -791,12 +867,13 @@ class MetaConverter:
# symbolic context.
def empty_create_subclass(
t: MetaTensorDesc,
outer_size,
outer_stride,
symbolic_context=symbolic_context,
callback=callback,
source=source,
):
outer_size: Tuple[int, ...],
outer_stride: Tuple[int, ...],
symbolic_context: Optional[
torch.fx.experimental.symbolic_shapes.SymbolicContext
] = symbolic_context,
source: Optional[torch._guards.Source] = source,
) -> _TensorT:
from torch._dynamo.source import AttrSource
from torch.fx.experimental.symbolic_shapes import SubclassSymbolicContext
@ -822,24 +899,38 @@ class MetaConverter:
)
def _empty_create_subclass(
t, outer_size, outer_stride, symbolic_context, callback, source
):
t: MetaTensorDesc,
outer_size: Optional[Tuple[int, ...]],
outer_stride: Optional[Tuple[int, ...]],
symbolic_context: Optional[
torch.fx.experimental.symbolic_shapes.SymbolicContext
],
callback: Callable[[Callable[[], torch.Tensor]], _TensorT],
source: torch._guards.Source,
) -> _TensorT:
# We are hitting plain meta_desc tensor so actually
# create a tensor here.
if t.attrs is None:
return self.meta_tensor(
t,
shape_env=shape_env,
callback=callback,
source=source,
symbolic_context=symbolic_context,
shape_env,
callback,
source,
symbolic_context,
)
inner_tensors = {}
for attr, meta_tensor_desc in t.attrs.items():
current_context = None
if symbolic_context is not None:
current_context = symbolic_context.inner_contexts[attr]
assert isinstance(symbolic_context, SubclassSymbolicContext)
if (
current_context_ := symbolic_context.inner_contexts[attr]
) is not None:
current_context = _checked_cast(
torch.fx.experimental.symbolic_shapes.SymbolicContext,
current_context_,
)
current_source = AttrSource(source, attr)
new_empty_tensor = _empty_create_subclass(
@ -852,10 +943,12 @@ class MetaConverter:
)
inner_tensors[attr] = new_empty_tensor
assert t.type is not None
return t.type.__tensor_unflatten__(
inner_tensors, t.ctx, outer_size, outer_stride
)
assert source is not None
sub = _empty_create_subclass(
t, outer_size, outer_stride, symbolic_context, callback, source
)
@ -879,8 +972,11 @@ class MetaConverter:
# closed-over ViewFunc state, as we don't have symbolic contexts for them, but we
# don't want to over-specialize during view replay.
def all_dynamic_symbolic_context(
t: MetaTensorDesc, source, shape_env, callback
):
t: MetaTensorDesc,
source: torch._guards.Source,
shape_env: Optional[torch.fx.experimental.symbolic_shapes.ShapeEnv],
callback: Callable[[Callable[[], torch.Tensor]], _TensorT],
) -> torch.fx.experimental.symbolic_shapes.SymbolicContext:
from torch._dynamo.source import AttrSource
from torch.fx.experimental.symbolic_shapes import (
DimDynamic,
@ -888,18 +984,22 @@ class MetaConverter:
SubclassSymbolicContext,
)
view_base_context: Optional[SymbolicContext] = None
view_base_context: Optional[
torch.fx.experimental.symbolic_shapes.SymbolicContext
] = None
if t.is_view:
assert t.base is not None
view_base_context = all_dynamic_symbolic_context(
t.base, AttrSource(source, "_base"), shape_env, callback
)
t_symbolic_context: SymbolicContext
t_symbolic_context: torch.fx.experimental.symbolic_shapes.SymbolicContext
t_dynamic_sizes = [DimDynamic.DYNAMIC] * t.ndim
if t.is_traceable_wrapper_subclass:
assert t.attrs is not None
inner_contexts: Dict[str, SymbolicContext] = {}
inner_contexts: Dict[
str, torch.fx.experimental.symbolic_shapes.SymbolicContext
] = {}
for attr, inner in t.attrs.items():
assert isinstance(attr, str)
inner_contexts[attr] = all_dynamic_symbolic_context(
@ -951,8 +1051,12 @@ class MetaConverter:
# Then view replay is done, swapping in the fake offsets so the view replay output
# is fully fake with no invalid specialization.
def view_from_base(
base: torch.Tensor, t: MetaTensorDesc, source=source, shape_env=shape_env
):
base: _TensorT,
t: MetaTensorDesc,
shape_env: Optional[
torch.fx.experimental.symbolic_shapes.ShapeEnv
] = shape_env,
) -> _TensorT:
# fake-ify t's metadata according to the outer symbolic context
(sizes, strides, storage_offset) = sym_sizes_strides_storage_offset(
t, source
@ -965,7 +1069,9 @@ class MetaConverter:
# TODO: Change this logic to use view replay for consistency?
# It's likely there is no view func available.
with maybe_suppress():
return base.as_strided(sizes, strides, storage_offset)
return self._checked_cast_tensor_t(
base.as_strided(sizes, strides, storage_offset)
)
from torch._dynamo.source import EphemeralSource
from torch.fx.experimental.symbolic_shapes import (
@ -973,7 +1079,7 @@ class MetaConverter:
sym_eq,
)
def symint_visitor_fn(s):
def symint_visitor_fn(s: int) -> int:
nonlocal symbolic_context
from torch.fx.experimental.symbolic_shapes import DimDynamic
@ -1017,10 +1123,10 @@ class MetaConverter:
# want a view of values with the offsets closed over. As the offsets component
# is needed to describe the output view, it's important that it's fakeified
# correctly.
fake_t = empty_create_subclass(
fake_t: _TensorT = empty_create_subclass(
t, outer_size=sizes, outer_stride=strides
)
attrs, _ = fake_t.__tensor_flatten__()
attrs, _ = fake_t.__tensor_flatten__() # type: ignore[attr-defined]
for attr in attrs:
real_to_fake_mapping[t.attrs[attr].id] = getattr(fake_t, attr)
@ -1028,9 +1134,11 @@ class MetaConverter:
visited_t: torch.Tensor,
# These arguments are never passed, we just use them to close
# over these relevant values
shape_env=shape_env,
callback=callback,
):
shape_env: Optional[
torch.fx.experimental.symbolic_shapes.ShapeEnv
] = shape_env,
callback: Callable[[Callable[[], torch.Tensor]], _TensorT] = callback, # type: ignore[assignment]
) -> torch.Tensor:
# It's possible to close over an undefined tensor (e.g. NJT's lengths).
if visited_t is None:
return None
@ -1057,8 +1165,8 @@ class MetaConverter:
visited_desc,
shape_env,
callback,
source=temp_source,
symbolic_context=all_dynamic_symbolic_context(
temp_source,
all_dynamic_symbolic_context(
visited_desc, temp_source, shape_env, callback
),
)
@ -1102,6 +1210,9 @@ class MetaConverter:
# Pray that sparse clone doesn't lose information
assert t.data is not None
with torch.no_grad(), no_dispatch():
assert isinstance(
r, torch._subclasses.fake_tensor.FakeTensor
)
r.real_tensor = _safe_clone(t.data)
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
# Note [is_coalesced is dispatched]
@ -1109,7 +1220,7 @@ class MetaConverter:
# which means that it will get caught by fake tensor mode.
# Ordinarily this would error, but there's some logic in
# fake tensor ensure this doesn't happen.
r._coalesced_(t.is_coalesced)
r._coalesced_(bool(t.is_coalesced))
if t.requires_grad:
r.requires_grad = True
if t.requires_grad and not is_leaf:
@ -1117,9 +1228,9 @@ class MetaConverter:
# but clone is fine for now for sparse tensors.
# (DelayedError does not work for sparse because it causes
# the Fake sparse tensor to "lose" its fakeness)
r = r.clone()
r = self._checked_cast_tensor_t(r.clone())
with torch.enable_grad():
r._coalesced_(t.is_coalesced)
r._coalesced_(bool(t.is_coalesced))
elif is_sparse_compressed_layout(t.layout):
is_leaf = t.is_leaf
@ -1154,15 +1265,15 @@ class MetaConverter:
# Pray sparse clone doesn't lose information
assert t.data is not None
with torch.no_grad(), no_dispatch():
assert isinstance(
r, torch._subclasses.fake_tensor.FakeTensor
)
r.real_tensor = _safe_clone(t.data)
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
if t.requires_grad:
r.requires_grad = True
if t.requires_grad and not is_leaf:
r = torch._C._functions.DelayedError(
"Internal error: Tried to backward() through example input",
1,
)(r)
r = self._backward_error(r)
elif t.is_nested and not t.is_traceable_wrapper_subclass:
# TODO: Handle this better in Dynamo?
# There are checks there now, but this can still be triggered by a dense
@ -1174,9 +1285,11 @@ class MetaConverter:
)
elif t.is_mkldnn:
is_leaf = t.is_leaf
sizes, strides, _storage_offset = sym_sizes_strides_storage_offset(
t, source
)
(
sizes,
strides,
_storage_offset,
) = sym_sizes_strides_storage_offset(t, source)
# TODO: This doesn't seem right, where's the MKLDNN'ness
# lol
r = callback(
@ -1188,6 +1301,9 @@ class MetaConverter:
with torch.no_grad(), no_dispatch():
assert t.size is not None
assert t.stride is not None
assert isinstance(
r, torch._subclasses.fake_tensor.FakeTensor
)
r.real_tensor = torch.empty_strided(
t.size, t.stride, dtype=t.dtype, device=t.device
)
@ -1197,10 +1313,7 @@ class MetaConverter:
if t.requires_grad:
r.requires_grad = True
if t.requires_grad and not is_leaf:
r = torch._C._functions.DelayedError(
"Internal error: Tried to backward() through example input",
1,
)(r)
r = self._backward_error(r)
elif t.is_functorch_wrapped:
if t.is_view:
from torch._dynamo.exc import unimplemented
@ -1211,9 +1324,10 @@ class MetaConverter:
# Wraps a functorch tensor class (BatchedTensor, GradTrackingTensor)
# in a FakeTensor
def _to_fake_tensor(t: MetaTensorDesc):
def _to_fake_tensor(t: MetaTensorDesc) -> _TensorT:
# TODO: why aren't the recursive calls going to
# meta_tensor
r: _TensorT
if t.is_batchedtensor:
assert t.unwrapped is not None
assert t.level is not None
@ -1228,7 +1342,9 @@ class MetaConverter:
with torch._functorch.pyfunctorch.temporarily_restore_interpreter_stack(
t.functorch_stack
):
r = _add_batch_dim(ft, bdim, lvl)
r = self._checked_cast_tensor_t(
_add_batch_dim(ft, bdim, lvl)
)
elif t.is_gradtrackingtensor:
assert t.unwrapped is not None
assert t.level is not None
@ -1242,33 +1358,32 @@ class MetaConverter:
with torch._functorch.pyfunctorch.temporarily_restore_interpreter_stack(
t.functorch_stack
):
r = torch._C._functorch._wrap_for_grad(ft, lvl)
r = self._checked_cast_tensor_t(
torch._C._functorch._wrap_for_grad(ft, lvl),
)
is_leaf = t.is_leaf
if t.requires_grad and safe_is_leaf(r):
r.requires_grad = True
elif t.requires_grad and not is_leaf:
r = torch._C._functions.DelayedError( # type: ignore[assignment]
"Internal error: Tried to backward() through example input",
1,
)(
r # type: ignore[arg-type]
)
r = self._backward_error(r)
elif t.is_functional:
assert t.unwrapped is not None
assert t.current_level is not None
ft = self.meta_tensor(
t.unwrapped,
shape_env=shape_env,
callback=callback,
shape_env,
callback,
# NB: reuse these exactly, we treat the
# functional tensor as "invisible".
# TODO: Actually this all probably doesn't
# work, take a closer look.
source=source,
symbolic_context=symbolic_context,
source,
symbolic_context,
)
r = self._checked_cast_tensor_t(
_wrap_functional_tensor(ft, t.current_level),
)
r = _wrap_functional_tensor(ft, t.current_level)
# TODO: is_leaf/requires_grad?
else:
assert t.stride is not None
@ -1302,12 +1417,14 @@ class MetaConverter:
assert not t.is_functorch_wrapped # handled above
unwrapped = self.meta_tensor(
t.unwrapped,
shape_env=shape_env,
callback=callback,
source=source,
symbolic_context=symbolic_context,
shape_env,
callback,
source,
symbolic_context,
)
r = self._checked_cast_tensor_t(
torch._to_functional_tensor(unwrapped)
)
r = torch._to_functional_tensor(unwrapped)
torch._mirror_autograd_meta_to(t.autograd_meta_from, r) # type: ignore[attr-defined]
elif t.is_view:
@ -1335,11 +1452,13 @@ class MetaConverter:
t.base,
shape_env,
callback,
source=torch._dynamo.source.AttrSource(source, "_base"),
symbolic_context=base_symbolic_context,
torch._dynamo.source.AttrSource(source, "_base"),
base_symbolic_context,
)
def is_c_of_r(complex_dtype, real_dtype):
def is_c_of_r(
complex_dtype: torch.dtype, real_dtype: torch.dtype
) -> bool:
return (
utils.is_complex_dtype(complex_dtype)
and utils.corresponding_real_dtype(complex_dtype)
@ -1361,14 +1480,16 @@ class MetaConverter:
if base.dtype == t.dtype:
pass
elif is_c_of_r(base.dtype, t.dtype):
base = torch.view_as_real(base)
base = self._checked_cast_tensor_t(torch.view_as_real(base))
elif is_c_of_r(t.dtype, base.dtype):
base = torch.view_as_complex(base)
base = self._checked_cast_tensor_t(
torch.view_as_complex(base)
)
else:
# This is not guaranteed to succeed. If it fails, it
# means there is another dtype-converting view function
# that hasn't been handled here
base = base.view(t.dtype)
base = self._checked_cast_tensor_t(base.view(t.dtype))
# This is very tricky. Naively, you might expect this
# to hold:
@ -1410,7 +1531,9 @@ class MetaConverter:
# NB: Can't have a non-leaf without requiring grad!
assert t.requires_grad
with torch.no_grad():
mid = base.view(base.shape)
mid = self._checked_cast_tensor_t(
base.view(base.shape)
)
mid.requires_grad = t.requires_grad
with torch.enable_grad():
r = view_from_base(mid, t)
@ -1459,6 +1582,9 @@ class MetaConverter:
with torch.no_grad(), no_dispatch():
assert t.size is not None
assert t.stride is not None
assert isinstance(
r, torch._subclasses.fake_tensor.FakeTensor
)
r.real_tensor = torch.empty_strided(
t.size, t.stride, dtype=t.dtype, device=t.device
)
@ -1477,10 +1603,7 @@ class MetaConverter:
# the metadata of the inner tensor.
# So instead, we now have a dedicated fn to set autograd history,
# without inadvertently changing other metadata.
r = torch._C._functions.DelayedError(
"Internal error: Tried to backward() through example input",
1,
)(r)
r = self._backward_error(r)
s = t.storage
assert s is not None
@ -1494,8 +1617,12 @@ class MetaConverter:
# You're normal and happy, install the fresh storage into the memo
self.set_storage_memo(s, r.untyped_storage())
if self.copy_data:
r.untyped_storage().real_storage = (
r.real_tensor.untyped_storage()
assert isinstance(
r, torch._subclasses.fake_tensor.FakeTensor
)
assert r.real_tensor is not None
_set_real_storage(
r.untyped_storage(), r.real_tensor.untyped_storage()
)
else:
# You're in crazy town; somehow you gave us a tensor
@ -1540,8 +1667,13 @@ class MetaConverter:
r.set_(r_s, storage_offset, sizes, strides)
if self.copy_data:
with torch.no_grad(), no_dispatch():
assert isinstance(
r, torch._subclasses.fake_tensor.FakeTensor
)
assert r.real_tensor is not None
assert t.stride is not None
r.real_tensor.set_(
r_s.real_storage,
_get_real_storage(r_s),
t.storage_offset,
t.size,
t.stride,
@ -1556,8 +1688,8 @@ class MetaConverter:
t.grad,
shape_env,
callback,
source=AttrSource(source, "grad"),
symbolic_context=symbolic_context,
AttrSource(source, "grad"),
symbolic_context,
)
torch._C._set_conj(r, t.is_conj)
torch._C._set_neg(r, t.is_neg)
@ -1577,27 +1709,33 @@ class MetaConverter:
# See Note: [Creating symbolic nested int]
if t.nested_int is not None:
assert isinstance(r, torch._subclasses.fake_tensor.FakeTensor)
r.nested_int_memo = r.fake_mode.create_symbolic_nested_int(
nt_tensor_id=t.nested_int
)
self.set_tensor_memo(t, r)
return self.get_tensor_memo(t)
return self._checked_get_tensor_memo(t)
def __call__(
self,
t,
shape_env=None,
t: torch.Tensor,
shape_env: Optional[ShapeEnv] = None,
*,
callback=lambda t: t(),
source=None,
symbolic_context=None,
callback: Optional[Callable[[Callable[[], torch.Tensor]], _TensorT]] = None,
source: Optional[Source] = None,
symbolic_context: Optional[SymbolicContext] = None,
# Controls whether or not we should dump the tensor metadata to structured logs
# when source is not None. Because we refakify after Dynamo is done,
# we don't want to dump info again from AOTAutograd, it is redundant.
trace=True,
):
trace: bool = True,
) -> _TensorT:
callback_: Callable[[Callable[[], torch.Tensor]], _TensorT]
if callback is None:
callback_ = self._identity_callable
else:
callback_ = callback
# TODO: zero tensors? We appear to have eliminated them by
# excluding complex for now
@ -1637,6 +1775,7 @@ class MetaConverter:
t_desc = self.describer.describe_tensor(t, trace=trace)
if trace:
assert source is not None
trace_structured(
"describe_source",
metadata_fn=lambda: {
@ -1659,10 +1798,10 @@ class MetaConverter:
r = self.meta_tensor(
t_desc,
shape_env=shape_env,
callback=callback,
source=source,
symbolic_context=symbolic_context,
shape_env,
callback_,
source,
symbolic_context,
)
if type(t) is torch.nn.Parameter: