typing fake_tensor.py (#128041)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128041
Approved by: https://github.com/eellison
ghstack dependencies: #129182
This commit is contained in:
Aaron Orenstein
2024-07-12 08:19:16 -07:00
committed by PyTorch MergeBot
parent 1ad0f38a37
commit 567482973d
14 changed files with 397 additions and 223 deletions

View File

@ -1204,7 +1204,7 @@ def gen_pyi(
],
"set_": [
"def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage], "
"offset: _int, size: _size, stride: _size) -> Tensor: ...",
"offset: _int, size: _symsize, stride: _symsize) -> Tensor: ...",
"def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage]) -> Tensor: ...",
],
"split": [

View File

@ -56,6 +56,7 @@ from torch.types import (
_qscheme,
_size,
_str,
_symsize,
)
from torch.utils._python_dispatch import TorchDispatchMode
@ -1661,6 +1662,18 @@ class _SetExcludeDispatchKeyGuard:
def __enter__(self): ...
def __exit__(self, exc_type, exc_value, traceback): ...
# Defined in torch/csrc/utils/schema_info.h
class _SchemaInfo:
def __init__(self, schema: _int) -> None: ...
@overload
def is_mutable(self) -> _bool: ...
@overload
def is_mutable(self, name: str) -> _bool: ...
def has_argument(self, name: str) -> _bool: ...
# Defined in torch/csrc/utils/init.cpp
class BenchmarkConfig:
num_calling_threads: _int

View File

@ -36,6 +36,9 @@ from typing import (
)
from typing_extensions import ParamSpec as _ParamSpec, TypeGuard as _TypeGuard
if TYPE_CHECKING:
from .types import IntLikeType
# multipy/deploy is setting this import before importing torch, this is the most
# reliable way we have to detect if we're running within deploy.
@ -471,6 +474,9 @@ class SymInt:
def __add__(self, other) -> "SymInt":
raise TypeError("type stub not overridden")
def __mod__(self, other: "IntLikeType") -> "SymInt":
raise TypeError("type stub not overridden")
def __mul__(self, other) -> "SymInt":
raise TypeError("type stub not overridden")
@ -504,6 +510,9 @@ class SymInt:
def __neg__(self):
raise TypeError("type stub not overridden")
def __sub__(self, other: "IntLikeType") -> "SymInt":
raise TypeError("type stub not overridden")
def __repr__(self):
return self.node._graph_repr()

View File

@ -165,6 +165,7 @@ def _apply_func_to_inner_tensors_of_same_dim(func, t, *args, **kwargs):
assert is_traceable_wrapper_subclass(t)
attrs, ctx = t.__tensor_flatten__()
assert isinstance(t, torch.Tensor)
for attr in attrs:
inner = getattr(t, attr)
if inner.dim() == t.dim():

View File

@ -83,6 +83,7 @@ def fakify(
constraint_sizes=[None] * n_dims,
)
t_id = id(t)
assert mode.shape_env is not None
if t_id in t_constraints:
for i, constraint in t_constraints[t_id].items():
symbolic_context.constraint_sizes[i] = constraint.constraint_range
@ -256,6 +257,7 @@ def produce_guards_and_solve_constraints(
_disable_forced_specializations: if True, avoids forced specializations
"""
shape_env = fake_mode.shape_env
assert shape_env is not None
assert shape_env.tracked_fakes is not None
placeholders = [tf.fake for tf in shape_env.tracked_fakes]
@ -322,6 +324,7 @@ def make_constraints(
"""
shape_env = fake_mode.shape_env
assert shape_env is not None
inline_constraints = gm.meta.get("inline_constraints", [])
range_constraints = {
symbol: inline_constraints[symbol] for symbol in inline_constraints

View File

@ -12,7 +12,7 @@ import pprint
from contextlib import nullcontext
from dataclasses import dataclass, field
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
import torch
import torch.utils.dlpack
@ -1450,7 +1450,7 @@ Expected metadata: {str(expected_tangent_metadata)}
Runtime metadata: {str(runtime_tangent_metadata)}
shape: {str(x.shape)}
shape: {str(cast(torch.Tensor, x).shape)}
To fix this, your tensor subclass must implement the dunder method __force_to_same_metadata__.
"""
)
@ -1830,14 +1830,16 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
)
assert CompiledFunction.metadata.traced_tangent_metas is not None
all_args = [
AOTDispatchAutograd.coerce_runtime_tangent(
t,
CompiledFunction.metadata.traced_tangent_metas[
i - tangents_start_idx
],
(
AOTDispatchAutograd.coerce_runtime_tangent(
t,
CompiledFunction.metadata.traced_tangent_metas[
i - tangents_start_idx
],
)
if tangents_start_idx <= i < tangents_end_idx
else t
)
if tangents_start_idx <= i < tangents_end_idx
else t
for i, t in enumerate(all_args)
]
all_args = unwrap_tensor_subclasses(
@ -1849,9 +1851,11 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
# Make the tangents contiguous. Note that we must do this after subclass desugaring
# because inputs to inductor have to be contiguous
all_args = [
AOTDispatchAutograd._force_contiguous(t)
if (tangents_start_idx <= i < tangents_end_idx)
else t
(
AOTDispatchAutograd._force_contiguous(t)
if (tangents_start_idx <= i < tangents_end_idx)
else t
)
for i, t in enumerate(all_args)
]

View File

@ -5,6 +5,7 @@ AOTAutograd's responsibility is to trace through all pytorch capabilities that l
and this includes tensor subclasses that implement __torch_dispatch__.
"""
import typing
from typing import Any, List, Optional, Tuple, Union
import torch.utils._pytree as pytree
@ -115,7 +116,7 @@ def unwrap_tensor_subclasses(wrapped_args, *, is_joint_structure: bool):
xs_inner = []
for x in xs:
if is_traceable_wrapper_subclass(x):
xs_inner.extend(get_plain_tensors(x))
xs_inner.extend(get_plain_tensors(typing.cast(Tensor, x)))
else:
xs_inner.append(x)
return xs_inner

View File

@ -16,7 +16,7 @@ from typing import TYPE_CHECKING
functionalize_rng_ops = False
# can be useful for debugging if we are incorrectly creating meta fake tensors
fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", True)
fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", "1") != "0"
# Enables optional asserts in hotpath code to check for errors. If
# you are seeing weird accuracy problems, try turning this on.
@ -24,7 +24,7 @@ fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", True)
# but it is on by default for aot_eager.
debug_assert = False
debug_partitioner = os.environ.get("AOT_PARTITIONER_DEBUG", False)
debug_partitioner = os.environ.get("AOT_PARTITIONER_DEBUG", "0") != "0"
# Today, if you are in a situation where there is "false aliasing"
# (e.g. you have a bunch of model parameters that all alias the same underlying buffer),

File diff suppressed because it is too large Load Diff

View File

@ -140,7 +140,8 @@ def _move_states_to_device(
raise AssertionError(
f"Expects DTensor to be moved to {dtensor_mesh_type} but got {tensor.device}"
)
if is_traceable_wrapper_subclass(tensor):
tensor_ = tensor
if is_traceable_wrapper_subclass(tensor_):
with torch.no_grad(): # avoid autograd increasing C++ refcount by 1
tensor_on_device = nn.Parameter(tensor.to(device))
torch.utils.swap_tensors(tensor, tensor_on_device)

View File

@ -1700,6 +1700,7 @@ def _export_for_training(
# The unbacked symint symbols are updated in aot_export
# so we serialize them here instead of inside dynamo.
assert fake_mode.shape_env is not None
gm.meta["inline_constraints"] = {
k: v
for k, v in fake_mode.shape_env.var_to_range.items()
@ -1884,6 +1885,7 @@ def _export(
# The unbacked symint symbols are updated in aot_export
# so we serialize them here instead of inside dynamo.
assert fake_mode.shape_env is not None
gm.meta["inline_constraints"] = {
k: v
for k, v in fake_mode.shape_env.var_to_range.items()

View File

@ -1649,6 +1649,7 @@ class _MakefxTracer:
return self.fake_tensor_mode.from_tensor(x, source=source)
# NB: don't match on bools
elif type(x) is int and self.tracing_mode == "symbolic":
assert self.fake_tensor_mode.shape_env is not None, "shape_env should be set if tracing with 'symbolic'"
return self.fake_tensor_mode.shape_env.create_symintnode(
self.fake_tensor_mode.shape_env.create_symbol(x, source, positive=None),
hint=x,

View File

@ -17,6 +17,7 @@ from builtins import ( # noqa: F401
from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
import torch
from torch import SymInt
if TYPE_CHECKING:
@ -40,6 +41,7 @@ _device = torch.device
_qscheme = torch.qscheme
_layout = torch.layout
_size = Union[torch.Size, List[builtins.int], Tuple[builtins.int, ...]]
_symsize = Union[torch.Size, Sequence[Union[_int, SymInt]]]
_dispatchkey = Union[builtins.str, torch._C.DispatchKey]
# int or SymInt

View File

@ -3,7 +3,7 @@ import contextlib
import warnings
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Union, Protocol, Sequence, Tuple, overload
from typing import Any, Dict, List, Optional, Set, Union, Protocol, Tuple, Sequence, overload
from typing_extensions import TypeGuard
import torch