Re-enable FakeTensor caching for SymInts (#152662)

Summary:

This backs out D60320595 which itself turned off FakeTensor caching when a SymInt was present.

There has been a lot of dynamic shape fixes done this year and tests pass so I'm assuming some of that work fixed what was breaking previously.

Test Plan: Reran the tests listed in T196779132 and they pass.

## Perf
### Instruction Counter Benchmark:
- 26% win on add_loop_eager_dynamic
- 13% win on add_loop_inductor_dynamic_gpu
### Perf Dashboard
Compilation Latency wins across the board but especially strong on the dynamic tests (like cudagraphs_dynamic) - for example MobileBertForMaskedLM went from 66s -> 50s.

Differential Revision: [D75467694](https://our.internmc.facebook.com/intern/diff/D75467694)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152662
Approved by: https://github.com/anijain2305
This commit is contained in:
Aaron Orenstein
2025-05-28 20:11:30 -07:00
committed by PyTorch MergeBot
parent 3027051590
commit fc0135ca11
9 changed files with 72 additions and 53 deletions

View File

@ -1521,6 +1521,10 @@ class FakeTensorMode(TorchDispatchMode):
# where it wasn't seen on a previous instance of the same op.
self.shape_env.settings if self.shape_env else None,
]
if state.known_symbols:
# If there are symbols then include the epoch - this is really more
# of a Shape env var which lives on the FakeTensorMode.
key_values.append(self.epoch)
# Collect the id_hashed objects to attach a weakref finalize later
id_hashed_objects: list[object] = []
# Translate any FakeTensor args to metadata.
@ -1632,10 +1636,6 @@ class FakeTensorMode(TorchDispatchMode):
raise _BypassDispatchCache("constant attribute")
if is_sparse_any(arg):
raise _BypassDispatchCache(f"{arg.layout} tensor")
# FIXME: For now back out caching when there are symbolic nbytes
# - this doesn't seem to play nice with set(). See T196779132 for examples.
if isinstance(arg.untyped_storage().nbytes(), SymInt):
raise _BypassDispatchCache("symbolic nbytes")
metadata = extract_tensor_metadata(arg)
metadata._flatten_into(result, self, state)
elif isinstance(arg, Tensor):
@ -1764,9 +1764,18 @@ class FakeTensorMode(TorchDispatchMode):
entry_for_synth_output = _DispatchCacheValidEntry(
output_infos=(entry,), is_output_tuple=False
)
synth_output = self._output_from_cache_entry(
state, entry_for_synth_output, key, func, args
)
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode
try:
synth_output = self._output_from_cache_entry(
state, entry_for_synth_output, key, func, args
)
except GuardOnDataDependentSymNode:
# This should probably never really happen. If it does it means that
# although the original call didn't get a data-dependent error when
# we tried to reconstruct the output we did - that's almost
# certainly a bug.
raise _BypassDispatchCache("data dependent symnode") from None
# Make sure the dispatch_key_set from the synthesized output tensor will
# be the same.
@ -1937,11 +1946,7 @@ class FakeTensorMode(TorchDispatchMode):
if entry.is_output_tuple:
outputs = [
self._get_output_tensor_from_cache_entry(
state,
output_info,
key,
func,
args,
state, output_info, key, func, args
)
for output_info in entry.output_infos
]
@ -1974,8 +1979,8 @@ class FakeTensorMode(TorchDispatchMode):
assert isinstance(b, int) and a == b
elif a is None:
assert b is None
elif isinstance(a, torch.SymInt):
assert a is b
elif isinstance(a, py_sym_types):
assert type(a) == type(b) and a.node is b.node
elif isinstance(a, torch.Tensor):
assert isinstance(b, torch.Tensor)
assert_metadata_eq(assert_eq, a, b)