mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Consolidate SymDispatchMode into ProxyTensorMode (#132674)
Instead of having a separate context variable for SymDispatchMode, we now simply delegate to the current active proxy tensor mode when we need to trace a SymInt. We maintain a separate `__sym_dispatch__` magic method as the calling convention is different than `__torch_dispatch__`. Consolidating the modes in this ways means that we can consistently disable both of these modes in tandem simply by removing the mode from the proxy mode infra slot. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/132674 Approved by: https://github.com/zou3519, https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
0f19d4150b
commit
361db32d47
@ -31,10 +31,6 @@ from torch import ( # noqa: F401
|
||||
SymFloat,
|
||||
SymInt,
|
||||
)
|
||||
from torch.fx.experimental._sym_dispatch_mode import (
|
||||
handle_sym_dispatch,
|
||||
sym_function_mode,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -1055,6 +1051,10 @@ def _make_node_magic(method, func):
|
||||
method_attr = method
|
||||
|
||||
def binary_magic_impl(self, other):
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
get_proxy_mode,
|
||||
handle_sym_dispatch,
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import safe_expand
|
||||
|
||||
op = method_to_operator(method)
|
||||
@ -1067,7 +1067,7 @@ def _make_node_magic(method, func):
|
||||
if alternate_impl and out_hint is not None:
|
||||
return to_node(self, alternate_impl(wrap_node(self), wrap_node(other)))
|
||||
|
||||
if sym_function_mode():
|
||||
if get_proxy_mode():
|
||||
return to_node(
|
||||
self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})
|
||||
)
|
||||
@ -1129,10 +1129,14 @@ def _make_node_magic(method, func):
|
||||
return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)
|
||||
|
||||
def unary_magic_impl(self):
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
get_proxy_mode,
|
||||
handle_sym_dispatch,
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import safe_expand
|
||||
|
||||
op = method_to_operator(method)
|
||||
if sym_function_mode():
|
||||
if get_proxy_mode():
|
||||
return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {}))
|
||||
# TODO: consider constant prop here
|
||||
expr = self.expr
|
||||
@ -1167,10 +1171,14 @@ def _make_node_magic(method, func):
|
||||
elif method == "sym_ite":
|
||||
|
||||
def sym_ite_impl(pred_node, then_node, else_node):
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
get_proxy_mode,
|
||||
handle_sym_dispatch,
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import safe_expand
|
||||
|
||||
out_hint = then_node.hint if pred_node.hint else else_node.hint
|
||||
if sym_function_mode():
|
||||
if get_proxy_mode():
|
||||
return to_node(
|
||||
pred_node,
|
||||
handle_sym_dispatch(
|
||||
@ -1208,10 +1216,14 @@ def _make_node_magic(method, func):
|
||||
elif method == "round":
|
||||
|
||||
def round_impl(self, ndigits=None):
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
get_proxy_mode,
|
||||
handle_sym_dispatch,
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import safe_expand
|
||||
|
||||
op = builtins.round
|
||||
if sym_function_mode():
|
||||
if get_proxy_mode():
|
||||
return to_node(
|
||||
self, handle_sym_dispatch(op, (wrap_node(self), ndigits), {})
|
||||
)
|
||||
@ -1256,8 +1268,13 @@ def _make_node_sizes_strides(method, func):
|
||||
# NB: don't LRU cache, lots of arguments
|
||||
|
||||
def sizes_strides_impl(self, sizes, strides):
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
get_proxy_mode,
|
||||
handle_sym_dispatch,
|
||||
)
|
||||
|
||||
op = getattr(sys.modules[__name__], method)
|
||||
if sym_function_mode():
|
||||
if get_proxy_mode():
|
||||
return to_node(
|
||||
self,
|
||||
handle_sym_dispatch(
|
||||
|
Reference in New Issue
Block a user