mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
FakeTensor cache SymInt support (#127596)
Adds support for SymInts in the FakeTensor cache. A couple notes: 1. When a SymInt is present in the input key for a FakeTensor operation we cache on the ShapeEnv instead of using the FakeTensorMode cache. This is necessary so we don't have to remember and check the guards. It reduces the cache hits but there's diminishing return on how much work we can do before the cache becomes more of a burden than a gain. 2. We need to be careful that when we cache an output SymInt that is a direct copy from the input that when we have a cache-hit we copy the SymNode from the input to the output. This is important because the fx-graph building code actually uses SymNode ids in the process of building the graph so constructing a same-content-but-different-id SymNode will fail. 3. In the cache key we store SymInts as a _PySymInputStub. These represent SymInt (and friends) but support `__hash__` and `__eq__` (which SymInt do not). 4. In the cache entry we store SymInts as a _SymIntOutputStub. Perf example: ``` python benchmarks/dynamo/timm_models.py --ci --accuracy --timing --explain --inductor --dynamic-shapes --dynamic-batch-only --device cuda --training --amp --total-partitions 2 --partition-id 0 --output /tmp/training_timm_models.csv --filter crossvit_9_240 ``` fake tensor cache before: ``` INFO: FakeTensor cache stats: INFO: cache_hits: 68137 INFO: cache_misses: 837 INFO: cache_bypasses: INFO: symbolic shape: 48224 INFO: CompositeImplicitAutograd: 917 INFO: non-fake tensor: 70 INFO: non-FakeTensor output: 62 INFO: non-builtin: 8 INFO: dynamic output shape: 1 ``` and after: ``` INFO: FakeTensor cache stats: INFO: cache_hits: 88187 INFO: cache_misses: 14233 INFO: cache_bypasses: INFO: CompositeImplicitAutograd: 1037 INFO: non-FakeTensor output: 602 INFO: non-fake tensor: 70 INFO: unsafe view: 36 INFO: non-builtin: 8 INFO: dynamic output shape: 1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127596 Approved by: https://github.com/eellison ghstack dependencies: #131014, #129780
This commit is contained in:
committed by
PyTorch MergeBot
parent
ebce85172e
commit
b193894b94
@ -9856,7 +9856,7 @@ ShapeEnv not equal: field values don't match:
|
||||
if not forward_deterministic and backward_deterministic:
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"^This compiled backward function is being run with torch\.use_deterministic_algorithms",
|
||||
r"^This compiled backward function is being run with torch\.use_deterministic_algorithms",
|
||||
):
|
||||
res.backward(grad)
|
||||
|
||||
|
||||
@ -31,6 +31,7 @@ from torch._subclasses.fake_tensor import (
|
||||
FakeTensorMode,
|
||||
unset_fake_temporarily,
|
||||
UnsupportedOperatorException,
|
||||
_CacheKeyState
|
||||
)
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
@ -1611,9 +1612,10 @@ class FakeTensorDispatchCache(TestCase):
|
||||
cache keys for inputs x and y are the same, but z is different.
|
||||
"""
|
||||
func = aten.add.Tensor
|
||||
key_x = fm._cache_key(func, [x], {})
|
||||
key_y = fm._cache_key(func, [y], {})
|
||||
key_z = fm._cache_key(func, [z], {})
|
||||
state = _CacheKeyState()
|
||||
key_x = fm._cache_key(state, func, [x], {})
|
||||
key_y = fm._cache_key(state, func, [y], {})
|
||||
key_z = fm._cache_key(state, func, [z], {})
|
||||
|
||||
self.assertEqual(key_x, key_y)
|
||||
self.assertNotEqual(key_x, key_z)
|
||||
|
||||
@ -1186,7 +1186,7 @@ def gen_pyi(
|
||||
"is_mkldnn": ["is_mkldnn: _bool"],
|
||||
"is_vulkan": ["is_vulkan: _bool"],
|
||||
"is_ipu": ["is_ipu: _bool"],
|
||||
"storage_offset": ["def storage_offset(self) -> _int: ..."],
|
||||
"storage_offset": ["def storage_offset(self) -> Union[_int, SymInt]: ..."],
|
||||
"to": [
|
||||
(
|
||||
f"def to(self, {args}, non_blocking: _bool = False, copy: _bool = False, *, "
|
||||
@ -1204,7 +1204,7 @@ def gen_pyi(
|
||||
],
|
||||
"set_": [
|
||||
"def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage], "
|
||||
"offset: _int, size: _symsize, stride: _symsize) -> Tensor: ...",
|
||||
"offset: IntLikeType, size: _symsize, stride: _symsize) -> Tensor: ...",
|
||||
"def set_(self, storage: Union[Storage, TypedStorage, UntypedStorage]) -> Tensor: ...",
|
||||
],
|
||||
"split": [
|
||||
|
||||
@ -44,6 +44,7 @@ from torch.types import (
|
||||
Device,
|
||||
Number,
|
||||
Storage,
|
||||
IntLikeType,
|
||||
_bool,
|
||||
_bytes,
|
||||
_complex,
|
||||
|
||||
@ -3,7 +3,8 @@ from dataclasses import dataclass
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import py_sym_types, SymBool, SymFloat, SymInt
|
||||
from torch import SymBool, SymFloat, SymInt
|
||||
from torch.types import py_sym_types
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -13,6 +13,7 @@ DISTRIBUTED = [
|
||||
]
|
||||
|
||||
register_log("dynamo", ["torch._dynamo", *DYNAMIC])
|
||||
register_log("fake_tensor", ["torch._subclasses.fake_tensor"])
|
||||
register_log("aot", ["torch._functorch.aot_autograd", "torch._functorch._aot_autograd"])
|
||||
register_log("autograd", "torch.autograd")
|
||||
register_log("inductor", ["torch._inductor", "torch._inductor.cudagraph_trees"])
|
||||
|
||||
@ -1835,7 +1835,7 @@ def are_strides_like_channels_last(
|
||||
for d in dim_order:
|
||||
if guard_size_oblivious(shape[d] == 0):
|
||||
return False
|
||||
if strides[d] < min:
|
||||
if guard_size_oblivious(strides[d] < min):
|
||||
return False
|
||||
if d == 0 and min == strides[1]:
|
||||
return False
|
||||
|
||||
257
torch/_subclasses/_fake_tensor_utils.py
Normal file
257
torch/_subclasses/_fake_tensor_utils.py
Normal file
@ -0,0 +1,257 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Type, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
|
||||
from torch import SymInt
|
||||
from torch.fx.experimental.sym_node import SymNode
|
||||
from torch.types import py_sym_types, PySymType
|
||||
from torch.utils._backport_slots import dataclass_slots
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import sympy
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
from .fake_tensor import _DispatchCacheKey, _MetadataIntLike
|
||||
|
||||
|
||||
@dataclass_slots
|
||||
@dataclass(frozen=True)
|
||||
class _DeconstructedSymNode:
|
||||
"""
|
||||
Represents a SymNode without the associated ShapeEnv
|
||||
"""
|
||||
|
||||
# n.b. keep the same protocol as SymNode
|
||||
_expr: sympy.Expr
|
||||
pytype: type
|
||||
_hint: Optional[Union[int, float, bool]]
|
||||
constant: Optional[Union[int, float, bool]]
|
||||
fx_node: torch.fx.Node
|
||||
|
||||
@staticmethod
|
||||
def from_node(node: SymNode) -> _DeconstructedSymNode:
|
||||
return _DeconstructedSymNode(
|
||||
node._expr, node.pytype, node._hint, node.constant, node.fx_node
|
||||
)
|
||||
|
||||
def extract(self, shape_env: ShapeEnv) -> SymNode:
|
||||
return SymNode(
|
||||
self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self._expr)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"_DeconstructedSymNode{{{self._expr!r}, {self.pytype!r}, {self._hint!r}, {self.constant!r}, {self.fx_node!r}}}"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def __hash__(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
# _value_eq to match SymNode
|
||||
def _value_eq(self, other: object) -> bool:
|
||||
if isinstance(other, (SymNode, _DeconstructedSymNode)):
|
||||
return (
|
||||
self._expr == other._expr
|
||||
and self.pytype == other.pytype
|
||||
and self._hint == other._hint
|
||||
and self.constant == other.constant
|
||||
and self.fx_node == other.fx_node
|
||||
)
|
||||
else:
|
||||
return False
|
||||
|
||||
# _value_hash to match SymNode
|
||||
def _value_hash(self) -> int:
|
||||
return hash((self._expr, self.pytype, self._hint, self.constant, self.fx_node))
|
||||
|
||||
|
||||
@dataclass_slots
|
||||
@dataclass(frozen=True)
|
||||
class _DeconstructedSymType:
|
||||
"""
|
||||
Represents a SymInt, SymFloat, SymBool without the associated ShapeEnv
|
||||
"""
|
||||
|
||||
ty: Type[PySymType]
|
||||
node: _DeconstructedSymNode
|
||||
|
||||
@staticmethod
|
||||
def from_sym_type(value: PySymType) -> _DeconstructedSymType:
|
||||
return _DeconstructedSymType(type(value), value.node)
|
||||
|
||||
def extract(self, shape_env: ShapeEnv) -> PySymType:
|
||||
return self.ty(self.node.extract(shape_env))
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.ty}({self.node})"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"_DeconstructedSymType({self.ty}, {self.node!r})"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return NotImplemented
|
||||
|
||||
|
||||
@dataclass_slots
|
||||
@dataclass(frozen=True)
|
||||
class _InputBackref:
|
||||
value: int
|
||||
|
||||
|
||||
@dataclass_slots
|
||||
@dataclass
|
||||
class _PySymInputStub:
|
||||
"""
|
||||
Represents a SymInt in the cached key. Needed because SymInt doesn't
|
||||
support __eq__ or __hash__ directly.
|
||||
"""
|
||||
|
||||
# value can be:
|
||||
# PySymType: This is the 'normal' SymInt value, wrapped so we can use
|
||||
# hash/eq as value hash/eq (normally SymInt does object
|
||||
# hash/eq).
|
||||
# _DeconstructedSymType: This is used when storing the _PySymInputStub in
|
||||
# the cache to avoid cyclic ShapeEnv references.
|
||||
# _InputBackref: This is a back-reference to a previous _PySymInputStub in
|
||||
# the key.
|
||||
value: Union[PySymType, _DeconstructedSymType, _InputBackref]
|
||||
|
||||
def __init__(
|
||||
self, value: Union[PySymType, _DeconstructedSymType, _InputBackref]
|
||||
) -> None:
|
||||
# For inputs (values in the `key`) we need to keep the PySymType intact
|
||||
# - this way if we need to reuse it as an output we can properly copy
|
||||
# the original value.
|
||||
self.value = value
|
||||
|
||||
def strip_shape_env(self) -> None:
|
||||
if isinstance(self.value, py_sym_types):
|
||||
self.value = _DeconstructedSymType.from_sym_type(self.value)
|
||||
|
||||
def extract(self, shape_env: ShapeEnv) -> PySymType:
|
||||
if isinstance(self.value, _DeconstructedSymType):
|
||||
return self.value.extract(shape_env)
|
||||
else:
|
||||
# We should never see an _InputBackref here - anyone extracting a
|
||||
# value should be pulling from the original entry (the one this
|
||||
# backref points at).
|
||||
assert not isinstance(self.value, _InputBackref)
|
||||
return self.value
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.value)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"_PySymInputStub({self.value!r})"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, _PySymInputStub):
|
||||
return False
|
||||
elif isinstance(self.value, _InputBackref) or isinstance(
|
||||
other.value, _InputBackref
|
||||
):
|
||||
return self.value == other.value
|
||||
else:
|
||||
return self.value.node._value_eq(other.value.node)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
if isinstance(self.value, _InputBackref):
|
||||
return hash(self.value)
|
||||
else:
|
||||
return self.value.node._value_hash()
|
||||
|
||||
|
||||
@dataclass_slots
|
||||
@dataclass
|
||||
class _SymIntOutputStub:
|
||||
"""
|
||||
Represents a SymInt in the cached output.
|
||||
"""
|
||||
|
||||
# This is either an `int` which represents the index in the key to copy the
|
||||
# SymNode from or it's the deconstructed SymNode itself.
|
||||
value: Union[int, _DeconstructedSymNode]
|
||||
|
||||
def __init__(self, value: SymInt, key_path: Optional[int]) -> None:
|
||||
if key_path is None:
|
||||
self.value = _DeconstructedSymNode.from_node(value.node)
|
||||
else:
|
||||
self.value = key_path
|
||||
|
||||
def extract(self, key: _DispatchCacheKey, shape_env: ShapeEnv) -> SymInt:
|
||||
if isinstance(self.value, _DeconstructedSymNode):
|
||||
return SymInt(self.value.extract(shape_env))
|
||||
else:
|
||||
src = key.key[self.value]
|
||||
assert isinstance(src, _PySymInputStub) and isinstance(src.value, SymInt)
|
||||
return src.value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"_SymIntOutputStub({self.value!r})"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def __hash__(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass_slots
|
||||
@dataclass
|
||||
class _CacheKeyState:
|
||||
"""
|
||||
State used while building our cache key.
|
||||
"""
|
||||
|
||||
# We track the SymNodes so when we get the output we can see if it exactly
|
||||
# matches one of the inputs so we can uncache it properly.
|
||||
sym_node_lookup: Dict[int, int] # id(SymNode) -> index
|
||||
|
||||
# There are cases where we're asked to perform an op when we have no
|
||||
# ShapeEnv on the FakeTensorMode - but for SymNodes we MUST have a
|
||||
# ShapeEnv. So as we scan if we see a SymNode (with a ShapeEnv) we record it
|
||||
# here.
|
||||
shape_env: Optional[ShapeEnv]
|
||||
|
||||
def __init__(self, shape_env: Optional[ShapeEnv] = None) -> None:
|
||||
self.sym_node_lookup = {}
|
||||
self.shape_env = shape_env
|
||||
|
||||
def cache_on_shape_env(self) -> bool:
|
||||
"""
|
||||
Returns true if the CacheKey needs to be cached on the ShapeEnv
|
||||
rather than the global cache.
|
||||
|
||||
If our inputs contain a SymNode then we can't cache this operation on
|
||||
the global cache because the cached output will implicitly depend on
|
||||
guard values which might not be true on some other ShapeEnv. So unless
|
||||
we're also going to cache the guards we need to cache this operation on
|
||||
the ShapeEnv instead of globally.
|
||||
"""
|
||||
return bool(self.sym_node_lookup)
|
||||
|
||||
def convert_sym_int(self, result: List[object], arg: SymInt) -> None:
|
||||
node_id = id(arg.node)
|
||||
if node_id in self.sym_node_lookup:
|
||||
result.append(_InputBackref(self.sym_node_lookup[node_id]))
|
||||
else:
|
||||
self.sym_node_lookup[node_id] = len(result)
|
||||
if self.shape_env is None:
|
||||
self.shape_env = arg.node.shape_env
|
||||
result.append(_PySymInputStub(arg))
|
||||
|
||||
def convert_output(self, arg: _MetadataIntLike) -> _MetadataIntLike:
|
||||
if isinstance(arg, SymInt):
|
||||
return _SymIntOutputStub(arg, self.sym_node_lookup.get(id(arg.node), None))
|
||||
else:
|
||||
return arg
|
||||
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import functools
|
||||
@ -34,7 +35,6 @@ from typing_extensions import Self, TypeGuard
|
||||
from weakref import ReferenceType
|
||||
|
||||
import torch
|
||||
import torch._custom_op
|
||||
|
||||
from torch import SymBool, SymFloat, SymInt, Tensor
|
||||
from torch._C._functorch import is_functorch_wrapped_tensor, is_legacy_batchedtensor
|
||||
@ -51,6 +51,7 @@ from torch.fx.immutable_collections import immutable_dict
|
||||
from torch.fx.operator_schemas import normalize_function
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
from torch.overrides import TorchFunctionMode
|
||||
from torch.types import IntLikeType, py_sym_types
|
||||
from torch.utils._backport_slots import dataclass_slots
|
||||
from torch.utils._mode_utils import no_dispatch
|
||||
from torch.utils._python_dispatch import (
|
||||
@ -60,6 +61,7 @@ from torch.utils._python_dispatch import (
|
||||
from torch.utils._pytree import PyTree, tree_map, tree_map_, TreeSpec
|
||||
from torch.utils._stats import count
|
||||
from torch.utils._traceback import CapturedTraceback
|
||||
from ._fake_tensor_utils import _CacheKeyState, _PySymInputStub, _SymIntOutputStub
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
@ -67,7 +69,6 @@ if TYPE_CHECKING:
|
||||
from torch._guards import Source
|
||||
from torch._ops import OpOverload
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymbolicContext
|
||||
from torch.types import IntLikeType
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -880,21 +881,24 @@ class FakeTensor(Tensor):
|
||||
return out
|
||||
|
||||
|
||||
_MetadataIntLike = Union[IntLikeType, "_PySymInputStub", "_SymIntOutputStub"]
|
||||
|
||||
|
||||
@dataclass_slots
|
||||
@dataclass(frozen=True)
|
||||
@dataclass
|
||||
class TensorMetadata:
|
||||
"""
|
||||
The Tensor metadata relevant to hashing FakeTensors when caching.
|
||||
"""
|
||||
|
||||
dtype: torch.dtype
|
||||
shape: torch.Size
|
||||
stride: Tuple[IntLikeType, ...]
|
||||
shape: Tuple[_MetadataIntLike, ...]
|
||||
stride: Tuple[_MetadataIntLike, ...]
|
||||
device: torch.device
|
||||
layout: torch.layout
|
||||
memory_format: Optional[torch.memory_format]
|
||||
storage_offset: int
|
||||
storage_bytes: Optional[int]
|
||||
storage_offset: _MetadataIntLike
|
||||
storage_bytes: Optional[_MetadataIntLike]
|
||||
requires_grad: bool
|
||||
is_quantized: bool
|
||||
is_conj: bool
|
||||
@ -905,11 +909,22 @@ class TensorMetadata:
|
||||
dense_dim: Optional[int]
|
||||
sparse_dim: Optional[int]
|
||||
|
||||
def _flatten_into(self, result: List[object], mode: FakeTensorMode) -> None:
|
||||
def _flatten_into(
|
||||
self,
|
||||
result: List[object],
|
||||
mode: FakeTensorMode,
|
||||
state: _CacheKeyState,
|
||||
) -> None:
|
||||
# Flatten the TensorMetadata out into `result`. Make sure to call
|
||||
# state.convert_sym_int() on any SymInts.
|
||||
for field in dataclasses.fields(self):
|
||||
value = getattr(self, field.name)
|
||||
if isinstance(value, (tuple, list, torch.Size)):
|
||||
mode._prep_args_for_hash(result, value)
|
||||
# This will recursively flatten the iterable, calling
|
||||
# convert_sym_int() as necessary.
|
||||
mode._prep_args_for_hash(result, value, state)
|
||||
elif isinstance(value, SymInt):
|
||||
state.convert_sym_int(result, value)
|
||||
else:
|
||||
result.append(value)
|
||||
|
||||
@ -919,64 +934,81 @@ def extract_tensor_metadata(t: Tensor) -> TensorMetadata:
|
||||
Extract the TensorMetadata of a tensor.
|
||||
"""
|
||||
memory_format: Optional[torch.memory_format] = suggest_memory_format(t)
|
||||
if is_sparse_any(t) or not t.is_contiguous(memory_format=memory_format):
|
||||
# Don't call is_contiguous() on a Tensor which has symbolic sizes or things
|
||||
# will go badly (guards will be messed up?)
|
||||
if (
|
||||
t._has_symbolic_sizes_strides
|
||||
or is_sparse_any(t)
|
||||
or not t.is_contiguous(memory_format=memory_format)
|
||||
):
|
||||
memory_format = None
|
||||
|
||||
storage_offset = t.storage_offset()
|
||||
|
||||
return TensorMetadata(
|
||||
dtype=t.dtype,
|
||||
shape=t.shape,
|
||||
stride=t.stride() if t.layout == torch.strided else (),
|
||||
device=t.device,
|
||||
layout=t.layout,
|
||||
memory_format=memory_format,
|
||||
storage_offset=storage_offset,
|
||||
t.dtype,
|
||||
t.shape,
|
||||
t.stride() if t.layout == torch.strided else (),
|
||||
t.device,
|
||||
t.layout,
|
||||
memory_format,
|
||||
storage_offset,
|
||||
# Only set storage_bytes for tensors that have storage (not sparse)
|
||||
storage_bytes=t.untyped_storage().nbytes() if not t.is_sparse else None,
|
||||
requires_grad=t.requires_grad,
|
||||
is_quantized=t.is_quantized,
|
||||
is_conj=t.is_conj(),
|
||||
is_neg=t.is_neg(),
|
||||
is_inference=t.is_inference(),
|
||||
is_sparse=t.is_sparse,
|
||||
is_coalesced=t.is_coalesced() if t.is_sparse else None,
|
||||
dense_dim=t.dense_dim() if t.is_sparse else None,
|
||||
sparse_dim=t.sparse_dim() if t.is_sparse else None,
|
||||
t.untyped_storage().nbytes() if not t.is_sparse else None,
|
||||
t.requires_grad,
|
||||
t.is_quantized,
|
||||
t.is_conj(),
|
||||
t.is_neg(),
|
||||
t.is_inference(),
|
||||
t.is_sparse,
|
||||
t.is_coalesced() if t.is_sparse else None,
|
||||
t.dense_dim() if t.is_sparse else None,
|
||||
t.sparse_dim() if t.is_sparse else None,
|
||||
)
|
||||
|
||||
|
||||
class _DispatchCacheKey(list):
|
||||
@dataclass_slots
|
||||
@dataclass
|
||||
class _DispatchCacheKey:
|
||||
"""
|
||||
Key for the FakeTensor dispatch cache. Inspired by (copied from)
|
||||
_HashedSeq from the functools.lru_cache implementation.
|
||||
Key for the FakeTensor dispatch cache.
|
||||
"""
|
||||
|
||||
__slots__ = "hashvalue" # noqa: PLC0205
|
||||
key: Tuple[object, ...]
|
||||
hashvalue: int
|
||||
|
||||
def __init__(
|
||||
self, tup: Tuple[object, ...], hash: Callable[[object], int] = hash
|
||||
) -> None:
|
||||
self[:] = tup
|
||||
def __init__(self, tup: Tuple[object, ...]) -> None:
|
||||
self.key = tup
|
||||
self.hashvalue = hash(tup)
|
||||
|
||||
def __hash__(self) -> int: # type: ignore[override]
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, _DispatchCacheKey) and self.key == other.key
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.hashvalue
|
||||
|
||||
def strip_shape_env(self) -> None:
|
||||
# We need to strip the ShapeEnv from any values before we store in the
|
||||
# cache so the cache doesn't keep our ShapeEnvs alive.
|
||||
for v in self.key:
|
||||
if isinstance(v, _PySymInputStub):
|
||||
v.strip_shape_env()
|
||||
|
||||
|
||||
@dataclass_slots
|
||||
@dataclass(frozen=True)
|
||||
class _DispatchCacheEntry:
|
||||
"""
|
||||
Entry type for the FakeTensor dispatch cache. Accounts for two possibilities:
|
||||
1) The op is inplace, and a hit means we need to alias the argument at a given
|
||||
index. 2) We need to synthesize a new FakeTensor given tensor metadata. For view
|
||||
ops, we further capture the index of the arg to alias.
|
||||
1) The op is inplace, and a hit means we need to alias the argument at a
|
||||
given index.
|
||||
2) We need to synthesize a new FakeTensor given tensor metadata. For view
|
||||
ops, we further capture the index of the arg to alias.
|
||||
"""
|
||||
|
||||
inplace_idx: Optional[int] = None
|
||||
metadata: Optional[TensorMetadata] = None
|
||||
view_idx: Optional[int] = None
|
||||
inplace_idx: Optional[int]
|
||||
metadata: Optional[TensorMetadata]
|
||||
view_idx: Optional[int]
|
||||
|
||||
|
||||
@dataclass_slots
|
||||
@ -1233,10 +1265,16 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
"""
|
||||
output: object = _UNASSIGNED
|
||||
try:
|
||||
key = self._cache_key(func, args, kwargs)
|
||||
entry = FakeTensorMode.cache.get(key, None)
|
||||
state = _CacheKeyState(self.shape_env)
|
||||
key = self._cache_key(state, func, args, kwargs)
|
||||
if state.cache_on_shape_env():
|
||||
assert state.shape_env is not None
|
||||
cache = state.shape_env.fake_tensor_cache
|
||||
else:
|
||||
cache = FakeTensorMode.cache
|
||||
entry = cache.get(key, None)
|
||||
if entry is not None:
|
||||
output = self._output_from_cache_entry(entry, func, args)
|
||||
output = self._output_from_cache_entry(state, entry, key, func, args)
|
||||
FakeTensorMode.cache_hits += 1
|
||||
if self.cache_crosscheck_enabled:
|
||||
# For debugging / testing: Validate that the output synthesized
|
||||
@ -1245,8 +1283,9 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
else:
|
||||
self._validate_cache_key(func, args, kwargs)
|
||||
output = self._dispatch_impl(func, types, args, kwargs)
|
||||
entry = self._make_cache_entry(key, func, args, kwargs, output)
|
||||
FakeTensorMode.cache[key] = entry
|
||||
entry = self._make_cache_entry(state, key, func, args, kwargs, output)
|
||||
key.strip_shape_env()
|
||||
cache[key] = entry
|
||||
FakeTensorMode.cache_misses += 1
|
||||
except _BypassDispatchCache as e:
|
||||
FakeTensorMode.cache_bypasses[e.reason] += 1
|
||||
@ -1258,6 +1297,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
|
||||
def _cache_key(
|
||||
self,
|
||||
state: _CacheKeyState,
|
||||
func: OpOverload,
|
||||
args: Sequence[object],
|
||||
kwargs: Mapping[str, object],
|
||||
@ -1284,9 +1324,9 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
]
|
||||
# Translate any FakeTensor args to metadata.
|
||||
if args:
|
||||
self._prep_args_for_hash(key_values, args)
|
||||
self._prep_args_for_hash(key_values, args, state)
|
||||
if kwargs:
|
||||
self._prep_args_for_hash(key_values, kwargs)
|
||||
self._prep_args_for_hash(key_values, kwargs, state)
|
||||
return _DispatchCacheKey(tuple(key_values))
|
||||
|
||||
def _validate_cache_key(
|
||||
@ -1335,6 +1375,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
self,
|
||||
result: List[object],
|
||||
args: Union[Mapping[str, object], Sequence[object], Iterable[object]],
|
||||
state: _CacheKeyState,
|
||||
) -> None:
|
||||
"""
|
||||
Translate the provided args into a form suitable for caching at FakeTensor
|
||||
@ -1343,16 +1384,14 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
unsupported cases that should bypass caching.
|
||||
"""
|
||||
if isinstance(args, dict):
|
||||
self._prep_args_for_hash(result, args.keys())
|
||||
self._prep_args_for_hash(result, args.values())
|
||||
self._prep_args_for_hash(result, args.keys(), state)
|
||||
self._prep_args_for_hash(result, args.values(), state)
|
||||
return
|
||||
|
||||
for arg in args:
|
||||
if isinstance(arg, FakeTensor):
|
||||
if not self.is_our_fake(arg):
|
||||
raise _BypassDispatchCache("not our fake")
|
||||
if arg._has_symbolic_sizes_strides:
|
||||
raise _BypassDispatchCache("symbolic shape")
|
||||
if arg.constant is not None:
|
||||
raise _BypassDispatchCache("constant attribute")
|
||||
if arg.is_sparse:
|
||||
@ -1366,18 +1405,18 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
# Does this subsume arg.is_sparse?
|
||||
raise _BypassDispatchCache("sparse tensor layout")
|
||||
# sparse tensors don't have storage, so check is after
|
||||
if isinstance(arg.untyped_storage().nbytes(), SymInt):
|
||||
raise _BypassDispatchCache("symbolic nbytes")
|
||||
if is_sparse_compressed(arg):
|
||||
raise _BypassDispatchCache("sparse compressed tensor")
|
||||
metadata = extract_tensor_metadata(arg)
|
||||
metadata._flatten_into(result, self)
|
||||
metadata._flatten_into(result, self, state)
|
||||
elif isinstance(arg, Tensor):
|
||||
raise _BypassDispatchCache("non-fake tensor")
|
||||
elif isinstance(arg, (SymBool, SymInt, SymFloat)):
|
||||
elif isinstance(arg, SymInt):
|
||||
state.convert_sym_int(result, arg)
|
||||
elif isinstance(arg, (SymBool, SymFloat)):
|
||||
raise _BypassDispatchCache("symbolic shape")
|
||||
elif isinstance(arg, (list, tuple, dict)):
|
||||
self._prep_args_for_hash(result, arg)
|
||||
self._prep_args_for_hash(result, arg, state)
|
||||
else:
|
||||
# It's important to capture the type of the arg since, e.g., 1 and 1.0
|
||||
# hash to the same value, but can produce different dtypes for the
|
||||
@ -1387,6 +1426,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
|
||||
def _make_cache_entry(
|
||||
self,
|
||||
state: _CacheKeyState,
|
||||
key: _DispatchCacheKey,
|
||||
func: OpOverload,
|
||||
args: Sequence[object],
|
||||
@ -1439,8 +1479,19 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
view_idx = idxs[0]
|
||||
|
||||
metadata = extract_tensor_metadata(output)
|
||||
metadata.shape = tuple(state.convert_output(v) for v in metadata.shape)
|
||||
metadata.stride = tuple(state.convert_output(v) for v in metadata.stride)
|
||||
metadata.storage_offset = state.convert_output(metadata.storage_offset)
|
||||
metadata.storage_bytes = (
|
||||
None
|
||||
if metadata.storage_bytes is None
|
||||
else state.convert_output(metadata.storage_bytes)
|
||||
)
|
||||
|
||||
entry = _DispatchCacheEntry(
|
||||
inplace_idx=None, metadata=metadata, view_idx=view_idx
|
||||
inplace_idx=None,
|
||||
metadata=metadata,
|
||||
view_idx=view_idx,
|
||||
)
|
||||
|
||||
# N.B.: Some checks for bypassing the cache would be performed on the
|
||||
@ -1448,7 +1499,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
# we can synthesize a tensor here and do the checks on that instance.
|
||||
# This approach keeps the (more frequent) cache-hit path as lightweight
|
||||
# as possible.
|
||||
synth_output = self._output_from_cache_entry(entry, func, args)
|
||||
synth_output = self._output_from_cache_entry(state, entry, key, func, args)
|
||||
|
||||
# Make sure the dispatch_key_set from the synthesized output tensor will
|
||||
# be the same.
|
||||
@ -1460,7 +1511,12 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
return entry
|
||||
|
||||
def _output_from_cache_entry(
|
||||
self, entry: _DispatchCacheEntry, func: OpOverload, args: Sequence[object]
|
||||
self,
|
||||
state: _CacheKeyState,
|
||||
entry: _DispatchCacheEntry,
|
||||
key: _DispatchCacheKey,
|
||||
func: OpOverload,
|
||||
args: Sequence[object],
|
||||
) -> Optional[FakeTensor]:
|
||||
"""
|
||||
Create a new FakeTensor from the cache entry.
|
||||
@ -1478,40 +1534,65 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
|
||||
assert not metadata.is_sparse
|
||||
|
||||
empty = torch.empty_strided(
|
||||
metadata.shape,
|
||||
metadata.stride,
|
||||
dtype=metadata.dtype,
|
||||
layout=metadata.layout,
|
||||
device="meta",
|
||||
requires_grad=metadata.requires_grad,
|
||||
def check_value(
|
||||
value: _MetadataIntLike, state: _CacheKeyState
|
||||
) -> Union[IntLikeType]:
|
||||
if isinstance(value, _SymIntOutputStub):
|
||||
assert state.shape_env is not None
|
||||
return value.extract(key, state.shape_env)
|
||||
else:
|
||||
assert not isinstance(value, _PySymInputStub)
|
||||
return value
|
||||
|
||||
shape = tuple(check_value(v, state) for v in metadata.shape)
|
||||
stride = tuple(check_value(v, state) for v in metadata.stride)
|
||||
storage_offset = check_value(metadata.storage_offset, state)
|
||||
storage_bytes = (
|
||||
None
|
||||
if metadata.storage_bytes is None
|
||||
else check_value(metadata.storage_bytes, state)
|
||||
)
|
||||
|
||||
maybe_suppress: Callable[[], typing.ContextManager] = contextlib.nullcontext
|
||||
if self.shape_env is not None:
|
||||
maybe_suppress = self.shape_env.suppress_guards
|
||||
|
||||
with in_kernel_invocation_manager(self), maybe_suppress():
|
||||
empty = torch.empty_strided(
|
||||
shape,
|
||||
stride,
|
||||
dtype=metadata.dtype,
|
||||
layout=metadata.layout,
|
||||
device="meta",
|
||||
requires_grad=metadata.requires_grad,
|
||||
)
|
||||
|
||||
if metadata.is_conj:
|
||||
torch._C._set_conj(empty, True)
|
||||
if metadata.is_neg:
|
||||
torch._C._set_neg(empty, True)
|
||||
|
||||
maybe_suppress: Callable[[], typing.ContextManager] = contextlib.nullcontext
|
||||
if self.shape_env is not None:
|
||||
maybe_suppress = self.shape_env.suppress_guards
|
||||
|
||||
if func.is_view:
|
||||
# For view ops, the storage should be the same as the tensor input.
|
||||
view_arg = args[cast(int, entry.view_idx)]
|
||||
assert isinstance(view_arg, FakeTensor)
|
||||
storage = view_arg.untyped_storage()
|
||||
with in_kernel_invocation_manager(self), maybe_suppress():
|
||||
empty.set_(
|
||||
storage, metadata.storage_offset, metadata.shape, metadata.stride
|
||||
)
|
||||
elif metadata.storage_offset != 0:
|
||||
empty.set_(storage, storage_offset, shape, stride)
|
||||
elif storage_offset != 0:
|
||||
storage = empty.untyped_storage()
|
||||
with in_kernel_invocation_manager(self), maybe_suppress():
|
||||
empty.set_(
|
||||
storage, metadata.storage_offset, metadata.shape, metadata.stride
|
||||
)
|
||||
if metadata.storage_bytes == 0:
|
||||
empty.set_(storage, storage_offset, shape, stride)
|
||||
|
||||
if isinstance(storage_bytes, SymInt):
|
||||
# Do it this way so we don't import symbolic_shapes (which imports
|
||||
# expensive sympy) unless we have to.
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
|
||||
zero_bytes = guard_size_oblivious(storage_bytes == 0)
|
||||
else:
|
||||
zero_bytes = storage_bytes == 0
|
||||
if zero_bytes:
|
||||
empty.untyped_storage().resize_(0)
|
||||
|
||||
return FakeTensor(self, empty, metadata.device)
|
||||
@ -1729,7 +1810,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
def maybe_to_real_tensor(t: T) -> Optional[Union[T, Tensor]]:
|
||||
if isinstance(t, FakeTensor):
|
||||
return t.real_tensor
|
||||
elif isinstance(t, SymTypes):
|
||||
elif isinstance(t, py_sym_types):
|
||||
assert self.shape_env is not None
|
||||
return t.node.pytype(
|
||||
t.node.expr.xreplace(self.shape_env.var_to_val).xreplace(
|
||||
@ -1742,7 +1823,6 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
compute_unbacked_bindings,
|
||||
free_unbacked_symbols,
|
||||
SymTypes,
|
||||
)
|
||||
|
||||
nil = object()
|
||||
@ -1787,7 +1867,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
if isinstance(t, FakeTensor):
|
||||
# NB: unconditionally overwrite
|
||||
t.real_tensor = real_t
|
||||
elif isinstance(t, SymTypes) and free_unbacked_symbols(t):
|
||||
elif isinstance(t, py_sym_types) and free_unbacked_symbols(t):
|
||||
if isinstance(t.node.expr, sympy.Symbol):
|
||||
assert self.shape_env is not None
|
||||
self.shape_env.set_unbacked_var_to_val(t.node.expr, real_t)
|
||||
@ -2276,3 +2356,16 @@ from torch._subclasses.fake_impls import ( # noqa: F401
|
||||
op_implementations_checks,
|
||||
stride_incorrect_op,
|
||||
)
|
||||
|
||||
|
||||
@atexit.register
|
||||
def dump_cache_stats() -> None:
|
||||
log.info("FakeTensor cache stats:")
|
||||
log.info(" cache_hits: %s", FakeTensorMode.cache_hits)
|
||||
log.info(" cache_misses: %s", FakeTensorMode.cache_misses)
|
||||
bypasses = FakeTensorMode.cache_bypasses
|
||||
if bypasses:
|
||||
log.info(" cache_bypasses:")
|
||||
width = max(len(k) for k in bypasses)
|
||||
for k, v in sorted(bypasses.items(), key=lambda i: -i[1]):
|
||||
log.info(" %-*s %s", width + 1, f"{k}:", v)
|
||||
|
||||
@ -237,7 +237,7 @@ class MetaTensorDescriber:
|
||||
# NB: We actually don't use storage to do views, but might as well
|
||||
# put it in for accuracy
|
||||
storage = self.describe_storage(t.untyped_storage(), trace=trace)
|
||||
storage_offset = t.storage_offset()
|
||||
storage_offset = t.storage_offset() # type: ignore[assignment]
|
||||
|
||||
stride = None
|
||||
if not (
|
||||
|
||||
@ -26,7 +26,7 @@ from .sym_node import SymNode
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager, nullcontext, AbstractContextManager, ExitStack
|
||||
from dataclasses import dataclass
|
||||
from torch import SymInt, SymFloat, SymBool, Tensor
|
||||
from torch import SymInt, SymBool, Tensor
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._library.fake_class_registry import FakeScriptObject
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode, unset_fake_temporarily, is_fake
|
||||
@ -111,8 +111,7 @@ class _NoDefault:
|
||||
|
||||
no_default = _NoDefault()
|
||||
|
||||
py_sym_types = (SymInt, SymFloat, SymBool)
|
||||
PySymType = Union[SymInt, SymFloat, SymBool]
|
||||
from torch.types import py_sym_types, PySymType
|
||||
|
||||
class _HasMeta(Protocol):
|
||||
meta: Dict[str, PySymType]
|
||||
|
||||
@ -47,7 +47,7 @@ sym_node_log = torch._logging.getArtifactLogger(__name__, "sym_node")
|
||||
__all__ = ["SymNode", "method_to_operator", "magic_methods"]
|
||||
|
||||
|
||||
SymTypes = (SymInt, SymFloat, SymBool)
|
||||
from torch.types import py_sym_types as SymTypes
|
||||
|
||||
|
||||
def _to_symtype(t):
|
||||
@ -123,7 +123,7 @@ class SymNode:
|
||||
"Cannot create SymNode of type "
|
||||
f"{pytype} with incompatible hint of type {type(hint)}"
|
||||
)
|
||||
if self.shape_env._translation_validation_enabled:
|
||||
if self.shape_env and self.shape_env._translation_validation_enabled:
|
||||
# This is technically not TV, but this assert is expensive so
|
||||
# let's only do it when we're already doing expensive things
|
||||
computed_hint = compute_hint()
|
||||
@ -138,15 +138,30 @@ class SymNode:
|
||||
# Record the FX node of the current node if we are doing translation
|
||||
# validation. They will be used for building the input assertions for
|
||||
# the translation validation problem.
|
||||
self.fx_node = (
|
||||
fx_node if self.shape_env._translation_validation_enabled else None
|
||||
tx_validation_en = (
|
||||
self.shape_env and self.shape_env._translation_validation_enabled
|
||||
)
|
||||
self.fx_node = tx_validation_en and fx_node
|
||||
|
||||
def with_shape_env(self, shape_env: "ShapeEnv") -> "SymNode":
|
||||
return SymNode(
|
||||
self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node
|
||||
)
|
||||
|
||||
def _value_eq(self, other: "SymNode") -> bool:
|
||||
# Purposely don't include the shape_env in the eq.
|
||||
return (
|
||||
self._expr == other._expr
|
||||
and self.pytype == other.pytype
|
||||
and self._hint == other._hint
|
||||
and self.constant == other.constant
|
||||
and self.fx_node == other.fx_node
|
||||
)
|
||||
|
||||
def _value_hash(self) -> int:
|
||||
# Purposely don't include the shape_env in the hash.
|
||||
return hash((self._expr, self.pytype, self._hint, self.constant, self.fx_node))
|
||||
|
||||
@property
|
||||
def expr(self):
|
||||
return self.shape_env.replace(self._expr)
|
||||
@ -248,7 +263,7 @@ class SymNode:
|
||||
|
||||
def __repr__(self):
|
||||
rep = [
|
||||
f"SymNode({self.expr}, shape_env={self.shape_env}, pytype={self.pytype}",
|
||||
f"SymNode({self._expr}, shape_env={self.shape_env}, pytype={self.pytype}",
|
||||
]
|
||||
if self._hint is not None:
|
||||
rep.append(f"hint={self._hint}")
|
||||
|
||||
@ -2387,6 +2387,16 @@ class ShapeEnv:
|
||||
[ShapeEnvEvent(ShapeEnv, kwargs=kwargs)] if self.should_record_events else []
|
||||
)
|
||||
|
||||
# FakeTensor per-ShapeEnv operation cache. This is used for caching
|
||||
# operations that contain symbolic shapes which have guards on the
|
||||
# ShapeEnv (so are ShapeEnv-dependent).
|
||||
#
|
||||
# NOTE: It's important that SymNodes in this cache have their ShapeEnv
|
||||
# stripped otherwise you end up with cycles which can only be cleaned
|
||||
# with the GC.
|
||||
self.fake_tensor_cache: Dict[torch._subclasses.fake_tensor._DispatchCacheKey,
|
||||
torch._subclasses.fake_tensor._DispatchCacheEntry] = {}
|
||||
|
||||
# Pro-tip: if you add new field to ShapeEnv, this affects some accept
|
||||
# tests. Accept their output with:
|
||||
#
|
||||
@ -2684,7 +2694,7 @@ class ShapeEnv:
|
||||
elif key == "name_to_node":
|
||||
# Compare just the set of keys is the same.
|
||||
return set(value.keys())
|
||||
elif key in ["symbol_guard_counter", "pending_fresh_unbacked_symbols"]:
|
||||
elif key in ("symbol_guard_counter", "pending_fresh_unbacked_symbols", "fake_tensor_cache"):
|
||||
# Skip this for comparisons
|
||||
return None
|
||||
return value
|
||||
|
||||
@ -17,7 +17,7 @@ from builtins import ( # noqa: F401
|
||||
from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch import SymInt
|
||||
from torch import SymBool, SymFloat, SymInt
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -47,6 +47,9 @@ _dispatchkey = Union[builtins.str, torch._C.DispatchKey]
|
||||
# int or SymInt
|
||||
IntLikeType = Union[_int, torch.SymInt]
|
||||
|
||||
py_sym_types = (SymInt, SymFloat, SymBool)
|
||||
PySymType = Union[SymInt, SymFloat, SymBool]
|
||||
|
||||
# Meta-type for "numeric" things; matches our docs
|
||||
Number = Union[builtins.int, builtins.float, builtins.bool]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user