mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 14:34:54 +08:00
Add a default implementation of __torch_dispatch__
I was working on an explanation of how to call into the "super" implementation of some given ATen operation inside of __torch_dispatch__ (https://github.com/albanD/subclass_zoo/blob/main/trivial_tensors.py) and I kept thinking to myself "Why doesn't just calling super() on __torch_dispatch__ work"? Well, after this patch, it does! The idea is if you don't actually unwrap the input tensors, you can call super().__torch_dispatch__ to get at the original behavior. Internally, this is implemented by disabling PythonKey and then redispatching. This implementation of disabled_torch_dispatch is not /quite/ right, and some reasons why are commented in the code. There is then some extra work I have to do to make sure we recognize disabled_torch_dispatch as the "default" implementation (so we don't start slapping PythonKey on all tensors, including base Tensors), which is modeled the same way as how disabled_torch_function is done. Signed-off-by: Edward Z. Yang <ezyangfb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/73684 Approved by: albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
7c0f166b26
commit
35cfa74f97
@ -6,6 +6,7 @@
|
||||
namespace torch {
|
||||
static thread_local bool enable_torch_function = true;
|
||||
PyObject* disabled_torch_function = nullptr;
|
||||
PyObject* disabled_torch_dispatch = nullptr;
|
||||
|
||||
bool torch_function_enabled() {
|
||||
return enable_torch_function;
|
||||
@ -18,6 +19,14 @@ namespace torch {
|
||||
void set_disabled_torch_function_impl(PyObject* value) {
|
||||
disabled_torch_function = value;
|
||||
}
|
||||
|
||||
PyObject* disabled_torch_dispatch_impl() {
|
||||
return disabled_torch_dispatch;
|
||||
}
|
||||
|
||||
void set_disabled_torch_dispatch_impl(PyObject* value) {
|
||||
disabled_torch_dispatch = value;
|
||||
}
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
@ -127,6 +136,43 @@ PyObject* THPModule_disable_torch_function(PyObject *self, PyObject *a) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THPModule_disable_torch_dispatch(PyObject *self, PyObject *a) {
|
||||
HANDLE_TH_ERRORS
|
||||
PyObject *func=nullptr, *types=nullptr, *args=nullptr, *kwargs=nullptr;
|
||||
if (!PyArg_ParseTuple(a, "OO|OO", &func, &types, &args, &kwargs)) {
|
||||
return nullptr;
|
||||
}
|
||||
py::tuple py_args;
|
||||
if (args == nullptr) {
|
||||
py_args = py::make_tuple();
|
||||
}
|
||||
else {
|
||||
py_args = py::reinterpret_borrow<py::tuple>(args);
|
||||
}
|
||||
|
||||
// This implementation is not completely correct. The moral
|
||||
// meaning of this function is that we should do a redispatch
|
||||
// "after" PythonKey, aka a redispatch() call. But we don't have a
|
||||
// dispatcher call here; we have an opaque Python object.
|
||||
//
|
||||
// What we have here is a close approximation: instead of redispatch(), we
|
||||
// just exclude Python and all the keys before it, so that we will go
|
||||
// to the next key after Python. The difference, however, is we are
|
||||
// now PERMANENTLY after Python. We don't think there are any legitimate
|
||||
// cases where we want to go for another round on the entire dispatcher key
|
||||
// set, but if there are, then we will have to do something else here.
|
||||
c10::impl::ExcludeDispatchKeyGuard guard_(
|
||||
// TODO: add constructor for this specifically
|
||||
c10::DispatchKeySet(c10::DispatchKeySet::FULL) -
|
||||
c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::Python)
|
||||
// NB: off by one hazard here, but it works out: python key is not
|
||||
// included in AFTER, so it is included in the negation (and that's
|
||||
// correct: we want to exclude Python key and everything BEFORE it.)
|
||||
);
|
||||
return PyObject_Call(func, py_args.ptr(), kwargs);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
// Makes sure that we don't check for __torch_function__ on basic Python types
|
||||
static bool is_basic_python_type(PyTypeObject *tp)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user