mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 00:54:56 +08:00
Python Dispatcher integration with C++ dispatcher (#84826)
Signed-off-by: Edward Z. Yang <ezyangfb.com> From @ezyang's original PR: There are a number of situations where we have non-backend kernels (e.g., CompositeImplicitAutograd, batching rules) which we would like to port to Python, but we have no way to integrate these ports with the overall system while using preexisting C++ registrations otherwise. This PR changes that by introducing a Python dispatcher (which can have its own kernels directly in Python), which can be interpose over ordinary C++ dispatch. The ingredients: We introduce a new PythonDispatcher dispatch key, that has the same tenor as FuncTorchDynamicLayerFrontMode: it works by getting triggered before every other dispatch key in the dispatch key, and shunting to a Python implementation The Python dispatcher is a per-interpreter global object that is enabled/disabled via the guard EnablePythonDispatcher/DisablePythonDispatcher. We don't make it compositional as I have no idea what a compositional version of this feature would look like. Because it is global, we don't need to memory manage it and so I use a simpler SafePyHandle (newly added) to control access to this pointer from non-Python C++. Like __torch_dispatch__, we use PyInterpreter to get to the Python interpreter to handle the dispatch. I need to reimplement dispatch table computation logic in Python. To do this, I expose a lot more helper functions for doing computations on alias dispatch keys and similar. I also improve the pybind11 handling for DispatchKey so that you can either accept the pybind11 bound enum or a string; this simplifies our binding code. See https://github.com/pybind/pybind11/issues/483#issuecomment-1237418106 for how this works; the technique is generally useful. I need to be able to call backend fallbacks. I do this by permitting you to call at a dispatch key which doesn't have a kernel for the operator; if the kernel doesn't exist, we check the backend fallback table instead. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/84826 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
44c30c5d1c
commit
35f6a69191
@ -1713,13 +1713,11 @@ void initJITBindings(PyObject* module) {
|
||||
return _get_operation_for_overload_or_packet(
|
||||
{op}, symbol, args, kwargs, /*is_overload*/ true);
|
||||
});
|
||||
auto func_dk =
|
||||
py::cpp_function([op, symbol, allow_numbers_as_tensors](
|
||||
const std::string& str_dk,
|
||||
py::args args,
|
||||
py::kwargs kwargs) {
|
||||
auto func_dk = py::cpp_function(
|
||||
[op, symbol, allow_numbers_as_tensors](
|
||||
c10::DispatchKey dk_, py::args args, py::kwargs kwargs) {
|
||||
c10::optional<c10::DispatchKey> dk =
|
||||
c10::make_optional(c10::parseDispatchKey(str_dk));
|
||||
c10::make_optional(dk_);
|
||||
ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors);
|
||||
return _get_operation_for_overload_or_packet(
|
||||
{op}, symbol, args, kwargs, /*is_overload*/ true, dk);
|
||||
|
||||
Reference in New Issue
Block a user