mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-26 08:34:52 +08:00
Consolidate SymDispatchMode into ProxyTensorMode (#132674)
Instead of having a separate context variable for SymDispatchMode, we now simply delegate to the current active proxy tensor mode when we need to trace a SymInt. We maintain a separate `__sym_dispatch__` magic method as the calling convention is different than `__torch_dispatch__`. Consolidating the modes in this ways means that we can consistently disable both of these modes in tandem simply by removing the mode from the proxy mode infra slot. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/132674 Approved by: https://github.com/zou3519, https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
0f19d4150b
commit
361db32d47
1
.github/labeler.yml
vendored
1
.github/labeler.yml
vendored
@ -29,7 +29,6 @@
|
|||||||
- torch/fx/experimental/recording.py
|
- torch/fx/experimental/recording.py
|
||||||
- torch/fx/experimental/sym_node.py
|
- torch/fx/experimental/sym_node.py
|
||||||
- torch/fx/experimental/validator.py
|
- torch/fx/experimental/validator.py
|
||||||
- torch/fx/experimental/_sym_dispatch_mode.py
|
|
||||||
- torch/fx/experimental/proxy_tensor.py
|
- torch/fx/experimental/proxy_tensor.py
|
||||||
- test/distributed/_tensor/test_dtensor_compile.py
|
- test/distributed/_tensor/test_dtensor_compile.py
|
||||||
- test/distributed/tensor/parallel/test_fsdp_2d_parallel.py
|
- test/distributed/tensor/parallel/test_fsdp_2d_parallel.py
|
||||||
|
|||||||
@ -850,7 +850,6 @@ coverage_ignore_functions = [
|
|||||||
"get_torch_dispatch_modes",
|
"get_torch_dispatch_modes",
|
||||||
"has_proxy_slot",
|
"has_proxy_slot",
|
||||||
"is_sym_node",
|
"is_sym_node",
|
||||||
"make_fx",
|
|
||||||
"maybe_disable_fake_tensor_mode",
|
"maybe_disable_fake_tensor_mode",
|
||||||
"maybe_handle_decomp",
|
"maybe_handle_decomp",
|
||||||
"proxy_call",
|
"proxy_call",
|
||||||
|
|||||||
@ -51,3 +51,17 @@ torch.fx.experimental.symbolic_shapes
|
|||||||
compute_unbacked_bindings
|
compute_unbacked_bindings
|
||||||
rebind_unbacked
|
rebind_unbacked
|
||||||
resolve_unbacked_bindings
|
resolve_unbacked_bindings
|
||||||
|
|
||||||
|
torch.fx.experimental.proxy_tensor
|
||||||
|
-------------------------------------
|
||||||
|
|
||||||
|
.. currentmodule:: torch.fx.experimental.proxy_tensor
|
||||||
|
.. automodule:: torch.fx.experimental.proxy_tensor
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
|
||||||
|
make_fx
|
||||||
|
handle_sym_dispatch
|
||||||
|
get_proxy_mode
|
||||||
|
|||||||
@ -1143,7 +1143,6 @@ API Reference
|
|||||||
.. py:module:: torch.fx.experimental.normalize
|
.. py:module:: torch.fx.experimental.normalize
|
||||||
.. py:module:: torch.fx.experimental.optimization
|
.. py:module:: torch.fx.experimental.optimization
|
||||||
.. py:module:: torch.fx.experimental.partitioner_utils
|
.. py:module:: torch.fx.experimental.partitioner_utils
|
||||||
.. py:module:: torch.fx.experimental.proxy_tensor
|
|
||||||
.. py:module:: torch.fx.experimental.recording
|
.. py:module:: torch.fx.experimental.recording
|
||||||
.. py:module:: torch.fx.experimental.refinement_types
|
.. py:module:: torch.fx.experimental.refinement_types
|
||||||
.. py:module:: torch.fx.experimental.rewriter
|
.. py:module:: torch.fx.experimental.rewriter
|
||||||
|
|||||||
@ -112,7 +112,6 @@ class AutogradCompilerInstance:
|
|||||||
# TODO(jansel): are all these modes needed?
|
# TODO(jansel): are all these modes needed?
|
||||||
self.stack.enter_context(decompose({}))
|
self.stack.enter_context(decompose({}))
|
||||||
self.stack.enter_context(self.fake_tensor_mode)
|
self.stack.enter_context(self.fake_tensor_mode)
|
||||||
self.stack.enter_context(self.proxy_mode.sym_mode)
|
|
||||||
self.stack.enter_context(self.proxy_mode)
|
self.stack.enter_context(self.proxy_mode)
|
||||||
self.stack.enter_context(disable_autocast_cache())
|
self.stack.enter_context(disable_autocast_cache())
|
||||||
self.stack.enter_context(preserve_node_meta())
|
self.stack.enter_context(preserve_node_meta())
|
||||||
|
|||||||
@ -25,7 +25,6 @@ from weakref import ReferenceType
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._logging
|
import torch._logging
|
||||||
import torch.fx.experimental._sym_dispatch_mode
|
|
||||||
from torch._C._dynamo.guards import GlobalStateGuard
|
from torch._C._dynamo.guards import GlobalStateGuard
|
||||||
from torch._dynamo.distributed import get_compile_pg
|
from torch._dynamo.distributed import get_compile_pg
|
||||||
from torch._guards import compile_context, CompileContext, CompileId, tracing
|
from torch._guards import compile_context, CompileContext, CompileId, tracing
|
||||||
@ -1234,9 +1233,7 @@ class CatchErrorsWrapper:
|
|||||||
frame, cache_entry, self.hooks, frame_state
|
frame, cache_entry, self.hooks, frame_state
|
||||||
)
|
)
|
||||||
|
|
||||||
with (
|
with compile_lock, _disable_current_modes():
|
||||||
compile_lock
|
|
||||||
), _disable_current_modes(), torch.fx.experimental._sym_dispatch_mode.disable_sym_dispatch():
|
|
||||||
# skip=1: skip this frame
|
# skip=1: skip this frame
|
||||||
return self._torchdynamo_orig_callable(
|
return self._torchdynamo_orig_callable(
|
||||||
frame, cache_entry, self.hooks, frame_state, skip=1
|
frame, cache_entry, self.hooks, frame_state, skip=1
|
||||||
|
|||||||
@ -1,72 +0,0 @@
|
|||||||
# mypy: allow-untyped-defs
|
|
||||||
import contextlib
|
|
||||||
from typing import List, Optional, Type
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["SymDispatchMode", "handle_sym_dispatch", "sym_function_mode"]
|
|
||||||
|
|
||||||
SYM_FUNCTION_MODE: Optional["SymDispatchMode"] = None
|
|
||||||
|
|
||||||
|
|
||||||
# SymDispatchMode gets invoked whenever an operation is processed on
|
|
||||||
# a PySymInt. When this occurs, you get called at __sym_dispatch__
|
|
||||||
# with the operation in question. This is symmetric to TorchDispatchMode
|
|
||||||
# but with some caveats:
|
|
||||||
#
|
|
||||||
# - In TorchDispatchMode, you get the same arguments as what a user
|
|
||||||
# invoked your API with; e.g., if you call torch.ops.aten.foo(a, b),
|
|
||||||
# you get (a, b) as args to your call. In SymDispatchMode, if
|
|
||||||
# you call a + b (where a and b are SymInts), you will get
|
|
||||||
# (a.node, b.node) as your args (these are PySymInts)
|
|
||||||
#
|
|
||||||
# - SymInt/PySymInt don't have FX proxy support (unlike, e.g., Tensor).
|
|
||||||
# So you have to manually call Tracer/create_node to write into
|
|
||||||
# the graph. See ProxySymDispatchMode for an example
|
|
||||||
#
|
|
||||||
class SymDispatchMode:
|
|
||||||
def __sym_dispatch__(self, func, types, args, kwargs):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
global SYM_FUNCTION_MODE
|
|
||||||
old = SYM_FUNCTION_MODE
|
|
||||||
if hasattr(self, "inner"):
|
|
||||||
raise RuntimeError(
|
|
||||||
f"{self} has already been used as a mode. Please use a fresh version"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.inner = old
|
|
||||||
SYM_FUNCTION_MODE = self
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
global SYM_FUNCTION_MODE
|
|
||||||
SYM_FUNCTION_MODE = self.inner
|
|
||||||
|
|
||||||
|
|
||||||
def handle_sym_dispatch(func, args, kwargs):
|
|
||||||
global SYM_FUNCTION_MODE
|
|
||||||
mode = sym_function_mode()
|
|
||||||
assert mode
|
|
||||||
SYM_FUNCTION_MODE = mode.inner
|
|
||||||
try:
|
|
||||||
# TODO: properly compute types
|
|
||||||
types: List[Type] = []
|
|
||||||
return mode.__sym_dispatch__(func, types, args, kwargs)
|
|
||||||
finally:
|
|
||||||
SYM_FUNCTION_MODE = mode
|
|
||||||
|
|
||||||
|
|
||||||
def sym_function_mode():
|
|
||||||
return SYM_FUNCTION_MODE
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def disable_sym_dispatch():
|
|
||||||
global SYM_FUNCTION_MODE
|
|
||||||
old = SYM_FUNCTION_MODE
|
|
||||||
SYM_FUNCTION_MODE = None
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
SYM_FUNCTION_MODE = old
|
|
||||||
@ -22,13 +22,13 @@ import warnings
|
|||||||
import weakref
|
import weakref
|
||||||
|
|
||||||
from ._backward_state import BackwardState
|
from ._backward_state import BackwardState
|
||||||
from ._sym_dispatch_mode import SymDispatchMode
|
|
||||||
from .sym_node import SymNode
|
from .sym_node import SymNode
|
||||||
from torch.utils._thunk import Thunk
|
from torch.utils._thunk import Thunk
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import contextmanager, nullcontext, AbstractContextManager, ExitStack
|
from contextlib import contextmanager, nullcontext, AbstractContextManager, ExitStack
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from torch import SymInt, SymBool, Tensor
|
from torch import SymInt, SymBool, Tensor
|
||||||
|
import torch._ops
|
||||||
from torch._dispatch.python import enable_python_dispatcher
|
from torch._dispatch.python import enable_python_dispatcher
|
||||||
from torch._library.fake_class_registry import FakeScriptObject
|
from torch._library.fake_class_registry import FakeScriptObject
|
||||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode, unset_fake_temporarily, is_fake
|
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode, unset_fake_temporarily, is_fake
|
||||||
@ -59,7 +59,10 @@ if TYPE_CHECKING:
|
|||||||
from torch.fx._symbolic_trace import PHBase
|
from torch.fx._symbolic_trace import PHBase
|
||||||
from torch.types import IntLikeType
|
from torch.types import IntLikeType
|
||||||
|
|
||||||
__all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", "py_sym_types", "get_innermost_proxy_mode"]
|
__all__ = [
|
||||||
|
"PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter",
|
||||||
|
"py_sym_types", "get_innermost_proxy_mode", "get_proxy_mode", "handle_sym_dispatch"
|
||||||
|
]
|
||||||
|
|
||||||
_ProxyTracer = Union["PythonKeyTracer", "_GraphAppendingTracerEx"]
|
_ProxyTracer = Union["PythonKeyTracer", "_GraphAppendingTracerEx"]
|
||||||
|
|
||||||
@ -1006,7 +1009,10 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
|
|||||||
|
|
||||||
|
|
||||||
class ProxyTorchDispatchMode(TorchDispatchMode):
|
class ProxyTorchDispatchMode(TorchDispatchMode):
|
||||||
_managers: List[AbstractContextManager]
|
# Ensure this is read-only; this exists only for legacy reasons
|
||||||
|
@property
|
||||||
|
def enable_tracing(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -1020,12 +1026,9 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
|||||||
super().__init__(dk)
|
super().__init__(dk)
|
||||||
self.tracer = tracer
|
self.tracer = tracer
|
||||||
self.tracing_mode = tracing_mode
|
self.tracing_mode = tracing_mode
|
||||||
self.enable_tracing = True
|
|
||||||
self.pre_dispatch = pre_dispatch
|
self.pre_dispatch = pre_dispatch
|
||||||
self._allow_fake_constant = _allow_fake_constant
|
self._allow_fake_constant = _allow_fake_constant
|
||||||
self._error_on_data_dependent_ops = _error_on_data_dependent_ops
|
self._error_on_data_dependent_ops = _error_on_data_dependent_ops
|
||||||
self.sym_mode = ProxySymDispatchMode(tracer)
|
|
||||||
self._managers = []
|
|
||||||
# Indicates to our torch_dispatch dispatching infra that
|
# Indicates to our torch_dispatch dispatching infra that
|
||||||
# this is an "infra" mode with lower dispatching precedence.
|
# this is an "infra" mode with lower dispatching precedence.
|
||||||
self._mode_key = torch._C._TorchDispatchModeKey.PROXY
|
self._mode_key = torch._C._TorchDispatchModeKey.PROXY
|
||||||
@ -1045,14 +1048,10 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
|||||||
args: Tuple[object, ...] = (),
|
args: Tuple[object, ...] = (),
|
||||||
kwargs: Optional[Dict[str, object]] = None
|
kwargs: Optional[Dict[str, object]] = None
|
||||||
) -> object:
|
) -> object:
|
||||||
with self.sym_mode.enable(False), set_original_aten_op(func):
|
with set_original_aten_op(func):
|
||||||
return self.inner_torch_dispatch(func, types, args, kwargs)
|
return self.inner_torch_dispatch(func, types, args, kwargs)
|
||||||
|
|
||||||
def __enter__(self) -> Self:
|
def __enter__(self) -> Self:
|
||||||
# sym mode first, then us...
|
|
||||||
m = self.sym_mode.enable(True)
|
|
||||||
self._managers.append(m)
|
|
||||||
m.__enter__()
|
|
||||||
# Stash and store the previous proxy mode (there may or may not be one)
|
# Stash and store the previous proxy mode (there may or may not be one)
|
||||||
maybe_prev_proxy_mode = _unset_infra_mode(torch._C._TorchDispatchModeKey.PROXY)
|
maybe_prev_proxy_mode = _unset_infra_mode(torch._C._TorchDispatchModeKey.PROXY)
|
||||||
self.enter_stack.append(maybe_prev_proxy_mode)
|
self.enter_stack.append(maybe_prev_proxy_mode)
|
||||||
@ -1064,8 +1063,6 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
|||||||
exc_value: Optional[BaseException],
|
exc_value: Optional[BaseException],
|
||||||
traceback: Optional[types.TracebackType]
|
traceback: Optional[types.TracebackType]
|
||||||
) -> Optional[bool]:
|
) -> Optional[bool]:
|
||||||
m = self._managers.pop()
|
|
||||||
# ...exit us first, then sym mode
|
|
||||||
b = super().__exit__(exc_type, exc_value, traceback)
|
b = super().__exit__(exc_type, exc_value, traceback)
|
||||||
|
|
||||||
# Re-enable the previous proxy mode, if there was one.
|
# Re-enable the previous proxy mode, if there was one.
|
||||||
@ -1073,11 +1070,7 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
|||||||
if mb_previous_proxy_mode is not None:
|
if mb_previous_proxy_mode is not None:
|
||||||
_push_mode(mb_previous_proxy_mode)
|
_push_mode(mb_previous_proxy_mode)
|
||||||
|
|
||||||
if not b:
|
return b
|
||||||
return m.__exit__(exc_type, exc_value, traceback)
|
|
||||||
else:
|
|
||||||
return m.__exit__(None, None, None)
|
|
||||||
|
|
||||||
|
|
||||||
def inner_torch_dispatch(
|
def inner_torch_dispatch(
|
||||||
self,
|
self,
|
||||||
@ -1088,9 +1081,6 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
|||||||
) -> object:
|
) -> object:
|
||||||
kwargs = kwargs or {}
|
kwargs = kwargs or {}
|
||||||
|
|
||||||
if not self.enable_tracing:
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
if func in (prim.device.default,):
|
if func in (prim.device.default,):
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
@ -1100,25 +1090,6 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
|||||||
def is_infra_mode(cls) -> bool:
|
def is_infra_mode(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
class ProxySymDispatchMode(SymDispatchMode):
|
|
||||||
def __init__(self, tracer: _ProxyTracer) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.tracer = tracer
|
|
||||||
# When false, we don't trace operations. If you do this, you MUST
|
|
||||||
# call track_tensor/track_tensor_tree on all results of the operation
|
|
||||||
# to ensure we can adequately track the results
|
|
||||||
self.enable_tracing = True
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def enable(self, b: bool) -> Generator[None, None, None]:
|
|
||||||
old = self.enable_tracing
|
|
||||||
self.enable_tracing = b
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
self.enable_tracing = old
|
|
||||||
|
|
||||||
def _compute_proxy(self, func: OpOverload, args: Tuple[object, ...], out: PySymType) -> Proxy:
|
def _compute_proxy(self, func: OpOverload, args: Tuple[object, ...], out: PySymType) -> Proxy:
|
||||||
n_args = tuple(
|
n_args = tuple(
|
||||||
get_proxy_slot(a, self.tracer).force().node if isinstance(a, py_sym_types) else a
|
get_proxy_slot(a, self.tracer).force().node if isinstance(a, py_sym_types) else a
|
||||||
@ -1139,9 +1110,6 @@ class ProxySymDispatchMode(SymDispatchMode):
|
|||||||
args: Tuple[object, ...],
|
args: Tuple[object, ...],
|
||||||
kwargs: Dict[str, object]
|
kwargs: Dict[str, object]
|
||||||
) -> object:
|
) -> object:
|
||||||
if not self.enable_tracing:
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
|
|
||||||
# Peephole optimize multiply by one
|
# Peephole optimize multiply by one
|
||||||
# NB: be careful not to trigger guards here!
|
# NB: be careful not to trigger guards here!
|
||||||
if func == operator.mul:
|
if func == operator.mul:
|
||||||
@ -1727,7 +1695,6 @@ class _MakefxTracer:
|
|||||||
stack.enter_context(self.fake_tensor_mode)
|
stack.enter_context(self.fake_tensor_mode)
|
||||||
stack.enter_context(self.python_dispatcher_mode)
|
stack.enter_context(self.python_dispatcher_mode)
|
||||||
stack.enter_context(self.proxy_function_mode)
|
stack.enter_context(self.proxy_function_mode)
|
||||||
stack.enter_context(proxy_mode.sym_mode)
|
|
||||||
stack.enter_context(self.torch_fn_metadata_mode)
|
stack.enter_context(self.torch_fn_metadata_mode)
|
||||||
stack.enter_context(proxy_mode)
|
stack.enter_context(proxy_mode)
|
||||||
stack.enter_context(disable_autocast_cache())
|
stack.enter_context(disable_autocast_cache())
|
||||||
@ -1787,8 +1754,13 @@ def make_fx(
|
|||||||
_allow_fake_constant: bool = False,
|
_allow_fake_constant: bool = False,
|
||||||
_error_on_data_dependent_ops: bool = True) -> Callable[..., GraphModule]:
|
_error_on_data_dependent_ops: bool = True) -> Callable[..., GraphModule]:
|
||||||
|
|
||||||
assert tracing_mode in ["real", "fake", "symbolic"]
|
"""
|
||||||
|
Given a function f, return a new function which when executed with valid
|
||||||
|
arguments to f, returns an FX GraphModule representing the set of operations that
|
||||||
|
were executed during the course of execution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert tracing_mode in ["real", "fake", "symbolic"]
|
||||||
|
|
||||||
make_fx_tracer = _MakefxTracer(
|
make_fx_tracer = _MakefxTracer(
|
||||||
decomposition_table,
|
decomposition_table,
|
||||||
@ -1810,8 +1782,38 @@ def get_torch_dispatch_modes() -> List[TorchDispatchMode]:
|
|||||||
return torch.utils._python_dispatch._get_current_dispatch_mode_stack()
|
return torch.utils._python_dispatch._get_current_dispatch_mode_stack()
|
||||||
|
|
||||||
|
|
||||||
def get_innermost_proxy_mode() -> ProxyTorchDispatchMode:
|
# TODO: this is a legacy name, there is only ever one proxy mode as it's an
|
||||||
return torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY)
|
# infra mode
|
||||||
|
def get_innermost_proxy_mode() -> Optional[ProxyTorchDispatchMode]:
|
||||||
|
return get_proxy_mode()
|
||||||
|
|
||||||
|
|
||||||
|
def get_proxy_mode() -> Optional[ProxyTorchDispatchMode]:
|
||||||
|
"""
|
||||||
|
Current the currently active proxy tracing mode, or None if
|
||||||
|
we are not currently tracing. This includes pre-dispatch proxy
|
||||||
|
tracing.
|
||||||
|
"""
|
||||||
|
pre_dispatch_mode = torch._ops._get_dispatch_mode_pre_dispatch(torch._C._TorchDispatchModeKey.PROXY)
|
||||||
|
mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY)
|
||||||
|
assert pre_dispatch_mode is None or mode is None, f"pre_dispatch_mode={pre_dispatch_mode}, mode={mode}"
|
||||||
|
return pre_dispatch_mode or mode
|
||||||
|
|
||||||
|
|
||||||
|
def handle_sym_dispatch(func: Callable[_P, R], args: _P.args, kwargs: _P.kwargs) -> R:
|
||||||
|
"""
|
||||||
|
Call into the currently active proxy tracing mode to do a
|
||||||
|
SymInt/SymFloat/SymBool dispatch trace on a function that operates on
|
||||||
|
these arguments.
|
||||||
|
"""
|
||||||
|
mode = get_proxy_mode()
|
||||||
|
assert mode
|
||||||
|
# Have to do it manually, because we're not doing the normal torch
|
||||||
|
# dispatch machinery which disables it for us
|
||||||
|
with disable_proxy_modes_tracing():
|
||||||
|
# TODO: properly compute types
|
||||||
|
types: List[Type] = []
|
||||||
|
return mode.__sym_dispatch__(func, types, args, kwargs) # type: ignore[arg-type, return-value]
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
|||||||
@ -31,10 +31,6 @@ from torch import ( # noqa: F401
|
|||||||
SymFloat,
|
SymFloat,
|
||||||
SymInt,
|
SymInt,
|
||||||
)
|
)
|
||||||
from torch.fx.experimental._sym_dispatch_mode import (
|
|
||||||
handle_sym_dispatch,
|
|
||||||
sym_function_mode,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -1055,6 +1051,10 @@ def _make_node_magic(method, func):
|
|||||||
method_attr = method
|
method_attr = method
|
||||||
|
|
||||||
def binary_magic_impl(self, other):
|
def binary_magic_impl(self, other):
|
||||||
|
from torch.fx.experimental.proxy_tensor import (
|
||||||
|
get_proxy_mode,
|
||||||
|
handle_sym_dispatch,
|
||||||
|
)
|
||||||
from torch.fx.experimental.symbolic_shapes import safe_expand
|
from torch.fx.experimental.symbolic_shapes import safe_expand
|
||||||
|
|
||||||
op = method_to_operator(method)
|
op = method_to_operator(method)
|
||||||
@ -1067,7 +1067,7 @@ def _make_node_magic(method, func):
|
|||||||
if alternate_impl and out_hint is not None:
|
if alternate_impl and out_hint is not None:
|
||||||
return to_node(self, alternate_impl(wrap_node(self), wrap_node(other)))
|
return to_node(self, alternate_impl(wrap_node(self), wrap_node(other)))
|
||||||
|
|
||||||
if sym_function_mode():
|
if get_proxy_mode():
|
||||||
return to_node(
|
return to_node(
|
||||||
self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})
|
self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})
|
||||||
)
|
)
|
||||||
@ -1129,10 +1129,14 @@ def _make_node_magic(method, func):
|
|||||||
return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)
|
return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)
|
||||||
|
|
||||||
def unary_magic_impl(self):
|
def unary_magic_impl(self):
|
||||||
|
from torch.fx.experimental.proxy_tensor import (
|
||||||
|
get_proxy_mode,
|
||||||
|
handle_sym_dispatch,
|
||||||
|
)
|
||||||
from torch.fx.experimental.symbolic_shapes import safe_expand
|
from torch.fx.experimental.symbolic_shapes import safe_expand
|
||||||
|
|
||||||
op = method_to_operator(method)
|
op = method_to_operator(method)
|
||||||
if sym_function_mode():
|
if get_proxy_mode():
|
||||||
return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {}))
|
return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {}))
|
||||||
# TODO: consider constant prop here
|
# TODO: consider constant prop here
|
||||||
expr = self.expr
|
expr = self.expr
|
||||||
@ -1167,10 +1171,14 @@ def _make_node_magic(method, func):
|
|||||||
elif method == "sym_ite":
|
elif method == "sym_ite":
|
||||||
|
|
||||||
def sym_ite_impl(pred_node, then_node, else_node):
|
def sym_ite_impl(pred_node, then_node, else_node):
|
||||||
|
from torch.fx.experimental.proxy_tensor import (
|
||||||
|
get_proxy_mode,
|
||||||
|
handle_sym_dispatch,
|
||||||
|
)
|
||||||
from torch.fx.experimental.symbolic_shapes import safe_expand
|
from torch.fx.experimental.symbolic_shapes import safe_expand
|
||||||
|
|
||||||
out_hint = then_node.hint if pred_node.hint else else_node.hint
|
out_hint = then_node.hint if pred_node.hint else else_node.hint
|
||||||
if sym_function_mode():
|
if get_proxy_mode():
|
||||||
return to_node(
|
return to_node(
|
||||||
pred_node,
|
pred_node,
|
||||||
handle_sym_dispatch(
|
handle_sym_dispatch(
|
||||||
@ -1208,10 +1216,14 @@ def _make_node_magic(method, func):
|
|||||||
elif method == "round":
|
elif method == "round":
|
||||||
|
|
||||||
def round_impl(self, ndigits=None):
|
def round_impl(self, ndigits=None):
|
||||||
|
from torch.fx.experimental.proxy_tensor import (
|
||||||
|
get_proxy_mode,
|
||||||
|
handle_sym_dispatch,
|
||||||
|
)
|
||||||
from torch.fx.experimental.symbolic_shapes import safe_expand
|
from torch.fx.experimental.symbolic_shapes import safe_expand
|
||||||
|
|
||||||
op = builtins.round
|
op = builtins.round
|
||||||
if sym_function_mode():
|
if get_proxy_mode():
|
||||||
return to_node(
|
return to_node(
|
||||||
self, handle_sym_dispatch(op, (wrap_node(self), ndigits), {})
|
self, handle_sym_dispatch(op, (wrap_node(self), ndigits), {})
|
||||||
)
|
)
|
||||||
@ -1256,8 +1268,13 @@ def _make_node_sizes_strides(method, func):
|
|||||||
# NB: don't LRU cache, lots of arguments
|
# NB: don't LRU cache, lots of arguments
|
||||||
|
|
||||||
def sizes_strides_impl(self, sizes, strides):
|
def sizes_strides_impl(self, sizes, strides):
|
||||||
|
from torch.fx.experimental.proxy_tensor import (
|
||||||
|
get_proxy_mode,
|
||||||
|
handle_sym_dispatch,
|
||||||
|
)
|
||||||
|
|
||||||
op = getattr(sys.modules[__name__], method)
|
op = getattr(sys.modules[__name__], method)
|
||||||
if sym_function_mode():
|
if get_proxy_mode():
|
||||||
return to_node(
|
return to_node(
|
||||||
self,
|
self,
|
||||||
handle_sym_dispatch(
|
handle_sym_dispatch(
|
||||||
|
|||||||
Reference in New Issue
Block a user