FakeTensor cache SymInt support (#127596)

Adds support for SymInts in the FakeTensor cache.

A couple notes:
1. When a SymInt is present in the input key for a FakeTensor operation we cache on the ShapeEnv instead of using the FakeTensorMode cache. This is necessary so we don't have to remember and check the guards. It reduces the cache hits but there's diminishing return on how much work we can do before the cache becomes more of a burden than a gain.
2. We need to be careful that when we cache an output SymInt that is a direct copy from the input that when we have a cache-hit we copy the SymNode from the input to the output. This is important because the fx-graph building code actually uses SymNode ids in the process of building the graph so constructing a same-content-but-different-id SymNode will fail.
3. In the cache key we store SymInts as a _PySymInputStub. These represent SymInt (and friends) but support `__hash__` and `__eq__` (which SymInt do not).
4. In the cache entry we store SymInts as a _SymIntOutputStub.

Perf example:
```
python benchmarks/dynamo/timm_models.py --ci --accuracy --timing
--explain --inductor --dynamic-shapes --dynamic-batch-only --device cuda
--training --amp --total-partitions 2 --partition-id 0 --output
/tmp/training_timm_models.csv --filter crossvit_9_240
```
fake tensor cache before:
```
INFO: FakeTensor cache stats:
INFO:   cache_hits: 68137
INFO:   cache_misses: 837
INFO:   cache_bypasses:
INFO:     symbolic shape:            48224
INFO:     CompositeImplicitAutograd: 917
INFO:     non-fake tensor:           70
INFO:     non-FakeTensor output:     62
INFO:     non-builtin:               8
INFO:     dynamic output shape:      1
```
and after:
```
INFO: FakeTensor cache stats:
INFO:   cache_hits: 88187
INFO:   cache_misses: 14233
INFO:   cache_bypasses:
INFO:     CompositeImplicitAutograd: 1037
INFO:     non-FakeTensor output:     602
INFO:     non-fake tensor:           70
INFO:     unsafe view:               36
INFO:     non-builtin:               8
INFO:     dynamic output shape:      1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127596
Approved by: https://github.com/eellison
ghstack dependencies: #131014, #129780
This commit is contained in:
Aaron Orenstein
2024-07-18 11:29:14 -07:00
committed by PyTorch MergeBot
parent ebce85172e
commit b193894b94
14 changed files with 484 additions and 102 deletions

View File

@ -9856,7 +9856,7 @@ ShapeEnv not equal: field values don't match:
if not forward_deterministic and backward_deterministic:
with self.assertRaisesRegex(
RuntimeError,
"^This compiled backward function is being run with torch\.use_deterministic_algorithms",
r"^This compiled backward function is being run with torch\.use_deterministic_algorithms",
):
res.backward(grad)

View File

@ -31,6 +31,7 @@ from torch._subclasses.fake_tensor import (
FakeTensorMode,
unset_fake_temporarily,
UnsupportedOperatorException,
_CacheKeyState
)
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import (
@ -1611,9 +1612,10 @@ class FakeTensorDispatchCache(TestCase):
cache keys for inputs x and y are the same, but z is different.
"""
func = aten.add.Tensor
key_x = fm._cache_key(func, [x], {})
key_y = fm._cache_key(func, [y], {})
key_z = fm._cache_key(func, [z], {})
state = _CacheKeyState()
key_x = fm._cache_key(state, func, [x], {})
key_y = fm._cache_key(state, func, [y], {})
key_z = fm._cache_key(state, func, [z], {})
self.assertEqual(key_x, key_y)
self.assertNotEqual(key_x, key_z)

View File

@ -1186,7 +1186,7 @@ def gen_pyi(
"is_mkldnn": ["is_mkldnn: _bool"],
"is_vulkan": ["is_vulkan: _bool"],
"is_ipu": ["is_ipu: _bool"],
"storage_offset": ["def storage_offset(self) -> _int: ..."],
"storage_offset": ["def storage_offset(self) -> Union[_int, SymInt]: ..."],
"to": [
(
f"def to(self, {args}, non_blocking: _bool = False, copy: _bool = False, *, "
@ -1204,7 +1204,7 @@ def gen_pyi(
],
"set_": [
"def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage], "
"offset: _int, size: _symsize, stride: _symsize) -> Tensor: ...",
"offset: IntLikeType, size: _symsize, stride: _symsize) -> Tensor: ...",
"def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage]) -> Tensor: ...",
],
"split": [

View File

@ -44,6 +44,7 @@ from torch.types import (
Device,
Number,
Storage,
IntLikeType,
_bool,
_bytes,
_complex,

View File

@ -3,7 +3,8 @@ from dataclasses import dataclass
from typing import Union
import torch
from torch.fx.experimental.proxy_tensor import py_sym_types, SymBool, SymFloat, SymInt
from torch import SymBool, SymFloat, SymInt
from torch.types import py_sym_types
@dataclass

View File

@ -13,6 +13,7 @@ DISTRIBUTED = [
]
register_log("dynamo", ["torch._dynamo", *DYNAMIC])
register_log("fake_tensor", ["torch._subclasses.fake_tensor"])
register_log("aot", ["torch._functorch.aot_autograd", "torch._functorch._aot_autograd"])
register_log("autograd", "torch.autograd")
register_log("inductor", ["torch._inductor", "torch._inductor.cudagraph_trees"])

View File

@ -1835,7 +1835,7 @@ def are_strides_like_channels_last(
for d in dim_order:
if guard_size_oblivious(shape[d] == 0):
return False
if strides[d] < min:
if guard_size_oblivious(strides[d] < min):
return False
if d == 0 and min == strides[1]:
return False

View File

@ -0,0 +1,257 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Optional, Type, TYPE_CHECKING, Union
import torch
from torch import SymInt
from torch.fx.experimental.sym_node import SymNode
from torch.types import py_sym_types, PySymType
from torch.utils._backport_slots import dataclass_slots
if TYPE_CHECKING:
import sympy
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from .fake_tensor import _DispatchCacheKey, _MetadataIntLike
@dataclass_slots
@dataclass(frozen=True)
class _DeconstructedSymNode:
"""
Represents a SymNode without the associated ShapeEnv
"""
# n.b. keep the same protocol as SymNode
_expr: sympy.Expr
pytype: type
_hint: Optional[Union[int, float, bool]]
constant: Optional[Union[int, float, bool]]
fx_node: torch.fx.Node
@staticmethod
def from_node(node: SymNode) -> _DeconstructedSymNode:
return _DeconstructedSymNode(
node._expr, node.pytype, node._hint, node.constant, node.fx_node
)
def extract(self, shape_env: ShapeEnv) -> SymNode:
return SymNode(
self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node
)
def __str__(self) -> str:
return str(self._expr)
def __repr__(self) -> str:
return f"_DeconstructedSymNode{{{self._expr!r}, {self.pytype!r}, {self._hint!r}, {self.constant!r}, {self.fx_node!r}}}"
def __eq__(self, other: object) -> bool:
raise NotImplementedError
def __hash__(self) -> int:
raise NotImplementedError
# _value_eq to match SymNode
def _value_eq(self, other: object) -> bool:
if isinstance(other, (SymNode, _DeconstructedSymNode)):
return (
self._expr == other._expr
and self.pytype == other.pytype
and self._hint == other._hint
and self.constant == other.constant
and self.fx_node == other.fx_node
)
else:
return False
# _value_hash to match SymNode
def _value_hash(self) -> int:
return hash((self._expr, self.pytype, self._hint, self.constant, self.fx_node))
@dataclass_slots
@dataclass(frozen=True)
class _DeconstructedSymType:
"""
Represents a SymInt, SymFloat, SymBool without the associated ShapeEnv
"""
ty: Type[PySymType]
node: _DeconstructedSymNode
@staticmethod
def from_sym_type(value: PySymType) -> _DeconstructedSymType:
return _DeconstructedSymType(type(value), value.node)
def extract(self, shape_env: ShapeEnv) -> PySymType:
return self.ty(self.node.extract(shape_env))
def __str__(self) -> str:
return f"{self.ty}({self.node})"
def __repr__(self) -> str:
return f"_DeconstructedSymType({self.ty}, {self.node!r})"
def __eq__(self, other: object) -> bool:
return NotImplemented
def __hash__(self) -> int:
return NotImplemented
@dataclass_slots
@dataclass(frozen=True)
class _InputBackref:
value: int
@dataclass_slots
@dataclass
class _PySymInputStub:
"""
Represents a SymInt in the cached key. Needed because SymInt doesn't
support __eq__ or __hash__ directly.
"""
# value can be:
# PySymType: This is the 'normal' SymInt value, wrapped so we can use
# hash/eq as value hash/eq (normally SymInt does object
# hash/eq).
# _DeconstructedSymType: This is used when storing the _PySymInputStub in
# the cache to avoid cyclic ShapeEnv references.
# _InputBackref: This is a back-reference to a previous _PySymInputStub in
# the key.
value: Union[PySymType, _DeconstructedSymType, _InputBackref]
def __init__(
self, value: Union[PySymType, _DeconstructedSymType, _InputBackref]
) -> None:
# For inputs (values in the `key`) we need to keep the PySymType intact
# - this way if we need to reuse it as an output we can properly copy
# the original value.
self.value = value
def strip_shape_env(self) -> None:
if isinstance(self.value, py_sym_types):
self.value = _DeconstructedSymType.from_sym_type(self.value)
def extract(self, shape_env: ShapeEnv) -> PySymType:
if isinstance(self.value, _DeconstructedSymType):
return self.value.extract(shape_env)
else:
# We should never see an _InputBackref here - anyone extracting a
# value should be pulling from the original entry (the one this
# backref points at).
assert not isinstance(self.value, _InputBackref)
return self.value
def __str__(self) -> str:
return str(self.value)
def __repr__(self) -> str:
return f"_PySymInputStub({self.value!r})"
def __eq__(self, other: object) -> bool:
if not isinstance(other, _PySymInputStub):
return False
elif isinstance(self.value, _InputBackref) or isinstance(
other.value, _InputBackref
):
return self.value == other.value
else:
return self.value.node._value_eq(other.value.node)
def __hash__(self) -> int:
if isinstance(self.value, _InputBackref):
return hash(self.value)
else:
return self.value.node._value_hash()
@dataclass_slots
@dataclass
class _SymIntOutputStub:
"""
Represents a SymInt in the cached output.
"""
# This is either an `int` which represents the index in the key to copy the
# SymNode from or it's the deconstructed SymNode itself.
value: Union[int, _DeconstructedSymNode]
def __init__(self, value: SymInt, key_path: Optional[int]) -> None:
if key_path is None:
self.value = _DeconstructedSymNode.from_node(value.node)
else:
self.value = key_path
def extract(self, key: _DispatchCacheKey, shape_env: ShapeEnv) -> SymInt:
if isinstance(self.value, _DeconstructedSymNode):
return SymInt(self.value.extract(shape_env))
else:
src = key.key[self.value]
assert isinstance(src, _PySymInputStub) and isinstance(src.value, SymInt)
return src.value
def __repr__(self) -> str:
return f"_SymIntOutputStub({self.value!r})"
def __eq__(self, other: object) -> bool:
raise NotImplementedError
def __hash__(self) -> int:
raise NotImplementedError
@dataclass_slots
@dataclass
class _CacheKeyState:
"""
State used while building our cache key.
"""
# We track the SymNodes so when we get the output we can see if it exactly
# matches one of the inputs so we can uncache it properly.
sym_node_lookup: Dict[int, int] # id(SymNode) -> index
# There are cases where we're asked to perform an op when we have no
# ShapeEnv on the FakeTensorMode - but for SymNodes we MUST have a
# ShapeEnv. So as we scan if we see a SymNode (with a ShapeEnv) we record it
# here.
shape_env: Optional[ShapeEnv]
def __init__(self, shape_env: Optional[ShapeEnv] = None) -> None:
self.sym_node_lookup = {}
self.shape_env = shape_env
def cache_on_shape_env(self) -> bool:
"""
Returns true if the CacheKey needs to be cached on the ShapeEnv
rather than the global cache.
If our inputs contain a SymNode then we can't cache this operation on
the global cache because the cached output will implicitly depend on
guard values which might not be true on some other ShapeEnv. So unless
we're also going to cache the guards we need to cache this operation on
the ShapeEnv instead of globally.
"""
return bool(self.sym_node_lookup)
def convert_sym_int(self, result: List[object], arg: SymInt) -> None:
node_id = id(arg.node)
if node_id in self.sym_node_lookup:
result.append(_InputBackref(self.sym_node_lookup[node_id]))
else:
self.sym_node_lookup[node_id] = len(result)
if self.shape_env is None:
self.shape_env = arg.node.shape_env
result.append(_PySymInputStub(arg))
def convert_output(self, arg: _MetadataIntLike) -> _MetadataIntLike:
if isinstance(arg, SymInt):
return _SymIntOutputStub(arg, self.sym_node_lookup.get(id(arg.node), None))
else:
return arg

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import atexit
import contextlib
import dataclasses
import functools
@ -34,7 +35,6 @@ from typing_extensions import Self, TypeGuard
from weakref import ReferenceType
import torch
import torch._custom_op
from torch import SymBool, SymFloat, SymInt, Tensor
from torch._C._functorch import is_functorch_wrapped_tensor, is_legacy_batchedtensor
@ -51,6 +51,7 @@ from torch.fx.immutable_collections import immutable_dict
from torch.fx.operator_schemas import normalize_function
from torch.multiprocessing.reductions import StorageWeakRef
from torch.overrides import TorchFunctionMode
from torch.types import IntLikeType, py_sym_types
from torch.utils._backport_slots import dataclass_slots
from torch.utils._mode_utils import no_dispatch
from torch.utils._python_dispatch import (
@ -60,6 +61,7 @@ from torch.utils._python_dispatch import (
from torch.utils._pytree import PyTree, tree_map, tree_map_, TreeSpec
from torch.utils._stats import count
from torch.utils._traceback import CapturedTraceback
from ._fake_tensor_utils import _CacheKeyState, _PySymInputStub, _SymIntOutputStub
if TYPE_CHECKING:
from types import TracebackType
@ -67,7 +69,6 @@ if TYPE_CHECKING:
from torch._guards import Source
from torch._ops import OpOverload
from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymbolicContext
from torch.types import IntLikeType
log = logging.getLogger(__name__)
@ -880,21 +881,24 @@ class FakeTensor(Tensor):
return out
_MetadataIntLike = Union[IntLikeType, "_PySymInputStub", "_SymIntOutputStub"]
@dataclass_slots
@dataclass(frozen=True)
@dataclass
class TensorMetadata:
"""
The Tensor metadata relevant to hashing FakeTensors when caching.
"""
dtype: torch.dtype
shape: torch.Size
stride: Tuple[IntLikeType, ...]
shape: Tuple[_MetadataIntLike, ...]
stride: Tuple[_MetadataIntLike, ...]
device: torch.device
layout: torch.layout
memory_format: Optional[torch.memory_format]
storage_offset: int
storage_bytes: Optional[int]
storage_offset: _MetadataIntLike
storage_bytes: Optional[_MetadataIntLike]
requires_grad: bool
is_quantized: bool
is_conj: bool
@ -905,11 +909,22 @@ class TensorMetadata:
dense_dim: Optional[int]
sparse_dim: Optional[int]
def _flatten_into(self, result: List[object], mode: FakeTensorMode) -> None:
def _flatten_into(
self,
result: List[object],
mode: FakeTensorMode,
state: _CacheKeyState,
) -> None:
# Flatten the TensorMetadata out into `result`. Make sure to call
# state.convert_sym_int() on any SymInts.
for field in dataclasses.fields(self):
value = getattr(self, field.name)
if isinstance(value, (tuple, list, torch.Size)):
mode._prep_args_for_hash(result, value)
# This will recursively flatten the iterable, calling
# convert_sym_int() as necessary.
mode._prep_args_for_hash(result, value, state)
elif isinstance(value, SymInt):
state.convert_sym_int(result, value)
else:
result.append(value)
@ -919,64 +934,81 @@ def extract_tensor_metadata(t: Tensor) -> TensorMetadata:
Extract the TensorMetadata of a tensor.
"""
memory_format: Optional[torch.memory_format] = suggest_memory_format(t)
if is_sparse_any(t) or not t.is_contiguous(memory_format=memory_format):
# Don't call is_contiguous() on a Tensor which has symbolic sizes or things
# will go badly (guards will be messed up?)
if (
t._has_symbolic_sizes_strides
or is_sparse_any(t)
or not t.is_contiguous(memory_format=memory_format)
):
memory_format = None
storage_offset = t.storage_offset()
return TensorMetadata(
dtype=t.dtype,
shape=t.shape,
stride=t.stride() if t.layout == torch.strided else (),
device=t.device,
layout=t.layout,
memory_format=memory_format,
storage_offset=storage_offset,
t.dtype,
t.shape,
t.stride() if t.layout == torch.strided else (),
t.device,
t.layout,
memory_format,
storage_offset,
# Only set storage_bytes for tensors that have storage (not sparse)
storage_bytes=t.untyped_storage().nbytes() if not t.is_sparse else None,
requires_grad=t.requires_grad,
is_quantized=t.is_quantized,
is_conj=t.is_conj(),
is_neg=t.is_neg(),
is_inference=t.is_inference(),
is_sparse=t.is_sparse,
is_coalesced=t.is_coalesced() if t.is_sparse else None,
dense_dim=t.dense_dim() if t.is_sparse else None,
sparse_dim=t.sparse_dim() if t.is_sparse else None,
t.untyped_storage().nbytes() if not t.is_sparse else None,
t.requires_grad,
t.is_quantized,
t.is_conj(),
t.is_neg(),
t.is_inference(),
t.is_sparse,
t.is_coalesced() if t.is_sparse else None,
t.dense_dim() if t.is_sparse else None,
t.sparse_dim() if t.is_sparse else None,
)
class _DispatchCacheKey(list):
@dataclass_slots
@dataclass
class _DispatchCacheKey:
"""
Key for the FakeTensor dispatch cache. Inspired by (copied from)
_HashedSeq from the functools.lru_cache implementation.
Key for the FakeTensor dispatch cache.
"""
__slots__ = "hashvalue" # noqa: PLC0205
key: Tuple[object, ...]
hashvalue: int
def __init__(
self, tup: Tuple[object, ...], hash: Callable[[object], int] = hash
) -> None:
self[:] = tup
def __init__(self, tup: Tuple[object, ...]) -> None:
self.key = tup
self.hashvalue = hash(tup)
def __hash__(self) -> int: # type: ignore[override]
def __eq__(self, other: object) -> bool:
return isinstance(other, _DispatchCacheKey) and self.key == other.key
def __hash__(self) -> int:
return self.hashvalue
def strip_shape_env(self) -> None:
# We need to strip the ShapeEnv from any values before we store in the
# cache so the cache doesn't keep our ShapeEnvs alive.
for v in self.key:
if isinstance(v, _PySymInputStub):
v.strip_shape_env()
@dataclass_slots
@dataclass(frozen=True)
class _DispatchCacheEntry:
"""
Entry type for the FakeTensor dispatch cache. Accounts for two possibilities:
1) The op is inplace, and a hit means we need to alias the argument at a given
index. 2) We need to synthesize a new FakeTensor given tensor metadata. For view
ops, we further capture the index of the arg to alias.
1) The op is inplace, and a hit means we need to alias the argument at a
given index.
2) We need to synthesize a new FakeTensor given tensor metadata. For view
ops, we further capture the index of the arg to alias.
"""
inplace_idx: Optional[int] = None
metadata: Optional[TensorMetadata] = None
view_idx: Optional[int] = None
inplace_idx: Optional[int]
metadata: Optional[TensorMetadata]
view_idx: Optional[int]
@dataclass_slots
@ -1233,10 +1265,16 @@ class FakeTensorMode(TorchDispatchMode):
"""
output: object = _UNASSIGNED
try:
key = self._cache_key(func, args, kwargs)
entry = FakeTensorMode.cache.get(key, None)
state = _CacheKeyState(self.shape_env)
key = self._cache_key(state, func, args, kwargs)
if state.cache_on_shape_env():
assert state.shape_env is not None
cache = state.shape_env.fake_tensor_cache
else:
cache = FakeTensorMode.cache
entry = cache.get(key, None)
if entry is not None:
output = self._output_from_cache_entry(entry, func, args)
output = self._output_from_cache_entry(state, entry, key, func, args)
FakeTensorMode.cache_hits += 1
if self.cache_crosscheck_enabled:
# For debugging / testing: Validate that the output synthesized
@ -1245,8 +1283,9 @@ class FakeTensorMode(TorchDispatchMode):
else:
self._validate_cache_key(func, args, kwargs)
output = self._dispatch_impl(func, types, args, kwargs)
entry = self._make_cache_entry(key, func, args, kwargs, output)
FakeTensorMode.cache[key] = entry
entry = self._make_cache_entry(state, key, func, args, kwargs, output)
key.strip_shape_env()
cache[key] = entry
FakeTensorMode.cache_misses += 1
except _BypassDispatchCache as e:
FakeTensorMode.cache_bypasses[e.reason] += 1
@ -1258,6 +1297,7 @@ class FakeTensorMode(TorchDispatchMode):
def _cache_key(
self,
state: _CacheKeyState,
func: OpOverload,
args: Sequence[object],
kwargs: Mapping[str, object],
@ -1284,9 +1324,9 @@ class FakeTensorMode(TorchDispatchMode):
]
# Translate any FakeTensor args to metadata.
if args:
self._prep_args_for_hash(key_values, args)
self._prep_args_for_hash(key_values, args, state)
if kwargs:
self._prep_args_for_hash(key_values, kwargs)
self._prep_args_for_hash(key_values, kwargs, state)
return _DispatchCacheKey(tuple(key_values))
def _validate_cache_key(
@ -1335,6 +1375,7 @@ class FakeTensorMode(TorchDispatchMode):
self,
result: List[object],
args: Union[Mapping[str, object], Sequence[object], Iterable[object]],
state: _CacheKeyState,
) -> None:
"""
Translate the provided args into a form suitable for caching at FakeTensor
@ -1343,16 +1384,14 @@ class FakeTensorMode(TorchDispatchMode):
unsupported cases that should bypass caching.
"""
if isinstance(args, dict):
self._prep_args_for_hash(result, args.keys())
self._prep_args_for_hash(result, args.values())
self._prep_args_for_hash(result, args.keys(), state)
self._prep_args_for_hash(result, args.values(), state)
return
for arg in args:
if isinstance(arg, FakeTensor):
if not self.is_our_fake(arg):
raise _BypassDispatchCache("not our fake")
if arg._has_symbolic_sizes_strides:
raise _BypassDispatchCache("symbolic shape")
if arg.constant is not None:
raise _BypassDispatchCache("constant attribute")
if arg.is_sparse:
@ -1366,18 +1405,18 @@ class FakeTensorMode(TorchDispatchMode):
# Does this subsume arg.is_sparse?
raise _BypassDispatchCache("sparse tensor layout")
# sparse tensors don't have storage, so check is after
if isinstance(arg.untyped_storage().nbytes(), SymInt):
raise _BypassDispatchCache("symbolic nbytes")
if is_sparse_compressed(arg):
raise _BypassDispatchCache("sparse compressed tensor")
metadata = extract_tensor_metadata(arg)
metadata._flatten_into(result, self)
metadata._flatten_into(result, self, state)
elif isinstance(arg, Tensor):
raise _BypassDispatchCache("non-fake tensor")
elif isinstance(arg, (SymBool, SymInt, SymFloat)):
elif isinstance(arg, SymInt):
state.convert_sym_int(result, arg)
elif isinstance(arg, (SymBool, SymFloat)):
raise _BypassDispatchCache("symbolic shape")
elif isinstance(arg, (list, tuple, dict)):
self._prep_args_for_hash(result, arg)
self._prep_args_for_hash(result, arg, state)
else:
# It's important to capture the type of the arg since, e.g., 1 and 1.0
# hash to the same value, but can produce different dtypes for the
@ -1387,6 +1426,7 @@ class FakeTensorMode(TorchDispatchMode):
def _make_cache_entry(
self,
state: _CacheKeyState,
key: _DispatchCacheKey,
func: OpOverload,
args: Sequence[object],
@ -1439,8 +1479,19 @@ class FakeTensorMode(TorchDispatchMode):
view_idx = idxs[0]
metadata = extract_tensor_metadata(output)
metadata.shape = tuple(state.convert_output(v) for v in metadata.shape)
metadata.stride = tuple(state.convert_output(v) for v in metadata.stride)
metadata.storage_offset = state.convert_output(metadata.storage_offset)
metadata.storage_bytes = (
None
if metadata.storage_bytes is None
else state.convert_output(metadata.storage_bytes)
)
entry = _DispatchCacheEntry(
inplace_idx=None, metadata=metadata, view_idx=view_idx
inplace_idx=None,
metadata=metadata,
view_idx=view_idx,
)
# N.B.: Some checks for bypassing the cache would be performed on the
@ -1448,7 +1499,7 @@ class FakeTensorMode(TorchDispatchMode):
# we can synthesize a tensor here and do the checks on that instance.
# This approach keeps the (more frequent) cache-hit path as lightweight
# as possible.
synth_output = self._output_from_cache_entry(entry, func, args)
synth_output = self._output_from_cache_entry(state, entry, key, func, args)
# Make sure the dispatch_key_set from the synthesized output tensor will
# be the same.
@ -1460,7 +1511,12 @@ class FakeTensorMode(TorchDispatchMode):
return entry
def _output_from_cache_entry(
self, entry: _DispatchCacheEntry, func: OpOverload, args: Sequence[object]
self,
state: _CacheKeyState,
entry: _DispatchCacheEntry,
key: _DispatchCacheKey,
func: OpOverload,
args: Sequence[object],
) -> Optional[FakeTensor]:
"""
Create a new FakeTensor from the cache entry.
@ -1478,40 +1534,65 @@ class FakeTensorMode(TorchDispatchMode):
assert not metadata.is_sparse
empty = torch.empty_strided(
metadata.shape,
metadata.stride,
dtype=metadata.dtype,
layout=metadata.layout,
device="meta",
requires_grad=metadata.requires_grad,
def check_value(
value: _MetadataIntLike, state: _CacheKeyState
) -> Union[IntLikeType]:
if isinstance(value, _SymIntOutputStub):
assert state.shape_env is not None
return value.extract(key, state.shape_env)
else:
assert not isinstance(value, _PySymInputStub)
return value
shape = tuple(check_value(v, state) for v in metadata.shape)
stride = tuple(check_value(v, state) for v in metadata.stride)
storage_offset = check_value(metadata.storage_offset, state)
storage_bytes = (
None
if metadata.storage_bytes is None
else check_value(metadata.storage_bytes, state)
)
maybe_suppress: Callable[[], typing.ContextManager] = contextlib.nullcontext
if self.shape_env is not None:
maybe_suppress = self.shape_env.suppress_guards
with in_kernel_invocation_manager(self), maybe_suppress():
empty = torch.empty_strided(
shape,
stride,
dtype=metadata.dtype,
layout=metadata.layout,
device="meta",
requires_grad=metadata.requires_grad,
)
if metadata.is_conj:
torch._C._set_conj(empty, True)
if metadata.is_neg:
torch._C._set_neg(empty, True)
maybe_suppress: Callable[[], typing.ContextManager] = contextlib.nullcontext
if self.shape_env is not None:
maybe_suppress = self.shape_env.suppress_guards
if func.is_view:
# For view ops, the storage should be the same as the tensor input.
view_arg = args[cast(int, entry.view_idx)]
assert isinstance(view_arg, FakeTensor)
storage = view_arg.untyped_storage()
with in_kernel_invocation_manager(self), maybe_suppress():
empty.set_(
storage, metadata.storage_offset, metadata.shape, metadata.stride
)
elif metadata.storage_offset != 0:
empty.set_(storage, storage_offset, shape, stride)
elif storage_offset != 0:
storage = empty.untyped_storage()
with in_kernel_invocation_manager(self), maybe_suppress():
empty.set_(
storage, metadata.storage_offset, metadata.shape, metadata.stride
)
if metadata.storage_bytes == 0:
empty.set_(storage, storage_offset, shape, stride)
if isinstance(storage_bytes, SymInt):
# Do it this way so we don't import symbolic_shapes (which imports
# expensive sympy) unless we have to.
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
zero_bytes = guard_size_oblivious(storage_bytes == 0)
else:
zero_bytes = storage_bytes == 0
if zero_bytes:
empty.untyped_storage().resize_(0)
return FakeTensor(self, empty, metadata.device)
@ -1729,7 +1810,7 @@ class FakeTensorMode(TorchDispatchMode):
def maybe_to_real_tensor(t: T) -> Optional[Union[T, Tensor]]:
if isinstance(t, FakeTensor):
return t.real_tensor
elif isinstance(t, SymTypes):
elif isinstance(t, py_sym_types):
assert self.shape_env is not None
return t.node.pytype(
t.node.expr.xreplace(self.shape_env.var_to_val).xreplace(
@ -1742,7 +1823,6 @@ class FakeTensorMode(TorchDispatchMode):
from torch.fx.experimental.symbolic_shapes import (
compute_unbacked_bindings,
free_unbacked_symbols,
SymTypes,
)
nil = object()
@ -1787,7 +1867,7 @@ class FakeTensorMode(TorchDispatchMode):
if isinstance(t, FakeTensor):
# NB: unconditionally overwrite
t.real_tensor = real_t
elif isinstance(t, SymTypes) and free_unbacked_symbols(t):
elif isinstance(t, py_sym_types) and free_unbacked_symbols(t):
if isinstance(t.node.expr, sympy.Symbol):
assert self.shape_env is not None
self.shape_env.set_unbacked_var_to_val(t.node.expr, real_t)
@ -2276,3 +2356,16 @@ from torch._subclasses.fake_impls import ( # noqa: F401
op_implementations_checks,
stride_incorrect_op,
)
@atexit.register
def dump_cache_stats() -> None:
log.info("FakeTensor cache stats:")
log.info(" cache_hits: %s", FakeTensorMode.cache_hits)
log.info(" cache_misses: %s", FakeTensorMode.cache_misses)
bypasses = FakeTensorMode.cache_bypasses
if bypasses:
log.info(" cache_bypasses:")
width = max(len(k) for k in bypasses)
for k, v in sorted(bypasses.items(), key=lambda i: -i[1]):
log.info(" %-*s %s", width + 1, f"{k}:", v)

View File

@ -237,7 +237,7 @@ class MetaTensorDescriber:
# NB: We actually don't use storage to do views, but might as well
# put it in for accuracy
storage = self.describe_storage(t.untyped_storage(), trace=trace)
storage_offset = t.storage_offset()
storage_offset = t.storage_offset() # type: ignore[assignment]
stride = None
if not (

View File

@ -26,7 +26,7 @@ from .sym_node import SymNode
from collections import defaultdict
from contextlib import contextmanager, nullcontext, AbstractContextManager, ExitStack
from dataclasses import dataclass
from torch import SymInt, SymFloat, SymBool, Tensor
from torch import SymInt, SymBool, Tensor
from torch._dispatch.python import enable_python_dispatcher
from torch._library.fake_class_registry import FakeScriptObject
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode, unset_fake_temporarily, is_fake
@ -111,8 +111,7 @@ class _NoDefault:
no_default = _NoDefault()
py_sym_types = (SymInt, SymFloat, SymBool)
PySymType = Union[SymInt, SymFloat, SymBool]
from torch.types import py_sym_types, PySymType
class _HasMeta(Protocol):
meta: Dict[str, PySymType]

View File

@ -47,7 +47,7 @@ sym_node_log = torch._logging.getArtifactLogger(__name__, "sym_node")
__all__ = ["SymNode", "method_to_operator", "magic_methods"]
SymTypes = (SymInt, SymFloat, SymBool)
from torch.types import py_sym_types as SymTypes
def _to_symtype(t):
@ -123,7 +123,7 @@ class SymNode:
"Cannot create SymNode of type "
f"{pytype} with incompatible hint of type {type(hint)}"
)
if self.shape_env._translation_validation_enabled:
if self.shape_env and self.shape_env._translation_validation_enabled:
# This is technically not TV, but this assert is expensive so
# let's only do it when we're already doing expensive things
computed_hint = compute_hint()
@ -138,15 +138,30 @@ class SymNode:
# Record the FX node of the current node if we are doing translation
# validation. They will be used for building the input assertions for
# the translation validation problem.
self.fx_node = (
fx_node if self.shape_env._translation_validation_enabled else None
tx_validation_en = (
self.shape_env and self.shape_env._translation_validation_enabled
)
self.fx_node = tx_validation_en and fx_node
def with_shape_env(self, shape_env: "ShapeEnv") -> "SymNode":
return SymNode(
self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node
)
def _value_eq(self, other: "SymNode") -> bool:
# Purposely don't include the shape_env in the eq.
return (
self._expr == other._expr
and self.pytype == other.pytype
and self._hint == other._hint
and self.constant == other.constant
and self.fx_node == other.fx_node
)
def _value_hash(self) -> int:
# Purposely don't include the shape_env in the hash.
return hash((self._expr, self.pytype, self._hint, self.constant, self.fx_node))
@property
def expr(self):
return self.shape_env.replace(self._expr)
@ -248,7 +263,7 @@ class SymNode:
def __repr__(self):
rep = [
f"SymNode({self.expr}, shape_env={self.shape_env}, pytype={self.pytype}",
f"SymNode({self._expr}, shape_env={self.shape_env}, pytype={self.pytype}",
]
if self._hint is not None:
rep.append(f"hint={self._hint}")

View File

@ -2387,6 +2387,16 @@ class ShapeEnv:
[ShapeEnvEvent(ShapeEnv, kwargs=kwargs)] if self.should_record_events else []
)
# FakeTensor per-ShapeEnv operation cache. This is used for caching
# operations that contain symbolic shapes which have guards on the
# ShapeEnv (so are ShapeEnv-dependent).
#
# NOTE: It's important that SymNodes in this cache have their ShapeEnv
# stripped otherwise you end up with cycles which can only be cleaned
# with the GC.
self.fake_tensor_cache: Dict[torch._subclasses.fake_tensor._DispatchCacheKey,
torch._subclasses.fake_tensor._DispatchCacheEntry] = {}
# Pro-tip: if you add new field to ShapeEnv, this affects some accept
# tests. Accept their output with:
#
@ -2684,7 +2694,7 @@ class ShapeEnv:
elif key == "name_to_node":
# Compare just the set of keys is the same.
return set(value.keys())
elif key in ["symbol_guard_counter", "pending_fresh_unbacked_symbols"]:
elif key in ("symbol_guard_counter", "pending_fresh_unbacked_symbols", "fake_tensor_cache"):
# Skip this for comparisons
return None
return value

View File

@ -17,7 +17,7 @@ from builtins import ( # noqa: F401
from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
import torch
from torch import SymInt
from torch import SymBool, SymFloat, SymInt
if TYPE_CHECKING:
@ -47,6 +47,9 @@ _dispatchkey = Union[builtins.str, torch._C.DispatchKey]
# int or SymInt
IntLikeType = Union[_int, torch.SymInt]
py_sym_types = (SymInt, SymFloat, SymBool)
PySymType = Union[SymInt, SymFloat, SymBool]
# Meta-type for "numeric" things; matches our docs
Number = Union[builtins.int, builtins.float, builtins.bool]