mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "Fix fake tensor caching when output has unbacked (#153034)"
This reverts commit cb5f31a4a164a4fa1eaa627f9b15cdc18aa95ef1. Reverted https://github.com/pytorch/pytorch/pull/153034 on behalf of https://github.com/malfet due to Seems to have introduced flakiness in MacOS inductor tests, see https://github.com/pytorch/pytorch/issues/153891 ([comment](https://github.com/pytorch/pytorch/pull/153034#issuecomment-2893059329))
This commit is contained in:
@ -2265,10 +2265,13 @@ class FakeTensorDispatchCache(TestCase):
|
||||
gc.collect()
|
||||
self.assertTrue(count_invoke_subgraph_keys() == 0)
|
||||
|
||||
|
||||
|
||||
@skipIfTorchDynamo("cache hit/miss changes with invoke_subgraph caching")
|
||||
def test_invoke_subgraph_cacheable_inplace(self):
|
||||
invoke_subgraph = torch._higher_order_ops.invoke_subgraph
|
||||
|
||||
|
||||
def fn(x, y):
|
||||
# aten ops are used so that eager backend graph is suitable for fake
|
||||
# tensor testing
|
||||
@ -2314,32 +2317,5 @@ class FakeTensorDispatchCache(TestCase):
|
||||
extract_tensor_metadata(b),
|
||||
)
|
||||
|
||||
@skipIfTorchDynamo("cache hit/miss changes with invoke_subgraph caching")
|
||||
def test_unbacked_output(self):
|
||||
# The point of this test is to have an op which has no symbols as input
|
||||
# but a symbol as an output and make sure that we skip caching it.
|
||||
class LengthsGather(torch.nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
offsets: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
bias = torch.gather(offsets, 0, indices)
|
||||
lengths_selected = torch.gather(lengths, 0, indices)
|
||||
index = torch.repeat_interleave(bias, lengths_selected, dim=0)
|
||||
return index
|
||||
|
||||
input = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
lengths = torch.tensor([0, 2, 3, 1, 4])
|
||||
indices = torch.tensor([2, 3, 4, 6, 7, 8, 9])
|
||||
offsets = torch.cumsum(lengths, 0)
|
||||
ep = torch.export.export(LengthsGather(), (input, lengths, indices, offsets), strict=False)
|
||||
|
||||
FakeTensorMode.cache_clear()
|
||||
ep.run_decompositions({})
|
||||
self.assertBypasses("unrepresented symbol in output", 2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -218,11 +218,6 @@ class _CacheKeyState:
|
||||
# matches one of the inputs so we can uncache it properly.
|
||||
sym_node_lookup: dict[int, int] # id(SymNode) -> index
|
||||
|
||||
# This is a list of all seen input sympy.Symbols. We use it when building
|
||||
# the cache entry to see if the output value has any symbols that we didn't
|
||||
# see on input. See _has_unrepresented_symbols().
|
||||
known_symbols: set[sympy.Symbol]
|
||||
|
||||
# There are cases where we're asked to perform an op when we have no
|
||||
# ShapeEnv on the FakeTensorMode - but for SymNodes we MUST have a
|
||||
# ShapeEnv. So as we scan if we see a SymNode (with a ShapeEnv) we record it
|
||||
@ -231,7 +226,6 @@ class _CacheKeyState:
|
||||
|
||||
def __init__(self, shape_env: Optional[ShapeEnv] = None) -> None:
|
||||
self.sym_node_lookup = {}
|
||||
self.known_symbols = set()
|
||||
self.shape_env = shape_env
|
||||
|
||||
def cache_on_shape_env(self) -> bool:
|
||||
@ -253,7 +247,6 @@ class _CacheKeyState:
|
||||
result.append(_InputBackref(self.sym_node_lookup[node_id]))
|
||||
else:
|
||||
self.sym_node_lookup[node_id] = len(result)
|
||||
self.known_symbols.update(arg.node.expr.free_symbols)
|
||||
if self.shape_env is None:
|
||||
self.shape_env = arg.node.shape_env
|
||||
result.append(_PySymInputStub(arg))
|
||||
|
@ -74,6 +74,12 @@ except ValueError as e:
|
||||
raise e
|
||||
|
||||
|
||||
class _Unassigned:
|
||||
pass
|
||||
|
||||
|
||||
_UNASSIGNED = _Unassigned()
|
||||
|
||||
DimList = list
|
||||
|
||||
pytree = torch.utils._pytree
|
||||
@ -1112,7 +1118,7 @@ class _DispatchCacheEntryOutputInfo:
|
||||
|
||||
@dataclass_slots
|
||||
@dataclass(frozen=True)
|
||||
class _DispatchCacheValidEntry:
|
||||
class _DispatchCacheEntry:
|
||||
"""
|
||||
Entry type for the FakeTensor dispatch cache. It supports two types of outputs
|
||||
1) tensor
|
||||
@ -1125,20 +1131,6 @@ class _DispatchCacheValidEntry:
|
||||
is_output_tuple: bool = False
|
||||
|
||||
|
||||
@dataclass_slots
|
||||
@dataclass(frozen=True)
|
||||
class _DispatchCacheBypassEntry:
|
||||
"""
|
||||
Entry type for a negative cache entry.
|
||||
"""
|
||||
|
||||
reason: str
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
_DispatchCacheEntry = Union[_DispatchCacheValidEntry, _DispatchCacheBypassEntry]
|
||||
|
||||
|
||||
@dataclass_slots
|
||||
@dataclass(frozen=True)
|
||||
class _BypassDispatchCache(Exception):
|
||||
@ -1426,64 +1418,37 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
Lookup a cache entry for the given arguments. If none exists, dispatch
|
||||
and cache the result (if the result is eligible for caching).
|
||||
"""
|
||||
output: object = _UNASSIGNED
|
||||
try:
|
||||
state = _CacheKeyState(self.shape_env)
|
||||
key = self._cache_key(state, func, args, kwargs)
|
||||
if state.cache_on_shape_env():
|
||||
assert state.shape_env is not None
|
||||
cache = state.shape_env.fake_tensor_cache
|
||||
else:
|
||||
cache = FakeTensorMode.cache
|
||||
entry = cache.get(key, None)
|
||||
if entry is not None:
|
||||
output = self._output_from_cache_entry(state, entry, key, func, args)
|
||||
FakeTensorMode.cache_hits += 1
|
||||
if self.cache_crosscheck_enabled:
|
||||
# For debugging / testing: Validate that the output synthesized
|
||||
# from the cache matches the output created by normal dispatch.
|
||||
with disable_fake_tensor_cache(self):
|
||||
self._crosscheck_cache_output(output, func, types, args, kwargs)
|
||||
else:
|
||||
self._validate_cache_key(func, args, kwargs)
|
||||
output = self._dispatch_impl(func, types, args, kwargs)
|
||||
entry = self._make_cache_entry(state, key, func, args, kwargs, output)
|
||||
key.strip_shape_env()
|
||||
cache[key] = entry
|
||||
FakeTensorMode.cache_misses += 1
|
||||
except _BypassDispatchCache as e:
|
||||
# We couldn't create the cache key at all
|
||||
FakeTensorMode.cache_bypasses[e.reason] += 1
|
||||
return self._dispatch_impl(func, types, args, kwargs)
|
||||
|
||||
if state.cache_on_shape_env():
|
||||
assert state.shape_env is not None
|
||||
cache = state.shape_env.fake_tensor_cache
|
||||
set_cache_key = _set_cache_key_for_shape_env
|
||||
else:
|
||||
cache = FakeTensorMode.cache
|
||||
set_cache_key = _set_cache_key
|
||||
entry = cache.get(key, None)
|
||||
if output is _UNASSIGNED:
|
||||
output = self._dispatch_impl(func, types, args, kwargs)
|
||||
|
||||
if entry is not None:
|
||||
if isinstance(entry, _DispatchCacheBypassEntry):
|
||||
# This represents a negative cache entry - we already saw that the
|
||||
# output is uncachable. Compute it from first principals.
|
||||
FakeTensorMode.cache_bypasses[entry.reason] += 1
|
||||
return self._dispatch_impl(func, types, args, kwargs)
|
||||
|
||||
# We have a cache entry.
|
||||
output = self._output_from_cache_entry(state, entry, key, func, args)
|
||||
FakeTensorMode.cache_hits += 1
|
||||
if self.cache_crosscheck_enabled:
|
||||
# For debugging / testing: Validate that the output synthesized
|
||||
# from the cache matches the output created by normal dispatch.
|
||||
with disable_fake_tensor_cache(self):
|
||||
self._crosscheck_cache_output(output, func, types, args, kwargs)
|
||||
return output
|
||||
|
||||
# We don't have a cache entry.
|
||||
output = self._dispatch_impl(func, types, args, kwargs)
|
||||
|
||||
try:
|
||||
self._validate_cache_key(func, args, kwargs)
|
||||
except _BypassDispatchCache as e:
|
||||
# We ran "extra" checks on the cache key and determined that it's no
|
||||
# good. Record the reason and mark it so we don't bother validating
|
||||
# again.
|
||||
FakeTensorMode.cache_bypasses[e.reason] += 1
|
||||
set_cache_key(cache, key, _DispatchCacheBypassEntry(e.reason))
|
||||
return output
|
||||
|
||||
try:
|
||||
entry = self._make_cache_entry(state, key, func, args, kwargs, output)
|
||||
except _BypassDispatchCache as e:
|
||||
# We had trouble making the cache entry. Record the reason and mark
|
||||
# it.
|
||||
FakeTensorMode.cache_bypasses[e.reason] += 1
|
||||
set_cache_key(cache, key, _DispatchCacheBypassEntry(e.reason))
|
||||
return output
|
||||
|
||||
set_cache_key(cache, key, entry)
|
||||
FakeTensorMode.cache_misses += 1
|
||||
return output
|
||||
|
||||
def _cache_key(
|
||||
@ -1669,17 +1634,17 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
kwargs: Mapping[str, object],
|
||||
output: Optional[FakeTensor],
|
||||
) -> None:
|
||||
# Is this even possible? According to the signature this can be None but
|
||||
# not `int`. So either the signature is a lie or (part of) this line is
|
||||
# unnecessary...
|
||||
from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
|
||||
|
||||
if isinstance(output, (int, type(None))):
|
||||
return
|
||||
|
||||
if _has_unrepresented_symbols(state, output):
|
||||
# Unbacked symbols are fine - but only if they're also represented
|
||||
# in the input. If there are any new unbacked symbols then we can't
|
||||
# cache this output.
|
||||
raise _BypassDispatchCache("unrepresented symbol in output")
|
||||
if isinstance(output, torch.SymInt):
|
||||
if has_free_unbacked_symbols(output):
|
||||
# This is unreachable but adding the check for extra safety in
|
||||
# case we change code path in future.
|
||||
raise _BypassDispatchCache("unbacked symbol in output")
|
||||
return
|
||||
|
||||
# Some ops return tuples of Tensors, but it's rare, so avoid
|
||||
# the complexity of caching other types.
|
||||
@ -1753,7 +1718,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
# we can synthesize a tensor here and do the checks on that instance.
|
||||
# This approach keeps the (more frequent) cache-hit path as lightweight
|
||||
# as possible.
|
||||
entry_for_synth_output = _DispatchCacheValidEntry(
|
||||
entry_for_synth_output = _DispatchCacheEntry(
|
||||
output_infos=(entry,), is_output_tuple=False
|
||||
)
|
||||
synth_output = self._output_from_cache_entry(
|
||||
@ -1777,7 +1742,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
args: Sequence[object],
|
||||
kwargs: Mapping[str, object],
|
||||
output: Optional[FakeTensor],
|
||||
) -> _DispatchCacheValidEntry:
|
||||
) -> _DispatchCacheEntry:
|
||||
"""
|
||||
Make a cache entry object for the given 'output' Tensor. Raises
|
||||
_BypassDispatchCache if the output tensor has characteristics that
|
||||
@ -1808,7 +1773,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
output_info = _DispatchCacheEntryOutputInfo(
|
||||
inplace_idx=None, metadata=None, view_idx=None, constant_value=output
|
||||
)
|
||||
return _DispatchCacheValidEntry(
|
||||
return _DispatchCacheEntry(
|
||||
output_infos=(output_info,), is_output_tuple=False
|
||||
)
|
||||
|
||||
@ -1829,7 +1794,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
)
|
||||
for out_elem in output
|
||||
]
|
||||
return _DispatchCacheValidEntry(
|
||||
return _DispatchCacheEntry(
|
||||
output_infos=tuple(output_infos), is_output_tuple=True
|
||||
)
|
||||
|
||||
@ -1837,7 +1802,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
output_info = self._get_output_info_for_cache_entry(
|
||||
state, key, func, args, kwargs, output
|
||||
)
|
||||
return _DispatchCacheValidEntry(
|
||||
return _DispatchCacheEntry(
|
||||
output_infos=(output_info,), is_output_tuple=False
|
||||
)
|
||||
|
||||
@ -1917,7 +1882,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
def _output_from_cache_entry(
|
||||
self,
|
||||
state: _CacheKeyState,
|
||||
entry: _DispatchCacheValidEntry,
|
||||
entry: _DispatchCacheEntry,
|
||||
key: _DispatchCacheKey,
|
||||
func: OpOverload,
|
||||
args: Sequence[object],
|
||||
@ -2921,19 +2886,6 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
_StoragePointer = object
|
||||
|
||||
|
||||
def _has_unrepresented_symbols(
|
||||
state: _CacheKeyState, output: Optional[FakeTensor]
|
||||
) -> bool:
|
||||
from torch.fx.experimental.symbolic_shapes import _iterate_exprs
|
||||
|
||||
for s in _iterate_exprs(output):
|
||||
for symbol in s.free_symbols:
|
||||
if symbol not in state.known_symbols:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# NB: returns fake tensors
|
||||
def run_fallback_kernel(
|
||||
fake_mode: FakeTensorMode,
|
||||
@ -2999,23 +2951,6 @@ def run_fallback_kernel(
|
||||
return pytree.tree_map(map_out, r)
|
||||
|
||||
|
||||
def _set_cache_key_for_shape_env(
|
||||
cache: dict[_DispatchCacheKey, _DispatchCacheEntry],
|
||||
key: _DispatchCacheKey,
|
||||
entry: _DispatchCacheEntry,
|
||||
) -> None:
|
||||
key.strip_shape_env()
|
||||
cache[key] = entry
|
||||
|
||||
|
||||
def _set_cache_key(
|
||||
cache: dict[_DispatchCacheKey, _DispatchCacheEntry],
|
||||
key: _DispatchCacheKey,
|
||||
entry: _DispatchCacheEntry,
|
||||
) -> None:
|
||||
cache[key] = entry
|
||||
|
||||
|
||||
# Just for use to allow copying a module to fake tensors,
|
||||
# does not apply elsewhere
|
||||
class FakeCopyMode(TorchFunctionMode):
|
||||
|
Reference in New Issue
Block a user