Consolidate SymDispatchMode into ProxyTensorMode (#132674)

Instead of having a separate context variable for SymDispatchMode, we
now simply delegate to the current active proxy tensor mode when we
need to trace a SymInt.  We maintain a separate `__sym_dispatch__` magic
method as the calling convention is different than `__torch_dispatch__`.

Consolidating the modes in this ways means that we can consistently
disable both of these modes in tandem simply by removing the mode
from the proxy mode infra slot.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132674
Approved by: https://github.com/zou3519, https://github.com/bdhirsh
This commit is contained in:
Edward Z. Yang
2024-08-08 04:59:11 -07:00
committed by PyTorch MergeBot
parent 0f19d4150b
commit 361db32d47
9 changed files with 90 additions and 136 deletions

1
.github/labeler.yml vendored
View File

@ -29,7 +29,6 @@
- torch/fx/experimental/recording.py - torch/fx/experimental/recording.py
- torch/fx/experimental/sym_node.py - torch/fx/experimental/sym_node.py
- torch/fx/experimental/validator.py - torch/fx/experimental/validator.py
- torch/fx/experimental/_sym_dispatch_mode.py
- torch/fx/experimental/proxy_tensor.py - torch/fx/experimental/proxy_tensor.py
- test/distributed/_tensor/test_dtensor_compile.py - test/distributed/_tensor/test_dtensor_compile.py
- test/distributed/tensor/parallel/test_fsdp_2d_parallel.py - test/distributed/tensor/parallel/test_fsdp_2d_parallel.py

View File

@ -850,7 +850,6 @@ coverage_ignore_functions = [
"get_torch_dispatch_modes", "get_torch_dispatch_modes",
"has_proxy_slot", "has_proxy_slot",
"is_sym_node", "is_sym_node",
"make_fx",
"maybe_disable_fake_tensor_mode", "maybe_disable_fake_tensor_mode",
"maybe_handle_decomp", "maybe_handle_decomp",
"proxy_call", "proxy_call",

View File

@ -51,3 +51,17 @@ torch.fx.experimental.symbolic_shapes
compute_unbacked_bindings compute_unbacked_bindings
rebind_unbacked rebind_unbacked
resolve_unbacked_bindings resolve_unbacked_bindings
torch.fx.experimental.proxy_tensor
-------------------------------------
.. currentmodule:: torch.fx.experimental.proxy_tensor
.. automodule:: torch.fx.experimental.proxy_tensor
.. autosummary::
:toctree: generated
:nosignatures:
make_fx
handle_sym_dispatch
get_proxy_mode

View File

@ -1143,7 +1143,6 @@ API Reference
.. py:module:: torch.fx.experimental.normalize .. py:module:: torch.fx.experimental.normalize
.. py:module:: torch.fx.experimental.optimization .. py:module:: torch.fx.experimental.optimization
.. py:module:: torch.fx.experimental.partitioner_utils .. py:module:: torch.fx.experimental.partitioner_utils
.. py:module:: torch.fx.experimental.proxy_tensor
.. py:module:: torch.fx.experimental.recording .. py:module:: torch.fx.experimental.recording
.. py:module:: torch.fx.experimental.refinement_types .. py:module:: torch.fx.experimental.refinement_types
.. py:module:: torch.fx.experimental.rewriter .. py:module:: torch.fx.experimental.rewriter

View File

@ -112,7 +112,6 @@ class AutogradCompilerInstance:
# TODO(jansel): are all these modes needed? # TODO(jansel): are all these modes needed?
self.stack.enter_context(decompose({})) self.stack.enter_context(decompose({}))
self.stack.enter_context(self.fake_tensor_mode) self.stack.enter_context(self.fake_tensor_mode)
self.stack.enter_context(self.proxy_mode.sym_mode)
self.stack.enter_context(self.proxy_mode) self.stack.enter_context(self.proxy_mode)
self.stack.enter_context(disable_autocast_cache()) self.stack.enter_context(disable_autocast_cache())
self.stack.enter_context(preserve_node_meta()) self.stack.enter_context(preserve_node_meta())

View File

@ -25,7 +25,6 @@ from weakref import ReferenceType
import torch import torch
import torch._logging import torch._logging
import torch.fx.experimental._sym_dispatch_mode
from torch._C._dynamo.guards import GlobalStateGuard from torch._C._dynamo.guards import GlobalStateGuard
from torch._dynamo.distributed import get_compile_pg from torch._dynamo.distributed import get_compile_pg
from torch._guards import compile_context, CompileContext, CompileId, tracing from torch._guards import compile_context, CompileContext, CompileId, tracing
@ -1234,9 +1233,7 @@ class CatchErrorsWrapper:
frame, cache_entry, self.hooks, frame_state frame, cache_entry, self.hooks, frame_state
) )
with ( with compile_lock, _disable_current_modes():
compile_lock
), _disable_current_modes(), torch.fx.experimental._sym_dispatch_mode.disable_sym_dispatch():
# skip=1: skip this frame # skip=1: skip this frame
return self._torchdynamo_orig_callable( return self._torchdynamo_orig_callable(
frame, cache_entry, self.hooks, frame_state, skip=1 frame, cache_entry, self.hooks, frame_state, skip=1

View File

@ -1,72 +0,0 @@
# mypy: allow-untyped-defs
import contextlib
from typing import List, Optional, Type
__all__ = ["SymDispatchMode", "handle_sym_dispatch", "sym_function_mode"]
SYM_FUNCTION_MODE: Optional["SymDispatchMode"] = None
# SymDispatchMode gets invoked whenever an operation is processed on
# a PySymInt. When this occurs, you get called at __sym_dispatch__
# with the operation in question. This is symmetric to TorchDispatchMode
# but with some caveats:
#
# - In TorchDispatchMode, you get the same arguments as what a user
# invoked your API with; e.g., if you call torch.ops.aten.foo(a, b),
# you get (a, b) as args to your call. In SymDispatchMode, if
# you call a + b (where a and b are SymInts), you will get
# (a.node, b.node) as your args (these are PySymInts)
#
# - SymInt/PySymInt don't have FX proxy support (unlike, e.g., Tensor).
# So you have to manually call Tracer/create_node to write into
# the graph. See ProxySymDispatchMode for an example
#
class SymDispatchMode:
def __sym_dispatch__(self, func, types, args, kwargs):
raise NotImplementedError
def __enter__(self):
global SYM_FUNCTION_MODE
old = SYM_FUNCTION_MODE
if hasattr(self, "inner"):
raise RuntimeError(
f"{self} has already been used as a mode. Please use a fresh version"
)
else:
self.inner = old
SYM_FUNCTION_MODE = self
return self
def __exit__(self, exc_type, exc_val, exc_tb):
global SYM_FUNCTION_MODE
SYM_FUNCTION_MODE = self.inner
def handle_sym_dispatch(func, args, kwargs):
global SYM_FUNCTION_MODE
mode = sym_function_mode()
assert mode
SYM_FUNCTION_MODE = mode.inner
try:
# TODO: properly compute types
types: List[Type] = []
return mode.__sym_dispatch__(func, types, args, kwargs)
finally:
SYM_FUNCTION_MODE = mode
def sym_function_mode():
return SYM_FUNCTION_MODE
@contextlib.contextmanager
def disable_sym_dispatch():
global SYM_FUNCTION_MODE
old = SYM_FUNCTION_MODE
SYM_FUNCTION_MODE = None
try:
yield
finally:
SYM_FUNCTION_MODE = old

View File

@ -22,13 +22,13 @@ import warnings
import weakref import weakref
from ._backward_state import BackwardState from ._backward_state import BackwardState
from ._sym_dispatch_mode import SymDispatchMode
from .sym_node import SymNode from .sym_node import SymNode
from torch.utils._thunk import Thunk from torch.utils._thunk import Thunk
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager, nullcontext, AbstractContextManager, ExitStack from contextlib import contextmanager, nullcontext, AbstractContextManager, ExitStack
from dataclasses import dataclass from dataclasses import dataclass
from torch import SymInt, SymBool, Tensor from torch import SymInt, SymBool, Tensor
import torch._ops
from torch._dispatch.python import enable_python_dispatcher from torch._dispatch.python import enable_python_dispatcher
from torch._library.fake_class_registry import FakeScriptObject from torch._library.fake_class_registry import FakeScriptObject
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode, unset_fake_temporarily, is_fake from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode, unset_fake_temporarily, is_fake
@ -59,7 +59,10 @@ if TYPE_CHECKING:
from torch.fx._symbolic_trace import PHBase from torch.fx._symbolic_trace import PHBase
from torch.types import IntLikeType from torch.types import IntLikeType
__all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", "py_sym_types", "get_innermost_proxy_mode"] __all__ = [
"PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter",
"py_sym_types", "get_innermost_proxy_mode", "get_proxy_mode", "handle_sym_dispatch"
]
_ProxyTracer = Union["PythonKeyTracer", "_GraphAppendingTracerEx"] _ProxyTracer = Union["PythonKeyTracer", "_GraphAppendingTracerEx"]
@ -1006,7 +1009,10 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
class ProxyTorchDispatchMode(TorchDispatchMode): class ProxyTorchDispatchMode(TorchDispatchMode):
_managers: List[AbstractContextManager] # Ensure this is read-only; this exists only for legacy reasons
@property
def enable_tracing(self) -> bool:
return True
def __init__( def __init__(
self, self,
@ -1020,12 +1026,9 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
super().__init__(dk) super().__init__(dk)
self.tracer = tracer self.tracer = tracer
self.tracing_mode = tracing_mode self.tracing_mode = tracing_mode
self.enable_tracing = True
self.pre_dispatch = pre_dispatch self.pre_dispatch = pre_dispatch
self._allow_fake_constant = _allow_fake_constant self._allow_fake_constant = _allow_fake_constant
self._error_on_data_dependent_ops = _error_on_data_dependent_ops self._error_on_data_dependent_ops = _error_on_data_dependent_ops
self.sym_mode = ProxySymDispatchMode(tracer)
self._managers = []
# Indicates to our torch_dispatch dispatching infra that # Indicates to our torch_dispatch dispatching infra that
# this is an "infra" mode with lower dispatching precedence. # this is an "infra" mode with lower dispatching precedence.
self._mode_key = torch._C._TorchDispatchModeKey.PROXY self._mode_key = torch._C._TorchDispatchModeKey.PROXY
@ -1045,14 +1048,10 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
args: Tuple[object, ...] = (), args: Tuple[object, ...] = (),
kwargs: Optional[Dict[str, object]] = None kwargs: Optional[Dict[str, object]] = None
) -> object: ) -> object:
with self.sym_mode.enable(False), set_original_aten_op(func): with set_original_aten_op(func):
return self.inner_torch_dispatch(func, types, args, kwargs) return self.inner_torch_dispatch(func, types, args, kwargs)
def __enter__(self) -> Self: def __enter__(self) -> Self:
# sym mode first, then us...
m = self.sym_mode.enable(True)
self._managers.append(m)
m.__enter__()
# Stash and store the previous proxy mode (there may or may not be one) # Stash and store the previous proxy mode (there may or may not be one)
maybe_prev_proxy_mode = _unset_infra_mode(torch._C._TorchDispatchModeKey.PROXY) maybe_prev_proxy_mode = _unset_infra_mode(torch._C._TorchDispatchModeKey.PROXY)
self.enter_stack.append(maybe_prev_proxy_mode) self.enter_stack.append(maybe_prev_proxy_mode)
@ -1064,8 +1063,6 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
exc_value: Optional[BaseException], exc_value: Optional[BaseException],
traceback: Optional[types.TracebackType] traceback: Optional[types.TracebackType]
) -> Optional[bool]: ) -> Optional[bool]:
m = self._managers.pop()
# ...exit us first, then sym mode
b = super().__exit__(exc_type, exc_value, traceback) b = super().__exit__(exc_type, exc_value, traceback)
# Re-enable the previous proxy mode, if there was one. # Re-enable the previous proxy mode, if there was one.
@ -1073,11 +1070,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
if mb_previous_proxy_mode is not None: if mb_previous_proxy_mode is not None:
_push_mode(mb_previous_proxy_mode) _push_mode(mb_previous_proxy_mode)
if not b: return b
return m.__exit__(exc_type, exc_value, traceback)
else:
return m.__exit__(None, None, None)
def inner_torch_dispatch( def inner_torch_dispatch(
self, self,
@ -1088,9 +1081,6 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
) -> object: ) -> object:
kwargs = kwargs or {} kwargs = kwargs or {}
if not self.enable_tracing:
return func(*args, **kwargs)
if func in (prim.device.default,): if func in (prim.device.default,):
return func(*args, **kwargs) return func(*args, **kwargs)
@ -1100,25 +1090,6 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
def is_infra_mode(cls) -> bool: def is_infra_mode(cls) -> bool:
return True return True
class ProxySymDispatchMode(SymDispatchMode):
def __init__(self, tracer: _ProxyTracer) -> None:
super().__init__()
self.tracer = tracer
# When false, we don't trace operations. If you do this, you MUST
# call track_tensor/track_tensor_tree on all results of the operation
# to ensure we can adequately track the results
self.enable_tracing = True
@contextmanager
def enable(self, b: bool) -> Generator[None, None, None]:
old = self.enable_tracing
self.enable_tracing = b
try:
yield
finally:
self.enable_tracing = old
def _compute_proxy(self, func: OpOverload, args: Tuple[object, ...], out: PySymType) -> Proxy: def _compute_proxy(self, func: OpOverload, args: Tuple[object, ...], out: PySymType) -> Proxy:
n_args = tuple( n_args = tuple(
get_proxy_slot(a, self.tracer).force().node if isinstance(a, py_sym_types) else a get_proxy_slot(a, self.tracer).force().node if isinstance(a, py_sym_types) else a
@ -1139,9 +1110,6 @@ class ProxySymDispatchMode(SymDispatchMode):
args: Tuple[object, ...], args: Tuple[object, ...],
kwargs: Dict[str, object] kwargs: Dict[str, object]
) -> object: ) -> object:
if not self.enable_tracing:
return func(*args, **kwargs)
# Peephole optimize multiply by one # Peephole optimize multiply by one
# NB: be careful not to trigger guards here! # NB: be careful not to trigger guards here!
if func == operator.mul: if func == operator.mul:
@ -1727,7 +1695,6 @@ class _MakefxTracer:
stack.enter_context(self.fake_tensor_mode) stack.enter_context(self.fake_tensor_mode)
stack.enter_context(self.python_dispatcher_mode) stack.enter_context(self.python_dispatcher_mode)
stack.enter_context(self.proxy_function_mode) stack.enter_context(self.proxy_function_mode)
stack.enter_context(proxy_mode.sym_mode)
stack.enter_context(self.torch_fn_metadata_mode) stack.enter_context(self.torch_fn_metadata_mode)
stack.enter_context(proxy_mode) stack.enter_context(proxy_mode)
stack.enter_context(disable_autocast_cache()) stack.enter_context(disable_autocast_cache())
@ -1787,8 +1754,13 @@ def make_fx(
_allow_fake_constant: bool = False, _allow_fake_constant: bool = False,
_error_on_data_dependent_ops: bool = True) -> Callable[..., GraphModule]: _error_on_data_dependent_ops: bool = True) -> Callable[..., GraphModule]:
assert tracing_mode in ["real", "fake", "symbolic"] """
Given a function f, return a new function which when executed with valid
arguments to f, returns an FX GraphModule representing the set of operations that
were executed during the course of execution.
"""
assert tracing_mode in ["real", "fake", "symbolic"]
make_fx_tracer = _MakefxTracer( make_fx_tracer = _MakefxTracer(
decomposition_table, decomposition_table,
@ -1810,8 +1782,38 @@ def get_torch_dispatch_modes() -> List[TorchDispatchMode]:
return torch.utils._python_dispatch._get_current_dispatch_mode_stack() return torch.utils._python_dispatch._get_current_dispatch_mode_stack()
def get_innermost_proxy_mode() -> ProxyTorchDispatchMode: # TODO: this is a legacy name, there is only ever one proxy mode as it's an
return torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) # infra mode
def get_innermost_proxy_mode() -> Optional[ProxyTorchDispatchMode]:
return get_proxy_mode()
def get_proxy_mode() -> Optional[ProxyTorchDispatchMode]:
"""
Current the currently active proxy tracing mode, or None if
we are not currently tracing. This includes pre-dispatch proxy
tracing.
"""
pre_dispatch_mode = torch._ops._get_dispatch_mode_pre_dispatch(torch._C._TorchDispatchModeKey.PROXY)
mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY)
assert pre_dispatch_mode is None or mode is None, f"pre_dispatch_mode={pre_dispatch_mode}, mode={mode}"
return pre_dispatch_mode or mode
def handle_sym_dispatch(func: Callable[_P, R], args: _P.args, kwargs: _P.kwargs) -> R:
"""
Call into the currently active proxy tracing mode to do a
SymInt/SymFloat/SymBool dispatch trace on a function that operates on
these arguments.
"""
mode = get_proxy_mode()
assert mode
# Have to do it manually, because we're not doing the normal torch
# dispatch machinery which disables it for us
with disable_proxy_modes_tracing():
# TODO: properly compute types
types: List[Type] = []
return mode.__sym_dispatch__(func, types, args, kwargs) # type: ignore[arg-type, return-value]
@contextmanager @contextmanager

View File

@ -31,10 +31,6 @@ from torch import ( # noqa: F401
SymFloat, SymFloat,
SymInt, SymInt,
) )
from torch.fx.experimental._sym_dispatch_mode import (
handle_sym_dispatch,
sym_function_mode,
)
if TYPE_CHECKING: if TYPE_CHECKING:
@ -1055,6 +1051,10 @@ def _make_node_magic(method, func):
method_attr = method method_attr = method
def binary_magic_impl(self, other): def binary_magic_impl(self, other):
from torch.fx.experimental.proxy_tensor import (
get_proxy_mode,
handle_sym_dispatch,
)
from torch.fx.experimental.symbolic_shapes import safe_expand from torch.fx.experimental.symbolic_shapes import safe_expand
op = method_to_operator(method) op = method_to_operator(method)
@ -1067,7 +1067,7 @@ def _make_node_magic(method, func):
if alternate_impl and out_hint is not None: if alternate_impl and out_hint is not None:
return to_node(self, alternate_impl(wrap_node(self), wrap_node(other))) return to_node(self, alternate_impl(wrap_node(self), wrap_node(other)))
if sym_function_mode(): if get_proxy_mode():
return to_node( return to_node(
self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {}) self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})
) )
@ -1129,10 +1129,14 @@ def _make_node_magic(method, func):
return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)
def unary_magic_impl(self): def unary_magic_impl(self):
from torch.fx.experimental.proxy_tensor import (
get_proxy_mode,
handle_sym_dispatch,
)
from torch.fx.experimental.symbolic_shapes import safe_expand from torch.fx.experimental.symbolic_shapes import safe_expand
op = method_to_operator(method) op = method_to_operator(method)
if sym_function_mode(): if get_proxy_mode():
return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {})) return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {}))
# TODO: consider constant prop here # TODO: consider constant prop here
expr = self.expr expr = self.expr
@ -1167,10 +1171,14 @@ def _make_node_magic(method, func):
elif method == "sym_ite": elif method == "sym_ite":
def sym_ite_impl(pred_node, then_node, else_node): def sym_ite_impl(pred_node, then_node, else_node):
from torch.fx.experimental.proxy_tensor import (
get_proxy_mode,
handle_sym_dispatch,
)
from torch.fx.experimental.symbolic_shapes import safe_expand from torch.fx.experimental.symbolic_shapes import safe_expand
out_hint = then_node.hint if pred_node.hint else else_node.hint out_hint = then_node.hint if pred_node.hint else else_node.hint
if sym_function_mode(): if get_proxy_mode():
return to_node( return to_node(
pred_node, pred_node,
handle_sym_dispatch( handle_sym_dispatch(
@ -1208,10 +1216,14 @@ def _make_node_magic(method, func):
elif method == "round": elif method == "round":
def round_impl(self, ndigits=None): def round_impl(self, ndigits=None):
from torch.fx.experimental.proxy_tensor import (
get_proxy_mode,
handle_sym_dispatch,
)
from torch.fx.experimental.symbolic_shapes import safe_expand from torch.fx.experimental.symbolic_shapes import safe_expand
op = builtins.round op = builtins.round
if sym_function_mode(): if get_proxy_mode():
return to_node( return to_node(
self, handle_sym_dispatch(op, (wrap_node(self), ndigits), {}) self, handle_sym_dispatch(op, (wrap_node(self), ndigits), {})
) )
@ -1256,8 +1268,13 @@ def _make_node_sizes_strides(method, func):
# NB: don't LRU cache, lots of arguments # NB: don't LRU cache, lots of arguments
def sizes_strides_impl(self, sizes, strides): def sizes_strides_impl(self, sizes, strides):
from torch.fx.experimental.proxy_tensor import (
get_proxy_mode,
handle_sym_dispatch,
)
op = getattr(sys.modules[__name__], method) op = getattr(sys.modules[__name__], method)
if sym_function_mode(): if get_proxy_mode():
return to_node( return to_node(
self, self,
handle_sym_dispatch( handle_sym_dispatch(