mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
This reverts commit 702838637d63936460ea2bf00b64ffec86ed6687. Reverted https://github.com/pytorch/pytorch/pull/90770 on behalf of https://github.com/DanilBaibak due to Break internal build
102 lines
3.6 KiB
Python
102 lines
3.6 KiB
Python
import contextlib
|
|
|
|
import warnings
|
|
from torch._C import _len_torch_dispatch_stack, _get_dispatch_stack_at,\
|
|
_pop_torch_dispatch_stack, _push_on_torch_dispatch_stack
|
|
|
|
|
|
# TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it:
|
|
# - We need a better user-facing api for _DisableTorchDispatch that
|
|
# is able to selectively disable __torch_dispatch__ of a particular class.
|
|
# - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor)
|
|
# - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694)
|
|
|
|
class TorchDispatchMode:
|
|
"""
|
|
A ``TorchDispatchMode`` allows you to override the meaning of all
|
|
``__torch_dispatch__`` overrideable functions within a dynamic scope,
|
|
without having to actually create a tensor subclass or manually
|
|
monkey-patch functions in the PyTorch API. Some common situations
|
|
where you should use a mode:
|
|
|
|
* You want to override the meaning of factory functions, or other
|
|
functions that do not otherwise take a tensor as an argument
|
|
(these cannot be overridden with tensor subclasses).
|
|
|
|
* You want to override the behavior of all functions without needing
|
|
to wrap your inputs in tensor subclasses; e.g., if you are just
|
|
interested in logging intermediate computations.
|
|
|
|
* You want to control the order of execution of various tensor
|
|
subclasses explicitly, rather than implicitly via the return of
|
|
``NotImplemented``.
|
|
|
|
Independent subclasses of :class:`TorchDispatchMode` are compositional:
|
|
modes can be pushed onto a stack using ``with MyMode():``.
|
|
When you call functions in the PyTorch API inside your
|
|
``__torch_dispatch__`` implementation, by default, they will forward on to
|
|
the next mode on the mode stack. If you want recursively call back into
|
|
your current ``__torch_dispatch__`` implementation, either explicitly
|
|
invoke ``self.__torch_dispatch__(...)``, or use the context manager
|
|
``__torch_dispatch__(self)`` to make PyTorch
|
|
API self-referential (beware of infinite loops, in this case!)
|
|
"""
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
raise NotImplementedError()
|
|
|
|
def __enter__(self):
|
|
_push_mode(self)
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
_pop_mode()
|
|
|
|
@classmethod
|
|
def push(cls, *args, **kwargs):
|
|
warnings.warn("`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`")
|
|
instance = cls(*args, **kwargs)
|
|
return instance
|
|
|
|
def _get_current_dispatch_mode():
|
|
stack_len = _len_torch_dispatch_stack()
|
|
return _get_dispatch_stack_at(stack_len - 1) if stack_len > 0 else None
|
|
|
|
|
|
def _get_current_dispatch_mode_stack():
|
|
stack_len = _len_torch_dispatch_stack()
|
|
return [_get_dispatch_stack_at(i) for i in range(stack_len)]
|
|
|
|
def _push_mode(mode):
|
|
_push_on_torch_dispatch_stack(mode)
|
|
|
|
|
|
def _pop_mode():
|
|
return _pop_torch_dispatch_stack()
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _pop_mode_temporarily():
|
|
old = _pop_mode()
|
|
try:
|
|
yield old
|
|
finally:
|
|
_push_mode(old)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def _disable_current_modes():
|
|
mode_len = _len_torch_dispatch_stack()
|
|
old_modes = [_pop_mode() for _ in range(mode_len)]
|
|
try:
|
|
yield old_modes
|
|
finally:
|
|
for mode in reversed(old_modes):
|
|
_push_mode(mode)
|
|
|
|
|
|
class BaseTorchDispatchMode(TorchDispatchMode):
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
return func(*args, **kwargs)
|