mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
FakeTensorMode shouldn't cache syms when tracing (#164718)
Improve FakeTensor cache to handle SymNode and tracing properly. For now, when we're proxy tracing just don't bother caching operations that contain SymNodes in the output. The problem is that the proxy tracer relies on SymNode identity and our cache doesn't preserve that. It can be fixed (and I left some notes in _validate_symbolic_output_for_caching() how) but it's not worth it for now. If we aren't proxy tracing then caching is fine. Thus these changes: 1. Our cache key needs to include whether we were actively tracing or not - this way if we create a cache entry when we weren't tracing and then we try to use it when we ARE tracing it gets rerun. 2. If there's a SymNode in the output then bypass tracing. 3. Some general cleanup of the output validation - we were unnecessarily doing it as a two-step process when it could just be a single step (it's still two parts internally but only a single outer try/except). Pull Request resolved: https://github.com/pytorch/pytorch/pull/164718 Approved by: https://github.com/bobrenjc93 ghstack dependencies: #165266, #164717
This commit is contained in:
committed by
PyTorch MergeBot
parent
5f21cc786a
commit
4c1c341fa0
@ -22,8 +22,7 @@ from torch.testing._internal.triton_utils import requires_gpu
|
|||||||
class TestDynamic(DTensorTestBase):
|
class TestDynamic(DTensorTestBase):
|
||||||
@requires_gpu
|
@requires_gpu
|
||||||
@with_comms
|
@with_comms
|
||||||
# FIXME: Currently broken for fake tensor cache
|
@parametrize("fake_tensor_cache_enabled", [False, True])
|
||||||
@parametrize("fake_tensor_cache_enabled", [False])
|
|
||||||
def test_embedding(self, fake_tensor_cache_enabled):
|
def test_embedding(self, fake_tensor_cache_enabled):
|
||||||
with patch.object(
|
with patch.object(
|
||||||
torch._dynamo.config, "fake_tensor_cache_enabled", fake_tensor_cache_enabled
|
torch._dynamo.config, "fake_tensor_cache_enabled", fake_tensor_cache_enabled
|
||||||
|
@ -1538,7 +1538,7 @@ class FakeTensorMode(TorchDispatchMode):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore # bad-argument-type
|
||||||
self._validate_cache_key(func, args, kwargs)
|
entry = self._make_cache_entry(state, key, func, args, kwargs, output)
|
||||||
except _BypassDispatchCache as e:
|
except _BypassDispatchCache as e:
|
||||||
# We ran "extra" checks on the cache key and determined that it's no
|
# We ran "extra" checks on the cache key and determined that it's no
|
||||||
# good. Record the reason and mark it so we don't bother validating
|
# good. Record the reason and mark it so we don't bother validating
|
||||||
@ -1556,16 +1556,6 @@ class FakeTensorMode(TorchDispatchMode):
|
|||||||
set_cache_key(cache, key, _DispatchCacheBypassEntry(e.reason))
|
set_cache_key(cache, key, _DispatchCacheBypassEntry(e.reason))
|
||||||
return output
|
return output
|
||||||
|
|
||||||
try:
|
|
||||||
# pyrefly: ignore # bad-argument-type
|
|
||||||
entry = self._make_cache_entry(state, key, func, args, kwargs, output)
|
|
||||||
except _BypassDispatchCache as e:
|
|
||||||
# We had trouble making the cache entry. Record the reason and mark
|
|
||||||
# it.
|
|
||||||
FakeTensorMode.cache_bypasses[e.reason] += 1
|
|
||||||
set_cache_key(cache, key, _DispatchCacheBypassEntry(e.reason))
|
|
||||||
return output
|
|
||||||
|
|
||||||
set_cache_key(cache, key, entry)
|
set_cache_key(cache, key, entry)
|
||||||
FakeTensorMode.cache_misses += 1
|
FakeTensorMode.cache_misses += 1
|
||||||
return output
|
return output
|
||||||
@ -1581,6 +1571,7 @@ class FakeTensorMode(TorchDispatchMode):
|
|||||||
Create a cache key given the dispatch args. Raises _BypassDispatchCache
|
Create a cache key given the dispatch args. Raises _BypassDispatchCache
|
||||||
for any situation that precludes caching.
|
for any situation that precludes caching.
|
||||||
"""
|
"""
|
||||||
|
is_tracing = torch.fx.experimental.proxy_tensor.get_proxy_mode() is not None
|
||||||
key_values = [
|
key_values = [
|
||||||
func,
|
func,
|
||||||
# Capture the default_dtype mode since that can affect the output tensor,
|
# Capture the default_dtype mode since that can affect the output tensor,
|
||||||
@ -1596,6 +1587,10 @@ class FakeTensorMode(TorchDispatchMode):
|
|||||||
# Disallowing dynamic shapes can introduce a DynamicOutputShapeException
|
# Disallowing dynamic shapes can introduce a DynamicOutputShapeException
|
||||||
# where it wasn't seen on a previous instance of the same op.
|
# where it wasn't seen on a previous instance of the same op.
|
||||||
self.shape_env.settings if self.shape_env else None,
|
self.shape_env.settings if self.shape_env else None,
|
||||||
|
# ProxyTorchDispatchMode needs to track how SymNodes are constructed
|
||||||
|
# so we need to handle things a little different depending on
|
||||||
|
# whether we're tracing or not.
|
||||||
|
is_tracing,
|
||||||
]
|
]
|
||||||
if state.known_symbols:
|
if state.known_symbols:
|
||||||
# If there are symbols then include the epoch - this is really more
|
# If there are symbols then include the epoch - this is really more
|
||||||
@ -1776,11 +1771,9 @@ class FakeTensorMode(TorchDispatchMode):
|
|||||||
if isinstance(output, (int, type(None))):
|
if isinstance(output, (int, type(None))):
|
||||||
return
|
return
|
||||||
|
|
||||||
if _has_unrepresented_symbols(state, output):
|
# Check for symbolic content that should bypass caching - raises
|
||||||
# Unbacked symbols are fine - but only if they're also represented
|
# _BypassDispatchCache if necessary.
|
||||||
# in the input. If there are any new unbacked symbols then we can't
|
_validate_symbolic_output_for_caching(state, output)
|
||||||
# cache this output.
|
|
||||||
raise _BypassDispatchCache("unrepresented symbol in output")
|
|
||||||
|
|
||||||
# Some ops return tuples of Tensors, but it's rare, so avoid
|
# Some ops return tuples of Tensors, but it's rare, so avoid
|
||||||
# the complexity of caching other types.
|
# the complexity of caching other types.
|
||||||
@ -1896,6 +1889,8 @@ class FakeTensorMode(TorchDispatchMode):
|
|||||||
from torch._higher_order_ops.utils import registered_hop_fake_fns
|
from torch._higher_order_ops.utils import registered_hop_fake_fns
|
||||||
from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
|
from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
|
||||||
|
|
||||||
|
self._validate_cache_key(func, args, kwargs)
|
||||||
|
|
||||||
# For hops, lets look at the output tensor to find any unbacked symints.
|
# For hops, lets look at the output tensor to find any unbacked symints.
|
||||||
# If there are none, then we rely on the existing checks to validate
|
# If there are none, then we rely on the existing checks to validate
|
||||||
# caching.
|
# caching.
|
||||||
@ -3072,17 +3067,65 @@ class FakeTensorMode(TorchDispatchMode):
|
|||||||
_StoragePointer = object
|
_StoragePointer = object
|
||||||
|
|
||||||
|
|
||||||
def _has_unrepresented_symbols(
|
def _validate_symbolic_output_for_caching(
|
||||||
state: _CacheKeyState, output: Optional[FakeTensor]
|
state: _CacheKeyState, output: FakeTensor
|
||||||
) -> bool:
|
) -> None:
|
||||||
from torch.fx.experimental.symbolic_shapes import _iterate_exprs
|
"""
|
||||||
|
Validate symbolic content in output and raise _BypassDispatchCache if
|
||||||
|
caching should be bypassed.
|
||||||
|
|
||||||
for s in _iterate_exprs(output):
|
Args:
|
||||||
for symbol in s.free_symbols:
|
state: Cache key state containing known symbols
|
||||||
if symbol not in state.known_symbols:
|
output: Output to validate
|
||||||
return True
|
proxy_mode_active: Whether PROXY dispatch mode is currently active
|
||||||
|
|
||||||
return False
|
Raises: _BypassDispatchCache: If output contains symbolic content that
|
||||||
|
prevents caching
|
||||||
|
|
||||||
|
Details:
|
||||||
|
|
||||||
|
If our output contains any symbols that didn't appear in the input then we
|
||||||
|
need to bypass. Usually this will be unbacked symbols which can't be
|
||||||
|
properly reconstructed but there could be "weird" cases where backed symbols
|
||||||
|
spontaneously appear (from non-input state)?
|
||||||
|
|
||||||
|
If we're proxy (symbol) tracing and the output contains ANY symbols then we
|
||||||
|
need to bypass. The problem is that ProxyTorchDispatchMode relies on SymNode
|
||||||
|
object identity and being able to see the construction of SymNodes.
|
||||||
|
|
||||||
|
We could improve the proxy tracing case in a few ways:
|
||||||
|
|
||||||
|
1. If the output SymNodes are directly copied from inputs then this is
|
||||||
|
actually fine - they're already tracked. This would probably be the
|
||||||
|
biggest bang/buck.
|
||||||
|
|
||||||
|
2. If the output (tensors) are all direct copies of the inputs then this is
|
||||||
|
also fine - since they're inputs they must be tracked. We already compute
|
||||||
|
this we just don't plumb it around enough.
|
||||||
|
|
||||||
|
3. If the output SymNodes are already tracked by the proxy then this is also
|
||||||
|
actually fine - they're properly tracked. This probably wouldn't be
|
||||||
|
common since for most outputs we use torch.empty_strided() and recompute
|
||||||
|
strides.
|
||||||
|
|
||||||
|
4. We could use the proxy to track "how" the SymNodes were computed and when
|
||||||
|
using the cache we could "replay" them properly to teach the proxy how to
|
||||||
|
build them.
|
||||||
|
"""
|
||||||
|
from torch.fx.experimental.symbolic_shapes import _iterate_exprs, _iterate_nodes
|
||||||
|
|
||||||
|
is_tracing = torch.fx.experimental.proxy_tensor.get_proxy_mode() is not None
|
||||||
|
if is_tracing:
|
||||||
|
# Check for SymNode types in PROXY mode - this should bypass caching
|
||||||
|
# regardless of whether symbols are known or not
|
||||||
|
for node in _iterate_nodes(output):
|
||||||
|
raise _BypassDispatchCache("Proxy mode with SymNode output")
|
||||||
|
else:
|
||||||
|
# Check for unrepresented symbols in tensor expressions
|
||||||
|
for s in _iterate_exprs(output):
|
||||||
|
for symbol in s.free_symbols:
|
||||||
|
if symbol not in state.known_symbols:
|
||||||
|
raise _BypassDispatchCache("unrepresented symbol in output")
|
||||||
|
|
||||||
|
|
||||||
# NB: returns fake tensors
|
# NB: returns fake tensors
|
||||||
|
@ -883,11 +883,16 @@ def _iterate_exprs(val: IterateExprs) -> Iterator[sympy.Basic]:
|
|||||||
Raises:
|
Raises:
|
||||||
AssertionError: If the value is of an unsupported type.
|
AssertionError: If the value is of an unsupported type.
|
||||||
"""
|
"""
|
||||||
|
# This is almost close enough to implement in terms of _iterate_nodes()
|
||||||
|
# except that it needs to handle `list[sympy.Basic]` which _iterate_nodes()
|
||||||
|
# can't handle.
|
||||||
if isinstance(val, SymTypes):
|
if isinstance(val, SymTypes):
|
||||||
# This allow applies to the jagged layout NestedTensor case as
|
# This allow applies to the jagged layout NestedTensor case as
|
||||||
# nested ints are not symbolic
|
# nested ints are not symbolic
|
||||||
if is_symbolic(val):
|
if is_symbolic(val):
|
||||||
yield val.node.expr
|
yield val.node.expr
|
||||||
|
elif isinstance(val, SymNode):
|
||||||
|
yield val.expr
|
||||||
elif isinstance(val, sympy.Basic):
|
elif isinstance(val, sympy.Basic):
|
||||||
yield val
|
yield val
|
||||||
elif isinstance(val, (int, float, bool)):
|
elif isinstance(val, (int, float, bool)):
|
||||||
@ -910,6 +915,28 @@ def _iterate_exprs(val: IterateExprs) -> Iterator[sympy.Basic]:
|
|||||||
raise AssertionError(f"cannot extract sympy expressions from {val} {type(val)}")
|
raise AssertionError(f"cannot extract sympy expressions from {val} {type(val)}")
|
||||||
|
|
||||||
|
|
||||||
|
def _iterate_nodes(val: Any) -> Iterator[SymNode]:
|
||||||
|
"""
|
||||||
|
Recursively iterate through a value and yield all SymNodes contained
|
||||||
|
within it.
|
||||||
|
"""
|
||||||
|
if isinstance(val, SymNode):
|
||||||
|
yield val
|
||||||
|
elif isinstance(val, py_sym_types):
|
||||||
|
# This allow applies to the jagged layout NestedTensor case as
|
||||||
|
# nested ints are not symbolic
|
||||||
|
if is_symbolic(val):
|
||||||
|
yield val.node
|
||||||
|
elif isinstance(val, (tuple, list, torch.Size)):
|
||||||
|
for s in val:
|
||||||
|
yield from _iterate_nodes(s)
|
||||||
|
elif isinstance(val, torch.Tensor):
|
||||||
|
yield from _iterate_nodes(val.size())
|
||||||
|
if not is_sparse_any(val):
|
||||||
|
yield from _iterate_nodes(val.stride())
|
||||||
|
yield from _iterate_nodes(val.storage_offset())
|
||||||
|
|
||||||
|
|
||||||
def free_symbols(val: IterateExprs) -> OrderedSet[sympy.Symbol]:
|
def free_symbols(val: IterateExprs) -> OrderedSet[sympy.Symbol]:
|
||||||
"""
|
"""
|
||||||
Recursively collect all free symbols from a value.
|
Recursively collect all free symbols from a value.
|
||||||
|
Reference in New Issue
Block a user