#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace py = pybind11; namespace torch { namespace impl { namespace dispatch { // NB: I'd like to index this on OperatorHandle, but I can't, as I can't // guarantee that the main interpreter has finish doing all registrations before // the other interpreters start banging on it static ska::flat_hash_map< c10::OperatorName, ska::flat_hash_map>> python_registrations_; static torch::Library::Kind parseKind(const std::string& k) { static std::unordered_map kind_map = { {"DEF", torch::Library::DEF}, {"IMPL", torch::Library::IMPL}, {"FRAGMENT", torch::Library::FRAGMENT}, }; auto it = kind_map.find(k); TORCH_CHECK(it != kind_map.end(), "could not parse ", k); return it->second; } static c10::AliasAnalysisKind parseAliasAnalysisKind(const std::string& k) { static std::unordered_map key_map = { {"CONSERVATIVE", c10::AliasAnalysisKind::CONSERVATIVE}, {"FROM_SCHEMA", c10::AliasAnalysisKind::FROM_SCHEMA}, {"PURE_FUNCTION", c10::AliasAnalysisKind::PURE_FUNCTION}, {"", c10::AliasAnalysisKind::FROM_SCHEMA}, // default }; auto it = key_map.find(k); TORCH_CHECK(it != key_map.end(), "could not parse ", k); return it->second; } template inline torch::CppFunction dispatch_str(const char* key, Func&& raw_f) { auto mb_key = std::string(key).empty() ? c10::nullopt : c10::make_optional(c10::parseDispatchKey(key)); if (mb_key) { return torch::dispatch(*mb_key, std::forward(raw_f)); } else { torch::CppFunction f(std::forward(raw_f)); return f; } } struct EnableHermeticPyObject { EnableHermeticPyObject() : old_(c10::impl::HermeticPyObjectTLS::get_state()), old_excluded_python_( c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Python)), old_python_( c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Python)), old_python_snapshot_(c10::impl::tls_is_dispatch_key_included( at::DispatchKey::PythonTLSSnapshot)) { c10::impl::HermeticPyObjectTLS::set_state(true); c10::impl::tls_set_dispatch_key_excluded(at::DispatchKey::Python, true); c10::impl::tls_set_dispatch_key_included(at::DispatchKey::Python, false); c10::impl::tls_set_dispatch_key_included( at::DispatchKey::PythonTLSSnapshot, false); } ~EnableHermeticPyObject() { c10::impl::HermeticPyObjectTLS::set_state(old_); c10::impl::tls_set_dispatch_key_excluded( at::DispatchKey::Python, old_excluded_python_); c10::impl::tls_set_dispatch_key_included( at::DispatchKey::Python, old_python_); c10::impl::tls_set_dispatch_key_included( at::DispatchKey::PythonTLSSnapshot, old_python_snapshot_); } bool old_; bool old_excluded_python_; bool old_python_; bool old_python_snapshot_; }; class PythonKernelHolder : public c10::OperatorKernel { c10::SafePyObject func_; c10::DispatchKey dispatch_key_; public: PythonKernelHolder(py::object func, c10::DispatchKey dispatch_key) : func_(func.release().ptr(), getPyInterpreter()), dispatch_key_(dispatch_key) {} void operator()( const c10::OperatorHandle& op, c10::DispatchKeySet keyset, torch::jit::Stack* stack) { // Figure out if we can handle it hermetically, or if we have // to double dispatch // If Torch Dispatch Mode is active, use its PyInterpreter for dispatch const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len(); if (mode_stack_len > 0) { const auto& cur_torch_dispatch_mode_state = c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1); cur_torch_dispatch_mode_state->pyinterpreter() ->python_op_registration_trampoline(op, dispatch_key_, stack); return; } const auto& schema = op.schema(); const auto num_arguments = schema.arguments().size(); // Otherwise, find a PyInterpreter on a Tensor IF if has Python key (which // means it's a nontrivial tensor subclass) for (const auto& ivalue : torch::jit::last(*stack, num_arguments)) { if (ivalue.isTensor()) { auto* interpreter = ivalue.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter(); if (interpreter && ivalue.unsafeToTensorImpl()->key_set().has( at::DispatchKey::Python)) { (*interpreter) ->python_op_registration_trampoline(op, dispatch_key_, stack); return; } } else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) { // NB: use toListRef as it doesn't induce refcount bumps // (toTensorListRef is not a thing) for (const auto& nv : ivalue.toListRef()) { if (nv.isNone()) { continue; } auto* interpreter = nv.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter(); if (interpreter && nv.unsafeToTensorImpl()->key_set().has(at::DispatchKey::Python)) { (*interpreter) ->python_op_registration_trampoline(op, dispatch_key_, stack); return; } } } } // Nothing requires the operator to be homed to a specific interpreter, so // run it on the current interpreter auto arguments = torch::jit::pop(*stack, op.schema().arguments().size()); py::gil_scoped_acquire g; // Jan 2024: We're slated to get rid of multipy, so stop forcing hermetic // mode unconditionally in all situations when you're using multipy. // Eventually just delete this entirely. (Note that you may break multipy // anyway this way with dispatcher registered functions that require // hermetic to be off.) #if defined(USE_DEPLOY) EnableHermeticPyObject g2; #endif auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments); auto obj = py::reinterpret_steal(PyObject_Call( func_.ptr(getPyInterpreter()), args_kwargs.first.ptr(), args_kwargs.second.ptr())); if (!obj) { throw python_error(); } pushPyOutToStack(op, stack, obj, "PythonKernelHolder"); } }; static torch::_RegisterOrVerify register_or_verify() { if (isMainPyInterpreter()) { return torch::_RegisterOrVerify::REGISTER; } else { return torch::_RegisterOrVerify::VERIFY; } } static py::object ophandle_call_boxed( const c10::OperatorHandle& handle, py::args args, const py::kwargs& kwargs) { auto stack = torch::jit::createStackForSchema( handle.schema(), std::move(args), kwargs, /*self=*/c10::nullopt); { pybind11::gil_scoped_release no_gil_guard; handle.callBoxed(stack); } return torch::jit::createPyObjectForStack(std::move(stack)); } // A small RAII guard that lets you explicitly *remove* a key from the TLS // exclude set. class SetExcludeDispatchKeyGuard { public: SetExcludeDispatchKeyGuard(at::DispatchKey k, bool set_excluded) : k(k), old(c10::impl::tls_is_dispatch_key_excluded(k)) { c10::impl::tls_set_dispatch_key_excluded(k, set_excluded); } ~SetExcludeDispatchKeyGuard() { c10::impl::tls_set_dispatch_key_excluded(k, old); } SetExcludeDispatchKeyGuard(const SetExcludeDispatchKeyGuard&) = delete; SetExcludeDispatchKeyGuard operator=(const SetExcludeDispatchKeyGuard&) = delete; SetExcludeDispatchKeyGuard(SetExcludeDispatchKeyGuard&&) = delete; SetExcludeDispatchKeyGuard operator=(SetExcludeDispatchKeyGuard&&) = delete; private: at::DispatchKey k; bool old; }; void initDispatchBindings(PyObject* module) { auto m = py::handle(module).cast(); py::class_(m, "_DispatchOperatorHandle") .def("schema", &c10::OperatorHandle::schema) .def("debug", &c10::OperatorHandle::debug); m.def("_dispatch_call_boxed", &ophandle_call_boxed); // TODO: figure out how to do chaining py::class_(m, "_DispatchModule") .def( "reset", [](const py::object& self) { TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); self.cast().reset(); return; }, "") // Some of these APIs are only for testing and do not work in multipy // environment .def( "def_", [](py::object self, const char* schema, const char* alias) { TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); self.cast().def( torch::schema(schema, parseAliasAnalysisKind(alias))); return self; }, "", py::arg("schema"), py::arg("alias") = "") // Simulated "legacy" def where alias analysis kind is not set. // Ordinarily this can only be exercised from RegisterOperators() API // but I am not going to bind that here .def( "def_legacy", [](py::object self, const char* schema) { TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); self.cast().def(torch::jit::parseSchema(schema)); return self; }, "", py::arg("schema")) // We can't conveniently turn Python functions into valid functions // in the dispatcher. So instead we provide a bunch of precanned // functions for testing purposes. You're NOT intended to actually // call these functions; they're just here so we can actually register // something // // Mangling scheme: args_rets. One character per. // t = Tensor .def( "def_name_t_t", [](py::object self, const char* name, const char* dispatch, const char* debug) { TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); self.cast().def( name, dispatch_str(dispatch, [](const at::Tensor& a) { return a; }).debug(debug)); return self; }, "", py::arg("name"), py::arg("dispatch") = "", py::arg("debug") = "default_def_name_t_t") .def( "def_schema_t_t", [](py::object self, const char* schema, const char* dispatch, const char* alias, const char* debug) { TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); self.cast().def( torch::schema(schema, parseAliasAnalysisKind(alias)), dispatch_str(dispatch, [](const at::Tensor& a) { return a; }).debug(debug)); return self; }, "", py::arg("name"), py::arg("dispatch") = "", py::arg("alias") = "", py::arg("debug") = "default_def_schema_t_t") // TODO: maybe consider deduplicating the definitions here, it's getting // pretty long .def( "impl_t_t", [](py::object self, const char* name, const char* dispatch, const char* debug) { TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); self.cast().impl( name, dispatch_str(dispatch, [](const at::Tensor& a) { return a; }).debug(debug)); return self; }, "", py::arg("name"), py::arg("dispatch") = "", py::arg("debug") = "impl_t_t") .def( "impl", [](const py::object& self, const char* name, // TODO: empty string no longer works c10::DispatchKey dispatch, py::object func) { HANDLE_TH_ERRORS auto& lib = self.cast(); if (func.is(py::module::import("torch.library") .attr("fallthrough_kernel"))) { lib.impl( name, torch::dispatch(dispatch, CppFunction::makeFallthrough()), register_or_verify()); } else { lib.impl( name, torch::dispatch( dispatch, CppFunction::makeFromBoxedFunctor( std::make_unique( func, dispatch))), register_or_verify()); python_registrations_[lib._resolve(name)].insert_or_assign( dispatch, std::make_shared( func.release().ptr(), getPyInterpreter())); } END_HANDLE_TH_ERRORS_PYBIND }, "", py::arg("name"), py::arg("dispatch"), py::arg("func")) .def( "define", [](const py::object& self, const char* schema, const char* alias_analysis, const std::vector& tags) { auto parsed_schema = torch::schema(schema, parseAliasAnalysisKind(alias_analysis)); self.cast().def( std::move(parsed_schema), tags, register_or_verify()); // TODO: this is dumb, had to make a second copy return torch::schema(schema, parseAliasAnalysisKind(alias_analysis)) .name(); }, "", py::arg("schema"), py::arg("alias_analysis") = "", py::arg("tags") = std::vector()) .def( "fallback_fallthrough", [](py::object self, const char* dispatch) { TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); self.cast().fallback( dispatch_str(dispatch, CppFunction::makeFallthrough())); return self; }, "", py::arg("dispatch") = ""); m.def( "_dispatch_library", [](const char* kind, std::string name, const char* dispatch, const char* file, uint32_t linenum) { HANDLE_TH_ERRORS return std::make_unique( parseKind(kind), std::move(name), std::string(dispatch).empty() ? c10::nullopt : c10::make_optional(c10::parseDispatchKey(dispatch)), "/dev/null", // temporary workaround linenum); END_HANDLE_TH_ERRORS_PYBIND }, "", py::arg("kind"), py::arg("name"), py::arg("dispatch"), py::arg("file") = "/dev/null", py::arg("linenum") = 0); m.def( "_dispatch_find_schema_or_throw", [](const char* name, const char* overload_name) -> c10::OperatorHandle { return c10::Dispatcher::singleton().findSchemaOrThrow( name, overload_name); }); m.def("_dispatch_dump", [](const char* name) -> std::string { auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); if (!op) { return ""; } else { return op->dumpState(); } }); m.def("_dispatch_dump_table", [](const char* name) -> std::string { auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); if (!op) { return ""; } else { return op->dumpComputedTable(); } }); m.def("_dispatch_check_invariants", [](const char* name) { auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); if (!op) { } else { return op->checkInvariants(); } }); m.def("_dispatch_check_all_invariants", []() { c10::Dispatcher::singleton().checkInvariants(); }); m.def("_dispatch_has_kernel", [](const char* name) -> bool { auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); return static_cast(op); }); m.def( // Returns whether or not a direct kernel registration exists // for this pair. "_dispatch_has_kernel_for_dispatch_key", [](const char* name, c10::DispatchKey dispatch) -> bool { auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); TORCH_CHECK(op, "operator ", name, " does not exist"); return op->hasKernelForDispatchKey(dispatch); }); m.def( "_dispatch_has_kernel_for_any_dispatch_key", [](const char* name, c10::DispatchKeySet ks) -> bool { auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); TORCH_CHECK(op, "operator ", name, " does not exist"); return op->hasKernelForAnyDispatchKey(ks); }); m.def( // Returns whether or not there is an entry in the runtime computed // dispatch table, for this pair. For example, if // "op" has a `CompositeImplicitAutograd` kernel, Then // _dispatch_has_computed_kernel_for_dispatch_key(op, backend) will return // true for all backends that are part of the alias set for // CompositeImplicitAutograd. "_dispatch_has_computed_kernel_for_dispatch_key", [](const char* name, const char* dispatch) -> bool { auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); TORCH_CHECK(op, "operator ", name, " does not exist"); return op->hasComputedKernelForDispatchKey( c10::parseDispatchKey(dispatch)); }); m.def("_dispatch_find_dangling_impls", []() -> std::vector { auto danglingImpls = c10::Dispatcher::singleton().findDanglingImpls(); std::vector states; states.reserve(danglingImpls.size()); for (auto& danglingImpl : danglingImpls) { states.emplace_back(danglingImpl.dumpState()); } return states; }); m.def("_dispatch_get_all_op_names", []() -> std::vector { auto op_names = c10::Dispatcher::singleton().getAllOpNames(); std::vector names; names.reserve(op_names.size()); for (auto& op : op_names) { std::stringstream ss; ss << op.name; if (!op.overload_name.empty()) { ss << "." << op.overload_name; } names.emplace_back(ss.str()); } return names; }); m.def( "_dispatch_tls_set_dispatch_key_excluded", [](c10::DispatchKey dispatch_key, bool desired_state) { c10::impl::tls_set_dispatch_key_excluded(dispatch_key, desired_state); }); m.def( "_dispatch_tls_is_dispatch_key_excluded", [](c10::DispatchKey dispatch_key) { return c10::impl::tls_is_dispatch_key_excluded(dispatch_key); }); m.def( "_dispatch_tls_set_dispatch_key_included", [](c10::DispatchKey dispatch_key, bool desired_state) { c10::impl::tls_set_dispatch_key_included(dispatch_key, desired_state); }); m.def( "_dispatch_tls_is_dispatch_key_included", [](c10::DispatchKey dispatch_key) { return c10::impl::tls_is_dispatch_key_included(dispatch_key); }); m.def("_dispatch_isTensorSubclassLike", [](const at::Tensor& tensor) { return at::isTensorSubclassLike(tensor); }); m.def("_dispatch_key_name", [](c10::DispatchKey k) { return c10::toString(k); }); m.def("_dispatch_key_parse", [](c10::DispatchKey k) { return k; }); m.def("_to_functionality_key", [](c10::DispatchKey k) { return c10::toFunctionalityKey(k); }); // E.g. given `DispatchKey::AutogradFunctionality`, returns a keyset of: // AutogradCPU // AutogradCUDA // ... // AutogradPrivateUse3 m.def("_functionality_to_backend_keys", [](c10::DispatchKey key) { std::vector keys; if (c10::isPerBackendFunctionalityKey(key)) { auto ks = c10::DispatchKeySet(key) | c10::DispatchKeySet(c10::DispatchKeySet::RAW, c10::full_backend_mask); for (auto k : ks) { keys.push_back(k); } } else { keys.push_back(key); } return keys; }); m.def("_dispatch_num_backends", []() { return c10::num_backends; }); #define DEF_ONE(n) .value(#n, c10::DispatchKey::n) py::enum_(m, "DispatchKey") // clang-format off DEF_ONE(Undefined) DEF_ONE(CompositeExplicitAutogradNonFunctional) DEF_ONE(CompositeExplicitAutograd) DEF_ONE(CompositeImplicitAutogradNestedTensor) DEF_ONE(CompositeImplicitAutograd) // NestedTensor is not a backend key DEF_ONE(AutogradNestedTensor) DEF_ONE(AutogradOther) DEF_ONE(Autograd) DEF_ONE(Conjugate) DEF_ONE(ZeroTensor) DEF_ONE(Negative) DEF_ONE(BackendSelect) DEF_ONE(ADInplaceOrView) DEF_ONE(PythonTLSSnapshot) DEF_ONE(Python) DEF_ONE(FuncTorchDynamicLayerFrontMode) DEF_ONE(FuncTorchDynamicLayerBackMode) DEF_ONE(FuncTorchBatchedDecomposition) DEF_ONE(FuncTorchBatched) DEF_ONE(FuncTorchVmapMode) DEF_ONE(FuncTorchGradWrapper) DEF_ONE(PythonDispatcher) DEF_ONE(PreDispatch) DEF_ONE(Functionalize) DEF_ONE(AutocastCPU) DEF_ONE(AutocastXPU) DEF_ONE(AutocastHPU) DEF_ONE(AutocastIPU) DEF_ONE(AutocastCUDA) DEF_ONE(AutocastPrivateUse1) // clang-format on #define DEF_SINGLE(n, prefix) .value(#prefix #n, c10::DispatchKey::prefix##n) #define DEF_MULTIPLE(fullname, prefix) \ DEF_SINGLE(, fullname) \ DEF_SINGLE(, StartOf##fullname##Backends) \ C10_FORALL_BACKEND_COMPONENTS(DEF_SINGLE, prefix) \ DEF_SINGLE(, EndOf##fullname##Backends) // clang-format off C10_FORALL_FUNCTIONALITY_KEYS(DEF_MULTIPLE) // clang-format on #undef DEF_MULTIPLE #undef DEF_SINGLE ; py::class_(m, "DispatchKeySet") .def(py::init()) .def("__or__", &c10::DispatchKeySet::operator|) .def("__sub__", &c10::DispatchKeySet::operator-) .def("__and__", &c10::DispatchKeySet::operator&) .def("highestPriorityTypeId", &c10::DispatchKeySet::highestPriorityTypeId) .def( "remove", [](c10::DispatchKeySet self, c10::DispatchKey k) { return self.remove(k); }) .def( "add", [](c10::DispatchKeySet self, c10::DispatchKey k) { return self.add(k); }) .def("has", &c10::DispatchKeySet::has) .def("__repr__", [](c10::DispatchKeySet d) { return c10::toString(d); }); m.attr("_dispatch_autogradother_backends") = py::cast(c10::autogradother_backends); m.attr("_additional_keys_to_prop_for_wrapper_tensors") = py::cast(at::functorch::kKeysToPropagateToWrapper); m.def("_dispatch_has_backend_fallback", [](c10::DispatchKey t) { return c10::Dispatcher::singleton().hasBackendFallbackForDispatchKey(t); }); m.def("_dispatch_keyset_full_after", [](c10::DispatchKey t) { return c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, t); }); m.def("_dispatch_keyset_full", []() { return c10::DispatchKeySet(c10::DispatchKeySet::FULL); }); m.def("_dispatch_is_alias_key", c10::isAliasDispatchKey); m.def("_dispatch_keyset_to_string", [](c10::DispatchKeySet keyset) { return c10::toString(keyset); }); m.def("_dispatch_get_backend_keyset_from_autograd", [](c10::DispatchKey k) { return c10::getBackendKeySetFromAutograd(k); }); m.def("_dispatch_keys", [](const at::Tensor& tensor) { auto* impl = tensor.unsafeGetTensorImpl(); return impl->key_set(); }); m.def("_dispatch_tls_local_include_set", []() { return c10::impl::tls_local_dispatch_key_set().included_; }); m.def("_dispatch_tls_local_exclude_set", []() { return c10::impl::tls_local_dispatch_key_set().excluded_; }); m.def("_functionalization_reapply_views_tls", []() { return at::functionalization::impl::getFunctionalizationReapplyViewsTLS(); }); m.def( "_dispatch_is_included_in_alias", [](c10::DispatchKey a, c10::DispatchKey b) { return c10::isIncludedInAlias(a, b); }); // DEPRECATED, please don't use this. Instead use // torch._C._ExcludeDispatchKeyGuard py_context_manager_DEPRECATED< c10::impl::ExcludeDispatchKeyGuard, c10::DispatchKeySet>(m, "ExcludeDispatchKeyGuard"); py_context_manager< c10::impl::ForceDispatchKeyGuard, c10::DispatchKeySet, c10::DispatchKeySet>(m, "_ForceDispatchKeyGuard"); py_context_manager( m, "_IncludeDispatchKeyGuard"); py_context_manager( m, "_ExcludeDispatchKeyGuard"); py_context_manager( m, "_SetExcludeDispatchKeyGuard"); py_context_manager_DEPRECATED( m, "_AutoDispatchBelowAutograd"); // Prints out the name of every operator that has a kernel registered to the // Dispatcher under [dispatch_key]. If no arguments are specified, it'll print // out the name of every operator that the Dispatcher knows of. This can be // useful to answer questions like "list all operators that do not have a CPU // kernel". m.def( "_dispatch_print_registrations_for_dispatch_key", [](const char* dispatch_key = "") { auto k = std::string(dispatch_key).empty() ? c10::nullopt : c10::make_optional(c10::parseDispatchKey(dispatch_key)); auto op_names = c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k); for (auto& op : op_names) { std::cout << op << std::endl; } }, py::arg("dispatch_key") = static_cast("")); m.def( "_parse_dispatch_key", [](const char* dispatch_key) -> c10::optional { try { return c10::parseDispatchKey(dispatch_key); } catch (const c10::Error& err) { return c10::nullopt; } }); m.def( "_dispatch_get_registrations_for_dispatch_key", [](const char* dispatch_key = "") { auto k = std::string(dispatch_key).empty() ? c10::nullopt : c10::make_optional(c10::parseDispatchKey(dispatch_key)); auto op_names = c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k); std::vector names; names.reserve(op_names.size()); for (auto& op : op_names) { names.emplace_back( op.name + (op.overload_name.empty() ? "" : "." + op.overload_name)); } return names; }, py::arg("dispatch_key") = static_cast("")); m.def( "_dispatch_set_report_error_callback", [](c10::OperatorHandle& handle, py::object callback) { auto obj = callback.release().ptr(); auto callback_obj = std::make_unique(obj, getPyInterpreter()); handle.setReportErrorCallback_(std::move(callback_obj)); }); m.def( "_dispatch_is_main_interpreter", []() { return isMainPyInterpreter(); }); m.def("_dispatch_pystub", [](const char* name, const char* overload) { return c10::Dispatcher::singleton().getAbstractImplPyStub( c10::OperatorName(name, overload)); }); m.def("_replace_", [](const at::Tensor& a, const at::Tensor& b) { return at::functionalization::impl::replace_(a, b); }); m.def("_propagate_xla_data", [](const at::Tensor& a, const at::Tensor& b) { at::functionalization::impl::propagate_xla_data(a, b); }); m.def("_commit_update", [](const at::Tensor& a) { return at::functionalization::impl::commit_update(a); }); m.def("_unsafe_reset_storage", [](const at::Tensor& a) { return at::functionalization::impl::unsafe_reset_storage(a); }); m.def("_dispatch_key_for_device", [](const std::string& device_type) { auto device = c10::Device(device_type); TORCH_CHECK( !device.has_index(), "Expected device_type string to not have a device index; got ", device_type); return c10::toString( c10::computeDispatchKey(c10::nullopt, c10::nullopt, device)); }); m.def("_are_functorch_transforms_active", []() { auto include_set = c10::impl::tls_local_dispatch_key_set().included_; return ( include_set.has(c10::DispatchKey::FuncTorchDynamicLayerFrontMode) || include_set.has(c10::DispatchKey::FuncTorchDynamicLayerBackMode)); }); m.def("_get_nested_int", [](int64_t data, int64_t coeff) { return c10::SymInt(c10::SymNode( c10::make_intrusive(data, coeff))); }); m.def("_get_constant_bool_symnode", [](int64_t data) { return c10::SymNode( c10::make_intrusive>(data)); }); m.def("_non_sym_sizes", [](const at::Tensor& a) { return a.sizes(); // NB: NOT sym_size }); using c10::impl::TorchDispatchModeKey; py::enum_(m, "_TorchDispatchModeKey") .value("FUNCTIONAL", TorchDispatchModeKey::FUNCTIONAL) .value("PROXY", TorchDispatchModeKey::PROXY) .value("FAKE", TorchDispatchModeKey::FAKE); } // TODO: dedupe with the kernel void python_op_registration_trampoline_impl( const c10::OperatorHandle& op, c10::DispatchKey key, torch::jit::Stack* stack) { auto arguments = torch::jit::pop(*stack, op.schema().arguments().size()); py::gil_scoped_acquire g; auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments); const auto& func = python_registrations_[op.operator_name()][key]; TORCH_INTERNAL_ASSERT(func != nullptr); auto* pyobj = func->ptr(getPyInterpreter()); TORCH_INTERNAL_ASSERT(pyobj != nullptr); auto obj = py::reinterpret_steal( PyObject_Call(pyobj, args_kwargs.first.ptr(), args_kwargs.second.ptr())); if (!obj) { throw python_error(); } pushPyOutToStack(op, stack, obj, "PythonKernelHolder"); } } // namespace dispatch } // namespace impl } // namespace torch