mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
typing fake_tensor.py (#128041)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128041 Approved by: https://github.com/eellison ghstack dependencies: #129182
This commit is contained in:
committed by
PyTorch MergeBot
parent
1ad0f38a37
commit
567482973d
@ -1204,7 +1204,7 @@ def gen_pyi(
|
||||
],
|
||||
"set_": [
|
||||
"def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage], "
|
||||
"offset: _int, size: _size, stride: _size) -> Tensor: ...",
|
||||
"offset: _int, size: _symsize, stride: _symsize) -> Tensor: ...",
|
||||
"def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage]) -> Tensor: ...",
|
||||
],
|
||||
"split": [
|
||||
|
@ -56,6 +56,7 @@ from torch.types import (
|
||||
_qscheme,
|
||||
_size,
|
||||
_str,
|
||||
_symsize,
|
||||
)
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
@ -1661,6 +1662,18 @@ class _SetExcludeDispatchKeyGuard:
|
||||
def __enter__(self): ...
|
||||
def __exit__(self, exc_type, exc_value, traceback): ...
|
||||
|
||||
# Defined in torch/csrc/utils/schema_info.h
|
||||
|
||||
class _SchemaInfo:
|
||||
def __init__(self, schema: _int) -> None: ...
|
||||
|
||||
@overload
|
||||
def is_mutable(self) -> _bool: ...
|
||||
@overload
|
||||
def is_mutable(self, name: str) -> _bool: ...
|
||||
|
||||
def has_argument(self, name: str) -> _bool: ...
|
||||
|
||||
# Defined in torch/csrc/utils/init.cpp
|
||||
class BenchmarkConfig:
|
||||
num_calling_threads: _int
|
||||
|
@ -36,6 +36,9 @@ from typing import (
|
||||
)
|
||||
from typing_extensions import ParamSpec as _ParamSpec, TypeGuard as _TypeGuard
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .types import IntLikeType
|
||||
|
||||
|
||||
# multipy/deploy is setting this import before importing torch, this is the most
|
||||
# reliable way we have to detect if we're running within deploy.
|
||||
@ -471,6 +474,9 @@ class SymInt:
|
||||
def __add__(self, other) -> "SymInt":
|
||||
raise TypeError("type stub not overridden")
|
||||
|
||||
def __mod__(self, other: "IntLikeType") -> "SymInt":
|
||||
raise TypeError("type stub not overridden")
|
||||
|
||||
def __mul__(self, other) -> "SymInt":
|
||||
raise TypeError("type stub not overridden")
|
||||
|
||||
@ -504,6 +510,9 @@ class SymInt:
|
||||
def __neg__(self):
|
||||
raise TypeError("type stub not overridden")
|
||||
|
||||
def __sub__(self, other: "IntLikeType") -> "SymInt":
|
||||
raise TypeError("type stub not overridden")
|
||||
|
||||
def __repr__(self):
|
||||
return self.node._graph_repr()
|
||||
|
||||
|
@ -165,6 +165,7 @@ def _apply_func_to_inner_tensors_of_same_dim(func, t, *args, **kwargs):
|
||||
assert is_traceable_wrapper_subclass(t)
|
||||
|
||||
attrs, ctx = t.__tensor_flatten__()
|
||||
assert isinstance(t, torch.Tensor)
|
||||
for attr in attrs:
|
||||
inner = getattr(t, attr)
|
||||
if inner.dim() == t.dim():
|
||||
|
@ -83,6 +83,7 @@ def fakify(
|
||||
constraint_sizes=[None] * n_dims,
|
||||
)
|
||||
t_id = id(t)
|
||||
assert mode.shape_env is not None
|
||||
if t_id in t_constraints:
|
||||
for i, constraint in t_constraints[t_id].items():
|
||||
symbolic_context.constraint_sizes[i] = constraint.constraint_range
|
||||
@ -256,6 +257,7 @@ def produce_guards_and_solve_constraints(
|
||||
_disable_forced_specializations: if True, avoids forced specializations
|
||||
"""
|
||||
shape_env = fake_mode.shape_env
|
||||
assert shape_env is not None
|
||||
assert shape_env.tracked_fakes is not None
|
||||
|
||||
placeholders = [tf.fake for tf in shape_env.tracked_fakes]
|
||||
@ -322,6 +324,7 @@ def make_constraints(
|
||||
"""
|
||||
|
||||
shape_env = fake_mode.shape_env
|
||||
assert shape_env is not None
|
||||
inline_constraints = gm.meta.get("inline_constraints", [])
|
||||
range_constraints = {
|
||||
symbol: inline_constraints[symbol] for symbol in inline_constraints
|
||||
|
@ -12,7 +12,7 @@ import pprint
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass, field
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.dlpack
|
||||
@ -1450,7 +1450,7 @@ Expected metadata: {str(expected_tangent_metadata)}
|
||||
|
||||
Runtime metadata: {str(runtime_tangent_metadata)}
|
||||
|
||||
shape: {str(x.shape)}
|
||||
shape: {str(cast(torch.Tensor, x).shape)}
|
||||
To fix this, your tensor subclass must implement the dunder method __force_to_same_metadata__.
|
||||
"""
|
||||
)
|
||||
@ -1830,6 +1830,7 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
||||
)
|
||||
assert CompiledFunction.metadata.traced_tangent_metas is not None
|
||||
all_args = [
|
||||
(
|
||||
AOTDispatchAutograd.coerce_runtime_tangent(
|
||||
t,
|
||||
CompiledFunction.metadata.traced_tangent_metas[
|
||||
@ -1838,6 +1839,7 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
||||
)
|
||||
if tangents_start_idx <= i < tangents_end_idx
|
||||
else t
|
||||
)
|
||||
for i, t in enumerate(all_args)
|
||||
]
|
||||
all_args = unwrap_tensor_subclasses(
|
||||
@ -1849,9 +1851,11 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
||||
# Make the tangents contiguous. Note that we must do this after subclass desugaring
|
||||
# because inputs to inductor have to be contiguous
|
||||
all_args = [
|
||||
(
|
||||
AOTDispatchAutograd._force_contiguous(t)
|
||||
if (tangents_start_idx <= i < tangents_end_idx)
|
||||
else t
|
||||
)
|
||||
for i, t in enumerate(all_args)
|
||||
]
|
||||
|
||||
|
@ -5,6 +5,7 @@ AOTAutograd's responsibility is to trace through all pytorch capabilities that l
|
||||
and this includes tensor subclasses that implement __torch_dispatch__.
|
||||
"""
|
||||
|
||||
import typing
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch.utils._pytree as pytree
|
||||
@ -115,7 +116,7 @@ def unwrap_tensor_subclasses(wrapped_args, *, is_joint_structure: bool):
|
||||
xs_inner = []
|
||||
for x in xs:
|
||||
if is_traceable_wrapper_subclass(x):
|
||||
xs_inner.extend(get_plain_tensors(x))
|
||||
xs_inner.extend(get_plain_tensors(typing.cast(Tensor, x)))
|
||||
else:
|
||||
xs_inner.append(x)
|
||||
return xs_inner
|
||||
|
@ -16,7 +16,7 @@ from typing import TYPE_CHECKING
|
||||
functionalize_rng_ops = False
|
||||
|
||||
# can be useful for debugging if we are incorrectly creating meta fake tensors
|
||||
fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", True)
|
||||
fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", "1") != "0"
|
||||
|
||||
# Enables optional asserts in hotpath code to check for errors. If
|
||||
# you are seeing weird accuracy problems, try turning this on.
|
||||
@ -24,7 +24,7 @@ fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", True)
|
||||
# but it is on by default for aot_eager.
|
||||
debug_assert = False
|
||||
|
||||
debug_partitioner = os.environ.get("AOT_PARTITIONER_DEBUG", False)
|
||||
debug_partitioner = os.environ.get("AOT_PARTITIONER_DEBUG", "0") != "0"
|
||||
|
||||
# Today, if you are in a situation where there is "false aliasing"
|
||||
# (e.g. you have a bunch of model parameters that all alias the same underlying buffer),
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -140,7 +140,8 @@ def _move_states_to_device(
|
||||
raise AssertionError(
|
||||
f"Expects DTensor to be moved to {dtensor_mesh_type} but got {tensor.device}"
|
||||
)
|
||||
if is_traceable_wrapper_subclass(tensor):
|
||||
tensor_ = tensor
|
||||
if is_traceable_wrapper_subclass(tensor_):
|
||||
with torch.no_grad(): # avoid autograd increasing C++ refcount by 1
|
||||
tensor_on_device = nn.Parameter(tensor.to(device))
|
||||
torch.utils.swap_tensors(tensor, tensor_on_device)
|
||||
|
@ -1700,6 +1700,7 @@ def _export_for_training(
|
||||
|
||||
# The unbacked symint symbols are updated in aot_export
|
||||
# so we serialize them here instead of inside dynamo.
|
||||
assert fake_mode.shape_env is not None
|
||||
gm.meta["inline_constraints"] = {
|
||||
k: v
|
||||
for k, v in fake_mode.shape_env.var_to_range.items()
|
||||
@ -1884,6 +1885,7 @@ def _export(
|
||||
|
||||
# The unbacked symint symbols are updated in aot_export
|
||||
# so we serialize them here instead of inside dynamo.
|
||||
assert fake_mode.shape_env is not None
|
||||
gm.meta["inline_constraints"] = {
|
||||
k: v
|
||||
for k, v in fake_mode.shape_env.var_to_range.items()
|
||||
|
@ -1649,6 +1649,7 @@ class _MakefxTracer:
|
||||
return self.fake_tensor_mode.from_tensor(x, source=source)
|
||||
# NB: don't match on bools
|
||||
elif type(x) is int and self.tracing_mode == "symbolic":
|
||||
assert self.fake_tensor_mode.shape_env is not None, "shape_env should be set if tracing with 'symbolic'"
|
||||
return self.fake_tensor_mode.shape_env.create_symintnode(
|
||||
self.fake_tensor_mode.shape_env.create_symbol(x, source, positive=None),
|
||||
hint=x,
|
||||
|
@ -17,6 +17,7 @@ from builtins import ( # noqa: F401
|
||||
from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch import SymInt
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -40,6 +41,7 @@ _device = torch.device
|
||||
_qscheme = torch.qscheme
|
||||
_layout = torch.layout
|
||||
_size = Union[torch.Size, List[builtins.int], Tuple[builtins.int, ...]]
|
||||
_symsize = Union[torch.Size, Sequence[Union[_int, SymInt]]]
|
||||
_dispatchkey = Union[builtins.str, torch._C.DispatchKey]
|
||||
|
||||
# int or SymInt
|
||||
|
@ -3,7 +3,7 @@ import contextlib
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Set, Union, Protocol, Sequence, Tuple, overload
|
||||
from typing import Any, Dict, List, Optional, Set, Union, Protocol, Tuple, Sequence, overload
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
import torch
|
||||
|
Reference in New Issue
Block a user