Consolidate SymDispatchMode into ProxyTensorMode (#132674)

Instead of having a separate context variable for SymDispatchMode, we
now simply delegate to the current active proxy tensor mode when we
need to trace a SymInt.  We maintain a separate `__sym_dispatch__` magic
method as the calling convention is different than `__torch_dispatch__`.

Consolidating the modes in this ways means that we can consistently
disable both of these modes in tandem simply by removing the mode
from the proxy mode infra slot.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132674
Approved by: https://github.com/zou3519, https://github.com/bdhirsh
This commit is contained in:
Edward Z. Yang
2024-08-08 04:59:11 -07:00
committed by PyTorch MergeBot
parent 0f19d4150b
commit 361db32d47
9 changed files with 90 additions and 136 deletions

View File

@ -31,10 +31,6 @@ from torch import ( # noqa: F401
SymFloat,
SymInt,
)
from torch.fx.experimental._sym_dispatch_mode import (
handle_sym_dispatch,
sym_function_mode,
)
if TYPE_CHECKING:
@ -1055,6 +1051,10 @@ def _make_node_magic(method, func):
method_attr = method
def binary_magic_impl(self, other):
from torch.fx.experimental.proxy_tensor import (
get_proxy_mode,
handle_sym_dispatch,
)
from torch.fx.experimental.symbolic_shapes import safe_expand
op = method_to_operator(method)
@ -1067,7 +1067,7 @@ def _make_node_magic(method, func):
if alternate_impl and out_hint is not None:
return to_node(self, alternate_impl(wrap_node(self), wrap_node(other)))
if sym_function_mode():
if get_proxy_mode():
return to_node(
self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})
)
@ -1129,10 +1129,14 @@ def _make_node_magic(method, func):
return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)
def unary_magic_impl(self):
from torch.fx.experimental.proxy_tensor import (
get_proxy_mode,
handle_sym_dispatch,
)
from torch.fx.experimental.symbolic_shapes import safe_expand
op = method_to_operator(method)
if sym_function_mode():
if get_proxy_mode():
return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {}))
# TODO: consider constant prop here
expr = self.expr
@ -1167,10 +1171,14 @@ def _make_node_magic(method, func):
elif method == "sym_ite":
def sym_ite_impl(pred_node, then_node, else_node):
from torch.fx.experimental.proxy_tensor import (
get_proxy_mode,
handle_sym_dispatch,
)
from torch.fx.experimental.symbolic_shapes import safe_expand
out_hint = then_node.hint if pred_node.hint else else_node.hint
if sym_function_mode():
if get_proxy_mode():
return to_node(
pred_node,
handle_sym_dispatch(
@ -1208,10 +1216,14 @@ def _make_node_magic(method, func):
elif method == "round":
def round_impl(self, ndigits=None):
from torch.fx.experimental.proxy_tensor import (
get_proxy_mode,
handle_sym_dispatch,
)
from torch.fx.experimental.symbolic_shapes import safe_expand
op = builtins.round
if sym_function_mode():
if get_proxy_mode():
return to_node(
self, handle_sym_dispatch(op, (wrap_node(self), ndigits), {})
)
@ -1256,8 +1268,13 @@ def _make_node_sizes_strides(method, func):
# NB: don't LRU cache, lots of arguments
def sizes_strides_impl(self, sizes, strides):
from torch.fx.experimental.proxy_tensor import (
get_proxy_mode,
handle_sym_dispatch,
)
op = getattr(sys.modules[__name__], method)
if sym_function_mode():
if get_proxy_mode():
return to_node(
self,
handle_sym_dispatch(