diff --git a/.github/labeler.yml b/.github/labeler.yml index f436ec684ffb..c6b6cc8118b4 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -29,7 +29,6 @@ - torch/fx/experimental/recording.py - torch/fx/experimental/sym_node.py - torch/fx/experimental/validator.py -- torch/fx/experimental/_sym_dispatch_mode.py - torch/fx/experimental/proxy_tensor.py - test/distributed/_tensor/test_dtensor_compile.py - test/distributed/tensor/parallel/test_fsdp_2d_parallel.py diff --git a/docs/source/conf.py b/docs/source/conf.py index 9305bf96a8c1..83930344daf8 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -850,7 +850,6 @@ coverage_ignore_functions = [ "get_torch_dispatch_modes", "has_proxy_slot", "is_sym_node", - "make_fx", "maybe_disable_fake_tensor_mode", "maybe_handle_decomp", "proxy_call", diff --git a/docs/source/fx.experimental.rst b/docs/source/fx.experimental.rst index d6885eb41ca0..e3496df82716 100644 --- a/docs/source/fx.experimental.rst +++ b/docs/source/fx.experimental.rst @@ -51,3 +51,17 @@ torch.fx.experimental.symbolic_shapes compute_unbacked_bindings rebind_unbacked resolve_unbacked_bindings + +torch.fx.experimental.proxy_tensor +------------------------------------- + +.. currentmodule:: torch.fx.experimental.proxy_tensor +.. automodule:: torch.fx.experimental.proxy_tensor + +.. autosummary:: + :toctree: generated + :nosignatures: + + make_fx + handle_sym_dispatch + get_proxy_mode diff --git a/docs/source/fx.rst b/docs/source/fx.rst index ac1f2349a955..eb0e145c4bb6 100644 --- a/docs/source/fx.rst +++ b/docs/source/fx.rst @@ -1143,7 +1143,6 @@ API Reference .. py:module:: torch.fx.experimental.normalize .. py:module:: torch.fx.experimental.optimization .. py:module:: torch.fx.experimental.partitioner_utils -.. py:module:: torch.fx.experimental.proxy_tensor .. py:module:: torch.fx.experimental.recording .. py:module:: torch.fx.experimental.refinement_types .. py:module:: torch.fx.experimental.rewriter diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index 199aa5717d86..b1397d9a1c09 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -112,7 +112,6 @@ class AutogradCompilerInstance: # TODO(jansel): are all these modes needed? self.stack.enter_context(decompose({})) self.stack.enter_context(self.fake_tensor_mode) - self.stack.enter_context(self.proxy_mode.sym_mode) self.stack.enter_context(self.proxy_mode) self.stack.enter_context(disable_autocast_cache()) self.stack.enter_context(preserve_node_meta()) diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 74c82aad3944..4785b65cb28b 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -25,7 +25,6 @@ from weakref import ReferenceType import torch import torch._logging -import torch.fx.experimental._sym_dispatch_mode from torch._C._dynamo.guards import GlobalStateGuard from torch._dynamo.distributed import get_compile_pg from torch._guards import compile_context, CompileContext, CompileId, tracing @@ -1234,9 +1233,7 @@ class CatchErrorsWrapper: frame, cache_entry, self.hooks, frame_state ) - with ( - compile_lock - ), _disable_current_modes(), torch.fx.experimental._sym_dispatch_mode.disable_sym_dispatch(): + with compile_lock, _disable_current_modes(): # skip=1: skip this frame return self._torchdynamo_orig_callable( frame, cache_entry, self.hooks, frame_state, skip=1 diff --git a/torch/fx/experimental/_sym_dispatch_mode.py b/torch/fx/experimental/_sym_dispatch_mode.py deleted file mode 100644 index b5fd2d2caf6c..000000000000 --- a/torch/fx/experimental/_sym_dispatch_mode.py +++ /dev/null @@ -1,72 +0,0 @@ -# mypy: allow-untyped-defs -import contextlib -from typing import List, Optional, Type - - -__all__ = ["SymDispatchMode", "handle_sym_dispatch", "sym_function_mode"] - -SYM_FUNCTION_MODE: Optional["SymDispatchMode"] = None - - -# SymDispatchMode gets invoked whenever an operation is processed on -# a PySymInt. When this occurs, you get called at __sym_dispatch__ -# with the operation in question. This is symmetric to TorchDispatchMode -# but with some caveats: -# -# - In TorchDispatchMode, you get the same arguments as what a user -# invoked your API with; e.g., if you call torch.ops.aten.foo(a, b), -# you get (a, b) as args to your call. In SymDispatchMode, if -# you call a + b (where a and b are SymInts), you will get -# (a.node, b.node) as your args (these are PySymInts) -# -# - SymInt/PySymInt don't have FX proxy support (unlike, e.g., Tensor). -# So you have to manually call Tracer/create_node to write into -# the graph. See ProxySymDispatchMode for an example -# -class SymDispatchMode: - def __sym_dispatch__(self, func, types, args, kwargs): - raise NotImplementedError - - def __enter__(self): - global SYM_FUNCTION_MODE - old = SYM_FUNCTION_MODE - if hasattr(self, "inner"): - raise RuntimeError( - f"{self} has already been used as a mode. Please use a fresh version" - ) - else: - self.inner = old - SYM_FUNCTION_MODE = self - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - global SYM_FUNCTION_MODE - SYM_FUNCTION_MODE = self.inner - - -def handle_sym_dispatch(func, args, kwargs): - global SYM_FUNCTION_MODE - mode = sym_function_mode() - assert mode - SYM_FUNCTION_MODE = mode.inner - try: - # TODO: properly compute types - types: List[Type] = [] - return mode.__sym_dispatch__(func, types, args, kwargs) - finally: - SYM_FUNCTION_MODE = mode - - -def sym_function_mode(): - return SYM_FUNCTION_MODE - - -@contextlib.contextmanager -def disable_sym_dispatch(): - global SYM_FUNCTION_MODE - old = SYM_FUNCTION_MODE - SYM_FUNCTION_MODE = None - try: - yield - finally: - SYM_FUNCTION_MODE = old diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index aaffbdff3f41..ce8625716798 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -22,13 +22,13 @@ import warnings import weakref from ._backward_state import BackwardState -from ._sym_dispatch_mode import SymDispatchMode from .sym_node import SymNode from torch.utils._thunk import Thunk from collections import defaultdict from contextlib import contextmanager, nullcontext, AbstractContextManager, ExitStack from dataclasses import dataclass from torch import SymInt, SymBool, Tensor +import torch._ops from torch._dispatch.python import enable_python_dispatcher from torch._library.fake_class_registry import FakeScriptObject from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode, unset_fake_temporarily, is_fake @@ -59,7 +59,10 @@ if TYPE_CHECKING: from torch.fx._symbolic_trace import PHBase from torch.types import IntLikeType -__all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", "py_sym_types", "get_innermost_proxy_mode"] +__all__ = [ + "PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", + "py_sym_types", "get_innermost_proxy_mode", "get_proxy_mode", "handle_sym_dispatch" +] _ProxyTracer = Union["PythonKeyTracer", "_GraphAppendingTracerEx"] @@ -1006,7 +1009,10 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode): class ProxyTorchDispatchMode(TorchDispatchMode): - _managers: List[AbstractContextManager] + # Ensure this is read-only; this exists only for legacy reasons + @property + def enable_tracing(self) -> bool: + return True def __init__( self, @@ -1020,12 +1026,9 @@ class ProxyTorchDispatchMode(TorchDispatchMode): super().__init__(dk) self.tracer = tracer self.tracing_mode = tracing_mode - self.enable_tracing = True self.pre_dispatch = pre_dispatch self._allow_fake_constant = _allow_fake_constant self._error_on_data_dependent_ops = _error_on_data_dependent_ops - self.sym_mode = ProxySymDispatchMode(tracer) - self._managers = [] # Indicates to our torch_dispatch dispatching infra that # this is an "infra" mode with lower dispatching precedence. self._mode_key = torch._C._TorchDispatchModeKey.PROXY @@ -1045,14 +1048,10 @@ class ProxyTorchDispatchMode(TorchDispatchMode): args: Tuple[object, ...] = (), kwargs: Optional[Dict[str, object]] = None ) -> object: - with self.sym_mode.enable(False), set_original_aten_op(func): + with set_original_aten_op(func): return self.inner_torch_dispatch(func, types, args, kwargs) def __enter__(self) -> Self: - # sym mode first, then us... - m = self.sym_mode.enable(True) - self._managers.append(m) - m.__enter__() # Stash and store the previous proxy mode (there may or may not be one) maybe_prev_proxy_mode = _unset_infra_mode(torch._C._TorchDispatchModeKey.PROXY) self.enter_stack.append(maybe_prev_proxy_mode) @@ -1064,8 +1063,6 @@ class ProxyTorchDispatchMode(TorchDispatchMode): exc_value: Optional[BaseException], traceback: Optional[types.TracebackType] ) -> Optional[bool]: - m = self._managers.pop() - # ...exit us first, then sym mode b = super().__exit__(exc_type, exc_value, traceback) # Re-enable the previous proxy mode, if there was one. @@ -1073,11 +1070,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode): if mb_previous_proxy_mode is not None: _push_mode(mb_previous_proxy_mode) - if not b: - return m.__exit__(exc_type, exc_value, traceback) - else: - return m.__exit__(None, None, None) - + return b def inner_torch_dispatch( self, @@ -1088,9 +1081,6 @@ class ProxyTorchDispatchMode(TorchDispatchMode): ) -> object: kwargs = kwargs or {} - if not self.enable_tracing: - return func(*args, **kwargs) - if func in (prim.device.default,): return func(*args, **kwargs) @@ -1100,25 +1090,6 @@ class ProxyTorchDispatchMode(TorchDispatchMode): def is_infra_mode(cls) -> bool: return True - -class ProxySymDispatchMode(SymDispatchMode): - def __init__(self, tracer: _ProxyTracer) -> None: - super().__init__() - self.tracer = tracer - # When false, we don't trace operations. If you do this, you MUST - # call track_tensor/track_tensor_tree on all results of the operation - # to ensure we can adequately track the results - self.enable_tracing = True - - @contextmanager - def enable(self, b: bool) -> Generator[None, None, None]: - old = self.enable_tracing - self.enable_tracing = b - try: - yield - finally: - self.enable_tracing = old - def _compute_proxy(self, func: OpOverload, args: Tuple[object, ...], out: PySymType) -> Proxy: n_args = tuple( get_proxy_slot(a, self.tracer).force().node if isinstance(a, py_sym_types) else a @@ -1139,9 +1110,6 @@ class ProxySymDispatchMode(SymDispatchMode): args: Tuple[object, ...], kwargs: Dict[str, object] ) -> object: - if not self.enable_tracing: - return func(*args, **kwargs) - # Peephole optimize multiply by one # NB: be careful not to trigger guards here! if func == operator.mul: @@ -1727,7 +1695,6 @@ class _MakefxTracer: stack.enter_context(self.fake_tensor_mode) stack.enter_context(self.python_dispatcher_mode) stack.enter_context(self.proxy_function_mode) - stack.enter_context(proxy_mode.sym_mode) stack.enter_context(self.torch_fn_metadata_mode) stack.enter_context(proxy_mode) stack.enter_context(disable_autocast_cache()) @@ -1787,8 +1754,13 @@ def make_fx( _allow_fake_constant: bool = False, _error_on_data_dependent_ops: bool = True) -> Callable[..., GraphModule]: - assert tracing_mode in ["real", "fake", "symbolic"] + """ + Given a function f, return a new function which when executed with valid + arguments to f, returns an FX GraphModule representing the set of operations that + were executed during the course of execution. + """ + assert tracing_mode in ["real", "fake", "symbolic"] make_fx_tracer = _MakefxTracer( decomposition_table, @@ -1810,8 +1782,38 @@ def get_torch_dispatch_modes() -> List[TorchDispatchMode]: return torch.utils._python_dispatch._get_current_dispatch_mode_stack() -def get_innermost_proxy_mode() -> ProxyTorchDispatchMode: - return torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) +# TODO: this is a legacy name, there is only ever one proxy mode as it's an +# infra mode +def get_innermost_proxy_mode() -> Optional[ProxyTorchDispatchMode]: + return get_proxy_mode() + + +def get_proxy_mode() -> Optional[ProxyTorchDispatchMode]: + """ + Current the currently active proxy tracing mode, or None if + we are not currently tracing. This includes pre-dispatch proxy + tracing. + """ + pre_dispatch_mode = torch._ops._get_dispatch_mode_pre_dispatch(torch._C._TorchDispatchModeKey.PROXY) + mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) + assert pre_dispatch_mode is None or mode is None, f"pre_dispatch_mode={pre_dispatch_mode}, mode={mode}" + return pre_dispatch_mode or mode + + +def handle_sym_dispatch(func: Callable[_P, R], args: _P.args, kwargs: _P.kwargs) -> R: + """ + Call into the currently active proxy tracing mode to do a + SymInt/SymFloat/SymBool dispatch trace on a function that operates on + these arguments. + """ + mode = get_proxy_mode() + assert mode + # Have to do it manually, because we're not doing the normal torch + # dispatch machinery which disables it for us + with disable_proxy_modes_tracing(): + # TODO: properly compute types + types: List[Type] = [] + return mode.__sym_dispatch__(func, types, args, kwargs) # type: ignore[arg-type, return-value] @contextmanager diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index 08ff6d3d9b16..fc5f7442a89a 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -31,10 +31,6 @@ from torch import ( # noqa: F401 SymFloat, SymInt, ) -from torch.fx.experimental._sym_dispatch_mode import ( - handle_sym_dispatch, - sym_function_mode, -) if TYPE_CHECKING: @@ -1055,6 +1051,10 @@ def _make_node_magic(method, func): method_attr = method def binary_magic_impl(self, other): + from torch.fx.experimental.proxy_tensor import ( + get_proxy_mode, + handle_sym_dispatch, + ) from torch.fx.experimental.symbolic_shapes import safe_expand op = method_to_operator(method) @@ -1067,7 +1067,7 @@ def _make_node_magic(method, func): if alternate_impl and out_hint is not None: return to_node(self, alternate_impl(wrap_node(self), wrap_node(other))) - if sym_function_mode(): + if get_proxy_mode(): return to_node( self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {}) ) @@ -1129,10 +1129,14 @@ def _make_node_magic(method, func): return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) def unary_magic_impl(self): + from torch.fx.experimental.proxy_tensor import ( + get_proxy_mode, + handle_sym_dispatch, + ) from torch.fx.experimental.symbolic_shapes import safe_expand op = method_to_operator(method) - if sym_function_mode(): + if get_proxy_mode(): return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {})) # TODO: consider constant prop here expr = self.expr @@ -1167,10 +1171,14 @@ def _make_node_magic(method, func): elif method == "sym_ite": def sym_ite_impl(pred_node, then_node, else_node): + from torch.fx.experimental.proxy_tensor import ( + get_proxy_mode, + handle_sym_dispatch, + ) from torch.fx.experimental.symbolic_shapes import safe_expand out_hint = then_node.hint if pred_node.hint else else_node.hint - if sym_function_mode(): + if get_proxy_mode(): return to_node( pred_node, handle_sym_dispatch( @@ -1208,10 +1216,14 @@ def _make_node_magic(method, func): elif method == "round": def round_impl(self, ndigits=None): + from torch.fx.experimental.proxy_tensor import ( + get_proxy_mode, + handle_sym_dispatch, + ) from torch.fx.experimental.symbolic_shapes import safe_expand op = builtins.round - if sym_function_mode(): + if get_proxy_mode(): return to_node( self, handle_sym_dispatch(op, (wrap_node(self), ndigits), {}) ) @@ -1256,8 +1268,13 @@ def _make_node_sizes_strides(method, func): # NB: don't LRU cache, lots of arguments def sizes_strides_impl(self, sizes, strides): + from torch.fx.experimental.proxy_tensor import ( + get_proxy_mode, + handle_sym_dispatch, + ) + op = getattr(sys.modules[__name__], method) - if sym_function_mode(): + if get_proxy_mode(): return to_node( self, handle_sym_dispatch(