Revert "Fix fake tensor caching when output has unbacked (#153034)"

This reverts commit cb5f31a4a164a4fa1eaa627f9b15cdc18aa95ef1.

Reverted https://github.com/pytorch/pytorch/pull/153034 on behalf of https://github.com/malfet due to Seems to have introduced flakiness in MacOS inductor tests, see https://github.com/pytorch/pytorch/issues/153891 ([comment](https://github.com/pytorch/pytorch/pull/153034#issuecomment-2893059329))
This commit is contained in:
PyTorch MergeBot
2025-05-20 06:02:38 +00:00
parent 9849c79fa2
commit 1075bb37d3
3 changed files with 48 additions and 144 deletions

View File

@ -2265,10 +2265,13 @@ class FakeTensorDispatchCache(TestCase):
gc.collect()
self.assertTrue(count_invoke_subgraph_keys() == 0)
@skipIfTorchDynamo("cache hit/miss changes with invoke_subgraph caching")
def test_invoke_subgraph_cacheable_inplace(self):
invoke_subgraph = torch._higher_order_ops.invoke_subgraph
def fn(x, y):
# aten ops are used so that eager backend graph is suitable for fake
# tensor testing
@ -2314,32 +2317,5 @@ class FakeTensorDispatchCache(TestCase):
extract_tensor_metadata(b),
)
@skipIfTorchDynamo("cache hit/miss changes with invoke_subgraph caching")
def test_unbacked_output(self):
# The point of this test is to have an op which has no symbols as input
# but a symbol as an output and make sure that we skip caching it.
class LengthsGather(torch.nn.Module):
def forward(
self,
input: torch.Tensor,
lengths: torch.Tensor,
indices: torch.Tensor,
offsets: torch.Tensor,
) -> torch.Tensor:
bias = torch.gather(offsets, 0, indices)
lengths_selected = torch.gather(lengths, 0, indices)
index = torch.repeat_interleave(bias, lengths_selected, dim=0)
return index
input = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
lengths = torch.tensor([0, 2, 3, 1, 4])
indices = torch.tensor([2, 3, 4, 6, 7, 8, 9])
offsets = torch.cumsum(lengths, 0)
ep = torch.export.export(LengthsGather(), (input, lengths, indices, offsets), strict=False)
FakeTensorMode.cache_clear()
ep.run_decompositions({})
self.assertBypasses("unrepresented symbol in output", 2)
if __name__ == "__main__":
run_tests()

View File

@ -218,11 +218,6 @@ class _CacheKeyState:
# matches one of the inputs so we can uncache it properly.
sym_node_lookup: dict[int, int] # id(SymNode) -> index
# This is a list of all seen input sympy.Symbols. We use it when building
# the cache entry to see if the output value has any symbols that we didn't
# see on input. See _has_unrepresented_symbols().
known_symbols: set[sympy.Symbol]
# There are cases where we're asked to perform an op when we have no
# ShapeEnv on the FakeTensorMode - but for SymNodes we MUST have a
# ShapeEnv. So as we scan if we see a SymNode (with a ShapeEnv) we record it
@ -231,7 +226,6 @@ class _CacheKeyState:
def __init__(self, shape_env: Optional[ShapeEnv] = None) -> None:
self.sym_node_lookup = {}
self.known_symbols = set()
self.shape_env = shape_env
def cache_on_shape_env(self) -> bool:
@ -253,7 +247,6 @@ class _CacheKeyState:
result.append(_InputBackref(self.sym_node_lookup[node_id]))
else:
self.sym_node_lookup[node_id] = len(result)
self.known_symbols.update(arg.node.expr.free_symbols)
if self.shape_env is None:
self.shape_env = arg.node.shape_env
result.append(_PySymInputStub(arg))

View File

@ -74,6 +74,12 @@ except ValueError as e:
raise e
class _Unassigned:
pass
_UNASSIGNED = _Unassigned()
DimList = list
pytree = torch.utils._pytree
@ -1112,7 +1118,7 @@ class _DispatchCacheEntryOutputInfo:
@dataclass_slots
@dataclass(frozen=True)
class _DispatchCacheValidEntry:
class _DispatchCacheEntry:
"""
Entry type for the FakeTensor dispatch cache. It supports two types of outputs
1) tensor
@ -1125,20 +1131,6 @@ class _DispatchCacheValidEntry:
is_output_tuple: bool = False
@dataclass_slots
@dataclass(frozen=True)
class _DispatchCacheBypassEntry:
"""
Entry type for a negative cache entry.
"""
reason: str
if TYPE_CHECKING:
_DispatchCacheEntry = Union[_DispatchCacheValidEntry, _DispatchCacheBypassEntry]
@dataclass_slots
@dataclass(frozen=True)
class _BypassDispatchCache(Exception):
@ -1426,64 +1418,37 @@ class FakeTensorMode(TorchDispatchMode):
Lookup a cache entry for the given arguments. If none exists, dispatch
and cache the result (if the result is eligible for caching).
"""
output: object = _UNASSIGNED
try:
state = _CacheKeyState(self.shape_env)
key = self._cache_key(state, func, args, kwargs)
if state.cache_on_shape_env():
assert state.shape_env is not None
cache = state.shape_env.fake_tensor_cache
else:
cache = FakeTensorMode.cache
entry = cache.get(key, None)
if entry is not None:
output = self._output_from_cache_entry(state, entry, key, func, args)
FakeTensorMode.cache_hits += 1
if self.cache_crosscheck_enabled:
# For debugging / testing: Validate that the output synthesized
# from the cache matches the output created by normal dispatch.
with disable_fake_tensor_cache(self):
self._crosscheck_cache_output(output, func, types, args, kwargs)
else:
self._validate_cache_key(func, args, kwargs)
output = self._dispatch_impl(func, types, args, kwargs)
entry = self._make_cache_entry(state, key, func, args, kwargs, output)
key.strip_shape_env()
cache[key] = entry
FakeTensorMode.cache_misses += 1
except _BypassDispatchCache as e:
# We couldn't create the cache key at all
FakeTensorMode.cache_bypasses[e.reason] += 1
return self._dispatch_impl(func, types, args, kwargs)
if state.cache_on_shape_env():
assert state.shape_env is not None
cache = state.shape_env.fake_tensor_cache
set_cache_key = _set_cache_key_for_shape_env
else:
cache = FakeTensorMode.cache
set_cache_key = _set_cache_key
entry = cache.get(key, None)
if output is _UNASSIGNED:
output = self._dispatch_impl(func, types, args, kwargs)
if entry is not None:
if isinstance(entry, _DispatchCacheBypassEntry):
# This represents a negative cache entry - we already saw that the
# output is uncachable. Compute it from first principals.
FakeTensorMode.cache_bypasses[entry.reason] += 1
return self._dispatch_impl(func, types, args, kwargs)
# We have a cache entry.
output = self._output_from_cache_entry(state, entry, key, func, args)
FakeTensorMode.cache_hits += 1
if self.cache_crosscheck_enabled:
# For debugging / testing: Validate that the output synthesized
# from the cache matches the output created by normal dispatch.
with disable_fake_tensor_cache(self):
self._crosscheck_cache_output(output, func, types, args, kwargs)
return output
# We don't have a cache entry.
output = self._dispatch_impl(func, types, args, kwargs)
try:
self._validate_cache_key(func, args, kwargs)
except _BypassDispatchCache as e:
# We ran "extra" checks on the cache key and determined that it's no
# good. Record the reason and mark it so we don't bother validating
# again.
FakeTensorMode.cache_bypasses[e.reason] += 1
set_cache_key(cache, key, _DispatchCacheBypassEntry(e.reason))
return output
try:
entry = self._make_cache_entry(state, key, func, args, kwargs, output)
except _BypassDispatchCache as e:
# We had trouble making the cache entry. Record the reason and mark
# it.
FakeTensorMode.cache_bypasses[e.reason] += 1
set_cache_key(cache, key, _DispatchCacheBypassEntry(e.reason))
return output
set_cache_key(cache, key, entry)
FakeTensorMode.cache_misses += 1
return output
def _cache_key(
@ -1669,17 +1634,17 @@ class FakeTensorMode(TorchDispatchMode):
kwargs: Mapping[str, object],
output: Optional[FakeTensor],
) -> None:
# Is this even possible? According to the signature this can be None but
# not `int`. So either the signature is a lie or (part of) this line is
# unnecessary...
from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
if isinstance(output, (int, type(None))):
return
if _has_unrepresented_symbols(state, output):
# Unbacked symbols are fine - but only if they're also represented
# in the input. If there are any new unbacked symbols then we can't
# cache this output.
raise _BypassDispatchCache("unrepresented symbol in output")
if isinstance(output, torch.SymInt):
if has_free_unbacked_symbols(output):
# This is unreachable but adding the check for extra safety in
# case we change code path in future.
raise _BypassDispatchCache("unbacked symbol in output")
return
# Some ops return tuples of Tensors, but it's rare, so avoid
# the complexity of caching other types.
@ -1753,7 +1718,7 @@ class FakeTensorMode(TorchDispatchMode):
# we can synthesize a tensor here and do the checks on that instance.
# This approach keeps the (more frequent) cache-hit path as lightweight
# as possible.
entry_for_synth_output = _DispatchCacheValidEntry(
entry_for_synth_output = _DispatchCacheEntry(
output_infos=(entry,), is_output_tuple=False
)
synth_output = self._output_from_cache_entry(
@ -1777,7 +1742,7 @@ class FakeTensorMode(TorchDispatchMode):
args: Sequence[object],
kwargs: Mapping[str, object],
output: Optional[FakeTensor],
) -> _DispatchCacheValidEntry:
) -> _DispatchCacheEntry:
"""
Make a cache entry object for the given 'output' Tensor. Raises
_BypassDispatchCache if the output tensor has characteristics that
@ -1808,7 +1773,7 @@ class FakeTensorMode(TorchDispatchMode):
output_info = _DispatchCacheEntryOutputInfo(
inplace_idx=None, metadata=None, view_idx=None, constant_value=output
)
return _DispatchCacheValidEntry(
return _DispatchCacheEntry(
output_infos=(output_info,), is_output_tuple=False
)
@ -1829,7 +1794,7 @@ class FakeTensorMode(TorchDispatchMode):
)
for out_elem in output
]
return _DispatchCacheValidEntry(
return _DispatchCacheEntry(
output_infos=tuple(output_infos), is_output_tuple=True
)
@ -1837,7 +1802,7 @@ class FakeTensorMode(TorchDispatchMode):
output_info = self._get_output_info_for_cache_entry(
state, key, func, args, kwargs, output
)
return _DispatchCacheValidEntry(
return _DispatchCacheEntry(
output_infos=(output_info,), is_output_tuple=False
)
@ -1917,7 +1882,7 @@ class FakeTensorMode(TorchDispatchMode):
def _output_from_cache_entry(
self,
state: _CacheKeyState,
entry: _DispatchCacheValidEntry,
entry: _DispatchCacheEntry,
key: _DispatchCacheKey,
func: OpOverload,
args: Sequence[object],
@ -2921,19 +2886,6 @@ class FakeTensorMode(TorchDispatchMode):
_StoragePointer = object
def _has_unrepresented_symbols(
state: _CacheKeyState, output: Optional[FakeTensor]
) -> bool:
from torch.fx.experimental.symbolic_shapes import _iterate_exprs
for s in _iterate_exprs(output):
for symbol in s.free_symbols:
if symbol not in state.known_symbols:
return True
return False
# NB: returns fake tensors
def run_fallback_kernel(
fake_mode: FakeTensorMode,
@ -2999,23 +2951,6 @@ def run_fallback_kernel(
return pytree.tree_map(map_out, r)
def _set_cache_key_for_shape_env(
cache: dict[_DispatchCacheKey, _DispatchCacheEntry],
key: _DispatchCacheKey,
entry: _DispatchCacheEntry,
) -> None:
key.strip_shape_env()
cache[key] = entry
def _set_cache_key(
cache: dict[_DispatchCacheKey, _DispatchCacheEntry],
key: _DispatchCacheKey,
entry: _DispatchCacheEntry,
) -> None:
cache[key] = entry
# Just for use to allow copying a module to fake tensors,
# does not apply elsewhere
class FakeCopyMode(TorchFunctionMode):