FakeTensorMode shouldn't cache syms when tracing (#164718)

Improve FakeTensor cache to handle SymNode and tracing properly.

For now, when we're proxy tracing just don't bother caching operations that contain SymNodes in the output. The problem is that the proxy tracer relies on SymNode identity and our cache doesn't preserve that. It can be fixed (and I left some notes in _validate_symbolic_output_for_caching() how) but it's not worth it for now.

If we aren't proxy tracing then caching is fine.

Thus these changes:

1. Our cache key needs to include whether we were actively tracing or not - this way if we create a cache entry when we weren't tracing and then we try to use it when we ARE tracing it gets rerun.

2. If there's a SymNode in the output then bypass tracing.

3. Some general cleanup of the output validation - we were unnecessarily doing it as a two-step process when it could just be a single step (it's still two parts internally but only a single outer try/except).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164718
Approved by: https://github.com/bobrenjc93
ghstack dependencies: #165266, #164717
This commit is contained in:
Aaron Orenstein
2025-10-15 14:20:07 -07:00
committed by PyTorch MergeBot
parent 5f21cc786a
commit 4c1c341fa0
3 changed files with 96 additions and 27 deletions

View File

@ -22,8 +22,7 @@ from torch.testing._internal.triton_utils import requires_gpu
class TestDynamic(DTensorTestBase): class TestDynamic(DTensorTestBase):
@requires_gpu @requires_gpu
@with_comms @with_comms
# FIXME: Currently broken for fake tensor cache @parametrize("fake_tensor_cache_enabled", [False, True])
@parametrize("fake_tensor_cache_enabled", [False])
def test_embedding(self, fake_tensor_cache_enabled): def test_embedding(self, fake_tensor_cache_enabled):
with patch.object( with patch.object(
torch._dynamo.config, "fake_tensor_cache_enabled", fake_tensor_cache_enabled torch._dynamo.config, "fake_tensor_cache_enabled", fake_tensor_cache_enabled

View File

@ -1538,7 +1538,7 @@ class FakeTensorMode(TorchDispatchMode):
try: try:
# pyrefly: ignore # bad-argument-type # pyrefly: ignore # bad-argument-type
self._validate_cache_key(func, args, kwargs) entry = self._make_cache_entry(state, key, func, args, kwargs, output)
except _BypassDispatchCache as e: except _BypassDispatchCache as e:
# We ran "extra" checks on the cache key and determined that it's no # We ran "extra" checks on the cache key and determined that it's no
# good. Record the reason and mark it so we don't bother validating # good. Record the reason and mark it so we don't bother validating
@ -1556,16 +1556,6 @@ class FakeTensorMode(TorchDispatchMode):
set_cache_key(cache, key, _DispatchCacheBypassEntry(e.reason)) set_cache_key(cache, key, _DispatchCacheBypassEntry(e.reason))
return output return output
try:
# pyrefly: ignore # bad-argument-type
entry = self._make_cache_entry(state, key, func, args, kwargs, output)
except _BypassDispatchCache as e:
# We had trouble making the cache entry. Record the reason and mark
# it.
FakeTensorMode.cache_bypasses[e.reason] += 1
set_cache_key(cache, key, _DispatchCacheBypassEntry(e.reason))
return output
set_cache_key(cache, key, entry) set_cache_key(cache, key, entry)
FakeTensorMode.cache_misses += 1 FakeTensorMode.cache_misses += 1
return output return output
@ -1581,6 +1571,7 @@ class FakeTensorMode(TorchDispatchMode):
Create a cache key given the dispatch args. Raises _BypassDispatchCache Create a cache key given the dispatch args. Raises _BypassDispatchCache
for any situation that precludes caching. for any situation that precludes caching.
""" """
is_tracing = torch.fx.experimental.proxy_tensor.get_proxy_mode() is not None
key_values = [ key_values = [
func, func,
# Capture the default_dtype mode since that can affect the output tensor, # Capture the default_dtype mode since that can affect the output tensor,
@ -1596,6 +1587,10 @@ class FakeTensorMode(TorchDispatchMode):
# Disallowing dynamic shapes can introduce a DynamicOutputShapeException # Disallowing dynamic shapes can introduce a DynamicOutputShapeException
# where it wasn't seen on a previous instance of the same op. # where it wasn't seen on a previous instance of the same op.
self.shape_env.settings if self.shape_env else None, self.shape_env.settings if self.shape_env else None,
# ProxyTorchDispatchMode needs to track how SymNodes are constructed
# so we need to handle things a little different depending on
# whether we're tracing or not.
is_tracing,
] ]
if state.known_symbols: if state.known_symbols:
# If there are symbols then include the epoch - this is really more # If there are symbols then include the epoch - this is really more
@ -1776,11 +1771,9 @@ class FakeTensorMode(TorchDispatchMode):
if isinstance(output, (int, type(None))): if isinstance(output, (int, type(None))):
return return
if _has_unrepresented_symbols(state, output): # Check for symbolic content that should bypass caching - raises
# Unbacked symbols are fine - but only if they're also represented # _BypassDispatchCache if necessary.
# in the input. If there are any new unbacked symbols then we can't _validate_symbolic_output_for_caching(state, output)
# cache this output.
raise _BypassDispatchCache("unrepresented symbol in output")
# Some ops return tuples of Tensors, but it's rare, so avoid # Some ops return tuples of Tensors, but it's rare, so avoid
# the complexity of caching other types. # the complexity of caching other types.
@ -1896,6 +1889,8 @@ class FakeTensorMode(TorchDispatchMode):
from torch._higher_order_ops.utils import registered_hop_fake_fns from torch._higher_order_ops.utils import registered_hop_fake_fns
from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
self._validate_cache_key(func, args, kwargs)
# For hops, lets look at the output tensor to find any unbacked symints. # For hops, lets look at the output tensor to find any unbacked symints.
# If there are none, then we rely on the existing checks to validate # If there are none, then we rely on the existing checks to validate
# caching. # caching.
@ -3072,17 +3067,65 @@ class FakeTensorMode(TorchDispatchMode):
_StoragePointer = object _StoragePointer = object
def _has_unrepresented_symbols( def _validate_symbolic_output_for_caching(
state: _CacheKeyState, output: Optional[FakeTensor] state: _CacheKeyState, output: FakeTensor
) -> bool: ) -> None:
from torch.fx.experimental.symbolic_shapes import _iterate_exprs """
Validate symbolic content in output and raise _BypassDispatchCache if
caching should be bypassed.
for s in _iterate_exprs(output): Args:
for symbol in s.free_symbols: state: Cache key state containing known symbols
if symbol not in state.known_symbols: output: Output to validate
return True proxy_mode_active: Whether PROXY dispatch mode is currently active
return False Raises: _BypassDispatchCache: If output contains symbolic content that
prevents caching
Details:
If our output contains any symbols that didn't appear in the input then we
need to bypass. Usually this will be unbacked symbols which can't be
properly reconstructed but there could be "weird" cases where backed symbols
spontaneously appear (from non-input state)?
If we're proxy (symbol) tracing and the output contains ANY symbols then we
need to bypass. The problem is that ProxyTorchDispatchMode relies on SymNode
object identity and being able to see the construction of SymNodes.
We could improve the proxy tracing case in a few ways:
1. If the output SymNodes are directly copied from inputs then this is
actually fine - they're already tracked. This would probably be the
biggest bang/buck.
2. If the output (tensors) are all direct copies of the inputs then this is
also fine - since they're inputs they must be tracked. We already compute
this we just don't plumb it around enough.
3. If the output SymNodes are already tracked by the proxy then this is also
actually fine - they're properly tracked. This probably wouldn't be
common since for most outputs we use torch.empty_strided() and recompute
strides.
4. We could use the proxy to track "how" the SymNodes were computed and when
using the cache we could "replay" them properly to teach the proxy how to
build them.
"""
from torch.fx.experimental.symbolic_shapes import _iterate_exprs, _iterate_nodes
is_tracing = torch.fx.experimental.proxy_tensor.get_proxy_mode() is not None
if is_tracing:
# Check for SymNode types in PROXY mode - this should bypass caching
# regardless of whether symbols are known or not
for node in _iterate_nodes(output):
raise _BypassDispatchCache("Proxy mode with SymNode output")
else:
# Check for unrepresented symbols in tensor expressions
for s in _iterate_exprs(output):
for symbol in s.free_symbols:
if symbol not in state.known_symbols:
raise _BypassDispatchCache("unrepresented symbol in output")
# NB: returns fake tensors # NB: returns fake tensors

View File

@ -883,11 +883,16 @@ def _iterate_exprs(val: IterateExprs) -> Iterator[sympy.Basic]:
Raises: Raises:
AssertionError: If the value is of an unsupported type. AssertionError: If the value is of an unsupported type.
""" """
# This is almost close enough to implement in terms of _iterate_nodes()
# except that it needs to handle `list[sympy.Basic]` which _iterate_nodes()
# can't handle.
if isinstance(val, SymTypes): if isinstance(val, SymTypes):
# This allow applies to the jagged layout NestedTensor case as # This allow applies to the jagged layout NestedTensor case as
# nested ints are not symbolic # nested ints are not symbolic
if is_symbolic(val): if is_symbolic(val):
yield val.node.expr yield val.node.expr
elif isinstance(val, SymNode):
yield val.expr
elif isinstance(val, sympy.Basic): elif isinstance(val, sympy.Basic):
yield val yield val
elif isinstance(val, (int, float, bool)): elif isinstance(val, (int, float, bool)):
@ -910,6 +915,28 @@ def _iterate_exprs(val: IterateExprs) -> Iterator[sympy.Basic]:
raise AssertionError(f"cannot extract sympy expressions from {val} {type(val)}") raise AssertionError(f"cannot extract sympy expressions from {val} {type(val)}")
def _iterate_nodes(val: Any) -> Iterator[SymNode]:
"""
Recursively iterate through a value and yield all SymNodes contained
within it.
"""
if isinstance(val, SymNode):
yield val
elif isinstance(val, py_sym_types):
# This allow applies to the jagged layout NestedTensor case as
# nested ints are not symbolic
if is_symbolic(val):
yield val.node
elif isinstance(val, (tuple, list, torch.Size)):
for s in val:
yield from _iterate_nodes(s)
elif isinstance(val, torch.Tensor):
yield from _iterate_nodes(val.size())
if not is_sparse_any(val):
yield from _iterate_nodes(val.stride())
yield from _iterate_nodes(val.storage_offset())
def free_symbols(val: IterateExprs) -> OrderedSet[sympy.Symbol]: def free_symbols(val: IterateExprs) -> OrderedSet[sympy.Symbol]:
""" """
Recursively collect all free symbols from a value. Recursively collect all free symbols from a value.