Files
pytorch/torch/fx/experimental/_sym_dispatch_mode.py
Xuehai Pan f3fce597e9 [BE][Easy][17/19] enforce style for empty lines in import segments in torch/[a-c]*/ and torch/[e-n]*/ (#129769)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129769
Approved by: https://github.com/ezyang
2024-08-04 10:24:09 +00:00

73 lines
2.1 KiB
Python

# mypy: allow-untyped-defs
import contextlib
from typing import List, Optional, Type
__all__ = ["SymDispatchMode", "handle_sym_dispatch", "sym_function_mode"]
SYM_FUNCTION_MODE: Optional["SymDispatchMode"] = None
# SymDispatchMode gets invoked whenever an operation is processed on
# a PySymInt. When this occurs, you get called at __sym_dispatch__
# with the operation in question. This is symmetric to TorchDispatchMode
# but with some caveats:
#
# - In TorchDispatchMode, you get the same arguments as what a user
# invoked your API with; e.g., if you call torch.ops.aten.foo(a, b),
# you get (a, b) as args to your call. In SymDispatchMode, if
# you call a + b (where a and b are SymInts), you will get
# (a.node, b.node) as your args (these are PySymInts)
#
# - SymInt/PySymInt don't have FX proxy support (unlike, e.g., Tensor).
# So you have to manually call Tracer/create_node to write into
# the graph. See ProxySymDispatchMode for an example
#
class SymDispatchMode:
def __sym_dispatch__(self, func, types, args, kwargs):
raise NotImplementedError
def __enter__(self):
global SYM_FUNCTION_MODE
old = SYM_FUNCTION_MODE
if hasattr(self, "inner"):
raise RuntimeError(
f"{self} has already been used as a mode. Please use a fresh version"
)
else:
self.inner = old
SYM_FUNCTION_MODE = self
return self
def __exit__(self, exc_type, exc_val, exc_tb):
global SYM_FUNCTION_MODE
SYM_FUNCTION_MODE = self.inner
def handle_sym_dispatch(func, args, kwargs):
global SYM_FUNCTION_MODE
mode = sym_function_mode()
assert mode
SYM_FUNCTION_MODE = mode.inner
try:
# TODO: properly compute types
types: List[Type] = []
return mode.__sym_dispatch__(func, types, args, kwargs)
finally:
SYM_FUNCTION_MODE = mode
def sym_function_mode():
return SYM_FUNCTION_MODE
@contextlib.contextmanager
def disable_sym_dispatch():
global SYM_FUNCTION_MODE
old = SYM_FUNCTION_MODE
SYM_FUNCTION_MODE = None
try:
yield
finally:
SYM_FUNCTION_MODE = old