mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 04:04:57 +08:00 
			
		
		
		
	In this PR, we abstracted the different types of aten operation parameters as `ParameterMetadata`. This structure intends to be used to represent and store the metadata of each aten operation parameter. Currently, it only supports `Tensor`, `TensorList`, and `Scalar`. ```C++ using ParameterMetadataValue = std::variant<TensorMetadata, std::vector<TensorMetadata>, c10::Scalar>; ``` With this PR, we can extend other parameter-type support in a more modularize way, like `string`, `int`, `double`, and other different types to be summarized as the following list. The list is collected from all aten operations and ordered by the number of being used. - `Tensor` - `bool` - `int64_t` - `TensorList` - `Scalar` - `c10::SymIntArrayRef` - `::std::optional<Tensor>` - `IntArrayRef` - `double` - `c10::SymInt` - `::std::optional<ScalarType>` - `::std::optional<double>` - `::std::optional<bool>` - `::std::optional<Layout>` - `::std::optional<Device>` - `::std::optional<int64_t>` - `Dimname` - `::std::optional<Generator>` - `c10::string_view` - `::std::optional<c10::string_view>` - `OptionalIntArrayRef` - `::std::optional<Scalar>` - `OptionalSymIntArrayRef` - `::std::optional<MemoryFormat>` - `::std::optional<c10::SymInt>` - `ScalarType` - `ArrayRef<Scalar>` - `DimnameList` - `::std::optional<ArrayRef<double>>` - `::std::array<bool,3>` - `::std::optional<DimnameList>` - `c10::List<::std::optional<Tensor>>` - `::std::array<bool,2>` - `Storage` - `::std::array<bool,4>` - `Device` - `DeviceIndex` - `ITensorListRef` - `Stream` - `Layout` - `MemoryFormat` Pull Request resolved: https://github.com/pytorch/pytorch/pull/125308 Approved by: https://github.com/jgong5, https://github.com/jansel
		
			
				
	
	
		
			978 lines
		
	
	
		
			34 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			978 lines
		
	
	
		
			34 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| #include <torch/csrc/jit/frontend/function_schema_parser.h>
 | |
| #include <torch/csrc/utils/python_dispatch.h>
 | |
| 
 | |
| #include <ATen/ATen.h>
 | |
| #include <ATen/FuncTorchTLS.h>
 | |
| #include <ATen/FunctionalTensorWrapper.h>
 | |
| #include <ATen/TensorSubclassLikeUtils.h>
 | |
| #include <ATen/core/NestedIntSymNodeImpl.h>
 | |
| #include <ATen/core/PythonOpRegistrationTrampoline.h>
 | |
| #include <ATen/core/dispatch/Dispatcher.h>
 | |
| 
 | |
| #include <ATen/functorch/BatchedTensorImpl.h>
 | |
| #include <torch/library.h>
 | |
| 
 | |
| #include <c10/core/SafePyObject.h>
 | |
| #include <torch/csrc/PyInterpreter.h>
 | |
| #include <torch/csrc/autograd/python_variable.h>
 | |
| #include <torch/csrc/jit/python/pybind_utils.h>
 | |
| #include <torch/csrc/utils/tensor_new.h>
 | |
| 
 | |
| #include <c10/util/flat_hash_map.h>
 | |
| #include <pybind11/operators.h>
 | |
| #include <pybind11/stl.h>
 | |
| #include <torch/csrc/inductor/aoti_eager/kernel_holder.h>
 | |
| #include <torch/csrc/utils/pybind.h>
 | |
| #include <torch/csrc/utils/python_raii.h>
 | |
| 
 | |
| #include <iostream>
 | |
| #include <utility>
 | |
| 
 | |
| namespace py = pybind11;
 | |
| 
 | |
| namespace torch::impl::dispatch {
 | |
| 
 | |
| // NB: I'd like to index this on OperatorHandle, but I can't, as I can't
 | |
| // guarantee that the main interpreter has finish doing all registrations before
 | |
| // the other interpreters start banging on it
 | |
| static ska::flat_hash_map<
 | |
|     c10::OperatorName,
 | |
|     ska::flat_hash_map<c10::DispatchKey, std::shared_ptr<c10::SafePyObject>>>
 | |
|     python_registrations_;
 | |
| 
 | |
| static torch::Library::Kind parseKind(const std::string& k) {
 | |
|   static std::unordered_map<std::string, torch::Library::Kind> kind_map = {
 | |
|       {"DEF", torch::Library::DEF},
 | |
|       {"IMPL", torch::Library::IMPL},
 | |
|       {"FRAGMENT", torch::Library::FRAGMENT},
 | |
|   };
 | |
|   auto it = kind_map.find(k);
 | |
|   TORCH_CHECK(it != kind_map.end(), "could not parse ", k);
 | |
|   return it->second;
 | |
| }
 | |
| static c10::AliasAnalysisKind parseAliasAnalysisKind(const std::string& k) {
 | |
|   static std::unordered_map<std::string, c10::AliasAnalysisKind> key_map = {
 | |
|       {"CONSERVATIVE", c10::AliasAnalysisKind::CONSERVATIVE},
 | |
|       {"FROM_SCHEMA", c10::AliasAnalysisKind::FROM_SCHEMA},
 | |
|       {"PURE_FUNCTION", c10::AliasAnalysisKind::PURE_FUNCTION},
 | |
|       {"", c10::AliasAnalysisKind::FROM_SCHEMA}, // default
 | |
|   };
 | |
|   auto it = key_map.find(k);
 | |
|   TORCH_CHECK(it != key_map.end(), "could not parse ", k);
 | |
|   return it->second;
 | |
| }
 | |
| 
 | |
| template <typename Func>
 | |
| inline torch::CppFunction dispatch_str(const char* key, Func&& raw_f) {
 | |
|   auto mb_key = std::string(key).empty()
 | |
|       ? c10::nullopt
 | |
|       : c10::make_optional(c10::parseDispatchKey(key));
 | |
|   if (mb_key) {
 | |
|     return torch::dispatch(*mb_key, std::forward<Func>(raw_f));
 | |
|   } else {
 | |
|     torch::CppFunction f(std::forward<Func>(raw_f));
 | |
|     return f;
 | |
|   }
 | |
| }
 | |
| 
 | |
| struct EnableHermeticPyObject {
 | |
|   EnableHermeticPyObject()
 | |
|       : old_(c10::impl::HermeticPyObjectTLS::get_state()),
 | |
|         old_excluded_python_(
 | |
|             c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Python)),
 | |
|         old_python_(
 | |
|             c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Python)),
 | |
|         old_python_snapshot_(c10::impl::tls_is_dispatch_key_included(
 | |
|             at::DispatchKey::PythonTLSSnapshot)) {
 | |
|     c10::impl::HermeticPyObjectTLS::set_state(true);
 | |
|     c10::impl::tls_set_dispatch_key_excluded(at::DispatchKey::Python, true);
 | |
|     c10::impl::tls_set_dispatch_key_included(at::DispatchKey::Python, false);
 | |
|     c10::impl::tls_set_dispatch_key_included(
 | |
|         at::DispatchKey::PythonTLSSnapshot, false);
 | |
|   }
 | |
|   ~EnableHermeticPyObject() {
 | |
|     c10::impl::HermeticPyObjectTLS::set_state(old_);
 | |
|     c10::impl::tls_set_dispatch_key_excluded(
 | |
|         at::DispatchKey::Python, old_excluded_python_);
 | |
|     c10::impl::tls_set_dispatch_key_included(
 | |
|         at::DispatchKey::Python, old_python_);
 | |
|     c10::impl::tls_set_dispatch_key_included(
 | |
|         at::DispatchKey::PythonTLSSnapshot, old_python_snapshot_);
 | |
|   }
 | |
|   bool old_;
 | |
|   bool old_excluded_python_;
 | |
|   bool old_python_;
 | |
|   bool old_python_snapshot_;
 | |
| };
 | |
| 
 | |
| class PythonKernelHolder : public c10::OperatorKernel {
 | |
|   c10::SafePyObject func_;
 | |
|   c10::DispatchKey dispatch_key_;
 | |
|   // If "with_keyset", then we expect a keyset as the first arg.
 | |
|   bool with_keyset_;
 | |
| 
 | |
|  public:
 | |
|   PythonKernelHolder(
 | |
|       py::object func,
 | |
|       c10::DispatchKey dispatch_key,
 | |
|       bool with_keyset = false)
 | |
|       : func_(func.release().ptr(), getPyInterpreter()),
 | |
|         dispatch_key_(dispatch_key),
 | |
|         with_keyset_(with_keyset) {}
 | |
| 
 | |
|   void operator()(
 | |
|       const c10::OperatorHandle& op,
 | |
|       c10::DispatchKeySet keyset,
 | |
|       torch::jit::Stack* stack) {
 | |
|     // Figure out if we can handle it hermetically, or if we have
 | |
|     // to double dispatch
 | |
| 
 | |
|     // If Torch Dispatch Mode is active, use its PyInterpreter for dispatch
 | |
|     const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len();
 | |
|     if (mode_stack_len > 0) {
 | |
|       const auto& cur_torch_dispatch_mode_state =
 | |
|           c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
 | |
|       cur_torch_dispatch_mode_state->pyinterpreter()
 | |
|           ->python_op_registration_trampoline(
 | |
|               op, dispatch_key_, keyset, stack, with_keyset_);
 | |
|       return;
 | |
|     }
 | |
| 
 | |
|     const auto& schema = op.schema();
 | |
|     const auto num_arguments = schema.arguments().size();
 | |
| 
 | |
|     // Otherwise, find a PyInterpreter on a Tensor IF if has Python key (which
 | |
|     // means it's a nontrivial tensor subclass)
 | |
|     for (const auto& ivalue : torch::jit::last(*stack, num_arguments)) {
 | |
|       if (ivalue.isTensor()) {
 | |
|         auto* interpreter =
 | |
|             ivalue.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter();
 | |
|         if (interpreter &&
 | |
|             ivalue.unsafeToTensorImpl()->key_set().has(
 | |
|                 at::DispatchKey::Python)) {
 | |
|           (*interpreter)
 | |
|               ->python_op_registration_trampoline(
 | |
|                   op, dispatch_key_, keyset, stack, with_keyset_);
 | |
|           return;
 | |
|         }
 | |
|       } else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) {
 | |
|         // NB: use toListRef as it doesn't induce refcount bumps
 | |
|         // (toTensorListRef is not a thing)
 | |
|         for (const auto& nv : ivalue.toListRef()) {
 | |
|           if (nv.isNone()) {
 | |
|             continue;
 | |
|           }
 | |
|           auto* interpreter =
 | |
|               nv.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter();
 | |
|           if (interpreter &&
 | |
|               nv.unsafeToTensorImpl()->key_set().has(at::DispatchKey::Python)) {
 | |
|             (*interpreter)
 | |
|                 ->python_op_registration_trampoline(
 | |
|                     op, dispatch_key_, keyset, stack, with_keyset_);
 | |
|             return;
 | |
|           }
 | |
|         }
 | |
|       }
 | |
|     }
 | |
| 
 | |
|     // Nothing requires the operator to be homed to a specific interpreter, so
 | |
|     // run it on the current interpreter
 | |
| 
 | |
|     auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
 | |
|     py::gil_scoped_acquire g;
 | |
|     // Jan 2024: We're slated to get rid of multipy, so stop forcing hermetic
 | |
|     // mode unconditionally in all situations when you're using multipy.
 | |
|     // Eventually just delete this entirely.  (Note that you may break multipy
 | |
|     // anyway this way with dispatcher registered functions that require
 | |
|     // hermetic to be off.)
 | |
| #if defined(USE_DEPLOY)
 | |
|     EnableHermeticPyObject g2;
 | |
| #endif
 | |
|     auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
 | |
|     auto func =
 | |
|         py::reinterpret_borrow<py::object>(func_.ptr(getPyInterpreter()));
 | |
|     auto obj = with_keyset_
 | |
|         ? func(keyset, *args_kwargs.first, **args_kwargs.second)
 | |
|         : func(*args_kwargs.first, **args_kwargs.second);
 | |
|     if (!obj) {
 | |
|       throw python_error();
 | |
|     }
 | |
|     pushPyOutToStack(op, stack, obj, "PythonKernelHolder");
 | |
|   }
 | |
| };
 | |
| 
 | |
| static torch::_RegisterOrVerify register_or_verify() {
 | |
|   if (isMainPyInterpreter()) {
 | |
|     return torch::_RegisterOrVerify::REGISTER;
 | |
|   } else {
 | |
|     return torch::_RegisterOrVerify::VERIFY;
 | |
|   }
 | |
| }
 | |
| 
 | |
| static py::object ophandle_call_boxed(
 | |
|     const c10::OperatorHandle& handle,
 | |
|     py::args args,
 | |
|     const py::kwargs& kwargs) {
 | |
|   auto stack = torch::jit::createStackForSchema(
 | |
|       handle.schema(),
 | |
|       std::move(args),
 | |
|       kwargs,
 | |
|       /*self=*/c10::nullopt);
 | |
|   {
 | |
|     pybind11::gil_scoped_release no_gil_guard;
 | |
|     handle.callBoxed(stack);
 | |
|   }
 | |
|   return torch::jit::createPyObjectForStack(std::move(stack));
 | |
| }
 | |
| 
 | |
| // A small RAII guard that lets you explicitly *remove* a key from the TLS
 | |
| // exclude set.
 | |
| class SetExcludeDispatchKeyGuard {
 | |
|  public:
 | |
|   SetExcludeDispatchKeyGuard(at::DispatchKey k, bool set_excluded)
 | |
|       : k(k), old(c10::impl::tls_is_dispatch_key_excluded(k)) {
 | |
|     c10::impl::tls_set_dispatch_key_excluded(k, set_excluded);
 | |
|   }
 | |
|   ~SetExcludeDispatchKeyGuard() {
 | |
|     c10::impl::tls_set_dispatch_key_excluded(k, old);
 | |
|   }
 | |
|   SetExcludeDispatchKeyGuard(const SetExcludeDispatchKeyGuard&) = delete;
 | |
|   SetExcludeDispatchKeyGuard operator=(const SetExcludeDispatchKeyGuard&) =
 | |
|       delete;
 | |
|   SetExcludeDispatchKeyGuard(SetExcludeDispatchKeyGuard&&) = delete;
 | |
|   SetExcludeDispatchKeyGuard operator=(SetExcludeDispatchKeyGuard&&) = delete;
 | |
| 
 | |
|  private:
 | |
|   at::DispatchKey k;
 | |
|   bool old;
 | |
| };
 | |
| 
 | |
| void initDispatchBindings(PyObject* module) {
 | |
|   auto m = py::handle(module).cast<py::module>();
 | |
| 
 | |
|   py::class_<c10::OperatorHandle>(m, "_DispatchOperatorHandle")
 | |
|       .def("schema", &c10::OperatorHandle::schema)
 | |
|       .def("debug", &c10::OperatorHandle::debug)
 | |
|       .def(
 | |
|           "redispatch_boxed",
 | |
|           [](const py::object& self,
 | |
|              c10::DispatchKeySet keyset,
 | |
|              py::args args,
 | |
|              const py::kwargs& kwargs) {
 | |
|             auto& handle = self.cast<c10::OperatorHandle&>();
 | |
|             auto stack = torch::jit::createStackForSchema(
 | |
|                 handle.schema(),
 | |
|                 std::move(args),
 | |
|                 kwargs,
 | |
|                 /*self=*/c10::nullopt);
 | |
|             {
 | |
|               pybind11::gil_scoped_release no_gil_guard;
 | |
|               handle.redispatchBoxed(keyset, &stack);
 | |
|             }
 | |
|             return torch::jit::createPyObjectForStack(std::move(stack));
 | |
|           });
 | |
| 
 | |
|   m.def("_dispatch_call_boxed", &ophandle_call_boxed);
 | |
| 
 | |
|   // TODO: figure out how to do chaining
 | |
|   py::class_<torch::Library>(m, "_DispatchModule")
 | |
|       .def(
 | |
|           "reset",
 | |
|           [](const py::object& self) {
 | |
|             TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
 | |
|             self.cast<torch::Library&>().reset();
 | |
|             return;
 | |
|           },
 | |
|           "")
 | |
|       // Some of these APIs are only for testing and do not work in multipy
 | |
|       // environment
 | |
|       .def(
 | |
|           "def_",
 | |
|           [](py::object self, const char* schema, const char* alias) {
 | |
|             TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
 | |
|             self.cast<torch::Library&>().def(
 | |
|                 torch::schema(schema, parseAliasAnalysisKind(alias)));
 | |
|             return self;
 | |
|           },
 | |
|           "",
 | |
|           py::arg("schema"),
 | |
|           py::arg("alias") = "")
 | |
|       // Simulated "legacy" def where alias analysis kind is not set.
 | |
|       // Ordinarily this can only be exercised from RegisterOperators() API
 | |
|       // but I am not going to bind that here
 | |
|       .def(
 | |
|           "def_legacy",
 | |
|           [](py::object self, const char* schema) {
 | |
|             TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
 | |
|             self.cast<torch::Library&>().def(torch::jit::parseSchema(schema));
 | |
|             return self;
 | |
|           },
 | |
|           "",
 | |
|           py::arg("schema"))
 | |
|       // We can't conveniently turn Python functions into valid functions
 | |
|       // in the dispatcher.  So instead we provide a bunch of precanned
 | |
|       // functions for testing purposes.  You're NOT intended to actually
 | |
|       // call these functions; they're just here so we can actually register
 | |
|       // something
 | |
|       //
 | |
|       // Mangling scheme: args_rets.  One character per.
 | |
|       //  t = Tensor
 | |
|       .def(
 | |
|           "def_name_t_t",
 | |
|           [](py::object self,
 | |
|              const char* name,
 | |
|              const char* dispatch,
 | |
|              const char* debug) {
 | |
|             TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
 | |
|             self.cast<torch::Library&>().def(
 | |
|                 name, dispatch_str(dispatch, [](const at::Tensor& a) {
 | |
|                         return a;
 | |
|                       }).debug(debug));
 | |
|             return self;
 | |
|           },
 | |
|           "",
 | |
|           py::arg("name"),
 | |
|           py::arg("dispatch") = "",
 | |
|           py::arg("debug") = "default_def_name_t_t")
 | |
|       .def(
 | |
|           "def_schema_t_t",
 | |
|           [](py::object self,
 | |
|              const char* schema,
 | |
|              const char* dispatch,
 | |
|              const char* alias,
 | |
|              const char* debug) {
 | |
|             TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
 | |
|             self.cast<torch::Library&>().def(
 | |
|                 torch::schema(schema, parseAliasAnalysisKind(alias)),
 | |
|                 dispatch_str(dispatch, [](const at::Tensor& a) {
 | |
|                   return a;
 | |
|                 }).debug(debug));
 | |
|             return self;
 | |
|           },
 | |
|           "",
 | |
|           py::arg("name"),
 | |
|           py::arg("dispatch") = "",
 | |
|           py::arg("alias") = "",
 | |
|           py::arg("debug") = "default_def_schema_t_t")
 | |
|       // TODO: maybe consider deduplicating the definitions here, it's getting
 | |
|       // pretty long
 | |
|       .def(
 | |
|           "impl_t_t",
 | |
|           [](py::object self,
 | |
|              const char* name,
 | |
|              const char* dispatch,
 | |
|              const char* debug) {
 | |
|             TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
 | |
|             self.cast<torch::Library&>().impl(
 | |
|                 name, dispatch_str(dispatch, [](const at::Tensor& a) {
 | |
|                         return a;
 | |
|                       }).debug(debug));
 | |
|             return self;
 | |
|           },
 | |
|           "",
 | |
|           py::arg("name"),
 | |
|           py::arg("dispatch") = "",
 | |
|           py::arg("debug") = "impl_t_t")
 | |
|       .def(
 | |
|           "impl_with_aoti_compile",
 | |
|           [](const py::object& self,
 | |
|              const char* ns,
 | |
|              const char* op_name_with_overload,
 | |
|              c10::DispatchKey dispatch) {
 | |
|             HANDLE_TH_ERRORS
 | |
|             std::string reg_op_name =
 | |
|                 std::string(ns).append("::").append(op_name_with_overload);
 | |
| 
 | |
|             auto& lib = self.cast<torch::Library&>();
 | |
|             lib.impl(
 | |
|                 reg_op_name.c_str(),
 | |
|                 torch::dispatch(
 | |
|                     dispatch,
 | |
|                     CppFunction::makeFromBoxedFunctor(
 | |
|                         std::make_unique<
 | |
|                             torch::inductor::AOTIPythonKernelHolder>(
 | |
|                             dispatch, ns, op_name_with_overload))),
 | |
|                 register_or_verify());
 | |
|             END_HANDLE_TH_ERRORS_PYBIND
 | |
|           },
 | |
|           "",
 | |
|           py::arg("ns"),
 | |
|           py::arg("op_name_with_overload"),
 | |
|           py::arg("dispatch"))
 | |
|       .def(
 | |
|           "impl",
 | |
|           [](const py::object& self,
 | |
|              const char* name,
 | |
|              // TODO: empty string no longer works
 | |
|              c10::DispatchKey dispatch,
 | |
|              py::object func,
 | |
|              bool with_keyset) {
 | |
|             HANDLE_TH_ERRORS
 | |
|             auto& lib = self.cast<torch::Library&>();
 | |
|             if (func.is(py::module::import("torch.library")
 | |
|                             .attr("fallthrough_kernel"))) {
 | |
|               lib.impl(
 | |
|                   name,
 | |
|                   torch::dispatch(dispatch, CppFunction::makeFallthrough()),
 | |
|                   register_or_verify());
 | |
|             } else {
 | |
|               lib.impl(
 | |
|                   name,
 | |
|                   torch::dispatch(
 | |
|                       dispatch,
 | |
|                       CppFunction::makeFromBoxedFunctor(
 | |
|                           std::make_unique<PythonKernelHolder>(
 | |
|                               func, dispatch, with_keyset))),
 | |
|                   register_or_verify());
 | |
|               python_registrations_[lib._resolve(name)].insert_or_assign(
 | |
|                   dispatch,
 | |
|                   std::make_shared<c10::SafePyObject>(
 | |
|                       func.release().ptr(), getPyInterpreter()));
 | |
|             }
 | |
|             END_HANDLE_TH_ERRORS_PYBIND
 | |
|           },
 | |
|           "",
 | |
|           py::arg("name"),
 | |
|           py::arg("dispatch"),
 | |
|           py::arg("func"),
 | |
|           py::arg("with_keyset") = false)
 | |
|       .def(
 | |
|           "define",
 | |
|           [](const py::object& self,
 | |
|              const char* schema,
 | |
|              const char* alias_analysis,
 | |
|              const std::vector<at::Tag>& tags) {
 | |
|             auto parsed_schema =
 | |
|                 torch::schema(schema, parseAliasAnalysisKind(alias_analysis));
 | |
|             self.cast<torch::Library&>().def(
 | |
|                 std::move(parsed_schema), tags, register_or_verify());
 | |
|             // TODO: this is dumb, had to make a second copy
 | |
|             return torch::schema(schema, parseAliasAnalysisKind(alias_analysis))
 | |
|                 .name();
 | |
|           },
 | |
|           "",
 | |
|           py::arg("schema"),
 | |
|           py::arg("alias_analysis") = "",
 | |
|           py::arg("tags") = std::vector<at::Tag>())
 | |
|       .def(
 | |
|           "fallback_fallthrough",
 | |
|           [](py::object self, const char* dispatch) {
 | |
|             TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
 | |
|             self.cast<torch::Library&>().fallback(
 | |
|                 dispatch_str(dispatch, CppFunction::makeFallthrough()));
 | |
|             return self;
 | |
|           },
 | |
|           "",
 | |
|           py::arg("dispatch") = "");
 | |
| 
 | |
|   m.def(
 | |
|       "_dispatch_library",
 | |
|       [](const char* kind,
 | |
|          std::string name,
 | |
|          const char* dispatch,
 | |
|          const char* file,
 | |
|          uint32_t linenum) {
 | |
|         HANDLE_TH_ERRORS
 | |
|         return std::make_unique<torch::Library>(
 | |
|             parseKind(kind),
 | |
|             std::move(name),
 | |
|             std::string(dispatch).empty()
 | |
|                 ? c10::nullopt
 | |
|                 : c10::make_optional(c10::parseDispatchKey(dispatch)),
 | |
|             "/dev/null", // temporary workaround
 | |
|             linenum);
 | |
|         END_HANDLE_TH_ERRORS_PYBIND
 | |
|       },
 | |
|       "",
 | |
|       py::arg("kind"),
 | |
|       py::arg("name"),
 | |
|       py::arg("dispatch"),
 | |
|       py::arg("file") = "/dev/null",
 | |
|       py::arg("linenum") = 0);
 | |
| 
 | |
|   m.def(
 | |
|       "_dispatch_find_schema_or_throw",
 | |
|       [](const char* name, const char* overload_name) -> c10::OperatorHandle {
 | |
|         return c10::Dispatcher::singleton().findSchemaOrThrow(
 | |
|             name, overload_name);
 | |
|       });
 | |
| 
 | |
|   m.def("_dispatch_dump", [](const char* name) -> std::string {
 | |
|     auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
 | |
|     if (!op) {
 | |
|       return "";
 | |
|     } else {
 | |
|       return op->dumpState();
 | |
|     }
 | |
|   });
 | |
| 
 | |
|   m.def("_dispatch_dump_table", [](const char* name) -> std::string {
 | |
|     auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
 | |
|     if (!op) {
 | |
|       return "";
 | |
|     } else {
 | |
|       return op->dumpComputedTable();
 | |
|     }
 | |
|   });
 | |
| 
 | |
|   m.def("_dispatch_check_invariants", [](const char* name) {
 | |
|     auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
 | |
|     if (!op) {
 | |
|     } else {
 | |
|       return op->checkInvariants();
 | |
|     }
 | |
|   });
 | |
| 
 | |
|   m.def("_dispatch_check_all_invariants", []() {
 | |
|     c10::Dispatcher::singleton().checkInvariants();
 | |
|   });
 | |
| 
 | |
|   m.def("_dispatch_has_kernel", [](const char* name) -> bool {
 | |
|     auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
 | |
|     return static_cast<bool>(op);
 | |
|   });
 | |
| 
 | |
|   m.def(
 | |
|       // Returns whether or not a direct kernel registration exists
 | |
|       // for this <op_name, dispatch_key> pair.
 | |
|       "_dispatch_has_kernel_for_dispatch_key",
 | |
|       [](const char* name, c10::DispatchKey dispatch) -> bool {
 | |
|         auto op =
 | |
|             c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
 | |
|         TORCH_CHECK(op, "operator ", name, " does not exist");
 | |
|         return op->hasKernelForDispatchKey(dispatch);
 | |
|       });
 | |
| 
 | |
|   m.def(
 | |
|       // Returns whether or not the kernel for this dispatach key is a
 | |
|       // fallthrough kernel
 | |
|       "_dispatch_kernel_for_dispatch_key_is_fallthrough",
 | |
|       [](const char* name, c10::DispatchKey dispatch) -> bool {
 | |
|         auto op =
 | |
|             c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
 | |
|         return op->isKernelFallthroughKernel(dispatch);
 | |
|       });
 | |
| 
 | |
|   m.def(
 | |
|       "_dispatch_has_kernel_for_any_dispatch_key",
 | |
|       [](const char* name, c10::DispatchKeySet ks) -> bool {
 | |
|         auto op =
 | |
|             c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
 | |
|         TORCH_CHECK(op, "operator ", name, " does not exist");
 | |
|         return op->hasKernelForAnyDispatchKey(ks);
 | |
|       });
 | |
| 
 | |
|   m.def(
 | |
|       // Returns whether or not there is an entry in the runtime computed
 | |
|       // dispatch table, for this <op_name, dispatch_key> pair. For example, if
 | |
|       // "op" has a `CompositeImplicitAutograd` kernel, Then
 | |
|       // _dispatch_has_computed_kernel_for_dispatch_key(op, backend) will return
 | |
|       // true for all backends that are part of the alias set for
 | |
|       // CompositeImplicitAutograd.
 | |
|       "_dispatch_has_computed_kernel_for_dispatch_key",
 | |
|       [](const char* name, const char* dispatch) -> bool {
 | |
|         auto op =
 | |
|             c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
 | |
|         TORCH_CHECK(op, "operator ", name, " does not exist");
 | |
|         return op->hasComputedKernelForDispatchKey(
 | |
|             c10::parseDispatchKey(dispatch));
 | |
|       });
 | |
| 
 | |
|   m.def("_dispatch_find_dangling_impls", []() -> std::vector<std::string> {
 | |
|     auto danglingImpls = c10::Dispatcher::singleton().findDanglingImpls();
 | |
| 
 | |
|     std::vector<std::string> states;
 | |
|     states.reserve(danglingImpls.size());
 | |
|     for (auto& danglingImpl : danglingImpls) {
 | |
|       states.emplace_back(danglingImpl.dumpState());
 | |
|     }
 | |
| 
 | |
|     return states;
 | |
|   });
 | |
| 
 | |
|   m.def("_dispatch_get_all_op_names", []() -> std::vector<std::string> {
 | |
|     auto op_names = c10::Dispatcher::singleton().getAllOpNames();
 | |
| 
 | |
|     std::vector<std::string> names;
 | |
|     names.reserve(op_names.size());
 | |
|     for (auto& op : op_names) {
 | |
|       std::stringstream ss;
 | |
|       ss << op.name;
 | |
|       if (!op.overload_name.empty()) {
 | |
|         ss << "." << op.overload_name;
 | |
|       }
 | |
|       names.emplace_back(ss.str());
 | |
|     }
 | |
| 
 | |
|     return names;
 | |
|   });
 | |
| 
 | |
|   m.def(
 | |
|       "_dispatch_tls_set_dispatch_key_excluded",
 | |
|       [](c10::DispatchKey dispatch_key, bool desired_state) {
 | |
|         c10::impl::tls_set_dispatch_key_excluded(dispatch_key, desired_state);
 | |
|       });
 | |
|   m.def(
 | |
|       "_dispatch_tls_is_dispatch_key_excluded",
 | |
|       [](c10::DispatchKey dispatch_key) {
 | |
|         return c10::impl::tls_is_dispatch_key_excluded(dispatch_key);
 | |
|       });
 | |
|   m.def(
 | |
|       "_dispatch_tls_set_dispatch_key_included",
 | |
|       [](c10::DispatchKey dispatch_key, bool desired_state) {
 | |
|         c10::impl::tls_set_dispatch_key_included(dispatch_key, desired_state);
 | |
|       });
 | |
|   m.def(
 | |
|       "_dispatch_tls_is_dispatch_key_included",
 | |
|       [](c10::DispatchKey dispatch_key) {
 | |
|         return c10::impl::tls_is_dispatch_key_included(dispatch_key);
 | |
|       });
 | |
| 
 | |
|   m.def("_dispatch_isTensorSubclassLike", [](const at::Tensor& tensor) {
 | |
|     return at::isTensorSubclassLike(tensor);
 | |
|   });
 | |
| 
 | |
|   m.def("_dispatch_key_name", [](c10::DispatchKey k) {
 | |
|     return c10::toString(k);
 | |
|   });
 | |
|   m.def("_dispatch_key_parse", [](c10::DispatchKey k) { return k; });
 | |
|   m.def("_to_functionality_key", [](c10::DispatchKey k) {
 | |
|     return c10::toFunctionalityKey(k);
 | |
|   });
 | |
|   // E.g. given `DispatchKey::AutogradFunctionality`, returns a keyset of:
 | |
|   //  AutogradCPU
 | |
|   //  AutogradCUDA
 | |
|   //  ...
 | |
|   //  AutogradPrivateUse3
 | |
|   m.def("_functionality_to_backend_keys", [](c10::DispatchKey key) {
 | |
|     std::vector<c10::DispatchKey> keys;
 | |
|     if (c10::isPerBackendFunctionalityKey(key)) {
 | |
|       auto ks = c10::DispatchKeySet(key) |
 | |
|           c10::DispatchKeySet(c10::DispatchKeySet::RAW, c10::full_backend_mask);
 | |
|       for (auto k : ks) {
 | |
|         keys.push_back(k);
 | |
|       }
 | |
|     } else {
 | |
|       keys.push_back(key);
 | |
|     }
 | |
|     return keys;
 | |
|   });
 | |
|   m.def("_dispatch_num_backends", []() { return c10::num_backends; });
 | |
| 
 | |
| #define DEF_ONE(n) .value(#n, c10::DispatchKey::n)
 | |
| 
 | |
|   py::enum_<c10::DispatchKey>(m, "DispatchKey")
 | |
|       // clang-format off
 | |
|       DEF_ONE(Undefined)
 | |
|       DEF_ONE(CompositeExplicitAutogradNonFunctional)
 | |
|       DEF_ONE(CompositeExplicitAutograd)
 | |
|       DEF_ONE(CompositeImplicitAutogradNestedTensor)
 | |
|       DEF_ONE(CompositeImplicitAutograd)
 | |
|       // NestedTensor is not a backend key
 | |
|       DEF_ONE(AutogradNestedTensor)
 | |
|       DEF_ONE(AutogradOther)
 | |
|       DEF_ONE(Autograd)
 | |
|       DEF_ONE(Conjugate)
 | |
|       DEF_ONE(ZeroTensor)
 | |
|       DEF_ONE(Negative)
 | |
|       DEF_ONE(BackendSelect)
 | |
|       DEF_ONE(ADInplaceOrView)
 | |
|       DEF_ONE(PythonTLSSnapshot)
 | |
|       DEF_ONE(Python)
 | |
|       DEF_ONE(FuncTorchDynamicLayerFrontMode)
 | |
|       DEF_ONE(FuncTorchDynamicLayerBackMode)
 | |
|       DEF_ONE(FuncTorchBatchedDecomposition)
 | |
|       DEF_ONE(FuncTorchBatched)
 | |
|       DEF_ONE(FuncTorchVmapMode)
 | |
|       DEF_ONE(FuncTorchGradWrapper)
 | |
|       DEF_ONE(PythonDispatcher)
 | |
|       DEF_ONE(PreDispatch)
 | |
|       DEF_ONE(Functionalize)
 | |
|       DEF_ONE(AutocastCPU)
 | |
|       DEF_ONE(AutocastXPU)
 | |
|       DEF_ONE(AutocastHPU)
 | |
|       DEF_ONE(AutocastIPU)
 | |
|       DEF_ONE(AutocastCUDA)
 | |
|       DEF_ONE(AutocastPrivateUse1)
 | |
|   // clang-format on
 | |
| 
 | |
| #define DEF_SINGLE(n, prefix) .value(#prefix #n, c10::DispatchKey::prefix##n)
 | |
| #define DEF_MULTIPLE(fullname, prefix)              \
 | |
|   DEF_SINGLE(, fullname)                            \
 | |
|   DEF_SINGLE(, StartOf##fullname##Backends)         \
 | |
|   C10_FORALL_BACKEND_COMPONENTS(DEF_SINGLE, prefix) \
 | |
|   DEF_SINGLE(, EndOf##fullname##Backends)
 | |
| 
 | |
|       // clang-format off
 | |
|   C10_FORALL_FUNCTIONALITY_KEYS(DEF_MULTIPLE)
 | |
|   // clang-format on
 | |
| 
 | |
| #undef DEF_MULTIPLE
 | |
| #undef DEF_SINGLE
 | |
|           ;
 | |
| 
 | |
|   py::class_<c10::DispatchKeySet>(m, "DispatchKeySet")
 | |
|       .def(py::init<c10::DispatchKey>())
 | |
|       .def("__or__", &c10::DispatchKeySet::operator|)
 | |
|       .def("__sub__", &c10::DispatchKeySet::operator-)
 | |
|       .def("__and__", &c10::DispatchKeySet::operator&)
 | |
|       .def("raw_repr", &c10::DispatchKeySet::raw_repr)
 | |
|       .def("highestPriorityTypeId", &c10::DispatchKeySet::highestPriorityTypeId)
 | |
|       .def(
 | |
|           "remove",
 | |
|           [](c10::DispatchKeySet self, c10::DispatchKey k) {
 | |
|             return self.remove(k);
 | |
|           })
 | |
|       .def(
 | |
|           "add",
 | |
|           [](c10::DispatchKeySet self, c10::DispatchKey k) {
 | |
|             return self.add(k);
 | |
|           })
 | |
|       .def("has", &c10::DispatchKeySet::has)
 | |
|       .def("__repr__", [](c10::DispatchKeySet d) { return c10::toString(d); });
 | |
| 
 | |
|   m.attr("_dispatch_autogradother_backends") =
 | |
|       py::cast(c10::autogradother_backends);
 | |
| 
 | |
|   m.attr("_additional_keys_to_prop_for_wrapper_tensors") =
 | |
|       py::cast(at::functorch::kKeysToPropagateToWrapper);
 | |
| 
 | |
|   m.attr("_after_autograd_keyset") = py::cast(c10::after_autograd_keyset);
 | |
|   m.attr("_after_ADInplaceOrView_keyset") =
 | |
|       py::cast(c10::after_ADInplaceOrView_keyset);
 | |
| 
 | |
|   m.def("_dispatch_has_backend_fallback", [](c10::DispatchKey t) {
 | |
|     return c10::Dispatcher::singleton().hasBackendFallbackForDispatchKey(t);
 | |
|   });
 | |
| 
 | |
|   m.def("_dispatch_keyset_full_after", [](c10::DispatchKey t) {
 | |
|     return c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, t);
 | |
|   });
 | |
| 
 | |
|   m.def("_dispatch_keyset_full", []() {
 | |
|     return c10::DispatchKeySet(c10::DispatchKeySet::FULL);
 | |
|   });
 | |
| 
 | |
|   m.def("_dispatch_is_alias_key", c10::isAliasDispatchKey);
 | |
| 
 | |
|   m.def("_dispatch_keyset_to_string", [](c10::DispatchKeySet keyset) {
 | |
|     return c10::toString(keyset);
 | |
|   });
 | |
| 
 | |
|   m.def("_dispatch_get_backend_keyset_from_autograd", [](c10::DispatchKey k) {
 | |
|     return c10::getBackendKeySetFromAutograd(k);
 | |
|   });
 | |
| 
 | |
|   m.def("_dispatch_keys", [](const at::Tensor& tensor) {
 | |
|     auto* impl = tensor.unsafeGetTensorImpl();
 | |
|     return impl->key_set();
 | |
|   });
 | |
|   m.def("_dispatch_tls_local_include_set", []() {
 | |
|     return c10::impl::tls_local_dispatch_key_set().included_;
 | |
|   });
 | |
|   m.def("_dispatch_tls_local_exclude_set", []() {
 | |
|     return c10::impl::tls_local_dispatch_key_set().excluded_;
 | |
|   });
 | |
|   m.def("_functionalization_reapply_views_tls", []() {
 | |
|     return at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
 | |
|   });
 | |
|   m.def(
 | |
|       "_dispatch_is_included_in_alias",
 | |
|       [](c10::DispatchKey a, c10::DispatchKey b) {
 | |
|         return c10::isIncludedInAlias(a, b);
 | |
|       });
 | |
| 
 | |
|   // DEPRECATED, please don't use this. Instead use
 | |
|   // torch._C._ExcludeDispatchKeyGuard
 | |
|   py_context_manager_DEPRECATED<
 | |
|       c10::impl::ExcludeDispatchKeyGuard,
 | |
|       c10::DispatchKeySet>(m, "ExcludeDispatchKeyGuard");
 | |
| 
 | |
|   py_context_manager<
 | |
|       c10::impl::ForceDispatchKeyGuard,
 | |
|       c10::DispatchKeySet,
 | |
|       c10::DispatchKeySet>(m, "_ForceDispatchKeyGuard");
 | |
|   py_context_manager<c10::impl::ForceDispatchKeyGuard>(
 | |
|       m, "_PreserveDispatchKeyGuard");
 | |
|   py_context_manager<c10::impl::IncludeDispatchKeyGuard, c10::DispatchKey>(
 | |
|       m, "_IncludeDispatchKeyGuard");
 | |
|   py_context_manager<c10::impl::ExcludeDispatchKeyGuard, c10::DispatchKeySet>(
 | |
|       m, "_ExcludeDispatchKeyGuard");
 | |
|   py_context_manager<SetExcludeDispatchKeyGuard, c10::DispatchKey, bool>(
 | |
|       m, "_SetExcludeDispatchKeyGuard");
 | |
| 
 | |
|   py_context_manager_DEPRECATED<at::AutoDispatchBelowAutograd>(
 | |
|       m, "_AutoDispatchBelowAutograd");
 | |
|   py_context_manager<at::AutoDispatchBelowADInplaceOrView>(
 | |
|       m, "_AutoDispatchBelowADInplaceOrView");
 | |
| 
 | |
|   // Prints out the name of every operator that has a kernel registered to the
 | |
|   // Dispatcher under [dispatch_key]. If no arguments are specified, it'll print
 | |
|   // out the name of every operator that the Dispatcher knows of. This can be
 | |
|   // useful to answer questions like "list all operators that do not have a CPU
 | |
|   // kernel".
 | |
|   m.def(
 | |
|       "_dispatch_print_registrations_for_dispatch_key",
 | |
|       [](const char* dispatch_key = "") {
 | |
|         auto k = std::string(dispatch_key).empty()
 | |
|             ? c10::nullopt
 | |
|             : c10::make_optional(c10::parseDispatchKey(dispatch_key));
 | |
|         auto op_names =
 | |
|             c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k);
 | |
|         for (auto& op : op_names) {
 | |
|           std::cout << op << '\n';
 | |
|         }
 | |
|       },
 | |
|       py::arg("dispatch_key") = static_cast<const char*>(""));
 | |
| 
 | |
|   m.def(
 | |
|       "_parse_dispatch_key",
 | |
|       [](const char* dispatch_key) -> std::optional<c10::DispatchKey> {
 | |
|         try {
 | |
|           return c10::parseDispatchKey(dispatch_key);
 | |
|         } catch (const c10::Error& err) {
 | |
|           return c10::nullopt;
 | |
|         }
 | |
|       });
 | |
| 
 | |
|   m.def(
 | |
|       "_dispatch_get_registrations_for_dispatch_key",
 | |
|       [](const char* dispatch_key = "") {
 | |
|         auto k = std::string(dispatch_key).empty()
 | |
|             ? c10::nullopt
 | |
|             : c10::make_optional(c10::parseDispatchKey(dispatch_key));
 | |
|         auto op_names =
 | |
|             c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k);
 | |
|         std::vector<std::string> names;
 | |
|         names.reserve(op_names.size());
 | |
|         for (auto& op : op_names) {
 | |
|           names.emplace_back(
 | |
|               op.name +
 | |
|               (op.overload_name.empty() ? "" : "." + op.overload_name));
 | |
|         }
 | |
|         return names;
 | |
|       },
 | |
|       py::arg("dispatch_key") = static_cast<const char*>(""));
 | |
|   m.def(
 | |
|       "_dispatch_set_report_error_callback",
 | |
|       [](c10::OperatorHandle& handle, py::object callback) {
 | |
|         auto obj = callback.release().ptr();
 | |
|         auto callback_obj =
 | |
|             std::make_unique<c10::SafePyObject>(obj, getPyInterpreter());
 | |
|         handle.setReportErrorCallback_(std::move(callback_obj));
 | |
|       });
 | |
| 
 | |
|   m.def(
 | |
|       "_dispatch_is_main_interpreter", []() { return isMainPyInterpreter(); });
 | |
|   m.def("_dispatch_pystub", [](const char* name, const char* overload) {
 | |
|     return c10::Dispatcher::singleton().getPyStub(
 | |
|         c10::OperatorName(name, overload));
 | |
|   });
 | |
| 
 | |
|   m.def("_replace_", [](const at::Tensor& a, const at::Tensor& b) {
 | |
|     return at::functionalization::impl::replace_(a, b);
 | |
|   });
 | |
|   m.def("_propagate_xla_data", [](const at::Tensor& a, const at::Tensor& b) {
 | |
|     at::functionalization::impl::propagate_xla_data(a, b);
 | |
|   });
 | |
|   m.def("_commit_update", [](const at::Tensor& a) {
 | |
|     return at::functionalization::impl::commit_update(a);
 | |
|   });
 | |
|   m.def("_unsafe_reset_storage", [](const at::Tensor& a) {
 | |
|     return at::functionalization::impl::unsafe_reset_storage(a);
 | |
|   });
 | |
| 
 | |
|   m.def("_dispatch_key_for_device", [](const std::string& device_type) {
 | |
|     auto device = c10::Device(device_type);
 | |
|     TORCH_CHECK(
 | |
|         !device.has_index(),
 | |
|         "Expected device_type string to not have a device index; got ",
 | |
|         device_type);
 | |
|     return c10::toString(
 | |
|         c10::computeDispatchKey(c10::nullopt, c10::nullopt, device));
 | |
|   });
 | |
| 
 | |
|   m.def("_are_functorch_transforms_active", []() {
 | |
|     auto include_set = c10::impl::tls_local_dispatch_key_set().included_;
 | |
|     return (
 | |
|         include_set.has(c10::DispatchKey::FuncTorchDynamicLayerFrontMode) ||
 | |
|         include_set.has(c10::DispatchKey::FuncTorchDynamicLayerBackMode));
 | |
|   });
 | |
| 
 | |
|   m.def("_get_nested_int", [](int64_t data, int64_t coeff) {
 | |
|     return c10::SymInt(c10::SymNode(
 | |
|         c10::make_intrusive<c10::NestedIntSymNodeImpl>(data, coeff)));
 | |
|   });
 | |
| 
 | |
|   m.def("_get_constant_bool_symnode", [](int64_t data) {
 | |
|     return c10::SymNode(
 | |
|         c10::make_intrusive<c10::ConstantSymNodeImpl<bool>>(data));
 | |
|   });
 | |
| 
 | |
|   m.def("_non_sym_sizes", [](const at::Tensor& a) {
 | |
|     return a.sizes(); // NB: NOT sym_size
 | |
|   });
 | |
| 
 | |
|   m.def("_set_throw_on_mutable_data_ptr", [](const at::Tensor& t) {
 | |
|     if (!t.unsafeGetTensorImpl()->has_storage()) {
 | |
|       // If the Tensor doesn't have a storage, then accessing .data_ptr()
 | |
|       // will already raise an error.
 | |
|       return;
 | |
|     }
 | |
|     // Otherwise, set (on the StorageImpl) that accessing (mutable) data_ptr
 | |
|     // will throw.
 | |
|     t.unsafeGetTensorImpl()
 | |
|         ->storage()
 | |
|         .unsafeGetStorageImpl()
 | |
|         ->set_throw_on_mutable_data_ptr();
 | |
|   });
 | |
| 
 | |
|   // Invariant: you must ONLY call this with FakeTensors.
 | |
|   m.def("_set_warn_deprecated_on_mutable_data_ptr", [](const at::Tensor& t) {
 | |
|     if (!t.unsafeGetTensorImpl()->has_storage()) {
 | |
|       // If the Tensor doesn't have a storage, then accessing .data_ptr()
 | |
|       // will already raise an error.
 | |
|       return;
 | |
|     }
 | |
|     t.unsafeGetTensorImpl()
 | |
|         ->storage()
 | |
|         .unsafeGetStorageImpl()
 | |
|         ->set_warn_deprecated_on_mutable_data_ptr();
 | |
|   });
 | |
| 
 | |
|   m.def("_only_lift_cpu_tensors", &torch::utils::only_lift_cpu_tensors);
 | |
|   m.def("_set_only_lift_cpu_tensors", &torch::utils::set_only_lift_cpu_tensors);
 | |
| 
 | |
|   using c10::impl::TorchDispatchModeKey;
 | |
|   py::enum_<TorchDispatchModeKey>(m, "_TorchDispatchModeKey")
 | |
|       .value("FUNCTIONAL", TorchDispatchModeKey::FUNCTIONAL)
 | |
|       .value("PROXY", TorchDispatchModeKey::PROXY)
 | |
|       .value("FAKE", TorchDispatchModeKey::FAKE);
 | |
| }
 | |
| 
 | |
| // TODO: dedupe with the kernel
 | |
| void python_op_registration_trampoline_impl(
 | |
|     const c10::OperatorHandle& op,
 | |
|     c10::DispatchKey key,
 | |
|     c10::DispatchKeySet keyset,
 | |
|     torch::jit::Stack* stack,
 | |
|     bool with_keyset) {
 | |
|   auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
 | |
|   py::gil_scoped_acquire g;
 | |
|   auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
 | |
|   const auto& func = python_registrations_[op.operator_name()][key];
 | |
|   TORCH_INTERNAL_ASSERT(func != nullptr);
 | |
|   auto* pyobj = func->ptr(getPyInterpreter());
 | |
|   TORCH_INTERNAL_ASSERT(pyobj != nullptr);
 | |
|   auto callable = py::reinterpret_borrow<py::object>(pyobj);
 | |
|   auto obj = with_keyset
 | |
|       ? callable(keyset, *args_kwargs.first, **args_kwargs.second)
 | |
|       : callable(*args_kwargs.first, **args_kwargs.second);
 | |
|   if (!obj) {
 | |
|     throw python_error();
 | |
|   }
 | |
|   pushPyOutToStack(op, stack, obj, "PythonKernelHolder");
 | |
| }
 | |
| 
 | |
| } // namespace torch::impl::dispatch
 |