mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Python Dispatcher integration with C++ dispatcher (#84826)
Signed-off-by: Edward Z. Yang <ezyangfb.com> From @ezyang's original PR: There are a number of situations where we have non-backend kernels (e.g., CompositeImplicitAutograd, batching rules) which we would like to port to Python, but we have no way to integrate these ports with the overall system while using preexisting C++ registrations otherwise. This PR changes that by introducing a Python dispatcher (which can have its own kernels directly in Python), which can be interpose over ordinary C++ dispatch. The ingredients: We introduce a new PythonDispatcher dispatch key, that has the same tenor as FuncTorchDynamicLayerFrontMode: it works by getting triggered before every other dispatch key in the dispatch key, and shunting to a Python implementation The Python dispatcher is a per-interpreter global object that is enabled/disabled via the guard EnablePythonDispatcher/DisablePythonDispatcher. We don't make it compositional as I have no idea what a compositional version of this feature would look like. Because it is global, we don't need to memory manage it and so I use a simpler SafePyHandle (newly added) to control access to this pointer from non-Python C++. Like __torch_dispatch__, we use PyInterpreter to get to the Python interpreter to handle the dispatch. I need to reimplement dispatch table computation logic in Python. To do this, I expose a lot more helper functions for doing computations on alias dispatch keys and similar. I also improve the pybind11 handling for DispatchKey so that you can either accept the pybind11 bound enum or a string; this simplifies our binding code. See https://github.com/pybind/pybind11/issues/483#issuecomment-1237418106 for how this works; the technique is generally useful. I need to be able to call backend fallbacks. I do this by permitting you to call at a dispatch key which doesn't have a kernel for the operator; if the kernel doesn't exist, we check the backend fallback table instead. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/84826 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
44c30c5d1c
commit
35f6a69191
@ -14,6 +14,7 @@ ThreadLocalState::ThreadLocalState()
|
||||
debug_info_(c10::ThreadLocalDebugInfo::current()),
|
||||
functorch_tls_(functorch::getCopyOfFuncTorchTLS()),
|
||||
autograd_tls_(c10::AutogradState::get_tls_state()),
|
||||
python_dispatcher_state_(c10::impl::PythonDispatcherTLS::get_state()),
|
||||
python_torch_function_state_(at::impl::PythonTorchFunctionTLS::get_state()) {
|
||||
rf_tls_ = at::get_record_function_tls_();
|
||||
|
||||
@ -41,6 +42,8 @@ void ThreadLocalState::setThreadLocalState(
|
||||
|
||||
at::SavedTensorDefaultHooks::set_stack(state.saved_tensors_default_hooks_);
|
||||
|
||||
c10::impl::PythonDispatcherTLS::set_state(state.python_dispatcher_state_);
|
||||
|
||||
c10::ThreadLocalDebugInfo::_forceCurrentDebugInfo(state.debug_info_);
|
||||
|
||||
c10::impl::_force_tls_local_dispatch_key_set(state.dispatch_key_);
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include <ATen/FuncTorchTLS.h>
|
||||
#include <ATen/PythonTorchFunctionTLS.h>
|
||||
#include <ATen/record_function.h>
|
||||
#include <c10/core/impl/PythonDispatcherTLS.h>
|
||||
#include <c10/core/impl/TorchDispatchModeTLS.h>
|
||||
|
||||
namespace at {
|
||||
@ -57,6 +58,9 @@ class TORCH_API ThreadLocalState {
|
||||
// TLS for enable_torch_dispatch_mode
|
||||
std::shared_ptr<SafePyObject> torch_dispatch_mode_state_;
|
||||
|
||||
// TLS for enable_python_dispatcher
|
||||
SafePyHandle python_dispatcher_state_;
|
||||
|
||||
// TLS for __torch_function__ (mode and disable_torch_function)
|
||||
at::impl::PythonTorchFunctionTLS python_torch_function_state_;
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include <c10/core/impl/TorchDispatchModeTLS.h>
|
||||
#include <c10/core/impl/PythonDispatcherTLS.h>
|
||||
#include <ATen/core/PythonFallbackKernel.h>
|
||||
#include <c10/core/SafePyObject.h>
|
||||
|
||||
@ -87,6 +88,12 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
|
||||
TORCH_INTERNAL_ASSERT(0, "Hit Python dispatch key but no arguments had PyInterpreter (no tensor args?)");
|
||||
}
|
||||
|
||||
void pythonDispatcherFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
|
||||
auto state = c10::impl::PythonDispatcherTLS::get_state();
|
||||
TORCH_INTERNAL_ASSERT(state, "Hit PythonDispatcher dispatch key but PythonDispatcherTLS was not set");
|
||||
state.pyinterpreter()->python_dispatcher(op, dispatch_keys.remove(c10::DispatchKey::PythonDispatcher), stack);
|
||||
}
|
||||
|
||||
void pythonTLSSnapshotFallback(const c10::OperatorHandle &op, c10::DispatchKeySet dispatch_keys, torch::jit::Stack* stack) {
|
||||
// It is ok for the tls to be already set here.
|
||||
// It means that there are multiple calls into the dispatcher not originating from python code.
|
||||
@ -134,6 +141,10 @@ TORCH_LIBRARY_IMPL(_, Python, m) {
|
||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonFallback>());
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, PythonDispatcher, m) {
|
||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonDispatcherFallback>());
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, PythonTLSSnapshot, m) {
|
||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&pythonTLSSnapshotFallback>());
|
||||
}
|
||||
|
@ -168,6 +168,12 @@ public:
|
||||
// See Note [Plumbing Keys Through The Dispatcher]
|
||||
void redispatchBoxed(const OperatorHandle& op, DispatchKeySet dispatchKeySet, Stack* stack) const;
|
||||
|
||||
bool hasBackendFallbackForDispatchKey(DispatchKey dk) {
|
||||
auto dispatch_ix = getDispatchTableIndexForDispatchKey(dk);
|
||||
if (dispatch_ix < 0) return false;
|
||||
return backendFallbackKernels_[dispatch_ix].kernel.isValid();
|
||||
}
|
||||
|
||||
|
||||
// ------------------------------------------------------------------------
|
||||
//
|
||||
@ -333,6 +339,10 @@ public:
|
||||
return operatorDef_->op.hasKernelForDispatchKey(k);
|
||||
}
|
||||
|
||||
bool hasKernelForAnyDispatchKey(DispatchKeySet k) const {
|
||||
return operatorDef_->op.hasKernelForAnyDispatchKey(k);
|
||||
}
|
||||
|
||||
bool hasComputedKernelForDispatchKey(DispatchKey k) const {
|
||||
return operatorDef_->op.hasComputedKernelForDispatchKey(k);
|
||||
}
|
||||
@ -635,11 +645,18 @@ inline void Dispatcher::callBoxedForDispatchKey(const OperatorHandle& op, Dispat
|
||||
// We still compute this as we're obligated to pass it on to the internal
|
||||
// kernel, if it is a boxed fallback
|
||||
auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack);
|
||||
const auto& kernel = entry.kernelForDispatchKey(dk);
|
||||
const auto& kernel = ([&]() {
|
||||
if (op.hasKernelForDispatchKey(dk)) {
|
||||
return entry.kernelForDispatchKey(dk);
|
||||
} else {
|
||||
auto idx = getDispatchTableIndexForDispatchKey(dk);
|
||||
TORCH_INTERNAL_ASSERT(idx >= 0);
|
||||
return backendFallbackKernels_[idx].kernel;
|
||||
}
|
||||
})();
|
||||
kernel.callBoxed(op, dispatchKeySet, stack);
|
||||
}
|
||||
|
||||
|
||||
inline void Dispatcher::redispatchBoxed(const OperatorHandle& op, DispatchKeySet dispatchKeySet, Stack* stack) const {
|
||||
// note: this doesn't need the mutex because write operations on the list keep iterators intact.
|
||||
const auto& entry = op.operatorDef_->op;
|
||||
|
@ -172,6 +172,9 @@ const char* toString(DispatchKey t) {
|
||||
case DispatchKey::TESTING_ONLY_GenericMode:
|
||||
return "TESTING_ONLY_GenericMode";
|
||||
|
||||
case DispatchKey::PythonDispatcher:
|
||||
return "PythonDispatcher";
|
||||
|
||||
// Aliases
|
||||
|
||||
case DispatchKey::Autograd:
|
||||
@ -283,6 +286,7 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
|
||||
{"TESTING_ONLY_GenericWrapper",
|
||||
c10::DispatchKey::TESTING_ONLY_GenericWrapper},
|
||||
{"TESTING_ONLY_GenericMode", c10::DispatchKey::TESTING_ONLY_GenericMode},
|
||||
{"PythonDispatcher", c10::DispatchKey::PythonDispatcher},
|
||||
|
||||
{"CPU", c10::DispatchKey::CPU},
|
||||
{"CUDA", c10::DispatchKey::CUDA},
|
||||
|
@ -401,6 +401,10 @@ enum class DispatchKey : uint16_t {
|
||||
// for a usage example
|
||||
TESTING_ONLY_GenericMode,
|
||||
|
||||
// This is a bypass that allows you to skip running the C++ dispatcher
|
||||
// entirely
|
||||
PythonDispatcher,
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
|
||||
EndOfFunctionalityKeys, // End of functionality keys.
|
||||
|
||||
|
@ -172,7 +172,9 @@ class DispatchKeySet final {
|
||||
(1ULL
|
||||
<< (num_backends + static_cast<uint8_t>(toFunctionalityKey(t)) -
|
||||
1)) -
|
||||
1) {}
|
||||
1) {
|
||||
*this = add(DispatchKey::PythonDispatcher);
|
||||
}
|
||||
|
||||
// Public version of DispatchKeySet(uint64_t) API; external users
|
||||
// must be explicit when they do this!
|
||||
|
@ -8,4 +8,9 @@ PyObject* SafePyObject::ptr(const c10::impl::PyInterpreter* interpreter) const {
|
||||
return data_;
|
||||
}
|
||||
|
||||
PyObject* SafePyHandle::ptr(const c10::impl::PyInterpreter* interpreter) const {
|
||||
TORCH_INTERNAL_ASSERT(interpreter == pyinterpreter_);
|
||||
return data_;
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
@ -42,4 +42,29 @@ struct C10_API SafePyObject {
|
||||
c10::impl::PyInterpreter* pyinterpreter_;
|
||||
};
|
||||
|
||||
// Like SafePyObject, but non-owning. Good for references to global PyObjects
|
||||
// that will be leaked on interpreter exit. You get a copy constructor/assign
|
||||
// this way.
|
||||
struct C10_API SafePyHandle {
|
||||
SafePyHandle() : data_(nullptr), pyinterpreter_(nullptr) {}
|
||||
SafePyHandle(PyObject* data, c10::impl::PyInterpreter* pyinterpreter)
|
||||
: data_(data), pyinterpreter_(pyinterpreter) {}
|
||||
|
||||
c10::impl::PyInterpreter& pyinterpreter() const {
|
||||
return *pyinterpreter_;
|
||||
}
|
||||
PyObject* ptr(const c10::impl::PyInterpreter*) const;
|
||||
void reset() {
|
||||
data_ = nullptr;
|
||||
pyinterpreter_ = nullptr;
|
||||
}
|
||||
operator bool() {
|
||||
return data_;
|
||||
}
|
||||
|
||||
private:
|
||||
PyObject* data_;
|
||||
c10::impl::PyInterpreter* pyinterpreter_;
|
||||
};
|
||||
|
||||
} // namespace c10
|
||||
|
@ -27,6 +27,13 @@ struct NoopPyInterpreterVTable final : public PyInterpreterVTable {
|
||||
PANIC(dispatch);
|
||||
}
|
||||
|
||||
void python_dispatcher(
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKeySet,
|
||||
torch::jit::Stack* stack) const override {
|
||||
PANIC(python_dispatcher);
|
||||
}
|
||||
|
||||
bool is_contiguous(const TensorImpl* self) const override {
|
||||
PANIC(is_contiguous);
|
||||
}
|
||||
|
@ -140,6 +140,12 @@ struct C10_API PyInterpreterVTable {
|
||||
virtual void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack)
|
||||
const = 0;
|
||||
|
||||
// Invoke the Python dispatcher to handle this call
|
||||
virtual void python_dispatcher(
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKeySet,
|
||||
torch::jit::Stack* stack) const = 0;
|
||||
|
||||
virtual bool is_contiguous(const TensorImpl* self) const = 0;
|
||||
virtual c10::Device device(const TensorImpl* self) const = 0;
|
||||
virtual int64_t dim(const TensorImpl* self) const = 0;
|
||||
|
32
c10/core/impl/PythonDispatcherTLS.cpp
Normal file
32
c10/core/impl/PythonDispatcherTLS.cpp
Normal file
@ -0,0 +1,32 @@
|
||||
#include <c10/core/DispatchKeySet.h>
|
||||
#include <c10/core/SafePyObject.h>
|
||||
#include <c10/core/impl/LocalDispatchKeySet.h>
|
||||
#include <c10/core/impl/PythonDispatcherTLS.h>
|
||||
|
||||
namespace c10 {
|
||||
namespace impl {
|
||||
|
||||
thread_local SafePyHandle pythonDispatcherState;
|
||||
|
||||
void PythonDispatcherTLS::set_state(SafePyHandle state) {
|
||||
if (state) {
|
||||
c10::impl::tls_set_dispatch_key_included(
|
||||
DispatchKey::PythonDispatcher, true);
|
||||
} else {
|
||||
PythonDispatcherTLS::reset_state();
|
||||
}
|
||||
pythonDispatcherState = state;
|
||||
}
|
||||
|
||||
SafePyHandle PythonDispatcherTLS::get_state() {
|
||||
return pythonDispatcherState;
|
||||
}
|
||||
|
||||
void PythonDispatcherTLS::reset_state() {
|
||||
pythonDispatcherState.reset();
|
||||
c10::impl::tls_set_dispatch_key_included(
|
||||
DispatchKey::PythonDispatcher, false);
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
} // namespace c10
|
27
c10/core/impl/PythonDispatcherTLS.h
Normal file
27
c10/core/impl/PythonDispatcherTLS.h
Normal file
@ -0,0 +1,27 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/SafePyObject.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
namespace c10 {
|
||||
namespace impl {
|
||||
|
||||
struct C10_API PythonDispatcherTLS {
|
||||
static void set_state(SafePyHandle state);
|
||||
static SafePyHandle get_state();
|
||||
static void reset_state();
|
||||
};
|
||||
|
||||
struct C10_API DisablePythonDispatcher {
|
||||
DisablePythonDispatcher() : old_(PythonDispatcherTLS::get_state()) {
|
||||
PythonDispatcherTLS::set_state({});
|
||||
}
|
||||
~DisablePythonDispatcher() {
|
||||
PythonDispatcherTLS::set_state(old_);
|
||||
}
|
||||
c10::SafePyHandle old_;
|
||||
};
|
||||
|
||||
} // namespace impl
|
||||
} // namespace c10
|
@ -1,11 +1,12 @@
|
||||
import torch
|
||||
from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard
|
||||
from functorch.experimental.ops import PyOperator, fallthrough_fn
|
||||
from torch._ops import PyOperator
|
||||
from torch.utils._pytree import tree_flatten
|
||||
from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule, get_proxy_slot
|
||||
import torch.utils._pytree as pytree
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.fx.experimental.proxy_tensor import track_tensor_tree
|
||||
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
|
||||
|
||||
|
||||
"""
|
||||
@ -14,6 +15,9 @@ In order to do this, we need implementations for each of the dispatch keys.
|
||||
"""
|
||||
from contextlib import contextmanager
|
||||
|
||||
cond = PyOperator('cond')
|
||||
|
||||
|
||||
# TODO(voz): Move out somewhere else once other py dispatched ops need it
|
||||
@contextmanager
|
||||
def suspend_mode(mode):
|
||||
@ -98,6 +102,7 @@ def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
||||
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
|
||||
|
||||
|
||||
@cond.py_impl(DispatchKey.CPU)
|
||||
def cond_dense(pred, true_fn, false_fn, operands):
|
||||
mode = torch._C._get_torch_dispatch_mode()
|
||||
assert (mode is None), "Mode should never be enabled for CPU key"
|
||||
@ -107,6 +112,7 @@ def cond_dense(pred, true_fn, false_fn, operands):
|
||||
return false_fn(*operands)
|
||||
|
||||
|
||||
@cond.py_impl(DispatchKey.AutogradCPU)
|
||||
def cond_autograd(pred, true_fn, false_fn, *operands):
|
||||
# TODO: support autograd
|
||||
flat_operands, _ = tree_flatten([true_fn, false_fn] + [operands])
|
||||
@ -117,21 +123,16 @@ def cond_autograd(pred, true_fn, false_fn, *operands):
|
||||
return cond(pred, true_fn, false_fn, *operands)
|
||||
|
||||
|
||||
def python_fallback(op):
|
||||
def inner(pred, true_fn, false_fn, operands):
|
||||
mode = torch._C._get_torch_dispatch_mode()
|
||||
assert (mode is not None), "Mode should always be enabled for python fallback key"
|
||||
with suspend_mode(mode):
|
||||
res = trace_cond(mode, op, pred, true_fn, false_fn, operands)
|
||||
return res
|
||||
|
||||
return inner
|
||||
@cond.py_impl(ProxyTorchDispatchMode)
|
||||
def inner(pred, true_fn, false_fn, operands):
|
||||
mode = torch._C._get_torch_dispatch_mode()
|
||||
assert (mode is not None), "Mode should always be enabled for python fallback key"
|
||||
with suspend_mode(mode):
|
||||
res = trace_cond(mode, cond, pred, true_fn, false_fn, operands)
|
||||
return res
|
||||
|
||||
|
||||
cond = PyOperator('cond')
|
||||
cond.impl(DispatchKey.CPU, cond_dense)
|
||||
cond.impl(DispatchKey.Python, python_fallback(cond))
|
||||
cond.impl(DispatchKey.PythonTLSSnapshot, fallthrough_fn)
|
||||
cond.impl(DispatchKey.AutogradCPU, cond_autograd)
|
||||
cond.impl(DispatchKey.ADInplaceOrView, fallthrough_fn)
|
||||
cond.impl(DispatchKey.BackendSelect, fallthrough_fn)
|
||||
# TODO(voz): Make this automatic for keys, this is very ugly atm
|
||||
cond.fallthrough(DispatchKey.PythonTLSSnapshot)
|
||||
cond.fallthrough(DispatchKey.ADInplaceOrView)
|
||||
cond.fallthrough(DispatchKey.BackendSelect)
|
||||
|
@ -1,36 +1 @@
|
||||
from torch._dispatch._dispatcher import PyDispatcher, to_flat_tuple, compute_keyset
|
||||
from torch.nn.functional import handle_torch_function
|
||||
from torch.overrides import has_torch_function
|
||||
import torch._C as _C
|
||||
|
||||
class PyOperator:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.table = {}
|
||||
|
||||
self.__name__ = name
|
||||
|
||||
def impl(self, dispatch_key, fn):
|
||||
assert dispatch_key not in self.table
|
||||
if fn is fallthrough_fn:
|
||||
self.table[dispatch_key] = fn(self, dispatch_key)
|
||||
else:
|
||||
self.table[dispatch_key] = fn
|
||||
|
||||
def lookup(self, keyset):
|
||||
dispatch_key = keyset.highestPriorityTypeId()
|
||||
return self.table[dispatch_key]
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
flat_args = to_flat_tuple(args, kwargs)
|
||||
if has_torch_function(flat_args):
|
||||
return handle_torch_function(self, flat_args, *args, **kwargs)
|
||||
|
||||
return PyDispatcher.call(self, *args, **kwargs)
|
||||
|
||||
def fallthrough_fn(operator, dispatch_key):
|
||||
def inner(*args, **kwargs):
|
||||
all_keys_after_current = _C._dispatch_keyset_full_after(dispatch_key)
|
||||
all_keys_after_current_masked = all_keys_after_current & compute_keyset(args, kwargs)
|
||||
return PyDispatcher.redispatch(operator, all_keys_after_current_masked, *args, **kwargs)
|
||||
return inner
|
||||
from torch._ops import PyOperator # noqa: F401
|
||||
|
@ -19,7 +19,7 @@ class TestControlFlow(TestCase):
|
||||
|
||||
|
||||
class TestControlFlowTraced(TestCase):
|
||||
def test_cond_traced(self):
|
||||
def test_cond_traced_not_nested(self):
|
||||
def true_fn(x):
|
||||
return x.sin()
|
||||
|
||||
@ -126,14 +126,12 @@ class TestControlFlowTraced(TestCase):
|
||||
def forward(self, x_1, pred_1, pred2_1):
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
conditional = functorch_experimental_ops_cond(pred_1,
|
||||
true_graph_0, false_graph_0, [[x_1]]); pred_1 = true_graph_0 = false_graph_0 = None
|
||||
conditional = torch.ops.pyop.cond(pred_1, true_graph_0, false_graph_0, [[x_1]]); pred_1 = true_graph_0 = false_graph_0 = None # noqa: B950
|
||||
true_graph_1 = self.true_graph_1
|
||||
false_graph_1 = self.false_graph_1
|
||||
conditional_1 = functorch_experimental_ops_cond(pred2_1,
|
||||
true_graph_1, false_graph_1, [[x_1, x_1]]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
|
||||
add_tensor = torch.ops.aten.add.Tensor(conditional, conditional_1); conditional = conditional_1 = None
|
||||
return add_tensor
|
||||
conditional_1 = torch.ops.pyop.cond(pred2_1, true_graph_1, false_graph_1, [[x_1, x_1]]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None # noqa: B950
|
||||
add = torch.ops.aten.add.Tensor(conditional, conditional_1); conditional = conditional_1 = None
|
||||
return add
|
||||
"""
|
||||
code = graph.code
|
||||
# Normalization hack, cause .code makes some weird whitespace
|
||||
@ -145,8 +143,8 @@ class TestControlFlowTraced(TestCase):
|
||||
out = """
|
||||
def forward(self, flat_args):
|
||||
flat_args_1, = fx_pytree.tree_flatten_spec([flat_args], self._in_spec)
|
||||
mul_tensor = torch.ops.aten.mul.Tensor(flat_args_1, flat_args_1); flat_args_1 = None
|
||||
return pytree.tree_unflatten([mul_tensor], self._out_spec)
|
||||
mul = torch.ops.aten.mul.Tensor(flat_args_1, flat_args_1); flat_args_1 = None
|
||||
return pytree.tree_unflatten([mul], self._out_spec)
|
||||
"""
|
||||
# Normalization hack, cause .code makes some weird whitespace
|
||||
code = "".join(code.split())
|
||||
|
@ -327,7 +327,7 @@ class TestPrims(TestCase):
|
||||
# Check that all call_function nodes are prims
|
||||
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
|
||||
all_prims_namespace = all(
|
||||
node.target.name.startswith("prims") for node in call_function_nodes
|
||||
node.target.name().startswith("prims") for node in call_function_nodes
|
||||
)
|
||||
self.assertTrue(all_prims_namespace)
|
||||
|
||||
|
@ -952,7 +952,6 @@ symbolic_tensor_failures = {
|
||||
xfail('bfloat16', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('bmm', ''), # aten.bmm.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('bool', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('broadcast_tensors', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('bucketize', ''), # aten.bucketize.Tensor - couldn't find symbolic meta function/decomposition
|
||||
xfail('byte', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('cartesian_prod', ''), # Tensors of type TensorImpl do not have numel
|
||||
@ -1120,7 +1119,6 @@ symbolic_tensor_failures = {
|
||||
xfail('nn.functional.grid_sample', ''), # aten.grid_sampler_2d.default - couldn't find symbolic meta function/decompos...
|
||||
xfail('nn.functional.group_norm', ''), # 'torch._C.SymIntNode' and 'int'
|
||||
xfail('nn.functional.hinge_embedding_loss', ''), # aten.empty_like.default - couldn't find symbolic meta function/deco...
|
||||
xfail('nn.functional.huber_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('nn.functional.instance_norm', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('nn.functional.interpolate', 'area'), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('nn.functional.interpolate', 'bicubic'), # aten.upsample_bicubic2d.vec - couldn't find symbolic meta function/d...
|
||||
@ -1128,7 +1126,6 @@ symbolic_tensor_failures = {
|
||||
xfail('nn.functional.interpolate', 'linear'), # aten.upsample_linear1d.vec - couldn't find symbolic meta function/dec...
|
||||
xfail('nn.functional.interpolate', 'nearest'), # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/d...
|
||||
xfail('nn.functional.interpolate', 'trilinear'), # aten.upsample_trilinear3d.vec - couldn't find symbolic meta functi...
|
||||
xfail('nn.functional.l1_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('nn.functional.linear', ''), # aten.mv.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('nn.functional.local_response_norm', ''), # Tensors of type TensorImpl do not have numel
|
||||
xfail('nn.functional.margin_ranking_loss', ''), # The underlying op of 'aten.stride' has no overload name '_schema'
|
||||
@ -1138,7 +1135,6 @@ symbolic_tensor_failures = {
|
||||
xfail('nn.functional.max_unpool1d', 'grad'), # aten.max_unpool2d.default - couldn't find symbolic meta function/decom...
|
||||
xfail('nn.functional.max_unpool2d', 'grad'), # aten.max_unpool2d.default - couldn't find symbolic meta function/decom...
|
||||
xfail('nn.functional.max_unpool3d', 'grad'), # aten.max_unpool3d.default - couldn't find symbolic meta function/decom...
|
||||
xfail('nn.functional.mse_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('nn.functional.multi_margin_loss', ''), # Could not run 'aten::multi_margin_loss' with arguments from the...
|
||||
xfail('nn.functional.multilabel_margin_loss', ''), # Could not run 'aten::multilabel_margin_loss_forward' with ...
|
||||
xfail('nn.functional.normalize', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
|
@ -1833,5 +1833,11 @@ $1 = torch._ops.aten.add.Tensor($0, $0)""")
|
||||
e = LayoutDefaultReturn(torch.randn(4, 2), use_wrapper_subclass)
|
||||
self.assertEqual(e.layout, torch.strided)
|
||||
|
||||
class TestPythonDispatcher(TestCase):
|
||||
def test_basic(self):
|
||||
x = torch.randn(2, requires_grad=True)
|
||||
r = torch._C._EnablePythonDispatcher()
|
||||
torch.add(x, x)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
@ -136,6 +136,9 @@ blocklist = [
|
||||
"floor_divide",
|
||||
"floor_divide_",
|
||||
"floor_divide_out",
|
||||
"to",
|
||||
"_to_copy",
|
||||
"copy_",
|
||||
]
|
||||
|
||||
binary_ops = (
|
||||
|
@ -997,11 +997,26 @@ class Generator(object):
|
||||
def _dispatch_library(kind: str, name: str, dispatch: str, file: str = "", linenum: Any = 0) -> Any: ...
|
||||
def _dispatch_has_kernel_for_dispatch_key(name: str, dispatch: str) -> _bool: ...
|
||||
def _dispatch_has_computed_kernel_for_dispatch_key(name: str, dispatch: str) -> _bool: ...
|
||||
def _dispatch_has_kernel_for_any_dispatch_key(name: str, dispatch_key_set: object) -> _bool: ...
|
||||
def _dispatch_has_kernel(name: str) -> _bool: ...
|
||||
def _dispatch_tls_is_dispatch_key_excluded(dispatch: str) -> _bool: ...
|
||||
def _dispatch_tls_set_dispatch_key_excluded(dispatch: str, val: _bool) -> None: ...
|
||||
def _dispatch_isTensorSubclassLike(tensor: Tensor) -> _bool: ...
|
||||
def _dispatch_dump(dispatch: str) -> str: ...
|
||||
def _dispatch_is_included_in_alias(dispatch_a: str, dispatch_b: str) -> _bool: ...
|
||||
def _dispatch_get_backend_keyset_from_autograd(dispatch: str) -> Any: ...
|
||||
def _dispatch_has_backend_fallback(dispatch: str) -> _bool: ...
|
||||
def _dispatch_key_name(dispatch: str) -> str: ...
|
||||
|
||||
# Define in torch/csrc/autograd/init.cpp
|
||||
class _DisablePythonDispatcher(object):
|
||||
pass
|
||||
|
||||
class _EnablePythonDispatcher(object):
|
||||
pass
|
||||
|
||||
def _set_python_dispatcher(dispatcher: object) -> None: ...
|
||||
|
||||
|
||||
class _AutoDispatchBelowAutograd:
|
||||
pass
|
||||
|
@ -962,3 +962,5 @@ if 'TORCH_CUDA_SANITIZER' in os.environ:
|
||||
import torch.cuda._sanitizer as csan
|
||||
|
||||
csan.enable_cuda_sanitizer()
|
||||
|
||||
from ._dispatch import python
|
||||
|
@ -1,50 +0,0 @@
|
||||
import torch
|
||||
import torch._C as _C
|
||||
from torch.utils._pytree import tree_flatten
|
||||
|
||||
"""
|
||||
This is a dispatcher (in Python)
|
||||
- You can define new operations (in Python) without schemas
|
||||
- It interfaces with the PyTorch dispatcher
|
||||
"""
|
||||
|
||||
class PyDispatcher:
|
||||
# operator is a PyOperator
|
||||
@staticmethod
|
||||
def call(operator, *args, **kwargs):
|
||||
dispatch_key_set = compute_keyset(args, kwargs)
|
||||
kernel = operator.lookup(dispatch_key_set)
|
||||
return kernel(*args, **kwargs)
|
||||
|
||||
# operator is a PyOperator
|
||||
@staticmethod
|
||||
def redispatch(operator, dispatch_key_set, *args, **kwargs):
|
||||
kernel = operator.lookup(dispatch_key_set)
|
||||
return kernel(*args, **kwargs)
|
||||
|
||||
|
||||
def compute_keyset(args, kwargs):
|
||||
tensors = get_tensors(args, kwargs)
|
||||
return key_extractor(tensors)
|
||||
|
||||
|
||||
# Note - this should maintain identical impl to the C++ dispatcher key extraction logic
|
||||
# at ATen/core/dispatch/DispatchKeyExtractor.h
|
||||
def key_extractor(tensors):
|
||||
key_set = _C._dispatch_tls_local_include_set() # type: ignore[attr-defined]
|
||||
for tensor in tensors:
|
||||
key_set = key_set | _C._dispatch_keys(tensor) # type: ignore[attr-defined]
|
||||
key_set = key_set - _C._dispatch_tls_local_exclude_set() # type: ignore[attr-defined]
|
||||
return key_set
|
||||
|
||||
|
||||
def to_flat_tuple(args, kwargs):
|
||||
flat_args, _ = tree_flatten(args)
|
||||
flat_kwargs, _ = tree_flatten(kwargs)
|
||||
flat_all = flat_args + flat_kwargs
|
||||
return flat_all
|
||||
|
||||
def get_tensors(args, kwargs):
|
||||
flat_all = to_flat_tuple(args, kwargs)
|
||||
tensor_args = [t for t in flat_all if isinstance(t, torch.Tensor)]
|
||||
return tuple(tensor_args)
|
88
torch/_dispatch/python.py
Normal file
88
torch/_dispatch/python.py
Normal file
@ -0,0 +1,88 @@
|
||||
import torch
|
||||
from contextlib import contextmanager
|
||||
|
||||
__all__ = ['enable_python_dispatcher', 'no_python_dispatcher']
|
||||
|
||||
DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
|
||||
|
||||
def has_key(op, k):
|
||||
return (
|
||||
torch._C._dispatch_has_kernel_for_dispatch_key(op.name(), k)
|
||||
or k in op.py_kernels
|
||||
)
|
||||
|
||||
is_included_in_alias = torch._C._dispatch_is_included_in_alias
|
||||
|
||||
# Equivalent to computeDispatchTableEntryWithDebug
|
||||
# TODO: memoize this or something
|
||||
def resolve_key(op: torch._ops.PyOperatorABC, k: DispatchKey): # type: ignore[valid-type]
|
||||
# 1. (Direct) operator registration
|
||||
if has_key(op, k):
|
||||
return k
|
||||
# 2.1 Use CompositeExplicitAutogradNonFunctional kernel if available
|
||||
cand = DispatchKey.CompositeExplicitAutogradNonFunctional
|
||||
if (k == DispatchKey.Undefined or is_included_in_alias(k, cand)) and has_key(op, cand):
|
||||
return cand
|
||||
# 2.2 Use CompositeExplicitAutograd kernel if available
|
||||
cand = DispatchKey.CompositeExplicitAutograd
|
||||
if (k == DispatchKey.Undefined or is_included_in_alias(k, cand)) and has_key(op, cand):
|
||||
return cand
|
||||
has_backend_kernel = (
|
||||
torch._C._dispatch_has_kernel_for_any_dispatch_key(op.name(), torch._C._dispatch_get_backend_keyset_from_autograd(k))
|
||||
or has_key(op, DispatchKey.CompositeExplicitAutograd)
|
||||
)
|
||||
# 2.3. Use CompositeImplicitAutograd kernel if available
|
||||
cand = DispatchKey.CompositeImplicitAutogradNestedTensor
|
||||
if (
|
||||
(k != DispatchKey.Undefined and is_included_in_alias(k, cand)) # type: ignore[attr-defined]
|
||||
and has_key(op, cand) and not has_backend_kernel):
|
||||
return cand
|
||||
cand = DispatchKey.CompositeImplicitAutograd
|
||||
if (k == DispatchKey.Undefined or is_included_in_alias(k, cand)) and has_key(op, cand):
|
||||
if (
|
||||
k == DispatchKey.AutogradOther
|
||||
and torch._C._dispatch_has_kernel_for_any_dispatch_key(op.name(), torch._C._dispatch_autogradother_backends) # type: ignore[attr-defined] # noqa: B950
|
||||
):
|
||||
raise RuntimeError("ambiguous autogradother kernel")
|
||||
elif not has_backend_kernel:
|
||||
return cand
|
||||
# 2.4. For autograd backend keys, use kernel from DispatchKey::Autograd if available
|
||||
cand = DispatchKey.Autograd
|
||||
if is_included_in_alias(k, cand) and has_key(op, cand):
|
||||
return cand
|
||||
# Backend fallback
|
||||
if torch._C._dispatch_has_backend_fallback(k):
|
||||
# The dispatch key itself will implicitly route to backend fallback.
|
||||
# This is probably not great for the pure Python implementation.
|
||||
return k
|
||||
raise RuntimeError("could not find kernel")
|
||||
|
||||
@contextmanager
|
||||
def no_python_dispatcher():
|
||||
g = torch._C._DisablePythonDispatcher()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
del g
|
||||
|
||||
@contextmanager
|
||||
def enable_python_dispatcher():
|
||||
g = torch._C._EnablePythonDispatcher()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
del g
|
||||
|
||||
# The Python dispatcher
|
||||
def python_dispatcher(op, ks, args, kwargs):
|
||||
"""
|
||||
with no_python_dispatcher():
|
||||
print(op, ks, args, kwargs)
|
||||
"""
|
||||
k = resolve_key(op, ks.highestPriorityTypeId())
|
||||
source = f'torch.ops.{op}.dispatch(k, *args, **kwargs)'
|
||||
filename = f'{op}[{torch._C._dispatch_key_name(k)}]'
|
||||
compiled = compile(source, filename, 'eval') # TODO: maybe cache?
|
||||
return eval(compiled, {'torch': torch, 'k': k, 'args': args, 'kwargs': kwargs})
|
||||
|
||||
torch._C._set_python_dispatcher(python_dispatcher)
|
202
torch/_ops.py
202
torch/_ops.py
@ -1,12 +1,19 @@
|
||||
import contextlib
|
||||
import ctypes
|
||||
import inspect
|
||||
import sys
|
||||
import types
|
||||
from abc import ABC
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch._C
|
||||
|
||||
import torch.jit
|
||||
from torch import _utils_internal
|
||||
from torch._C import DispatchKey # type: ignore[attr-defined]
|
||||
from torch.overrides import handle_torch_function, has_torch_function
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.utils._pytree import tree_flatten
|
||||
|
||||
# Query `hasattr` only once.
|
||||
_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags")
|
||||
@ -26,9 +33,129 @@ def dl_open_guard():
|
||||
sys.setdlopenflags(old_flags)
|
||||
|
||||
|
||||
# TODO(voz) We are missing an entire axis of registration - Modes for the python key
|
||||
class PyOperatorABC(ABC):
|
||||
def __call__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def py_impl(self, dispatch_key, fn):
|
||||
pass
|
||||
|
||||
def name(self):
|
||||
pass
|
||||
|
||||
def dispatch(self, dispatch_key, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class PyOperator(PyOperatorABC):
|
||||
def __init__(self, name):
|
||||
self._name = name
|
||||
self.table = {}
|
||||
self.python_key_mode_table = {}
|
||||
|
||||
# Make _OPNamespace not scream, this whole name based association needs a good hard look
|
||||
self.__name__ = "pyop." + name
|
||||
pyop_namespace.py_ops[name] = self
|
||||
|
||||
def fallthrough(self, dispatch_key):
|
||||
self.table[dispatch_key] = self._fallthrough_fn(self, dispatch_key)
|
||||
|
||||
def py_impl(self, dispatch_key_or_mode):
|
||||
def inner(fn):
|
||||
if inspect.isclass(dispatch_key_or_mode) and issubclass(
|
||||
dispatch_key_or_mode, TorchDispatchMode
|
||||
):
|
||||
mode = dispatch_key_or_mode
|
||||
assert mode not in self.python_key_mode_table
|
||||
# TODO(voz): Should we replace setting torch._C.DispatchKey.Python entirely with setting mode keys?
|
||||
self.python_key_mode_table[mode] = fn
|
||||
return fn
|
||||
|
||||
dispatch_key = dispatch_key_or_mode
|
||||
assert (
|
||||
dispatch_key != torch._C.DispatchKey.Python # type: ignore[attr-defined]
|
||||
), "Please register a mode for the torch._C.DispatchKey.Python key instead."
|
||||
assert isinstance(dispatch_key, torch._C.DispatchKey) # type: ignore[attr-defined]
|
||||
assert dispatch_key not in self.table
|
||||
self.table[dispatch_key] = fn
|
||||
return fn
|
||||
|
||||
return inner
|
||||
|
||||
def dispatch(self, dispatch_key, *args, **kwargs):
|
||||
if dispatch_key == torch._C.DispatchKey.Python: # type: ignore[attr-defined]
|
||||
# TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now.
|
||||
curr_mode = type(torch._C._get_torch_dispatch_mode())
|
||||
assert (
|
||||
curr_mode is not None
|
||||
), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode."
|
||||
assert (
|
||||
curr_mode in self.python_key_mode_table
|
||||
), f"Current active mode {curr_mode} not registered"
|
||||
# TODO(voz): The idea behind this is that we do not yet support dispatch by key + mode, only key.
|
||||
return self.python_key_mode_table[curr_mode](*args, **kwargs)
|
||||
|
||||
assert dispatch_key in self.table
|
||||
return self.table[dispatch_key](*args, **kwargs)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
flat_args = _to_flat_tuple(args, kwargs)
|
||||
if has_torch_function(flat_args):
|
||||
return handle_torch_function(self, flat_args, *args, **kwargs)
|
||||
|
||||
dispatch_key_set = _compute_keyset(args, kwargs)
|
||||
return self.dispatch(dispatch_key_set.highestPriorityTypeId(), *args, **kwargs)
|
||||
|
||||
def name(self):
|
||||
return self.name
|
||||
|
||||
# TODO(voz): Should rewrite fallthrough register as the impl for keys we do not specify
|
||||
# as opposed to being this sort of explicit thing where ops are a little too key aware...
|
||||
def _fallthrough_fn(self, operator, dispatch_key):
|
||||
def inner(*args, **kwargs):
|
||||
all_keys_after_current = torch._C._dispatch_keyset_full_after(dispatch_key) # type: ignore[attr-defined]
|
||||
all_keys_after_current_masked = all_keys_after_current & _compute_keyset(
|
||||
args, kwargs
|
||||
) # type: ignore[attr-defined]
|
||||
return self.dispatch(
|
||||
all_keys_after_current_masked.highestPriorityTypeId(), *args, **kwargs
|
||||
)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def _to_flat_tuple(args, kwargs):
|
||||
flat_args, _ = tree_flatten(args)
|
||||
flat_kwargs, _ = tree_flatten(kwargs)
|
||||
flat_all = flat_args + flat_kwargs
|
||||
return flat_all
|
||||
|
||||
|
||||
def _compute_keyset(args, kwargs):
|
||||
tensors = _get_tensors(args, kwargs)
|
||||
return key_extractor(tensors)
|
||||
|
||||
|
||||
def _get_tensors(args, kwargs):
|
||||
flat_all = _to_flat_tuple(args, kwargs)
|
||||
tensor_args = [t for t in flat_all if isinstance(t, torch.Tensor)]
|
||||
return tuple(tensor_args)
|
||||
|
||||
|
||||
# Note - this should maintain identical impl to the C++ dispatcher key extraction logic
|
||||
# at ATen/core/dispatch/DispatchKeyExtractor.h
|
||||
def key_extractor(tensors):
|
||||
key_set = torch._C._dispatch_tls_local_include_set() # type: ignore[attr-defined]
|
||||
for tensor in tensors:
|
||||
key_set = key_set | torch._C._dispatch_keys(tensor) # type: ignore[attr-defined]
|
||||
key_set = key_set - torch._C._dispatch_tls_local_exclude_set() # type: ignore[attr-defined]
|
||||
return key_set
|
||||
|
||||
|
||||
# Each OpOverload object contains pointer to a a specific operator overload, a pointer to the parent `OpOverloadPacket` object.
|
||||
# You can obtain an OpOverload object through attribute query on OpOverloadPacket.
|
||||
class OpOverload:
|
||||
class OpOverload(PyOperatorABC):
|
||||
def __init__(self, overloadpacket, op, op_dk, schema, tags):
|
||||
self._op = op
|
||||
self._op_dk = op_dk
|
||||
@ -38,12 +165,15 @@ class OpOverload:
|
||||
self._overloadname = (
|
||||
"default" if schema.overload_name == "" else schema.overload_name
|
||||
)
|
||||
self.name = self._schema.name
|
||||
self._name = self._schema.name
|
||||
if schema.overload_name:
|
||||
self.name += "." + schema.overload_name
|
||||
self._name += "." + schema.overload_name
|
||||
self.py_kernels: Dict[DispatchKey, Any] = {}
|
||||
self.__name__ = "{}.{}".format(
|
||||
self._schema.name.split("::")[1], self._overloadname
|
||||
)
|
||||
# TODO(voz): Lots of shared logic around python_key_mode_table, maybe pull into base...
|
||||
self.python_key_mode_table = {}
|
||||
self.__module__ = overloadpacket.__module__
|
||||
op.__module__ = overloadpacket.__module__
|
||||
|
||||
@ -70,12 +200,56 @@ class OpOverload:
|
||||
return "{}.{}.{}".format(*self._schema.name.split("::"), self._overloadname)
|
||||
|
||||
def decompose(self, *args, **kwargs):
|
||||
dk = "CompositeImplicitAutograd"
|
||||
if torch._C._dispatch_has_kernel_for_dispatch_key(self.name, dk):
|
||||
return self._op_dk(dk, *args, **kwargs)
|
||||
dk = torch._C.DispatchKey.CompositeImplicitAutograd # type: ignore[attr-defined]
|
||||
if (
|
||||
torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk)
|
||||
or dk in self.py_kernels
|
||||
):
|
||||
return self.dispatch(dk, *args, **kwargs)
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def py_impl(self, dispatch_key_or_mode):
|
||||
def inner(fn):
|
||||
if inspect.isclass(dispatch_key_or_mode) and issubclass(
|
||||
dispatch_key_or_mode, TorchDispatchMode
|
||||
):
|
||||
mode = dispatch_key_or_mode
|
||||
assert mode not in self.python_key_mode_table
|
||||
# TODO(voz): Should we replace setting torch._C.DispatchKey.Python entirely with setting mode keys?
|
||||
self.python_key_mode_table[mode] = fn
|
||||
return fn
|
||||
|
||||
assert isinstance(dispatch_key_or_mode, torch._C.DispatchKey) # type: ignore[attr-defined]
|
||||
assert (
|
||||
dispatch_key_or_mode != torch._C.DispatchKey.Python # type: ignore[attr-defined]
|
||||
), "Please register a mode for the torch._C.DispatchKey.Python key instead."
|
||||
|
||||
self.py_kernels[dispatch_key_or_mode] = fn
|
||||
return fn
|
||||
|
||||
return inner
|
||||
|
||||
def dispatch(self, dispatch_key, *args, **kwargs):
|
||||
if dispatch_key == torch._C.DispatchKey.Python: # type: ignore[attr-defined]
|
||||
# TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now.
|
||||
curr_mode = type(torch._C._get_torch_dispatch_mode())
|
||||
assert (
|
||||
curr_mode is not None
|
||||
), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode."
|
||||
if curr_mode not in self.python_key_mode_table:
|
||||
return self._op_dk(dispatch_key, *args, **kwargs)
|
||||
# TODO(voz): The idea behind this is that we do not yet support dispatch by key + mode, only key.
|
||||
return self.python_key_mode_table[curr_mode](*args, **kwargs)
|
||||
|
||||
if dispatch_key in self.py_kernels:
|
||||
return self.py_kernels[dispatch_key](*args, **kwargs)
|
||||
else:
|
||||
return self._op_dk(dispatch_key, *args, **kwargs)
|
||||
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def overloadpacket(self):
|
||||
return self._overloadpacket
|
||||
@ -218,8 +392,15 @@ class _OpNamespace(types.ModuleType):
|
||||
def __init__(self, name):
|
||||
super(_OpNamespace, self).__init__("torch.ops." + name)
|
||||
self.name = name
|
||||
if self.name == "pyop":
|
||||
self.pyops = pyop_namespace
|
||||
else:
|
||||
self.pyops = None # type: ignore[assignment]
|
||||
|
||||
def __getattr__(self, op_name):
|
||||
pyops = object.__getattribute__(self, "pyops")
|
||||
if pyops is not None:
|
||||
return pyops.py_ops[op_name]
|
||||
# It is not a valid op_name when __file__ is passed in
|
||||
if op_name == "__file__":
|
||||
return "torch.ops"
|
||||
@ -253,6 +434,15 @@ class _OpNamespace(types.ModuleType):
|
||||
return opoverloadpacket
|
||||
|
||||
|
||||
class _PyOpNamespace(_OpNamespace):
|
||||
def __init__(self):
|
||||
super(_PyOpNamespace, self).__init__("torch.ops.pyop")
|
||||
self.py_ops = {}
|
||||
|
||||
|
||||
pyop_namespace = _PyOpNamespace()
|
||||
|
||||
|
||||
class _Ops(types.ModuleType):
|
||||
__file__ = "_ops.py"
|
||||
|
||||
|
@ -271,6 +271,7 @@ __all__ = [
|
||||
]
|
||||
|
||||
Tensor = torch.Tensor
|
||||
DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _broadcast_shapes(*_shapes):
|
||||
@ -2190,7 +2191,11 @@ def broadcast_shapes(*shapes) -> ShapeType:
|
||||
return torch.Size(_broadcast_shapes(*shapes))
|
||||
|
||||
|
||||
@torch.ops.aten.broadcast_tensors.default.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||
@torch.ops.aten.broadcast_tensors.default.py_impl(DispatchKey.Meta)
|
||||
def broadcast_tensors(*tensors) -> List[TensorLikeType]:
|
||||
if len(tensors) == 1 and not isinstance(tensors[0], Tensor):
|
||||
tensors = tensors[0]
|
||||
return list(_maybe_broadcast(*tensors, preserve_cpu_scalar_tensors=False))
|
||||
|
||||
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/core/InferenceMode.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/core/impl/PythonDispatcherTLS.h>
|
||||
#include <torch/csrc/Exceptions.h>
|
||||
#include <torch/csrc/autograd/autograd.h>
|
||||
#include <torch/csrc/autograd/function.h>
|
||||
@ -52,6 +53,19 @@ struct EnableTorchFunction {
|
||||
bool old_;
|
||||
};
|
||||
|
||||
PyObject* globalPythonDispatcher = nullptr;
|
||||
|
||||
struct EnablePythonDispatcher {
|
||||
EnablePythonDispatcher() : old_(c10::impl::PythonDispatcherTLS::get_state()) {
|
||||
c10::impl::PythonDispatcherTLS::set_state(
|
||||
{globalPythonDispatcher, getPyInterpreter()});
|
||||
}
|
||||
~EnablePythonDispatcher() {
|
||||
c10::impl::PythonDispatcherTLS::set_state(old_);
|
||||
}
|
||||
c10::SafePyHandle old_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
|
||||
@ -328,6 +342,18 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
|
||||
.def(py::init<>());
|
||||
py::class_<EnableTorchFunction>(_C_m, "_EnableTorchFunction")
|
||||
.def(py::init<>());
|
||||
py::class_<EnablePythonDispatcher>(_C_m, "_EnablePythonDispatcher")
|
||||
.def(py::init<>());
|
||||
py::class_<c10::impl::DisablePythonDispatcher>(
|
||||
_C_m, "_DisablePythonDispatcher")
|
||||
.def(py::init<>());
|
||||
_C_m.def("_set_python_dispatcher", [](py::object dispatcher) {
|
||||
TORCH_CHECK(
|
||||
!globalPythonDispatcher,
|
||||
"overwriting the global python dispatcher is not supported; if you need this file an issue");
|
||||
// NB: intentionally leak
|
||||
globalPythonDispatcher = dispatcher.release().ptr();
|
||||
});
|
||||
py::class_<DisableFuncTorch>(_C_m, "_DisableFuncTorch").def(py::init<>());
|
||||
|
||||
py::class_<torch::autograd::SavedVariable>(m, "SavedTensor")
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/core/SafePyObject.h>
|
||||
#include <c10/core/impl/GPUTrace.h>
|
||||
#include <c10/core/impl/PythonDispatcherTLS.h>
|
||||
#include <c10/util/DeadlockDetection.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <pybind11/pytypes.h>
|
||||
@ -235,6 +236,10 @@ struct ConcretePyInterpreterVTable final
|
||||
|
||||
void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack)
|
||||
const override;
|
||||
void python_dispatcher(
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKeySet,
|
||||
torch::jit::Stack* stack) const override;
|
||||
|
||||
bool is_contiguous(const TensorImpl* self) const override;
|
||||
c10::Device device(const TensorImpl* self) const override;
|
||||
@ -701,6 +706,7 @@ static PyObject* THPVariable_make_subclass(
|
||||
"cls must be a type (got %s)", Py_TYPE(cls)->tp_name);
|
||||
}
|
||||
torch_dispatch_mode::StashTorchDispatchModeGuard td_g;
|
||||
c10::impl::DisablePythonDispatcher dpd_g;
|
||||
auto data =
|
||||
r.tensor(1).detach(); // creates a fresh Tensor (DEFINITELY_UNINITIALIZED)
|
||||
// We set `data`'s `allow_tensor_metadata_change` to true here, because we
|
||||
@ -2265,6 +2271,67 @@ void ConcretePyInterpreterVTable::dispatch(
|
||||
op, stack, py::reinterpret_steal<py::object>(obj), "__torch_dispatch__");
|
||||
}
|
||||
|
||||
void ConcretePyInterpreterVTable::python_dispatcher(
|
||||
const c10::OperatorHandle& op,
|
||||
c10::DispatchKeySet ks,
|
||||
torch::jit::Stack* stack) const {
|
||||
const auto& schema = op.schema();
|
||||
const auto num_arguments = schema.arguments().size();
|
||||
auto arguments = torch::jit::pop(*stack, num_arguments);
|
||||
|
||||
// Parse the name into namespace and name (no overload_name)
|
||||
// TODO: put this into the library
|
||||
const auto& qualified_name = op.operator_name().name;
|
||||
const auto& overload_name = schema.overload_name();
|
||||
auto pos = qualified_name.find("::");
|
||||
TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name);
|
||||
// Make me some null terminated strings
|
||||
std::string ns_str = qualified_name.substr(0, pos);
|
||||
const char* ns = ns_str.c_str();
|
||||
const char* func_name = qualified_name.c_str() + pos + strlen("::");
|
||||
|
||||
// The plan: convert all the arguments back into PyObjects,
|
||||
// extracting out the tensor handles, then call
|
||||
// handle_torch_function_no_python_arg_parser
|
||||
// NB: at the point arguments are pushed to the stack, ALL defaults
|
||||
// are already present
|
||||
|
||||
py::gil_scoped_acquire g;
|
||||
|
||||
std::vector<py::handle> overloaded_args;
|
||||
py::handle torch_api_function =
|
||||
py::module::import("torch").attr("ops").attr(ns).attr(func_name);
|
||||
py::handle torch_api_function_overload;
|
||||
if (overload_name == "") {
|
||||
torch_api_function_overload = torch_api_function.attr("default");
|
||||
} else {
|
||||
torch_api_function_overload =
|
||||
torch_api_function.attr(overload_name.c_str());
|
||||
}
|
||||
std::string module_name_str = "torch.ops." + ns_str;
|
||||
|
||||
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
|
||||
auto args = std::move(args_kwargs.first);
|
||||
auto kwargs = std::move(args_kwargs.second);
|
||||
|
||||
auto python_dispatcher =
|
||||
c10::impl::PythonDispatcherTLS::get_state().ptr(getPyInterpreter());
|
||||
TORCH_INTERNAL_ASSERT(python_dispatcher);
|
||||
|
||||
py::object obj = py::reinterpret_steal<py::object>(PyObject_CallFunction(
|
||||
python_dispatcher,
|
||||
"OOOO",
|
||||
torch_api_function_overload,
|
||||
py::cast(ks).ptr(),
|
||||
args.ptr(),
|
||||
kwargs.ptr()));
|
||||
|
||||
if (obj == nullptr)
|
||||
throw python_error();
|
||||
|
||||
pushPyOutToStack(op, stack, std::move(obj), "Python dispatcher");
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<TensorImpl> ConcretePyInterpreterVTable::detach(
|
||||
const c10::TensorImpl* self) const {
|
||||
pybind11::gil_scoped_acquire gil;
|
||||
@ -2377,7 +2444,7 @@ c10::IntArrayRef ConcretePyInterpreterVTable::strides(
|
||||
if (out == Py_None) {
|
||||
TORCH_CHECK(
|
||||
!self->has_symbolic_sizes_strides(),
|
||||
"Cannot call sizes on a tensor with symbolic shapes/strides");
|
||||
"Cannot call strides on a tensor with symbolic shapes/strides");
|
||||
return self->strides_default();
|
||||
}
|
||||
|
||||
|
@ -1713,13 +1713,11 @@ void initJITBindings(PyObject* module) {
|
||||
return _get_operation_for_overload_or_packet(
|
||||
{op}, symbol, args, kwargs, /*is_overload*/ true);
|
||||
});
|
||||
auto func_dk =
|
||||
py::cpp_function([op, symbol, allow_numbers_as_tensors](
|
||||
const std::string& str_dk,
|
||||
py::args args,
|
||||
py::kwargs kwargs) {
|
||||
auto func_dk = py::cpp_function(
|
||||
[op, symbol, allow_numbers_as_tensors](
|
||||
c10::DispatchKey dk_, py::args args, py::kwargs kwargs) {
|
||||
c10::optional<c10::DispatchKey> dk =
|
||||
c10::make_optional(c10::parseDispatchKey(str_dk));
|
||||
c10::make_optional(dk_);
|
||||
ToIValueAllowNumbersAsTensors g(allow_numbers_as_tensors);
|
||||
return _get_operation_for_overload_or_packet(
|
||||
{op}, symbol, args, kwargs, /*is_overload*/ true, dk);
|
||||
|
@ -136,6 +136,33 @@ struct type_caster<at::Device> {
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_caster<c10::DispatchKey>
|
||||
: public type_caster_base<c10::DispatchKey> {
|
||||
using base = type_caster_base<c10::DispatchKey>;
|
||||
c10::DispatchKey tmp;
|
||||
|
||||
public:
|
||||
bool load(handle src, bool convert) {
|
||||
if (base::load(src, convert)) {
|
||||
return true;
|
||||
} else if (py::isinstance(
|
||||
src, py::module_::import("builtins").attr("str"))) {
|
||||
tmp = c10::parseDispatchKey(py::cast<std::string>(src));
|
||||
value = &tmp;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static handle cast(
|
||||
c10::DispatchKey src,
|
||||
return_value_policy policy,
|
||||
handle parent) {
|
||||
return base::cast(src, policy, parent);
|
||||
}
|
||||
};
|
||||
|
||||
// Pybind11 bindings for our optional and variant types.
|
||||
// http://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html#c-17-library-containers
|
||||
template <typename T>
|
||||
|
@ -92,6 +92,7 @@ bool should_allow_numbers_as_tensors(const std::string& name) {
|
||||
"sub", "sub_", "sub_out",
|
||||
"subtract", "subtract_", "subtract_out", // alias of sub
|
||||
"true_divide", "true_divide_", "true_divide_out",
|
||||
"to", "_to_copy", "copy_",
|
||||
"floor_divide", "floor_divide_", "floor_divide_out"};
|
||||
return allowed.find(name) != allowed.end();
|
||||
}
|
||||
|
@ -295,11 +295,20 @@ void initDispatchBindings(PyObject* module) {
|
||||
// Returns whether or not a direct kernel registration exists
|
||||
// for this <op_name, dispatch_key> pair.
|
||||
"_dispatch_has_kernel_for_dispatch_key",
|
||||
[](const char* name, const char* dispatch) -> bool {
|
||||
[](const char* name, c10::DispatchKey dispatch) -> bool {
|
||||
auto op =
|
||||
c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
|
||||
TORCH_CHECK(op, "operator ", name, " does not exist");
|
||||
return op->hasKernelForDispatchKey(c10::parseDispatchKey(dispatch));
|
||||
return op->hasKernelForDispatchKey(dispatch);
|
||||
});
|
||||
|
||||
m.def(
|
||||
"_dispatch_has_kernel_for_any_dispatch_key",
|
||||
[](const char* name, c10::DispatchKeySet ks) -> bool {
|
||||
auto op =
|
||||
c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
|
||||
TORCH_CHECK(op, "operator ", name, " does not exist");
|
||||
return op->hasKernelForAnyDispatchKey(ks);
|
||||
});
|
||||
|
||||
m.def(
|
||||
@ -332,38 +341,48 @@ void initDispatchBindings(PyObject* module) {
|
||||
|
||||
m.def(
|
||||
"_dispatch_tls_set_dispatch_key_excluded",
|
||||
[](const char* dispatch_key, bool desired_state) {
|
||||
c10::impl::tls_set_dispatch_key_excluded(
|
||||
c10::parseDispatchKey(dispatch_key), desired_state);
|
||||
[](c10::DispatchKey dispatch_key, bool desired_state) {
|
||||
c10::impl::tls_set_dispatch_key_excluded(dispatch_key, desired_state);
|
||||
});
|
||||
m.def(
|
||||
"_dispatch_tls_is_dispatch_key_excluded",
|
||||
[](c10::DispatchKey dispatch_key) {
|
||||
return c10::impl::tls_is_dispatch_key_excluded(dispatch_key);
|
||||
});
|
||||
m.def("_dispatch_tls_is_dispatch_key_excluded", [](const char* dispatch_key) {
|
||||
return c10::impl::tls_is_dispatch_key_excluded(
|
||||
c10::parseDispatchKey(dispatch_key));
|
||||
});
|
||||
|
||||
m.def("_dispatch_isTensorSubclassLike", [](const at::Tensor& tensor) {
|
||||
return at::isTensorSubclassLike(tensor);
|
||||
});
|
||||
|
||||
m.def("_dispatch_key_name", [](uint64_t dispatch_key) {
|
||||
auto dt = (c10::DispatchKey)dispatch_key;
|
||||
return c10::toString(dt);
|
||||
m.def("_dispatch_key_name", [](c10::DispatchKey k) {
|
||||
return c10::toString(k);
|
||||
});
|
||||
m.def("_dispatch_key_parse", [](c10::DispatchKey k) { return k; });
|
||||
m.def("_dispatch_num_backends", []() { return c10::num_backends; });
|
||||
|
||||
py::enum_<c10::DispatchKey>(m, "DispatchKey")
|
||||
.value("Undefined", c10::DispatchKey::Undefined)
|
||||
.value("Dense", c10::DispatchKey::Dense)
|
||||
.value("BackendSelect", c10::DispatchKey::BackendSelect)
|
||||
.value("CPU", c10::DispatchKey::CPU)
|
||||
.value("CUDA", c10::DispatchKey::CUDA)
|
||||
.value("AutocastCPU", c10::DispatchKey::AutocastCPU)
|
||||
.value("AutocastCUDA", c10::DispatchKey::AutocastCUDA)
|
||||
.value("AutogradCPU", c10::DispatchKey::AutogradCPU)
|
||||
.value("ADInplaceOrView", c10::DispatchKey::ADInplaceOrView)
|
||||
.value("AutogradCUDA", c10::DispatchKey::AutogradCUDA)
|
||||
.value("PythonTLSSnapshot", c10::DispatchKey::PythonTLSSnapshot)
|
||||
.value("Python", c10::DispatchKey::Python);
|
||||
#define DEF_ONE(n) .value(#n, c10::DispatchKey::n)
|
||||
|
||||
py::enum_<c10::DispatchKey>(m, "DispatchKey") DEF_ONE(Undefined)
|
||||
DEF_ONE(CompositeExplicitAutogradNonFunctional)
|
||||
DEF_ONE(CompositeExplicitAutograd)
|
||||
DEF_ONE(CompositeImplicitAutogradNestedTensor)
|
||||
DEF_ONE(CompositeImplicitAutograd) DEF_ONE(AutogradOther)
|
||||
DEF_ONE(Autograd) DEF_ONE(BackendSelect)
|
||||
DEF_ONE(ADInplaceOrView) DEF_ONE(PythonTLSSnapshot)
|
||||
DEF_ONE(Python)
|
||||
|
||||
#define DEF_SINGLE(n, prefix) .value(#prefix #n, c10::DispatchKey::prefix##n)
|
||||
#define DEF_MULTIPLE(fullname, prefix) \
|
||||
DEF_SINGLE(, fullname) \
|
||||
DEF_SINGLE(, StartOf##fullname##Backends) \
|
||||
C10_FORALL_BACKEND_COMPONENTS(DEF_SINGLE, prefix) \
|
||||
DEF_SINGLE(, EndOf##fullname##Backends)
|
||||
|
||||
C10_FORALL_FUNCTIONALITY_KEYS(DEF_MULTIPLE)
|
||||
|
||||
#undef DEF_MULTIPLE
|
||||
#undef DEF_SINGLE
|
||||
;
|
||||
|
||||
py::class_<c10::DispatchKeySet>(m, "DispatchKeySet")
|
||||
.def(py::init<c10::DispatchKey>())
|
||||
@ -371,7 +390,15 @@ void initDispatchBindings(PyObject* module) {
|
||||
.def("__sub__", &c10::DispatchKeySet::operator-)
|
||||
.def("__and__", &c10::DispatchKeySet::operator&)
|
||||
.def("highestPriorityTypeId", &c10::DispatchKeySet::highestPriorityTypeId)
|
||||
.def("has", &c10::DispatchKeySet::has);
|
||||
.def("has", &c10::DispatchKeySet::has)
|
||||
.def("__repr__", [](c10::DispatchKeySet d) { return c10::toString(d); });
|
||||
|
||||
m.attr("_dispatch_autogradother_backends") =
|
||||
py::cast(c10::autogradother_backends);
|
||||
|
||||
m.def("_dispatch_has_backend_fallback", [](c10::DispatchKey t) {
|
||||
return c10::Dispatcher::singleton().hasBackendFallbackForDispatchKey(t);
|
||||
});
|
||||
|
||||
m.def("_dispatch_keyset_full_after", [](c10::DispatchKey t) {
|
||||
return c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, t);
|
||||
@ -381,6 +408,10 @@ void initDispatchBindings(PyObject* module) {
|
||||
return c10::toString(keyset);
|
||||
});
|
||||
|
||||
m.def("_dispatch_get_backend_keyset_from_autograd", [](c10::DispatchKey k) {
|
||||
return c10::getBackendKeySetFromAutograd(k);
|
||||
});
|
||||
|
||||
m.def("_dispatch_keys", [](const at::Tensor& tensor) {
|
||||
auto* impl = tensor.unsafeGetTensorImpl();
|
||||
return impl->key_set();
|
||||
@ -391,6 +422,11 @@ void initDispatchBindings(PyObject* module) {
|
||||
m.def("_dispatch_tls_local_exclude_set", []() {
|
||||
return c10::impl::tls_local_dispatch_key_set().excluded_;
|
||||
});
|
||||
m.def(
|
||||
"_dispatch_is_included_in_alias",
|
||||
[](c10::DispatchKey a, c10::DispatchKey b) {
|
||||
return c10::isIncludedInAlias(a, b);
|
||||
});
|
||||
py::class_<c10::impl::ExcludeDispatchKeyGuard>(m, "ExcludeDispatchKeyGuard")
|
||||
.def(py::init<c10::DispatchKeySet>());
|
||||
|
||||
|
@ -10,6 +10,7 @@ import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch.fx import Tracer, GraphModule
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
import torch.fx as fx
|
||||
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
||||
from contextlib import contextmanager, nullcontext
|
||||
@ -579,6 +580,10 @@ def make_fx(f, decomposition_table=None, tracing_mode="real"):
|
||||
else:
|
||||
raise AssertionError(f"Unexpected tracing type: {tracing_mode}")
|
||||
|
||||
python_dispatcher_mode: Any = nullcontext()
|
||||
if tracing_mode == "symbolic":
|
||||
python_dispatcher_mode = enable_python_dispatcher()
|
||||
|
||||
proxy_mode = ProxyTorchDispatchMode(fx_tracer)
|
||||
|
||||
def wrap_fake_concrete(x):
|
||||
@ -617,7 +622,7 @@ def make_fx(f, decomposition_table=None, tracing_mode="real"):
|
||||
|
||||
# We disable the autocast cache as the autocast cache causes type conversions on parameters to
|
||||
# check a cache, which introduces untracked tensors into the graph
|
||||
with decompose(decomposition_table), fake_tensor_mode, \
|
||||
with decompose(decomposition_table), fake_tensor_mode, python_dispatcher_mode, \
|
||||
sym_mode, proxy_mode, disable_autocast_cache(): # type: ignore[attr-defined]
|
||||
t = dispatch_trace(wrap_key(func, args, fx_tracer), tracer=fx_tracer, concrete_args=tuple(phs))
|
||||
|
||||
|
Reference in New Issue
Block a user