mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64360 This PR adds a (private) enable_python_mode context manager. (see torch/utils/_python_dispatch.py). enable_python_mode accepts the type of a __torch_dispatch__ object as its argument. Whenever an operator gets called inside of the context manager, it dispatches to the __torch_dispatch__ of the passed-in type. Example usage: ``` with enable_python_mode(LoggingTensor): z = torch.empty([]) assert isinstance(z, LoggingTensor) ``` There are quite a few changes that were made to support this. First, we added TorchDispatchTypeObject, a C++ struct that represents the type of a `__torch_dispatch__` object (e.g. LoggingTensor). It holds both the PyObject* representing the class and a PyInterpreter* so we know which Python interpreter it came from. Next, we updated the concrete_dispatch_fn in python_variable.cpp to accept a `const std::shared_ptr<TorchDispatchTypeObject>&` argument. When this is null, dispatching happens as usual. When it is non-null, we prepend the TorchDispatchTypeObject's PyObject* to the overloaded args list so that it is considered first for dispatch. To get that to work, we changed how `handle_torch_dispatch_no_python_arg_parser` works. The "overloaded args list" previously only consisted of Tensor PyObjects, but now it can have types in addition to Tensors! - We renamed `append_overloaded_arg` to `append_overloaded_arg` - We added a new `append_overloaded_type` that appends a type to overloaded_args - We added special handling in `handle_torch_dispatch_no_python_arg_parser` and `append_overloaded_arg` to handle types in addition to Tensors. Then, there is PythonMode and PythonModeTLS. - We reuse the DispatchKey::Python dispatch key as a mode key - We use PythonMode::enter and PythonMode::exit to enable/disable DispatchKey::Python and set the PythonModeTLS. - PythonModeTLS stores a TorchDispatchTypeObject as metadata. - PythonMode is in libtorch_python, and PythonModeTLS is in ATen. This split is due to the libtorch_python library boundary (because we need to save TLS in ATen/ThreadLocalState) - We modify the PythonFallbackKernel to look up the relevant TorchDispatchTypeObject (if Python Mode is active) and dispatch using it. There are two more miscellaneous changes: - internal_new_from_data (torch/csrc/utils/tensor_new.cpp) gets an exclude guard. enable_python_mode currently does not handle torch.tensor and the exclude guard is to prevent a bug. Future: - This PR does not allow for the nesting of Python modes. In the future we should be able to enable this with a more sane no_dispatch API and by changing the TLS to a stack. For now I did not need this for CompositeImplicitAutograd testing. Test Plan: - new tests Reviewed By: ezyang Differential Revision: D30698082 Pulled By: zou3519 fbshipit-source-id: 7094a90eee6aa51f8b71bc4d91cfb6f49e9691f8
35 lines
1.6 KiB
Python
35 lines
1.6 KiB
Python
import torch
|
|
import contextlib
|
|
from typing import Iterator
|
|
|
|
# Context manager that causes all pytorch operators to dispatch to the passed-in
|
|
# type's __torch_dispatch__ function.
|
|
# operation that accepts no tensors but returns a tensor.
|
|
#
|
|
# enable_python_mode is affected by torch._C._DisableTorchDispatch.
|
|
#
|
|
# NB: Calling an operator inside __torch_dispatch__ does go through
|
|
# __torch_dispatch__ again. Please use _DisableTorchDispatch inside
|
|
# __torch_dispatch__ to prevent infinite recursion.
|
|
#
|
|
# TODO: Limitations and things about enable_python_mode we should fix before exposing it:
|
|
# - it currently cannot be nested. This should be simple to implement; we need a
|
|
# stack of TorchDispatchTypeObjects and the next bullet point.
|
|
# - We need a better user-facing api for torch._C._DisableTorchDispatch that
|
|
# is able to selectively disable __torch_dispatch__ of a particular class.
|
|
# - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor)
|
|
# - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694)
|
|
@contextlib.contextmanager
|
|
def enable_python_mode(cls) -> Iterator[None]:
|
|
if not hasattr(cls, '__torch_dispatch__'):
|
|
raise ValueError('The class passed to enable_python_mode '
|
|
'must have a __torch_dispatch__ classmethod')
|
|
if not isinstance(cls, type) or not issubclass(cls, (torch.Tensor,)):
|
|
raise ValueError('The argument passed to enable_python_mode '
|
|
'must be the type of a Tensor subclass')
|
|
torch._C._enter_python_mode(cls)
|
|
try:
|
|
yield
|
|
finally:
|
|
torch._C._exit_python_mode()
|