Re-enable FakeTensor caching for SymInts (#152662)

Summary:

This backs out D60320595 which itself turned off FakeTensor caching when a SymInt was present.

There has been a lot of dynamic shape fixes done this year and tests pass so I'm assuming some of that work fixed what was breaking previously.

Test Plan: Reran the tests listed in T196779132 and they pass.

## Perf
### Instruction Counter Benchmark:
- 26% win on add_loop_eager_dynamic
- 13% win on add_loop_inductor_dynamic_gpu
### Perf Dashboard
Compilation Latency wins across the board but especially strong on the dynamic tests (like cudagraphs_dynamic) - for example MobileBertForMaskedLM went from 66s -> 50s.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152662
Approved by: https://github.com/anijain2305
This commit is contained in:
Aaron Orenstein
2025-05-25 12:28:17 -07:00
committed by PyTorch MergeBot
parent 062387fb53
commit 7d11c61c26
6 changed files with 38 additions and 34 deletions

View File

@ -1521,6 +1521,10 @@ class FakeTensorMode(TorchDispatchMode):
# where it wasn't seen on a previous instance of the same op.
self.shape_env.settings if self.shape_env else None,
]
if state.known_symbols:
# If there are symbols then include the epoch - this is really more
# of a Shape env var which lives on the FakeTensorMode.
key_values.append(self.epoch)
# Collect the id_hashed objects to attach a weakref finalize later
id_hashed_objects: list[object] = []
# Translate any FakeTensor args to metadata.
@ -1632,10 +1636,6 @@ class FakeTensorMode(TorchDispatchMode):
raise _BypassDispatchCache("constant attribute")
if is_sparse_any(arg):
raise _BypassDispatchCache(f"{arg.layout} tensor")
# FIXME: For now back out caching when there are symbolic nbytes
# - this doesn't seem to play nice with set(). See T196779132 for examples.
if isinstance(arg.untyped_storage().nbytes(), SymInt):
raise _BypassDispatchCache("symbolic nbytes")
metadata = extract_tensor_metadata(arg)
metadata._flatten_into(result, self, state)
elif isinstance(arg, Tensor):
@ -1937,11 +1937,7 @@ class FakeTensorMode(TorchDispatchMode):
if entry.is_output_tuple:
outputs = [
self._get_output_tensor_from_cache_entry(
state,
output_info,
key,
func,
args,
state, output_info, key, func, args
)
for output_info in entry.output_infos
]