mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
type annotations for meta_utils (#140203)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140203 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
c25999bdc0
commit
82597d07aa
@ -1,8 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import typing
|
||||
import warnings
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
@ -12,14 +12,18 @@ from typing import (
|
||||
ClassVar,
|
||||
ContextManager,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
NewType,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import TypeAlias, TypeGuard
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
import torch
|
||||
from torch._C._autograd import CreationMeta
|
||||
@ -47,16 +51,17 @@ if TYPE_CHECKING:
|
||||
from torch._guards import Source
|
||||
|
||||
# Import here to avoid cycle
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
|
||||
# Import the following modules during type checking to enable code intelligence features,
|
||||
# Do not import unconditionally, as they import sympy and importing sympy is very slow
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymbolicContext
|
||||
|
||||
DimList = List
|
||||
_TensorLikeT = TypeVar("_TensorLikeT", "MetaTensorDesc", torch.Tensor)
|
||||
_T = TypeVar("_T")
|
||||
_TensorT = TypeVar("_TensorT", bound=torch.Tensor)
|
||||
|
||||
|
||||
def safe_is_leaf(t):
|
||||
def safe_is_leaf(t: Union[MetaTensorDesc, torch.Tensor]) -> bool:
|
||||
try:
|
||||
return t.is_leaf
|
||||
except RuntimeError:
|
||||
@ -64,28 +69,37 @@ def safe_is_leaf(t):
|
||||
return False
|
||||
|
||||
|
||||
def safe_grad(t):
|
||||
def safe_grad(t: _TensorLikeT) -> Optional[_TensorLikeT]:
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", "The .grad attribute of a Tensor")
|
||||
return t.grad
|
||||
|
||||
|
||||
def assert_eq(a, b):
|
||||
def _expect_safe_grad(t: _TensorLikeT) -> _TensorLikeT:
|
||||
grad = safe_grad(t)
|
||||
assert grad is not None
|
||||
return grad
|
||||
|
||||
|
||||
def assert_eq(a: _T, b: _T) -> None:
|
||||
assert a == b, f"{a} != {b}"
|
||||
|
||||
|
||||
def assert_metadata_eq(
|
||||
assert_eq,
|
||||
assert_eq: Callable[[object, object], None],
|
||||
m1: Union[MetaTensorDesc, torch.Tensor],
|
||||
m2: torch.Tensor,
|
||||
*,
|
||||
skip_symbolic=False,
|
||||
skip_leaf=False,
|
||||
):
|
||||
if isinstance(m1, torch.Tensor):
|
||||
m1 = MetaTensorDescriber().describe_tensor(m1)
|
||||
skip_symbolic: bool = False,
|
||||
skip_leaf: bool = False,
|
||||
) -> None:
|
||||
m1 = (
|
||||
MetaTensorDescriber().describe_tensor(m1)
|
||||
if isinstance(m1, torch.Tensor)
|
||||
else m1
|
||||
)
|
||||
|
||||
def go(m1, m2):
|
||||
def go(m1: MetaTensorDesc, m2: torch.Tensor) -> None:
|
||||
assert_eq(m1.dtype, m2.dtype)
|
||||
if not skip_symbolic:
|
||||
assert_eq(m1.shape, m2.shape)
|
||||
@ -100,7 +114,7 @@ def assert_metadata_eq(
|
||||
assert_eq(m1.is_neg, m2.is_neg())
|
||||
assert_eq(m1.grad is not None, safe_grad(m2) is not None)
|
||||
if m1.grad is not None:
|
||||
go(m1.grad, safe_grad(m2))
|
||||
go(m1.grad, _expect_safe_grad(m2))
|
||||
# TODO: move "assert_eq(m1.layout, m2.layout)" out of sparse
|
||||
# branches (but not ready for prime time yet)...
|
||||
if m1.is_sparse:
|
||||
@ -118,6 +132,8 @@ def assert_metadata_eq(
|
||||
assert_eq(m1.storage_offset, m2.storage_offset())
|
||||
assert_eq(m1.is_view, m2._is_view())
|
||||
if m1.is_view:
|
||||
assert m1.base is not None
|
||||
assert m2._base is not None
|
||||
go(m1.base, m2._base)
|
||||
# TODO: test if is resizable (no direct query for this atm)
|
||||
# TODO: audit AutogradMeta to see if it matches
|
||||
@ -126,11 +142,12 @@ def assert_metadata_eq(
|
||||
return go(m1, m2)
|
||||
|
||||
|
||||
def is_sparse_coo(t):
|
||||
# TypeGuard (not TypeIs): False does not imply !torch.Tensor
|
||||
def is_sparse_coo(t: object) -> TypeGuard[torch.Tensor]:
|
||||
return isinstance(t, torch.Tensor) and t.layout is torch.sparse_coo
|
||||
|
||||
|
||||
def is_sparse_compressed_layout(layout):
|
||||
def is_sparse_compressed_layout(layout: torch.layout) -> bool:
|
||||
return layout in {
|
||||
torch.sparse_csr,
|
||||
torch.sparse_csc,
|
||||
@ -139,20 +156,38 @@ def is_sparse_compressed_layout(layout):
|
||||
}
|
||||
|
||||
|
||||
def is_sparse_compressed(t):
|
||||
# TypeGuard (not TypeIs): False does not imply !torch.Tensor
|
||||
def is_sparse_compressed(t: object) -> TypeGuard[torch.Tensor]:
|
||||
return isinstance(t, torch.Tensor) and is_sparse_compressed_layout(t.layout)
|
||||
|
||||
|
||||
# TypeGuard (not TypeIs): False does not imply !torch.Tensor
|
||||
def is_sparse_any(t: object) -> TypeGuard[torch.Tensor]:
|
||||
return is_sparse_coo(t) or is_sparse_compressed(t)
|
||||
|
||||
|
||||
def _checked_cast(ty: Type[_T], obj: object) -> _T:
|
||||
assert isinstance(obj, ty), f"expected {ty} but got {type(obj)}"
|
||||
return obj
|
||||
|
||||
|
||||
def _get_real_storage(base: torch.UntypedStorage) -> torch.UntypedStorage:
|
||||
return base.real_storage # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _set_real_storage(
|
||||
base: torch.UntypedStorage, real_storage: torch.UntypedStorage
|
||||
) -> None:
|
||||
base.real_storage = real_storage # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# Don't use id() directly, because those can get reallocated over time.
|
||||
MetaStorageId: TypeAlias = int
|
||||
MetaTensorId: TypeAlias = int
|
||||
MetaStorageId = NewType("MetaStorageId", int)
|
||||
MetaTensorId = NewType("MetaTensorId", int)
|
||||
|
||||
|
||||
DESCRIBER_NEXT_ID = 0
|
||||
_DescriberId = NewType("_DescriberId", int)
|
||||
DESCRIBER_NEXT_ID = _DescriberId(0)
|
||||
|
||||
|
||||
class MetaTensorDescriber:
|
||||
@ -166,33 +201,35 @@ class MetaTensorDescriber:
|
||||
the same ID when we see the same tensor/storage.
|
||||
"""
|
||||
|
||||
def __init__(self, *, copy_data=False):
|
||||
def __init__(self, *, copy_data: bool = False) -> None:
|
||||
global DESCRIBER_NEXT_ID
|
||||
self.id = DESCRIBER_NEXT_ID
|
||||
DESCRIBER_NEXT_ID += 1
|
||||
self.next_tensor_id: MetaTensorId = 0
|
||||
self.next_storage_id: MetaStorageId = 0
|
||||
DESCRIBER_NEXT_ID = _DescriberId(DESCRIBER_NEXT_ID + 1)
|
||||
self.next_tensor_id: MetaTensorId = MetaTensorId(0)
|
||||
self.next_storage_id: MetaStorageId = MetaStorageId(0)
|
||||
# Tensor -> int
|
||||
self.lookup_tensor = WeakIdKeyDictionary()
|
||||
# Storage -> int
|
||||
self.lookup_storage = WeakIdKeyDictionary()
|
||||
self.copy_data = copy_data
|
||||
self.traced_tensors = set()
|
||||
self.traced_storages = set()
|
||||
self.traced_tensors: Set[int] = set()
|
||||
self.traced_storages: Set[int] = set()
|
||||
|
||||
def get_tensor_id(self, t: torch.Tensor):
|
||||
def get_tensor_id(self, t: torch.Tensor) -> MetaTensorId:
|
||||
if t not in self.lookup_tensor:
|
||||
self.lookup_tensor[t] = self.next_tensor_id
|
||||
self.next_tensor_id += 1
|
||||
self.next_tensor_id = MetaTensorId(self.next_tensor_id + 1)
|
||||
return self.lookup_tensor[t]
|
||||
|
||||
def get_storage_id(self, s: torch.UntypedStorage):
|
||||
def get_storage_id(self, s: torch.UntypedStorage) -> MetaStorageId:
|
||||
if s not in self.lookup_storage:
|
||||
self.lookup_storage[s] = self.next_storage_id
|
||||
self.next_storage_id += 1
|
||||
self.next_storage_id = MetaStorageId(self.next_storage_id + 1)
|
||||
return self.lookup_storage[s]
|
||||
|
||||
def describe_storage(self, s: torch.UntypedStorage, *, trace: bool = False):
|
||||
def describe_storage(
|
||||
self, s: torch.UntypedStorage, *, trace: bool = False
|
||||
) -> MetaStorageDesc:
|
||||
r = MetaStorageDesc(
|
||||
id=self.get_storage_id(s),
|
||||
size=s.size(),
|
||||
@ -210,7 +247,7 @@ class MetaTensorDescriber:
|
||||
|
||||
def describe_tensor(
|
||||
self, t: torch.Tensor, *, recurse: bool = True, trace: bool = False
|
||||
):
|
||||
) -> MetaTensorDesc:
|
||||
is_leaf = safe_is_leaf(t)
|
||||
is_view = t._is_view()
|
||||
is_sparse = t.is_sparse
|
||||
@ -381,8 +418,8 @@ class MetaTensorDescriber:
|
||||
else None
|
||||
),
|
||||
grad=(
|
||||
self.describe_tensor(safe_grad(t), trace=trace)
|
||||
if safe_grad(t) is not None
|
||||
self.describe_tensor(grad, trace=trace)
|
||||
if (grad := safe_grad(t)) is not None
|
||||
else None
|
||||
),
|
||||
creation_meta=(
|
||||
@ -430,7 +467,7 @@ class MetaStorageDesc:
|
||||
# serializable in JSON, you want to do something special here anyway
|
||||
data: Optional[torch.UntypedStorage]
|
||||
|
||||
def as_json(self, describer_id):
|
||||
def as_json(self, describer_id: _DescriberId) -> Dict[str, object]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"describer_id": describer_id,
|
||||
@ -439,7 +476,7 @@ class MetaStorageDesc:
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MetaTensorDesc:
|
||||
class MetaTensorDesc(Generic[_TensorT]):
|
||||
id: MetaTensorId
|
||||
ndim: int
|
||||
dtype: torch.dtype
|
||||
@ -520,15 +557,15 @@ class MetaTensorDesc:
|
||||
|
||||
ctx: Optional[object] = None # is_traceable_wrapper_subclass
|
||||
type: Optional[Type] = None # is_traceable_wrapper_subclass
|
||||
fake_mode: Optional[FakeTensorMode] = None
|
||||
fake_mode: Optional[torch._subclasses.fake_tensor.FakeTensorMode] = None
|
||||
view_func: Optional[
|
||||
Callable[
|
||||
[
|
||||
torch.Tensor,
|
||||
Callable[[int], int],
|
||||
Callable[[torch.Tensor], torch.Tensor],
|
||||
Callable[[torch.Tensor], _TensorT],
|
||||
],
|
||||
torch.Tensor,
|
||||
_TensorT,
|
||||
]
|
||||
] = None
|
||||
# level looks serializable, but actually it is meaningless without
|
||||
@ -555,8 +592,8 @@ class MetaTensorDesc:
|
||||
|
||||
# NB: This will reference numeric IDs, and it is assumed that you've
|
||||
# already serialized everything this recursively references
|
||||
def as_json(self, describer_id):
|
||||
def json(k, v):
|
||||
def as_json(self, describer_id: _DescriberId) -> Dict[str, object]:
|
||||
def json(k: str, v: object) -> object:
|
||||
# Some best-effort debugging serialization for unserializable
|
||||
# fields (feel free to add other special cases as appropriate)
|
||||
if k in ["data", "autograd_meta_from"]:
|
||||
@ -592,7 +629,7 @@ class MetaTensorDesc:
|
||||
return r
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
def shape(self) -> Tuple[int, ...]:
|
||||
return self.size
|
||||
|
||||
|
||||
@ -608,13 +645,13 @@ class MetaTensorDesc:
|
||||
# FakeTensor as src, we MUST NOT run the copy/clone operation. A better way
|
||||
# to do this would be to not use no_dispatch and instead just disable fake
|
||||
# tensor mode only (allowing for subclass dispatch to occur)
|
||||
def _safe_copy(dst, src):
|
||||
def _safe_copy(dst: torch.Tensor, src: Optional[torch.Tensor]) -> None:
|
||||
if type(src) is not torch.Tensor:
|
||||
return
|
||||
dst.copy_(src)
|
||||
|
||||
|
||||
def _safe_clone(src):
|
||||
def _safe_clone(src: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
if type(src) is not torch.Tensor:
|
||||
return None
|
||||
return src.clone()
|
||||
@ -627,13 +664,17 @@ def _safe_clone(src):
|
||||
# share storage because this is how we correlate shared storages to the same
|
||||
# meta storages. This class will hold weak references to cached tenosrs
|
||||
# and tensor storages.
|
||||
class MetaConverter:
|
||||
def __init__(self, *, copy_data: bool = False):
|
||||
class MetaConverter(Generic[_TensorT]):
|
||||
def __init__(self, *, copy_data: bool = False) -> None:
|
||||
# Maps MetaStorageId to UntypedStorage
|
||||
self.storage_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
|
||||
self.storage_memo: weakref.WeakValueDictionary[
|
||||
MetaStorageId, torch.UntypedStorage
|
||||
] = weakref.WeakValueDictionary()
|
||||
# Maps MetaTensorId to torch.Tensor (typically a meta tensor or
|
||||
# FakeTensor)
|
||||
self.tensor_memo: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
|
||||
self.tensor_memo: weakref.WeakValueDictionary[
|
||||
MetaTensorId, _TensorT
|
||||
] = weakref.WeakValueDictionary()
|
||||
self.hit = 0
|
||||
self.miss = 0
|
||||
self.del_hook = None
|
||||
@ -645,25 +686,34 @@ class MetaConverter:
|
||||
self.copy_data = copy_data
|
||||
self.describer = MetaTensorDescriber(copy_data=copy_data)
|
||||
|
||||
def successful(self):
|
||||
def successful(self) -> bool:
|
||||
return self.hit > 0 and self.miss == 0
|
||||
|
||||
def get_tensor_memo(self, t: MetaTensorDesc):
|
||||
def get_tensor_memo(self, t: MetaTensorDesc) -> Optional[torch.Tensor]:
|
||||
return self.tensor_memo.get(t.id, None)
|
||||
|
||||
def set_tensor_memo(self, t: MetaTensorDesc, v):
|
||||
def _checked_get_tensor_memo(self, t: MetaTensorDesc) -> _TensorT:
|
||||
r = self.tensor_memo.get(t.id, None)
|
||||
assert r is not None
|
||||
return r
|
||||
|
||||
def set_tensor_memo(self, t: MetaTensorDesc, v: _TensorT) -> None:
|
||||
self.tensor_memo[t.id] = v
|
||||
|
||||
def get_storage_memo(self, s: MetaStorageDesc):
|
||||
def get_storage_memo(self, s: MetaStorageDesc) -> Optional[torch.UntypedStorage]:
|
||||
return self.storage_memo.get(s.id, None)
|
||||
|
||||
def set_storage_memo(self, s: MetaStorageDesc, v):
|
||||
def set_storage_memo(self, s: MetaStorageDesc, v: torch.UntypedStorage) -> None:
|
||||
self.storage_memo[s.id] = v
|
||||
|
||||
def meta_storage(self, s: MetaStorageDesc, callback):
|
||||
def meta_storage(
|
||||
self,
|
||||
s: MetaStorageDesc,
|
||||
callback: Callable[[Callable[[], torch.Tensor]], _TensorT],
|
||||
) -> torch.UntypedStorage:
|
||||
# If we are fakeifying a tensor that has a secretly-zero-sized storage,
|
||||
# Need to make sure to resize the meta storage too.
|
||||
if self.get_storage_memo(s) is None:
|
||||
if (memo := self.get_storage_memo(s)) is None:
|
||||
r_s = callback(
|
||||
lambda: torch.empty(s.size, dtype=torch.uint8, device="meta"),
|
||||
).untyped_storage()
|
||||
@ -672,11 +722,29 @@ class MetaConverter:
|
||||
# implemented as Tensor operations
|
||||
with torch.no_grad(), no_dispatch():
|
||||
assert s.data is not None
|
||||
r_s.real_storage = s.data.clone()
|
||||
_set_real_storage(r_s, s.data.clone())
|
||||
self.set_storage_memo(s, r_s)
|
||||
return r_s
|
||||
else:
|
||||
return self.get_storage_memo(s)
|
||||
return memo
|
||||
|
||||
@classmethod
|
||||
def _checked_cast_tensor_t(cls, t: torch.Tensor) -> _TensorT:
|
||||
# TODO: how to check _TensorT?
|
||||
return typing.cast(_TensorT, t)
|
||||
|
||||
@classmethod
|
||||
def _identity_callable(cls, t: Callable[[], torch.Tensor]) -> _TensorT:
|
||||
return cls._checked_cast_tensor_t(t())
|
||||
|
||||
@classmethod
|
||||
def _backward_error(cls, t: _TensorT) -> _TensorT:
|
||||
errfn = torch._C._functions.DelayedError(
|
||||
"Internal error: Tried to backward() through example input",
|
||||
1,
|
||||
)
|
||||
err = errfn(t)
|
||||
return typing.cast(_TensorT, err)
|
||||
|
||||
# This function assumes that it's possible to do the conversion
|
||||
# NB: name here is used in a conventional way by Dynamo; it corresponds
|
||||
@ -687,11 +755,11 @@ class MetaConverter:
|
||||
def meta_tensor(
|
||||
self,
|
||||
t: MetaTensorDesc,
|
||||
shape_env: Optional[ShapeEnv] = None,
|
||||
callback=lambda t: t(),
|
||||
source: Optional[Source] = None,
|
||||
symbolic_context: Optional[SymbolicContext] = None,
|
||||
):
|
||||
shape_env: Optional[ShapeEnv],
|
||||
callback: Callable[[Callable[[], torch.Tensor]], _TensorT],
|
||||
source: Optional[Source],
|
||||
symbolic_context: Optional[SymbolicContext],
|
||||
) -> _TensorT:
|
||||
if source is None:
|
||||
from torch._dynamo.source import ConstantSource
|
||||
|
||||
@ -739,7 +807,11 @@ class MetaConverter:
|
||||
maybe_suppress = shape_env.suppress_guards
|
||||
|
||||
def sym_sizes_strides_storage_offset(
|
||||
t: MetaTensorDesc, src, symbolic_context=symbolic_context
|
||||
t: MetaTensorDesc,
|
||||
src: torch._guards.Source,
|
||||
symbolic_context: Optional[
|
||||
torch.fx.experimental.symbolic_shapes.SymbolicContext
|
||||
] = symbolic_context,
|
||||
) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]:
|
||||
assert t.stride is not None
|
||||
if shape_env is not None:
|
||||
@ -773,8 +845,12 @@ class MetaConverter:
|
||||
return (t.size, t.stride, t.storage_offset)
|
||||
|
||||
def empty_create(
|
||||
inner_t: MetaTensorDesc, inner_src, symbolic_context=symbolic_context
|
||||
):
|
||||
inner_t: MetaTensorDesc,
|
||||
inner_src: torch._guards.Source,
|
||||
symbolic_context: Optional[
|
||||
torch.fx.experimental.symbolic_shapes.SymbolicContext
|
||||
] = symbolic_context,
|
||||
) -> torch.Tensor:
|
||||
(
|
||||
inner_sizes,
|
||||
inner_strides,
|
||||
@ -791,12 +867,13 @@ class MetaConverter:
|
||||
# symbolic context.
|
||||
def empty_create_subclass(
|
||||
t: MetaTensorDesc,
|
||||
outer_size,
|
||||
outer_stride,
|
||||
symbolic_context=symbolic_context,
|
||||
callback=callback,
|
||||
source=source,
|
||||
):
|
||||
outer_size: Tuple[int, ...],
|
||||
outer_stride: Tuple[int, ...],
|
||||
symbolic_context: Optional[
|
||||
torch.fx.experimental.symbolic_shapes.SymbolicContext
|
||||
] = symbolic_context,
|
||||
source: Optional[torch._guards.Source] = source,
|
||||
) -> _TensorT:
|
||||
from torch._dynamo.source import AttrSource
|
||||
from torch.fx.experimental.symbolic_shapes import SubclassSymbolicContext
|
||||
|
||||
@ -822,24 +899,38 @@ class MetaConverter:
|
||||
)
|
||||
|
||||
def _empty_create_subclass(
|
||||
t, outer_size, outer_stride, symbolic_context, callback, source
|
||||
):
|
||||
t: MetaTensorDesc,
|
||||
outer_size: Optional[Tuple[int, ...]],
|
||||
outer_stride: Optional[Tuple[int, ...]],
|
||||
symbolic_context: Optional[
|
||||
torch.fx.experimental.symbolic_shapes.SymbolicContext
|
||||
],
|
||||
callback: Callable[[Callable[[], torch.Tensor]], _TensorT],
|
||||
source: torch._guards.Source,
|
||||
) -> _TensorT:
|
||||
# We are hitting plain meta_desc tensor so actually
|
||||
# create a tensor here.
|
||||
if t.attrs is None:
|
||||
return self.meta_tensor(
|
||||
t,
|
||||
shape_env=shape_env,
|
||||
callback=callback,
|
||||
source=source,
|
||||
symbolic_context=symbolic_context,
|
||||
shape_env,
|
||||
callback,
|
||||
source,
|
||||
symbolic_context,
|
||||
)
|
||||
|
||||
inner_tensors = {}
|
||||
for attr, meta_tensor_desc in t.attrs.items():
|
||||
current_context = None
|
||||
if symbolic_context is not None:
|
||||
current_context = symbolic_context.inner_contexts[attr]
|
||||
assert isinstance(symbolic_context, SubclassSymbolicContext)
|
||||
if (
|
||||
current_context_ := symbolic_context.inner_contexts[attr]
|
||||
) is not None:
|
||||
current_context = _checked_cast(
|
||||
torch.fx.experimental.symbolic_shapes.SymbolicContext,
|
||||
current_context_,
|
||||
)
|
||||
|
||||
current_source = AttrSource(source, attr)
|
||||
new_empty_tensor = _empty_create_subclass(
|
||||
@ -852,10 +943,12 @@ class MetaConverter:
|
||||
)
|
||||
inner_tensors[attr] = new_empty_tensor
|
||||
|
||||
assert t.type is not None
|
||||
return t.type.__tensor_unflatten__(
|
||||
inner_tensors, t.ctx, outer_size, outer_stride
|
||||
)
|
||||
|
||||
assert source is not None
|
||||
sub = _empty_create_subclass(
|
||||
t, outer_size, outer_stride, symbolic_context, callback, source
|
||||
)
|
||||
@ -879,8 +972,11 @@ class MetaConverter:
|
||||
# closed-over ViewFunc state, as we don't have symbolic contexts for them, but we
|
||||
# don't want to over-specialize during view replay.
|
||||
def all_dynamic_symbolic_context(
|
||||
t: MetaTensorDesc, source, shape_env, callback
|
||||
):
|
||||
t: MetaTensorDesc,
|
||||
source: torch._guards.Source,
|
||||
shape_env: Optional[torch.fx.experimental.symbolic_shapes.ShapeEnv],
|
||||
callback: Callable[[Callable[[], torch.Tensor]], _TensorT],
|
||||
) -> torch.fx.experimental.symbolic_shapes.SymbolicContext:
|
||||
from torch._dynamo.source import AttrSource
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
DimDynamic,
|
||||
@ -888,18 +984,22 @@ class MetaConverter:
|
||||
SubclassSymbolicContext,
|
||||
)
|
||||
|
||||
view_base_context: Optional[SymbolicContext] = None
|
||||
view_base_context: Optional[
|
||||
torch.fx.experimental.symbolic_shapes.SymbolicContext
|
||||
] = None
|
||||
if t.is_view:
|
||||
assert t.base is not None
|
||||
view_base_context = all_dynamic_symbolic_context(
|
||||
t.base, AttrSource(source, "_base"), shape_env, callback
|
||||
)
|
||||
|
||||
t_symbolic_context: SymbolicContext
|
||||
t_symbolic_context: torch.fx.experimental.symbolic_shapes.SymbolicContext
|
||||
t_dynamic_sizes = [DimDynamic.DYNAMIC] * t.ndim
|
||||
if t.is_traceable_wrapper_subclass:
|
||||
assert t.attrs is not None
|
||||
inner_contexts: Dict[str, SymbolicContext] = {}
|
||||
inner_contexts: Dict[
|
||||
str, torch.fx.experimental.symbolic_shapes.SymbolicContext
|
||||
] = {}
|
||||
for attr, inner in t.attrs.items():
|
||||
assert isinstance(attr, str)
|
||||
inner_contexts[attr] = all_dynamic_symbolic_context(
|
||||
@ -951,8 +1051,12 @@ class MetaConverter:
|
||||
# Then view replay is done, swapping in the fake offsets so the view replay output
|
||||
# is fully fake with no invalid specialization.
|
||||
def view_from_base(
|
||||
base: torch.Tensor, t: MetaTensorDesc, source=source, shape_env=shape_env
|
||||
):
|
||||
base: _TensorT,
|
||||
t: MetaTensorDesc,
|
||||
shape_env: Optional[
|
||||
torch.fx.experimental.symbolic_shapes.ShapeEnv
|
||||
] = shape_env,
|
||||
) -> _TensorT:
|
||||
# fake-ify t's metadata according to the outer symbolic context
|
||||
(sizes, strides, storage_offset) = sym_sizes_strides_storage_offset(
|
||||
t, source
|
||||
@ -965,7 +1069,9 @@ class MetaConverter:
|
||||
# TODO: Change this logic to use view replay for consistency?
|
||||
# It's likely there is no view func available.
|
||||
with maybe_suppress():
|
||||
return base.as_strided(sizes, strides, storage_offset)
|
||||
return self._checked_cast_tensor_t(
|
||||
base.as_strided(sizes, strides, storage_offset)
|
||||
)
|
||||
|
||||
from torch._dynamo.source import EphemeralSource
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
@ -973,7 +1079,7 @@ class MetaConverter:
|
||||
sym_eq,
|
||||
)
|
||||
|
||||
def symint_visitor_fn(s):
|
||||
def symint_visitor_fn(s: int) -> int:
|
||||
nonlocal symbolic_context
|
||||
from torch.fx.experimental.symbolic_shapes import DimDynamic
|
||||
|
||||
@ -1017,10 +1123,10 @@ class MetaConverter:
|
||||
# want a view of values with the offsets closed over. As the offsets component
|
||||
# is needed to describe the output view, it's important that it's fakeified
|
||||
# correctly.
|
||||
fake_t = empty_create_subclass(
|
||||
fake_t: _TensorT = empty_create_subclass(
|
||||
t, outer_size=sizes, outer_stride=strides
|
||||
)
|
||||
attrs, _ = fake_t.__tensor_flatten__()
|
||||
attrs, _ = fake_t.__tensor_flatten__() # type: ignore[attr-defined]
|
||||
for attr in attrs:
|
||||
real_to_fake_mapping[t.attrs[attr].id] = getattr(fake_t, attr)
|
||||
|
||||
@ -1028,9 +1134,11 @@ class MetaConverter:
|
||||
visited_t: torch.Tensor,
|
||||
# These arguments are never passed, we just use them to close
|
||||
# over these relevant values
|
||||
shape_env=shape_env,
|
||||
callback=callback,
|
||||
):
|
||||
shape_env: Optional[
|
||||
torch.fx.experimental.symbolic_shapes.ShapeEnv
|
||||
] = shape_env,
|
||||
callback: Callable[[Callable[[], torch.Tensor]], _TensorT] = callback, # type: ignore[assignment]
|
||||
) -> torch.Tensor:
|
||||
# It's possible to close over an undefined tensor (e.g. NJT's lengths).
|
||||
if visited_t is None:
|
||||
return None
|
||||
@ -1057,8 +1165,8 @@ class MetaConverter:
|
||||
visited_desc,
|
||||
shape_env,
|
||||
callback,
|
||||
source=temp_source,
|
||||
symbolic_context=all_dynamic_symbolic_context(
|
||||
temp_source,
|
||||
all_dynamic_symbolic_context(
|
||||
visited_desc, temp_source, shape_env, callback
|
||||
),
|
||||
)
|
||||
@ -1102,6 +1210,9 @@ class MetaConverter:
|
||||
# Pray that sparse clone doesn't lose information
|
||||
assert t.data is not None
|
||||
with torch.no_grad(), no_dispatch():
|
||||
assert isinstance(
|
||||
r, torch._subclasses.fake_tensor.FakeTensor
|
||||
)
|
||||
r.real_tensor = _safe_clone(t.data)
|
||||
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
|
||||
# Note [is_coalesced is dispatched]
|
||||
@ -1109,7 +1220,7 @@ class MetaConverter:
|
||||
# which means that it will get caught by fake tensor mode.
|
||||
# Ordinarily this would error, but there's some logic in
|
||||
# fake tensor ensure this doesn't happen.
|
||||
r._coalesced_(t.is_coalesced)
|
||||
r._coalesced_(bool(t.is_coalesced))
|
||||
if t.requires_grad:
|
||||
r.requires_grad = True
|
||||
if t.requires_grad and not is_leaf:
|
||||
@ -1117,9 +1228,9 @@ class MetaConverter:
|
||||
# but clone is fine for now for sparse tensors.
|
||||
# (DelayedError does not work for sparse because it causes
|
||||
# the Fake sparse tensor to "lose" its fakeness)
|
||||
r = r.clone()
|
||||
r = self._checked_cast_tensor_t(r.clone())
|
||||
with torch.enable_grad():
|
||||
r._coalesced_(t.is_coalesced)
|
||||
r._coalesced_(bool(t.is_coalesced))
|
||||
elif is_sparse_compressed_layout(t.layout):
|
||||
is_leaf = t.is_leaf
|
||||
|
||||
@ -1154,15 +1265,15 @@ class MetaConverter:
|
||||
# Pray sparse clone doesn't lose information
|
||||
assert t.data is not None
|
||||
with torch.no_grad(), no_dispatch():
|
||||
assert isinstance(
|
||||
r, torch._subclasses.fake_tensor.FakeTensor
|
||||
)
|
||||
r.real_tensor = _safe_clone(t.data)
|
||||
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
|
||||
if t.requires_grad:
|
||||
r.requires_grad = True
|
||||
if t.requires_grad and not is_leaf:
|
||||
r = torch._C._functions.DelayedError(
|
||||
"Internal error: Tried to backward() through example input",
|
||||
1,
|
||||
)(r)
|
||||
r = self._backward_error(r)
|
||||
elif t.is_nested and not t.is_traceable_wrapper_subclass:
|
||||
# TODO: Handle this better in Dynamo?
|
||||
# There are checks there now, but this can still be triggered by a dense
|
||||
@ -1174,9 +1285,11 @@ class MetaConverter:
|
||||
)
|
||||
elif t.is_mkldnn:
|
||||
is_leaf = t.is_leaf
|
||||
sizes, strides, _storage_offset = sym_sizes_strides_storage_offset(
|
||||
t, source
|
||||
)
|
||||
(
|
||||
sizes,
|
||||
strides,
|
||||
_storage_offset,
|
||||
) = sym_sizes_strides_storage_offset(t, source)
|
||||
# TODO: This doesn't seem right, where's the MKLDNN'ness
|
||||
# lol
|
||||
r = callback(
|
||||
@ -1188,6 +1301,9 @@ class MetaConverter:
|
||||
with torch.no_grad(), no_dispatch():
|
||||
assert t.size is not None
|
||||
assert t.stride is not None
|
||||
assert isinstance(
|
||||
r, torch._subclasses.fake_tensor.FakeTensor
|
||||
)
|
||||
r.real_tensor = torch.empty_strided(
|
||||
t.size, t.stride, dtype=t.dtype, device=t.device
|
||||
)
|
||||
@ -1197,10 +1313,7 @@ class MetaConverter:
|
||||
if t.requires_grad:
|
||||
r.requires_grad = True
|
||||
if t.requires_grad and not is_leaf:
|
||||
r = torch._C._functions.DelayedError(
|
||||
"Internal error: Tried to backward() through example input",
|
||||
1,
|
||||
)(r)
|
||||
r = self._backward_error(r)
|
||||
elif t.is_functorch_wrapped:
|
||||
if t.is_view:
|
||||
from torch._dynamo.exc import unimplemented
|
||||
@ -1211,9 +1324,10 @@ class MetaConverter:
|
||||
|
||||
# Wraps a functorch tensor class (BatchedTensor, GradTrackingTensor)
|
||||
# in a FakeTensor
|
||||
def _to_fake_tensor(t: MetaTensorDesc):
|
||||
def _to_fake_tensor(t: MetaTensorDesc) -> _TensorT:
|
||||
# TODO: why aren't the recursive calls going to
|
||||
# meta_tensor
|
||||
r: _TensorT
|
||||
if t.is_batchedtensor:
|
||||
assert t.unwrapped is not None
|
||||
assert t.level is not None
|
||||
@ -1228,7 +1342,9 @@ class MetaConverter:
|
||||
with torch._functorch.pyfunctorch.temporarily_restore_interpreter_stack(
|
||||
t.functorch_stack
|
||||
):
|
||||
r = _add_batch_dim(ft, bdim, lvl)
|
||||
r = self._checked_cast_tensor_t(
|
||||
_add_batch_dim(ft, bdim, lvl)
|
||||
)
|
||||
elif t.is_gradtrackingtensor:
|
||||
assert t.unwrapped is not None
|
||||
assert t.level is not None
|
||||
@ -1242,33 +1358,32 @@ class MetaConverter:
|
||||
with torch._functorch.pyfunctorch.temporarily_restore_interpreter_stack(
|
||||
t.functorch_stack
|
||||
):
|
||||
r = torch._C._functorch._wrap_for_grad(ft, lvl)
|
||||
r = self._checked_cast_tensor_t(
|
||||
torch._C._functorch._wrap_for_grad(ft, lvl),
|
||||
)
|
||||
|
||||
is_leaf = t.is_leaf
|
||||
if t.requires_grad and safe_is_leaf(r):
|
||||
r.requires_grad = True
|
||||
elif t.requires_grad and not is_leaf:
|
||||
r = torch._C._functions.DelayedError( # type: ignore[assignment]
|
||||
"Internal error: Tried to backward() through example input",
|
||||
1,
|
||||
)(
|
||||
r # type: ignore[arg-type]
|
||||
)
|
||||
r = self._backward_error(r)
|
||||
elif t.is_functional:
|
||||
assert t.unwrapped is not None
|
||||
assert t.current_level is not None
|
||||
ft = self.meta_tensor(
|
||||
t.unwrapped,
|
||||
shape_env=shape_env,
|
||||
callback=callback,
|
||||
shape_env,
|
||||
callback,
|
||||
# NB: reuse these exactly, we treat the
|
||||
# functional tensor as "invisible".
|
||||
# TODO: Actually this all probably doesn't
|
||||
# work, take a closer look.
|
||||
source=source,
|
||||
symbolic_context=symbolic_context,
|
||||
source,
|
||||
symbolic_context,
|
||||
)
|
||||
r = self._checked_cast_tensor_t(
|
||||
_wrap_functional_tensor(ft, t.current_level),
|
||||
)
|
||||
r = _wrap_functional_tensor(ft, t.current_level)
|
||||
# TODO: is_leaf/requires_grad?
|
||||
else:
|
||||
assert t.stride is not None
|
||||
@ -1302,12 +1417,14 @@ class MetaConverter:
|
||||
assert not t.is_functorch_wrapped # handled above
|
||||
unwrapped = self.meta_tensor(
|
||||
t.unwrapped,
|
||||
shape_env=shape_env,
|
||||
callback=callback,
|
||||
source=source,
|
||||
symbolic_context=symbolic_context,
|
||||
shape_env,
|
||||
callback,
|
||||
source,
|
||||
symbolic_context,
|
||||
)
|
||||
r = self._checked_cast_tensor_t(
|
||||
torch._to_functional_tensor(unwrapped)
|
||||
)
|
||||
r = torch._to_functional_tensor(unwrapped)
|
||||
torch._mirror_autograd_meta_to(t.autograd_meta_from, r) # type: ignore[attr-defined]
|
||||
|
||||
elif t.is_view:
|
||||
@ -1335,11 +1452,13 @@ class MetaConverter:
|
||||
t.base,
|
||||
shape_env,
|
||||
callback,
|
||||
source=torch._dynamo.source.AttrSource(source, "_base"),
|
||||
symbolic_context=base_symbolic_context,
|
||||
torch._dynamo.source.AttrSource(source, "_base"),
|
||||
base_symbolic_context,
|
||||
)
|
||||
|
||||
def is_c_of_r(complex_dtype, real_dtype):
|
||||
def is_c_of_r(
|
||||
complex_dtype: torch.dtype, real_dtype: torch.dtype
|
||||
) -> bool:
|
||||
return (
|
||||
utils.is_complex_dtype(complex_dtype)
|
||||
and utils.corresponding_real_dtype(complex_dtype)
|
||||
@ -1361,14 +1480,16 @@ class MetaConverter:
|
||||
if base.dtype == t.dtype:
|
||||
pass
|
||||
elif is_c_of_r(base.dtype, t.dtype):
|
||||
base = torch.view_as_real(base)
|
||||
base = self._checked_cast_tensor_t(torch.view_as_real(base))
|
||||
elif is_c_of_r(t.dtype, base.dtype):
|
||||
base = torch.view_as_complex(base)
|
||||
base = self._checked_cast_tensor_t(
|
||||
torch.view_as_complex(base)
|
||||
)
|
||||
else:
|
||||
# This is not guaranteed to succeed. If it fails, it
|
||||
# means there is another dtype-converting view function
|
||||
# that hasn't been handled here
|
||||
base = base.view(t.dtype)
|
||||
base = self._checked_cast_tensor_t(base.view(t.dtype))
|
||||
|
||||
# This is very tricky. Naively, you might expect this
|
||||
# to hold:
|
||||
@ -1410,7 +1531,9 @@ class MetaConverter:
|
||||
# NB: Can't have a non-leaf without requiring grad!
|
||||
assert t.requires_grad
|
||||
with torch.no_grad():
|
||||
mid = base.view(base.shape)
|
||||
mid = self._checked_cast_tensor_t(
|
||||
base.view(base.shape)
|
||||
)
|
||||
mid.requires_grad = t.requires_grad
|
||||
with torch.enable_grad():
|
||||
r = view_from_base(mid, t)
|
||||
@ -1459,6 +1582,9 @@ class MetaConverter:
|
||||
with torch.no_grad(), no_dispatch():
|
||||
assert t.size is not None
|
||||
assert t.stride is not None
|
||||
assert isinstance(
|
||||
r, torch._subclasses.fake_tensor.FakeTensor
|
||||
)
|
||||
r.real_tensor = torch.empty_strided(
|
||||
t.size, t.stride, dtype=t.dtype, device=t.device
|
||||
)
|
||||
@ -1477,10 +1603,7 @@ class MetaConverter:
|
||||
# the metadata of the inner tensor.
|
||||
# So instead, we now have a dedicated fn to set autograd history,
|
||||
# without inadvertently changing other metadata.
|
||||
r = torch._C._functions.DelayedError(
|
||||
"Internal error: Tried to backward() through example input",
|
||||
1,
|
||||
)(r)
|
||||
r = self._backward_error(r)
|
||||
|
||||
s = t.storage
|
||||
assert s is not None
|
||||
@ -1494,8 +1617,12 @@ class MetaConverter:
|
||||
# You're normal and happy, install the fresh storage into the memo
|
||||
self.set_storage_memo(s, r.untyped_storage())
|
||||
if self.copy_data:
|
||||
r.untyped_storage().real_storage = (
|
||||
r.real_tensor.untyped_storage()
|
||||
assert isinstance(
|
||||
r, torch._subclasses.fake_tensor.FakeTensor
|
||||
)
|
||||
assert r.real_tensor is not None
|
||||
_set_real_storage(
|
||||
r.untyped_storage(), r.real_tensor.untyped_storage()
|
||||
)
|
||||
else:
|
||||
# You're in crazy town; somehow you gave us a tensor
|
||||
@ -1540,8 +1667,13 @@ class MetaConverter:
|
||||
r.set_(r_s, storage_offset, sizes, strides)
|
||||
if self.copy_data:
|
||||
with torch.no_grad(), no_dispatch():
|
||||
assert isinstance(
|
||||
r, torch._subclasses.fake_tensor.FakeTensor
|
||||
)
|
||||
assert r.real_tensor is not None
|
||||
assert t.stride is not None
|
||||
r.real_tensor.set_(
|
||||
r_s.real_storage,
|
||||
_get_real_storage(r_s),
|
||||
t.storage_offset,
|
||||
t.size,
|
||||
t.stride,
|
||||
@ -1556,8 +1688,8 @@ class MetaConverter:
|
||||
t.grad,
|
||||
shape_env,
|
||||
callback,
|
||||
source=AttrSource(source, "grad"),
|
||||
symbolic_context=symbolic_context,
|
||||
AttrSource(source, "grad"),
|
||||
symbolic_context,
|
||||
)
|
||||
torch._C._set_conj(r, t.is_conj)
|
||||
torch._C._set_neg(r, t.is_neg)
|
||||
@ -1577,27 +1709,33 @@ class MetaConverter:
|
||||
|
||||
# See Note: [Creating symbolic nested int]
|
||||
if t.nested_int is not None:
|
||||
assert isinstance(r, torch._subclasses.fake_tensor.FakeTensor)
|
||||
r.nested_int_memo = r.fake_mode.create_symbolic_nested_int(
|
||||
nt_tensor_id=t.nested_int
|
||||
)
|
||||
|
||||
self.set_tensor_memo(t, r)
|
||||
|
||||
return self.get_tensor_memo(t)
|
||||
return self._checked_get_tensor_memo(t)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
t,
|
||||
shape_env=None,
|
||||
t: torch.Tensor,
|
||||
shape_env: Optional[ShapeEnv] = None,
|
||||
*,
|
||||
callback=lambda t: t(),
|
||||
source=None,
|
||||
symbolic_context=None,
|
||||
callback: Optional[Callable[[Callable[[], torch.Tensor]], _TensorT]] = None,
|
||||
source: Optional[Source] = None,
|
||||
symbolic_context: Optional[SymbolicContext] = None,
|
||||
# Controls whether or not we should dump the tensor metadata to structured logs
|
||||
# when source is not None. Because we refakify after Dynamo is done,
|
||||
# we don't want to dump info again from AOTAutograd, it is redundant.
|
||||
trace=True,
|
||||
):
|
||||
trace: bool = True,
|
||||
) -> _TensorT:
|
||||
callback_: Callable[[Callable[[], torch.Tensor]], _TensorT]
|
||||
if callback is None:
|
||||
callback_ = self._identity_callable
|
||||
else:
|
||||
callback_ = callback
|
||||
# TODO: zero tensors? We appear to have eliminated them by
|
||||
# excluding complex for now
|
||||
|
||||
@ -1637,6 +1775,7 @@ class MetaConverter:
|
||||
t_desc = self.describer.describe_tensor(t, trace=trace)
|
||||
|
||||
if trace:
|
||||
assert source is not None
|
||||
trace_structured(
|
||||
"describe_source",
|
||||
metadata_fn=lambda: {
|
||||
@ -1659,10 +1798,10 @@ class MetaConverter:
|
||||
|
||||
r = self.meta_tensor(
|
||||
t_desc,
|
||||
shape_env=shape_env,
|
||||
callback=callback,
|
||||
source=source,
|
||||
symbolic_context=symbolic_context,
|
||||
shape_env,
|
||||
callback_,
|
||||
source,
|
||||
symbolic_context,
|
||||
)
|
||||
|
||||
if type(t) is torch.nn.Parameter:
|
||||
|
Reference in New Issue
Block a user