#include #include #include #include #include #include namespace torch { PyObject* disabled_torch_function = nullptr; PyObject* disabled_torch_dispatch = nullptr; bool torch_function_enabled() { return at::impl::PythonTorchFunctionTLS::get_disabled_state() == at::impl::TorchFunctionDisabledState::ENABLED; } PyObject* disabled_torch_function_impl() { return disabled_torch_function; } void set_disabled_torch_function_impl(PyObject* value) { disabled_torch_function = value; } PyObject* disabled_torch_dispatch_impl() { return disabled_torch_dispatch; } void set_disabled_torch_dispatch_impl(PyObject* value) { disabled_torch_dispatch = value; } } // namespace torch typedef struct { PyObject_HEAD /* Type-specific fields go here. */ at::impl::TorchFunctionDisabledState old_state; } DisableTorchFunctionSubclass; PyObject* DisableTorchFunctionSubclass__enter( PyObject* self, PyObject* unused) { const auto old_state = at::impl::PythonTorchFunctionTLS::get_disabled_state(); ((DisableTorchFunctionSubclass*)self)->old_state = old_state; if (old_state == at::impl::TorchFunctionDisabledState::ENABLED) { at::impl::PythonTorchFunctionTLS::set_disabled_state( at::impl::TorchFunctionDisabledState::SUBCLASSES_DISABLED); } Py_RETURN_NONE; } PyObject* DisableTorchFunctionSubclass__exit(PyObject* self, PyObject* unused) { at::impl::PythonTorchFunctionTLS::set_disabled_state( ((DisableTorchFunctionSubclass*)self)->old_state); Py_RETURN_NONE; } PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject* unused) { if (torch::torch_function_enabled()) { Py_RETURN_TRUE; } else { Py_RETURN_FALSE; } } PyObject* THPModule_isAllDisabledTorchFunction( PyObject* self, PyObject* unused) { if (at::impl::torch_function_all_disabled()) { Py_RETURN_TRUE; } else { Py_RETURN_FALSE; } } static PyMethodDef DisableTorchFunctionSubclass_methods[] = { // NOLINT {"__enter__", DisableTorchFunctionSubclass__enter, METH_NOARGS, nullptr}, {"__exit__", DisableTorchFunctionSubclass__exit, METH_VARARGS, nullptr}, {nullptr, nullptr, 0, nullptr}}; PyTypeObject DisableTorchFunctionSubclassType = { PyVarObject_HEAD_INIT(nullptr, 0) "torch._C.DisableTorchFunctionSubclass", /* tp_name */ sizeof(DisableTorchFunctionSubclass), /* tp_basicsize */ 0, /* tp_itemsize */ nullptr, /* tp_dealloc */ 0, /* tp_vectorcall_offset */ nullptr, /* tp_getattr */ nullptr, /* tp_setattr */ nullptr, /* tp_reserved */ nullptr, /* tp_repr */ nullptr, /* tp_as_number */ nullptr, /* tp_as_sequence */ nullptr, /* tp_as_mapping */ nullptr, /* tp_hash */ nullptr, /* tp_call */ nullptr, /* tp_str */ nullptr, /* tp_getattro */ nullptr, /* tp_setattro */ nullptr, /* tp_as_buffer */ Py_TPFLAGS_DEFAULT, /* tp_flags */ nullptr, /* tp_doc */ nullptr, /* tp_traverse */ nullptr, /* tp_clear */ nullptr, /* tp_richcompare */ 0, /* tp_weaklistoffset */ nullptr, /* tp_iter */ nullptr, /* tp_iternext */ DisableTorchFunctionSubclass_methods, /* tp_methods */ nullptr, /* tp_members */ nullptr, /* tp_getset */ nullptr, /* tp_base */ nullptr, /* tp_dict */ nullptr, /* tp_descr_get */ nullptr, /* tp_descr_set */ 0, /* tp_dictoffset */ nullptr, /* tp_init */ PyType_GenericAlloc, /* tp_alloc */ PyType_GenericNew, /* tp_new */ }; PyObject* THPModule_DisableTorchFunctionSubclassType() { if (PyType_Ready(&DisableTorchFunctionSubclassType) < 0) { return nullptr; } return (PyObject*)(&DisableTorchFunctionSubclassType); } typedef struct { PyObject_HEAD /* Type-specific fields go here. */ at::impl::TorchFunctionDisabledState old_state; } DisableTorchFunction; PyObject* DisableTorchFunction__enter(PyObject* self, PyObject* unused) { ((DisableTorchFunctionSubclass*)self)->old_state = at::impl::PythonTorchFunctionTLS::get_disabled_state(); at::impl::PythonTorchFunctionTLS::set_disabled_state( at::impl::TorchFunctionDisabledState::ALL_DISABLED); Py_RETURN_NONE; } PyObject* DisableTorchFunction__exit(PyObject* self, PyObject* unused) { at::impl::PythonTorchFunctionTLS::set_disabled_state( ((DisableTorchFunctionSubclass*)self)->old_state); Py_RETURN_NONE; } static PyMethodDef DisableTorchFunction_methods[] = { // NOLINT {"__enter__", DisableTorchFunction__enter, METH_NOARGS, nullptr}, {"__exit__", DisableTorchFunction__exit, METH_VARARGS, nullptr}, {nullptr, nullptr, 0, nullptr}}; PyTypeObject DisableTorchFunctionType = { PyVarObject_HEAD_INIT(nullptr, 0) "torch._C.DisableTorchFunction", /* tp_name */ sizeof(DisableTorchFunction), /* tp_basicsize */ 0, /* tp_itemsize */ nullptr, /* tp_dealloc */ 0, /* tp_vectorcall_offset */ nullptr, /* tp_getattr */ nullptr, /* tp_setattr */ nullptr, /* tp_reserved */ nullptr, /* tp_repr */ nullptr, /* tp_as_number */ nullptr, /* tp_as_sequence */ nullptr, /* tp_as_mapping */ nullptr, /* tp_hash */ nullptr, /* tp_call */ nullptr, /* tp_str */ nullptr, /* tp_getattro */ nullptr, /* tp_setattro */ nullptr, /* tp_as_buffer */ Py_TPFLAGS_DEFAULT, /* tp_flags */ nullptr, /* tp_doc */ nullptr, /* tp_traverse */ nullptr, /* tp_clear */ nullptr, /* tp_richcompare */ 0, /* tp_weaklistoffset */ nullptr, /* tp_iter */ nullptr, /* tp_iternext */ DisableTorchFunction_methods, /* tp_methods */ nullptr, /* tp_members */ nullptr, /* tp_getset */ nullptr, /* tp_base */ nullptr, /* tp_dict */ nullptr, /* tp_descr_get */ nullptr, /* tp_descr_set */ 0, /* tp_dictoffset */ nullptr, /* tp_init */ PyType_GenericAlloc, /* tp_alloc */ PyType_GenericNew, /* tp_new */ }; PyObject* THPModule_DisableTorchFunctionType() { if (PyType_Ready(&DisableTorchFunctionType) < 0) { return nullptr; } return (PyObject*)(&DisableTorchFunctionType); } PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* a) { HANDLE_TH_ERRORS PyObject *func = nullptr, *types = nullptr, *args = nullptr, *kwargs = nullptr; if (!PyArg_ParseTuple(a, "OO|OO", &func, &types, &args, &kwargs)) { return nullptr; } py::tuple py_args; if (args == nullptr) { py_args = py::make_tuple(); } else if (PyList_Check(args)) { py_args = py::reinterpret_steal(PyList_AsTuple(args)); } else if (PyTuple_Check(args)) { py_args = py::reinterpret_borrow(args); } else { throw torch::TypeError( "expected List or Tuple (got %s)", Py_TYPE(args)->tp_name); } // These are all C-API calls so no exceptions will be raised // and therefore no need for RAII approach to storing // the old value. auto old_value = at::impl::PythonTorchFunctionTLS::get_disabled_state(); if (old_value == at::impl::TorchFunctionDisabledState::ENABLED) { at::impl::PythonTorchFunctionTLS::set_disabled_state( at::impl::TorchFunctionDisabledState::SUBCLASSES_DISABLED); } // kwargs can safely be nullptr here. PyObject* result = PyObject_Call(func, py_args.ptr(), kwargs); at::impl::PythonTorchFunctionTLS::set_disabled_state(old_value); return result; END_HANDLE_TH_ERRORS } PyObject* THPModule_disable_torch_dispatch(PyObject* self, PyObject* a) { HANDLE_TH_ERRORS PyObject *func = nullptr, *types = nullptr, *args = nullptr, *kwargs = nullptr; if (!PyArg_ParseTuple(a, "OO|OO", &func, &types, &args, &kwargs)) { return nullptr; } py::tuple py_args; if (args == nullptr) { py_args = py::make_tuple(); } else if (PyList_Check(args)) { py_args = py::reinterpret_steal(PyList_AsTuple(args)); } else if (PyTuple_Check(args)) { py_args = py::reinterpret_borrow(args); } else { throw torch::TypeError( "expected List or Tuple (got %s)", Py_TYPE(args)->tp_name); } // This implementation is not completely correct. The moral // meaning of this function is that we should do a redispatch // "after" PythonKey, aka a redispatch() call. But we don't have a // dispatcher call here; we have an opaque Python object. // // What we have here is a close approximation: instead of redispatch(), we // just exclude Python and all the keys before it, so that we will go // to the next key after Python. The difference, however, is we are // now PERMANENTLY after Python. We don't think there are any legitimate // cases where we want to go for another round on the entire dispatcher key // set, but if there are, then we will have to do something else here. c10::impl::ExcludeDispatchKeyGuard guard_( // TODO: add constructor for this specifically c10::DispatchKeySet(c10::DispatchKeySet::FULL) - c10::DispatchKeySet( c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::Python) // NB: off by one hazard here, but it works out: python key is not // included in AFTER, so it is included in the negation (and that's // correct: we want to exclude Python key and everything BEFORE it.) ); auto r = PyObject_Call(func, py_args.ptr(), kwargs); if (r == nullptr) throw python_error(); return r; END_HANDLE_TH_ERRORS } // Makes sure that we don't check for __torch_function__ on basic Python types static bool is_basic_python_type(PyTypeObject* tp) { return ( /* Basic number types */ tp == &PyBool_Type || tp == &PyLong_Type || tp == &PyFloat_Type || tp == &PyComplex_Type || /* Basic sequence types */ tp == &PyList_Type || tp == &PyTuple_Type || tp == &PyDict_Type || tp == &PySet_Type || tp == &PyFrozenSet_Type || tp == &PyUnicode_Type || tp == &PyBytes_Type || /* other builtins */ tp == &PySlice_Type || tp == Py_TYPE(Py_None) || tp == Py_TYPE(Py_Ellipsis) || tp == Py_TYPE(Py_NotImplemented) || PyModule_Check(tp) || /* sentinel to swallow trailing || */ false); } inline bool has_torch_function_attr(PyObject* obj) { auto attr = PyObject_FastGetAttrString(obj, "__torch_function__"); return ( attr.ptr() != nullptr && attr.ptr() != torch::disabled_torch_function); } namespace torch { auto check_has_torch_function(PyObject* obj, bool ignore_mode) -> bool { if (!ignore_mode && at::impl::torch_function_mode_enabled()) return true; PyTypeObject* tp = Py_TYPE(obj); return ( !THPVariable_CheckTypeExact(tp) && !is_basic_python_type(tp) && torch::torch_function_enabled() && has_torch_function_attr(obj)); } } // namespace torch inline bool sequence_has_torch_function(PyObject* args) { // NOLINTNEXTLINE(bugprone-branch-clone) Py_ssize_t nargs = PySequence_Fast_GET_SIZE(args); for (Py_ssize_t i = 0; i < nargs; i++) { PyObject* obj = PySequence_Fast_GET_ITEM(args, i); if (torch::check_has_torch_function(obj)) { return true; } } return false; } inline bool array_has_torch_function(PyObject* const* args, Py_ssize_t nargs) { for (Py_ssize_t i = 0; i < nargs; i++) { if (torch::check_has_torch_function(args[i])) { return true; } } return false; } PyObject* THPModule_has_torch_function(PyObject*, PyObject* arg) { bool result; // NOLINT(cppcoreguidelines-init-variables) if (PyTuple_CheckExact(arg) || PyList_CheckExact(arg)) { // Fast path: // If we know that we have a tuple or list, we can skip an INCREF and // DECREF from PySequence_Fast. Core functions will always follow this // convention (almost always tuples), and it shaves ~3.5% off the cost of // the check. result = sequence_has_torch_function(arg); } else { auto args = py::reinterpret_steal( PySequence_Fast(arg, "expected a sequence")); if (!args) { return nullptr; } result = sequence_has_torch_function(args.ptr()); } if (result) { Py_RETURN_TRUE; } Py_RETURN_FALSE; } PyObject* THPModule_has_torch_function_unary(PyObject*, PyObject* obj) { // Special case `THPModule_has_torch_function` for the single arg case. if (torch::check_has_torch_function(obj)) { Py_RETURN_TRUE; } Py_RETURN_FALSE; } PyObject* THPModule_has_torch_function_variadic( PyObject*, PyObject* const* args, Py_ssize_t nargs) { if (array_has_torch_function(args, nargs)) { Py_RETURN_TRUE; } Py_RETURN_FALSE; }