mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Signed-off-by: Edward Z. Yang <ezyangfb.com> From @ezyang's original PR: There are a number of situations where we have non-backend kernels (e.g., CompositeImplicitAutograd, batching rules) which we would like to port to Python, but we have no way to integrate these ports with the overall system while using preexisting C++ registrations otherwise. This PR changes that by introducing a Python dispatcher (which can have its own kernels directly in Python), which can be interpose over ordinary C++ dispatch. The ingredients: We introduce a new PythonDispatcher dispatch key, that has the same tenor as FuncTorchDynamicLayerFrontMode: it works by getting triggered before every other dispatch key in the dispatch key, and shunting to a Python implementation The Python dispatcher is a per-interpreter global object that is enabled/disabled via the guard EnablePythonDispatcher/DisablePythonDispatcher. We don't make it compositional as I have no idea what a compositional version of this feature would look like. Because it is global, we don't need to memory manage it and so I use a simpler SafePyHandle (newly added) to control access to this pointer from non-Python C++. Like __torch_dispatch__, we use PyInterpreter to get to the Python interpreter to handle the dispatch. I need to reimplement dispatch table computation logic in Python. To do this, I expose a lot more helper functions for doing computations on alias dispatch keys and similar. I also improve the pybind11 handling for DispatchKey so that you can either accept the pybind11 bound enum or a string; this simplifies our binding code. See https://github.com/pybind/pybind11/issues/483#issuecomment-1237418106 for how this works; the technique is generally useful. I need to be able to call backend fallbacks. I do this by permitting you to call at a dispatch key which doesn't have a kernel for the operator; if the kernel doesn't exist, we check the backend fallback table instead. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/84826 Approved by: https://github.com/ezyang
28 lines
598 B
C++
28 lines
598 B
C++
#pragma once
|
|
|
|
#include <c10/core/SafePyObject.h>
|
|
#include <c10/macros/Macros.h>
|
|
#include <c10/util/Optional.h>
|
|
|
|
namespace c10 {
|
|
namespace impl {
|
|
|
|
struct C10_API PythonDispatcherTLS {
|
|
static void set_state(SafePyHandle state);
|
|
static SafePyHandle get_state();
|
|
static void reset_state();
|
|
};
|
|
|
|
struct C10_API DisablePythonDispatcher {
|
|
DisablePythonDispatcher() : old_(PythonDispatcherTLS::get_state()) {
|
|
PythonDispatcherTLS::set_state({});
|
|
}
|
|
~DisablePythonDispatcher() {
|
|
PythonDispatcherTLS::set_state(old_);
|
|
}
|
|
c10::SafePyHandle old_;
|
|
};
|
|
|
|
} // namespace impl
|
|
} // namespace c10
|