mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129769 Approved by: https://github.com/ezyang
73 lines
2.1 KiB
Python
73 lines
2.1 KiB
Python
# mypy: allow-untyped-defs
|
|
import contextlib
|
|
from typing import List, Optional, Type
|
|
|
|
|
|
__all__ = ["SymDispatchMode", "handle_sym_dispatch", "sym_function_mode"]
|
|
|
|
SYM_FUNCTION_MODE: Optional["SymDispatchMode"] = None
|
|
|
|
|
|
# SymDispatchMode gets invoked whenever an operation is processed on
|
|
# a PySymInt. When this occurs, you get called at __sym_dispatch__
|
|
# with the operation in question. This is symmetric to TorchDispatchMode
|
|
# but with some caveats:
|
|
#
|
|
# - In TorchDispatchMode, you get the same arguments as what a user
|
|
# invoked your API with; e.g., if you call torch.ops.aten.foo(a, b),
|
|
# you get (a, b) as args to your call. In SymDispatchMode, if
|
|
# you call a + b (where a and b are SymInts), you will get
|
|
# (a.node, b.node) as your args (these are PySymInts)
|
|
#
|
|
# - SymInt/PySymInt don't have FX proxy support (unlike, e.g., Tensor).
|
|
# So you have to manually call Tracer/create_node to write into
|
|
# the graph. See ProxySymDispatchMode for an example
|
|
#
|
|
class SymDispatchMode:
|
|
def __sym_dispatch__(self, func, types, args, kwargs):
|
|
raise NotImplementedError
|
|
|
|
def __enter__(self):
|
|
global SYM_FUNCTION_MODE
|
|
old = SYM_FUNCTION_MODE
|
|
if hasattr(self, "inner"):
|
|
raise RuntimeError(
|
|
f"{self} has already been used as a mode. Please use a fresh version"
|
|
)
|
|
else:
|
|
self.inner = old
|
|
SYM_FUNCTION_MODE = self
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
global SYM_FUNCTION_MODE
|
|
SYM_FUNCTION_MODE = self.inner
|
|
|
|
|
|
def handle_sym_dispatch(func, args, kwargs):
|
|
global SYM_FUNCTION_MODE
|
|
mode = sym_function_mode()
|
|
assert mode
|
|
SYM_FUNCTION_MODE = mode.inner
|
|
try:
|
|
# TODO: properly compute types
|
|
types: List[Type] = []
|
|
return mode.__sym_dispatch__(func, types, args, kwargs)
|
|
finally:
|
|
SYM_FUNCTION_MODE = mode
|
|
|
|
|
|
def sym_function_mode():
|
|
return SYM_FUNCTION_MODE
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def disable_sym_dispatch():
|
|
global SYM_FUNCTION_MODE
|
|
old = SYM_FUNCTION_MODE
|
|
SYM_FUNCTION_MODE = None
|
|
try:
|
|
yield
|
|
finally:
|
|
SYM_FUNCTION_MODE = old
|