mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Re-enable FakeTensor caching for SymInts (#152662)
Summary: This backs out D60320595 which itself turned off FakeTensor caching when a SymInt was present. There has been a lot of dynamic shape fixes done this year and tests pass so I'm assuming some of that work fixed what was breaking previously. Test Plan: Reran the tests listed in T196779132 and they pass. ## Perf ### Instruction Counter Benchmark: - 26% win on add_loop_eager_dynamic - 13% win on add_loop_inductor_dynamic_gpu ### Perf Dashboard Compilation Latency wins across the board but especially strong on the dynamic tests (like cudagraphs_dynamic) - for example MobileBertForMaskedLM went from 66s -> 50s. Differential Revision: [D75467694](https://our.internmc.facebook.com/intern/diff/D75467694) Pull Request resolved: https://github.com/pytorch/pytorch/pull/152662 Approved by: https://github.com/anijain2305
This commit is contained in:
committed by
PyTorch MergeBot
parent
3027051590
commit
fc0135ca11
@ -1521,6 +1521,10 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
# where it wasn't seen on a previous instance of the same op.
|
||||
self.shape_env.settings if self.shape_env else None,
|
||||
]
|
||||
if state.known_symbols:
|
||||
# If there are symbols then include the epoch - this is really more
|
||||
# of a Shape env var which lives on the FakeTensorMode.
|
||||
key_values.append(self.epoch)
|
||||
# Collect the id_hashed objects to attach a weakref finalize later
|
||||
id_hashed_objects: list[object] = []
|
||||
# Translate any FakeTensor args to metadata.
|
||||
@ -1632,10 +1636,6 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
raise _BypassDispatchCache("constant attribute")
|
||||
if is_sparse_any(arg):
|
||||
raise _BypassDispatchCache(f"{arg.layout} tensor")
|
||||
# FIXME: For now back out caching when there are symbolic nbytes
|
||||
# - this doesn't seem to play nice with set(). See T196779132 for examples.
|
||||
if isinstance(arg.untyped_storage().nbytes(), SymInt):
|
||||
raise _BypassDispatchCache("symbolic nbytes")
|
||||
metadata = extract_tensor_metadata(arg)
|
||||
metadata._flatten_into(result, self, state)
|
||||
elif isinstance(arg, Tensor):
|
||||
@ -1764,9 +1764,18 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
entry_for_synth_output = _DispatchCacheValidEntry(
|
||||
output_infos=(entry,), is_output_tuple=False
|
||||
)
|
||||
synth_output = self._output_from_cache_entry(
|
||||
state, entry_for_synth_output, key, func, args
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode
|
||||
|
||||
try:
|
||||
synth_output = self._output_from_cache_entry(
|
||||
state, entry_for_synth_output, key, func, args
|
||||
)
|
||||
except GuardOnDataDependentSymNode:
|
||||
# This should probably never really happen. If it does it means that
|
||||
# although the original call didn't get a data-dependent error when
|
||||
# we tried to reconstruct the output we did - that's almost
|
||||
# certainly a bug.
|
||||
raise _BypassDispatchCache("data dependent symnode") from None
|
||||
|
||||
# Make sure the dispatch_key_set from the synthesized output tensor will
|
||||
# be the same.
|
||||
@ -1937,11 +1946,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
if entry.is_output_tuple:
|
||||
outputs = [
|
||||
self._get_output_tensor_from_cache_entry(
|
||||
state,
|
||||
output_info,
|
||||
key,
|
||||
func,
|
||||
args,
|
||||
state, output_info, key, func, args
|
||||
)
|
||||
for output_info in entry.output_infos
|
||||
]
|
||||
@ -1974,8 +1979,8 @@ class FakeTensorMode(TorchDispatchMode):
|
||||
assert isinstance(b, int) and a == b
|
||||
elif a is None:
|
||||
assert b is None
|
||||
elif isinstance(a, torch.SymInt):
|
||||
assert a is b
|
||||
elif isinstance(a, py_sym_types):
|
||||
assert type(a) == type(b) and a.node is b.node
|
||||
elif isinstance(a, torch.Tensor):
|
||||
assert isinstance(b, torch.Tensor)
|
||||
assert_metadata_eq(assert_eq, a, b)
|
||||
|
Reference in New Issue
Block a user