Files
pytorch/torch/csrc/utils/python_dispatch.cpp
PaliC 1b99c1859c [BE] Make PyObjectSlot use a global PyInterpreter and remove (#158427)
This PR is a bit more involved but effectively works to drastically simplify PyObjectSlot and PyInterpreter.
1) For PyObjectSlot we now use a global pyinterpreter since there only is one. From here we change all of the call sites to rely on this assumption.
2) We also remove the "tags" of the PyInterpreter by deprecating `PyInterpreterStatus`.

For the reviewer, sadly it seems like `functorch/csrc/dim/dim.cpp` needed to get linted, so there is an unreadable amount of changes there. Fortunately, the only actual change in the file is as follows which just removes `getPyInterpreter()` from  the `check_pyobj` call.

```
 mpy::handle handle_from_tensor(Arena& A, TensorRef t) {
-    // fast case: tensor is live in python
-    std::optional<PyObject*> mb_obj =
-        t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter(), /*ignore_hermetic_tls=*/false);
-    if (mb_obj.has_value() && !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) {
-        return *mb_obj;
-    }
-    return A.autorelease(mpy::object::checked_steal(THPVariable_Wrap(*t)));
-}
-}
+  // fast case: tensor is live in python
+  std::optional<PyObject*> mb_obj =
+      t->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
+          /*ignore_hermetic_tls=*/false);
+  if (mb_obj.has_value() &&
+      !t->unsafeGetTensorImpl()->pyobj_slot()->owns_pyobj()) {
+    return *mb_obj;
+  }
+  return A.autorelease(mpy::object::checked_steal(THPVariable_Wrap(*t)));
+}
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158427
Approved by: https://github.com/albanD
2025-07-30 17:29:43 +00:00

1035 lines
37 KiB
C++

#include <torch/csrc/jit/frontend/function_schema_parser.h>
#include <torch/csrc/utils/python_dispatch.h>
#include <ATen/ATen.h>
#include <ATen/FuncTorchTLS.h>
#include <ATen/FunctionalTensorWrapper.h>
#include <ATen/TensorSubclassLikeUtils.h>
#include <ATen/autocast_mode.h>
#include <ATen/core/NestedIntSymNodeImpl.h>
#include <ATen/core/PythonOpRegistrationTrampoline.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/functorch/BatchedTensorImpl.h>
#include <torch/library.h>
#include <c10/core/SafePyObject.h>
#include <torch/csrc/PyInterpreter.h>
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/utils/tensor_new.h>
#include <c10/util/flat_hash_map.h>
#include <pybind11/operators.h>
#include <pybind11/stl.h>
#include <torch/csrc/inductor/aoti_eager/kernel_holder.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_raii.h>
#include <iostream>
#include <utility>
namespace py = pybind11;
namespace torch::impl::dispatch {
// NB: I'd like to index this on OperatorHandle, but I can't, as I can't
// guarantee that the main interpreter has finish doing all registrations before
// the other interpreters start banging on it
static ska::flat_hash_map<
c10::OperatorName,
ska::flat_hash_map<c10::DispatchKey, std::shared_ptr<c10::SafePyObject>>>
python_registrations_;
static torch::Library::Kind parseKind(const std::string& k) {
static std::unordered_map<std::string, torch::Library::Kind> kind_map = {
{"DEF", torch::Library::DEF},
{"IMPL", torch::Library::IMPL},
{"FRAGMENT", torch::Library::FRAGMENT},
};
auto it = kind_map.find(k);
TORCH_CHECK(it != kind_map.end(), "could not parse ", k);
return it->second;
}
static c10::AliasAnalysisKind parseAliasAnalysisKind(const std::string& k) {
static std::unordered_map<std::string, c10::AliasAnalysisKind> key_map = {
{"CONSERVATIVE", c10::AliasAnalysisKind::CONSERVATIVE},
{"FROM_SCHEMA", c10::AliasAnalysisKind::FROM_SCHEMA},
{"PURE_FUNCTION", c10::AliasAnalysisKind::PURE_FUNCTION},
{"", c10::AliasAnalysisKind::FROM_SCHEMA}, // default
};
auto it = key_map.find(k);
TORCH_CHECK(it != key_map.end(), "could not parse ", k);
return it->second;
}
template <typename Func>
inline static torch::CppFunction dispatch_str(const char* key, Func&& raw_f) {
if (key[0] != '\0') {
return torch::dispatch(
c10::parseDispatchKey(key), std::forward<Func>(raw_f));
} else {
torch::CppFunction f(std::forward<Func>(raw_f));
return f;
}
}
struct EnableHermeticPyObject {
EnableHermeticPyObject()
: old_(c10::impl::HermeticPyObjectTLS::get_state()),
old_excluded_python_(
c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Python)),
old_python_(
c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Python)),
old_python_snapshot_(c10::impl::tls_is_dispatch_key_included(
at::DispatchKey::PythonTLSSnapshot)) {
c10::impl::HermeticPyObjectTLS::set_state(true);
c10::impl::tls_set_dispatch_key_excluded(at::DispatchKey::Python, true);
c10::impl::tls_set_dispatch_key_included(at::DispatchKey::Python, false);
c10::impl::tls_set_dispatch_key_included(
at::DispatchKey::PythonTLSSnapshot, false);
}
~EnableHermeticPyObject() {
c10::impl::HermeticPyObjectTLS::set_state(old_);
c10::impl::tls_set_dispatch_key_excluded(
at::DispatchKey::Python, old_excluded_python_);
c10::impl::tls_set_dispatch_key_included(
at::DispatchKey::Python, old_python_);
c10::impl::tls_set_dispatch_key_included(
at::DispatchKey::PythonTLSSnapshot, old_python_snapshot_);
}
EnableHermeticPyObject(const EnableHermeticPyObject&) = delete;
EnableHermeticPyObject(EnableHermeticPyObject&&) = delete;
EnableHermeticPyObject& operator=(const EnableHermeticPyObject&) = delete;
EnableHermeticPyObject& operator=(EnableHermeticPyObject&&) = delete;
bool old_;
bool old_excluded_python_;
bool old_python_;
bool old_python_snapshot_;
};
class PythonKernelHolder : public c10::OperatorKernel {
c10::SafePyObject func_;
c10::DispatchKey dispatch_key_;
// If "with_keyset", then we expect a keyset as the first arg.
bool with_keyset_;
// If "with_op", then we expect the op as first arg (or second if keyset)
bool with_op_;
public:
PythonKernelHolder(
py::object func,
c10::DispatchKey dispatch_key,
bool with_keyset = false,
bool with_op = false)
: func_(func.release().ptr(), getPyInterpreter()),
dispatch_key_(dispatch_key),
with_keyset_(with_keyset),
with_op_(with_op) {}
void operator()(
const c10::OperatorHandle& op,
c10::DispatchKeySet keyset,
torch::jit::Stack* stack) {
// Figure out if we can handle it hermetically, or if we have
// to double dispatch
// If Torch Dispatch Mode is active, use its PyInterpreter for dispatch
const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len();
if (mode_stack_len > 0) {
const auto& cur_torch_dispatch_mode_state =
c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
cur_torch_dispatch_mode_state->pyinterpreter()
->python_op_registration_trampoline(
op, dispatch_key_, keyset, stack, with_keyset_, with_op_);
return;
}
const auto& schema = op.schema();
const auto num_arguments = schema.arguments().size();
// Otherwise, find a PyInterpreter on a Tensor IF if has Python key (which
// means it's a nontrivial tensor subclass)
for (const auto& ivalue : torch::jit::last(*stack, num_arguments)) {
if (ivalue.isTensor()) {
auto* interpreter =
ivalue.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter();
if (interpreter &&
ivalue.unsafeToTensorImpl()->key_set().has(
at::DispatchKey::Python)) {
(*interpreter)
->python_op_registration_trampoline(
op, dispatch_key_, keyset, stack, with_keyset_, with_op_);
return;
}
} else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) {
// NB: use toListRef as it doesn't induce refcount bumps
// (toTensorListRef is not a thing)
for (const auto& nv : ivalue.toListRef()) {
if (nv.isNone()) {
continue;
}
auto* interpreter =
nv.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter();
if (interpreter &&
nv.unsafeToTensorImpl()->key_set().has(at::DispatchKey::Python)) {
(*interpreter)
->python_op_registration_trampoline(
op, dispatch_key_, keyset, stack, with_keyset_, with_op_);
return;
}
}
}
}
// Nothing requires the operator to be homed to a specific interpreter, so
// run it on the current interpreter
auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
py::gil_scoped_acquire g;
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
auto func =
py::reinterpret_borrow<py::object>(func_.ptr(getPyInterpreter()));
auto obj = with_op_ ? with_keyset_
? func(
keyset,
torch::detail::getTorchApiFunction(op),
*args_kwargs.first,
**args_kwargs.second)
: func(
torch::detail::getTorchApiFunction(op),
*args_kwargs.first,
**args_kwargs.second)
: with_keyset_ ? func(keyset, *args_kwargs.first, **args_kwargs.second)
: func(*args_kwargs.first, **args_kwargs.second);
if (!obj) {
throw python_error();
}
pushPyOutToStack(op, stack, obj, "PythonKernelHolder");
}
};
// @todo sahanp: Afait only register is used in the codebase. This can be
// removed / simplified
static torch::_RegisterOrVerify register_or_verify() {
return torch::_RegisterOrVerify::REGISTER;
}
static py::object ophandle_call_boxed(
const c10::OperatorHandle& handle,
const py::args& args,
const py::kwargs& kwargs) {
auto stack = torch::jit::createStackForSchema(
handle.schema(),
args,
kwargs,
/*self=*/std::nullopt);
{
pybind11::gil_scoped_release no_gil_guard;
handle.callBoxed(stack);
}
return torch::jit::createPyObjectForStack(std::move(stack));
}
// A small RAII guard that lets you explicitly *remove* a key from the TLS
// exclude set.
class SetExcludeDispatchKeyGuard {
public:
SetExcludeDispatchKeyGuard(at::DispatchKey k, bool set_excluded)
: k(k), old(c10::impl::tls_is_dispatch_key_excluded(k)) {
c10::impl::tls_set_dispatch_key_excluded(k, set_excluded);
}
~SetExcludeDispatchKeyGuard() {
c10::impl::tls_set_dispatch_key_excluded(k, old);
}
SetExcludeDispatchKeyGuard(const SetExcludeDispatchKeyGuard&) = delete;
SetExcludeDispatchKeyGuard operator=(const SetExcludeDispatchKeyGuard&) =
delete;
SetExcludeDispatchKeyGuard(SetExcludeDispatchKeyGuard&&) = delete;
SetExcludeDispatchKeyGuard operator=(SetExcludeDispatchKeyGuard&&) = delete;
private:
at::DispatchKey k;
bool old;
};
void initDispatchBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
py::class_<c10::OperatorHandle>(m, "_DispatchOperatorHandle")
.def("schema", &c10::OperatorHandle::schema)
.def("debug", &c10::OperatorHandle::debug)
.def(
"redispatch_boxed",
[](const py::object& self,
c10::DispatchKeySet keyset,
py::args args,
const py::kwargs& kwargs) {
auto& handle = self.cast<c10::OperatorHandle&>();
auto stack = torch::jit::createStackForSchema(
handle.schema(),
std::move(args),
kwargs,
/*self=*/std::nullopt);
{
pybind11::gil_scoped_release no_gil_guard;
handle.redispatchBoxed(keyset, &stack);
}
return torch::jit::createPyObjectForStack(std::move(stack));
});
m.def("_dispatch_call_boxed", &ophandle_call_boxed);
// TODO: figure out how to do chaining
py::class_<torch::Library>(m, "_DispatchModule")
.def(
"reset",
[](const py::object& self) {
self.cast<torch::Library&>().reset();
return;
},
"")
// Some of these APIs are only for testing and do not work in
// multipy environment // codespell:ignore multipy
.def(
"def_",
[](py::object self, const char* schema, const char* alias) {
self.cast<torch::Library&>().def(
torch::schema(schema, parseAliasAnalysisKind(alias)));
return self;
},
"",
py::arg("schema"),
py::arg("alias") = "")
// Simulated "legacy" def where alias analysis kind is not set.
// Ordinarily this can only be exercised from RegisterOperators() API
// but I am not going to bind that here
.def(
"def_legacy",
[](py::object self, const char* schema) {
self.cast<torch::Library&>().def(torch::jit::parseSchema(schema));
return self;
},
"",
py::arg("schema"))
// We can't conveniently turn Python functions into valid functions
// in the dispatcher. So instead we provide a bunch of precanned
// functions for testing purposes. You're NOT intended to actually
// call these functions; they're just here so we can actually register
// something
//
// Mangling scheme: args_rets. One character per.
// t = Tensor
.def(
"def_name_t_t",
[](py::object self,
const char* name,
const char* dispatch,
const char* debug) {
self.cast<torch::Library&>().def(
name, dispatch_str(dispatch, [](const at::Tensor& a) {
return a;
}).debug(debug));
return self;
},
"",
py::arg("name"),
py::arg("dispatch") = "",
py::arg("debug") = "default_def_name_t_t")
.def(
"def_schema_t_t",
[](py::object self,
const char* schema,
const char* dispatch,
const char* alias,
const char* debug) {
self.cast<torch::Library&>().def(
torch::schema(schema, parseAliasAnalysisKind(alias)),
dispatch_str(dispatch, [](const at::Tensor& a) {
return a;
}).debug(debug));
return self;
},
"",
py::arg("name"),
py::arg("dispatch") = "",
py::arg("alias") = "",
py::arg("debug") = "default_def_schema_t_t")
// TODO: maybe consider deduplicating the definitions here, it's getting
// pretty long
.def(
"impl_t_t",
[](py::object self,
const char* name,
const char* dispatch,
const char* debug) {
self.cast<torch::Library&>().impl(
name, dispatch_str(dispatch, [](const at::Tensor& a) {
return a;
}).debug(debug));
return self;
},
"",
py::arg("name"),
py::arg("dispatch") = "",
py::arg("debug") = "impl_t_t")
.def(
"impl_with_aoti_compile",
[](const py::object& self,
const char* ns,
const char* op_name_with_overload,
c10::DispatchKey dispatch) {
HANDLE_TH_ERRORS
std::string reg_op_name =
std::string(ns).append("::").append(op_name_with_overload);
auto& lib = self.cast<torch::Library&>();
lib.impl(
reg_op_name.c_str(),
torch::dispatch(
dispatch,
CppFunction::makeFromBoxedFunctor(
std::make_unique<
torch::inductor::AOTIPythonKernelHolder>(
dispatch, ns, op_name_with_overload))),
register_or_verify());
END_HANDLE_TH_ERRORS_PYBIND
},
"",
py::arg("ns"),
py::arg("op_name_with_overload"),
py::arg("dispatch"))
.def(
"impl",
[](const py::object& self,
const char* name,
// TODO: empty string no longer works
c10::DispatchKey dispatch,
py::object func,
bool with_keyset) {
HANDLE_TH_ERRORS
auto& lib = self.cast<torch::Library&>();
if (func.is(py::module::import("torch.library")
.attr("fallthrough_kernel"))) {
lib.impl(
name,
torch::dispatch(dispatch, CppFunction::makeFallthrough()),
register_or_verify());
} else {
lib.impl(
name,
torch::dispatch(
dispatch,
CppFunction::makeFromBoxedFunctor(
std::make_unique<PythonKernelHolder>(
func, dispatch, with_keyset))),
register_or_verify());
python_registrations_[lib._resolve(name)].insert_or_assign(
dispatch,
std::make_shared<c10::SafePyObject>(
func.release().ptr(), getPyInterpreter()));
}
END_HANDLE_TH_ERRORS_PYBIND
},
"",
py::arg("name"),
py::arg("dispatch"),
py::arg("func"),
py::arg("with_keyset") = false)
.def(
"define",
[](const py::object& self,
const char* schema,
const char* alias_analysis,
const std::vector<at::Tag>& tags) {
auto parsed_schema =
torch::schema(schema, parseAliasAnalysisKind(alias_analysis));
self.cast<torch::Library&>().def(
std::move(parsed_schema), tags, register_or_verify());
// TODO: this is dumb, had to make a second copy
return torch::schema(schema, parseAliasAnalysisKind(alias_analysis))
.name();
},
"",
py::arg("schema"),
py::arg("alias_analysis") = "",
py::arg("tags") = std::vector<at::Tag>())
.def(
"fallback_fallthrough",
[](py::object self, const char* dispatch) {
self.cast<torch::Library&>().fallback(
dispatch_str(dispatch, CppFunction::makeFallthrough()));
return self;
},
"",
py::arg("dispatch") = "")
.def(
"fallback",
[](const py::object& self,
c10::DispatchKey dispatch,
const py::object& func,
bool with_keyset) {
HANDLE_TH_ERRORS
auto& lib = self.cast<torch::Library&>();
if (func.is(py::module::import("torch.library")
.attr("fallthrough_kernel"))) {
lib.fallback(
torch::dispatch(dispatch, CppFunction::makeFallthrough()));
} else {
lib.fallback(torch::dispatch(
dispatch,
CppFunction::makeFromBoxedFunctor(
std::make_unique<PythonKernelHolder>(
func, dispatch, with_keyset, /*with_op*/ true))));
}
END_HANDLE_TH_ERRORS_PYBIND
},
"",
py::arg("dispatch"),
py::arg("func"),
py::arg("with_keyset") = false);
m.def(
"_dispatch_library",
[](const char* kind,
std::string name,
const char* dispatch,
const char* file,
uint32_t linenum) {
HANDLE_TH_ERRORS
return std::make_unique<torch::Library>(
parseKind(kind),
std::move(name),
std::string(dispatch).empty()
? std::nullopt
: std::make_optional(c10::parseDispatchKey(dispatch)),
"/dev/null", // temporary workaround
linenum);
END_HANDLE_TH_ERRORS_PYBIND
},
"",
py::arg("kind"),
py::arg("name"),
py::arg("dispatch"),
py::arg("file") = "/dev/null",
py::arg("linenum") = 0);
m.def(
"_dispatch_find_schema_or_throw",
[](const char* name, const char* overload_name) -> c10::OperatorHandle {
return c10::Dispatcher::singleton().findSchemaOrThrow(
name, overload_name);
});
m.def("_dispatch_dump", [](const char* name) -> std::string {
auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
if (!op) {
return "";
} else {
return op->dumpState();
}
});
m.def("_dispatch_dump_table", [](const char* name) -> std::string {
auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
if (!op) {
return "";
} else {
return op->dumpComputedTable();
}
});
m.def("_dispatch_check_invariants", [](const char* name) {
auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
if (!op) {
} else {
return op->checkInvariants();
}
});
m.def("_dispatch_check_all_invariants", []() {
c10::Dispatcher::singleton().checkInvariants();
});
m.def("_dispatch_has_kernel", [](const char* name) -> bool {
auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
return static_cast<bool>(op);
});
m.def(
// Returns whether or not a direct kernel registration exists
// for this <op_name, dispatch_key> pair.
"_dispatch_has_kernel_for_dispatch_key",
[](const char* name, c10::DispatchKey dispatch) -> bool {
auto op =
c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
TORCH_CHECK(op, "operator ", name, " does not exist");
return op->hasKernelForDispatchKey(dispatch);
});
m.def(
// Returns whether or not the kernel for this dispatach key is a
// fallthrough kernel
"_dispatch_kernel_for_dispatch_key_is_fallthrough",
[](const char* name, c10::DispatchKey dispatch) -> bool {
auto op =
c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
return op->isKernelFallthroughKernel(dispatch);
});
m.def(
"_dispatch_has_kernel_for_any_dispatch_key",
[](const char* name, c10::DispatchKeySet ks) -> bool {
auto op =
c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
TORCH_CHECK(op, "operator ", name, " does not exist");
return op->hasKernelForAnyDispatchKey(ks);
});
m.def(
// Returns whether or not there is an entry in the runtime computed
// dispatch table, for this <op_name, dispatch_key> pair. For example, if
// "op" has a `CompositeImplicitAutograd` kernel, Then
// _dispatch_has_computed_kernel_for_dispatch_key(op, backend) will return
// true for all backends that are part of the alias set for
// CompositeImplicitAutograd.
"_dispatch_has_computed_kernel_for_dispatch_key",
[](const char* name, const char* dispatch) -> bool {
auto op =
c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
TORCH_CHECK(op, "operator ", name, " does not exist");
return op->hasComputedKernelForDispatchKey(
c10::parseDispatchKey(dispatch));
});
m.def("_dispatch_find_dangling_impls", []() -> std::vector<std::string> {
auto danglingImpls = c10::Dispatcher::singleton().findDanglingImpls();
std::vector<std::string> states;
states.reserve(danglingImpls.size());
for (auto& danglingImpl : danglingImpls) {
states.emplace_back(danglingImpl.dumpState());
}
return states;
});
m.def("_dispatch_get_all_op_names", []() -> std::vector<std::string> {
auto op_names = c10::Dispatcher::singleton().getAllOpNames();
std::vector<std::string> names;
names.reserve(op_names.size());
for (auto& op : op_names) {
std::stringstream ss;
ss << op.name;
if (!op.overload_name.empty()) {
ss << "." << op.overload_name;
}
names.emplace_back(std::move(ss).str());
}
return names;
});
m.def(
"_dispatch_tls_set_dispatch_key_excluded",
[](c10::DispatchKey dispatch_key, bool desired_state) {
c10::impl::tls_set_dispatch_key_excluded(dispatch_key, desired_state);
});
m.def(
"_dispatch_tls_is_dispatch_key_excluded",
[](c10::DispatchKey dispatch_key) {
return c10::impl::tls_is_dispatch_key_excluded(dispatch_key);
});
m.def(
"_dispatch_tls_set_dispatch_key_included",
[](c10::DispatchKey dispatch_key, bool desired_state) {
c10::impl::tls_set_dispatch_key_included(dispatch_key, desired_state);
});
m.def(
"_dispatch_tls_is_dispatch_key_included",
[](c10::DispatchKey dispatch_key) {
return c10::impl::tls_is_dispatch_key_included(dispatch_key);
});
m.def("_dispatch_isTensorSubclassLike", [](const at::Tensor& tensor) {
return at::isTensorSubclassLike(tensor);
});
m.def("_dispatch_key_name", [](c10::DispatchKey k) {
return c10::toString(k);
});
m.def("_dispatch_key_parse", [](c10::DispatchKey k) { return k; });
m.def("_to_functionality_key", [](c10::DispatchKey k) {
return c10::toFunctionalityKey(k);
});
// E.g. given `DispatchKey::AutogradFunctionality`, returns a keyset of:
// AutogradCPU
// AutogradCUDA
// ...
// AutogradPrivateUse3
m.def("_functionality_to_backend_keys", [](c10::DispatchKey key) {
std::vector<c10::DispatchKey> keys;
if (c10::isPerBackendFunctionalityKey(key)) {
auto ks = c10::DispatchKeySet(key) |
c10::DispatchKeySet(c10::DispatchKeySet::RAW, c10::full_backend_mask);
for (auto k : ks) {
keys.push_back(k);
}
} else {
keys.push_back(key);
}
return keys;
});
m.def("_dispatch_num_backends", []() { return c10::num_backends; });
#define DEF_ONE(n) .value(#n, c10::DispatchKey::n)
py::enum_<c10::DispatchKey>(m, "DispatchKey")
// clang-format off
DEF_ONE(Undefined)
DEF_ONE(CompositeExplicitAutogradNonFunctional)
DEF_ONE(CompositeExplicitAutograd)
DEF_ONE(CompositeImplicitAutogradNestedTensor)
DEF_ONE(CompositeImplicitAutograd)
// NestedTensor is not a backend key
DEF_ONE(AutogradNestedTensor)
DEF_ONE(AutogradOther)
DEF_ONE(Autograd)
DEF_ONE(Conjugate)
DEF_ONE(ZeroTensor)
DEF_ONE(Negative)
DEF_ONE(BackendSelect)
DEF_ONE(ADInplaceOrView)
DEF_ONE(PythonTLSSnapshot)
DEF_ONE(Python)
DEF_ONE(FuncTorchDynamicLayerFrontMode)
DEF_ONE(FuncTorchDynamicLayerBackMode)
DEF_ONE(FuncTorchBatchedDecomposition)
DEF_ONE(FuncTorchBatched)
DEF_ONE(FuncTorchVmapMode)
DEF_ONE(FuncTorchGradWrapper)
DEF_ONE(PythonDispatcher)
DEF_ONE(PreDispatch)
DEF_ONE(Functionalize)
DEF_ONE(AutocastCPU)
DEF_ONE(AutocastMPS)
DEF_ONE(AutocastXPU)
DEF_ONE(AutocastHPU)
DEF_ONE(AutocastIPU)
DEF_ONE(AutocastCUDA)
DEF_ONE(AutocastPrivateUse1)
// clang-format on
#define DEF_SINGLE(n, prefix) .value(#prefix #n, c10::DispatchKey::prefix##n)
#define DEF_MULTIPLE(fullname, prefix) \
DEF_SINGLE(, fullname) \
DEF_SINGLE(, StartOf##fullname##Backends) \
C10_FORALL_BACKEND_COMPONENTS(DEF_SINGLE, prefix) \
DEF_SINGLE(, EndOf##fullname##Backends)
// clang-format off
C10_FORALL_FUNCTIONALITY_KEYS(DEF_MULTIPLE)
// clang-format on
#undef DEF_MULTIPLE
#undef DEF_SINGLE
;
py::class_<c10::DispatchKeySet>(m, "DispatchKeySet")
.def(py::init<c10::DispatchKey>())
.def("__or__", &c10::DispatchKeySet::operator|)
.def("__sub__", &c10::DispatchKeySet::operator-)
.def("__and__", &c10::DispatchKeySet::operator&)
.def("raw_repr", &c10::DispatchKeySet::raw_repr)
.def("highestPriorityTypeId", &c10::DispatchKeySet::highestPriorityTypeId)
.def(
"remove",
[](c10::DispatchKeySet self, c10::DispatchKey k) {
return self.remove(k);
})
.def(
"add",
[](c10::DispatchKeySet self, c10::DispatchKey k) {
return self.add(k);
})
.def("has", &c10::DispatchKeySet::has)
.def("__repr__", [](c10::DispatchKeySet d) { return c10::toString(d); })
.def(
"__eq__",
[](c10::DispatchKeySet self, c10::DispatchKeySet other) {
return self.raw_repr() == other.raw_repr();
})
.def(py::pickle(
[](const c10::DispatchKeySet&
obj) { // __getstate__ : creates tuple of state
return py::make_tuple(obj.raw_repr());
},
[](const py::tuple& t) { // __setstate__ : restores state from tuple
TORCH_CHECK(
t.size() == 1, "__setstate__ expected tuple with one element");
return c10::DispatchKeySet::from_raw_repr(t[0].cast<uint64_t>());
}))
.def_static("from_raw_repr", &c10::DispatchKeySet::from_raw_repr);
m.attr("_dispatch_autogradother_backends") =
py::cast(c10::autogradother_backends);
m.attr("_additional_keys_to_prop_for_wrapper_tensors") =
py::cast(at::functorch::kKeysToPropagateToWrapper);
m.attr("_after_autograd_keyset") = py::cast(c10::after_autograd_keyset);
m.attr("_after_ADInplaceOrView_keyset") =
py::cast(c10::after_ADInplaceOrView_keyset);
m.def("_dispatch_has_backend_fallback", [](c10::DispatchKey t) {
return c10::Dispatcher::singleton().hasBackendFallbackForDispatchKey(t);
});
m.def("_dispatch_keyset_full_after", [](c10::DispatchKey t) {
return c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, t);
});
m.def("_dispatch_keyset_full", []() {
return c10::DispatchKeySet(c10::DispatchKeySet::FULL);
});
m.def("_dispatch_is_alias_key", c10::isAliasDispatchKey);
m.def("_dispatch_keyset_to_string", [](c10::DispatchKeySet keyset) {
return c10::toString(keyset);
});
m.def("_dispatch_get_backend_keyset_from_autograd", [](c10::DispatchKey k) {
return c10::getBackendKeySetFromAutograd(k);
});
m.def("_dispatch_keys", [](const at::Tensor& tensor) {
auto* impl = tensor.unsafeGetTensorImpl();
return impl->key_set();
});
m.def("_dispatch_tls_local_include_set", []() {
return c10::impl::tls_local_dispatch_key_set().included_;
});
m.def("_dispatch_tls_local_exclude_set", []() {
return c10::impl::tls_local_dispatch_key_set().excluded_;
});
m.def("_functionalization_reapply_views_tls", []() {
return at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
});
m.def(
"_dispatch_is_included_in_alias",
[](c10::DispatchKey a, c10::DispatchKey b) {
return c10::isIncludedInAlias(a, b);
});
// DEPRECATED, please don't use this. Instead use
// torch._C._ExcludeDispatchKeyGuard
py_context_manager_DEPRECATED<
c10::impl::ExcludeDispatchKeyGuard,
c10::DispatchKeySet>(m, "ExcludeDispatchKeyGuard");
py_context_manager<
c10::impl::ForceDispatchKeyGuard,
c10::DispatchKeySet,
c10::DispatchKeySet>(m, "_ForceDispatchKeyGuard");
py_context_manager<c10::impl::ForceDispatchKeyGuard>(
m, "_PreserveDispatchKeyGuard");
py_context_manager<c10::impl::IncludeDispatchKeyGuard, c10::DispatchKey>(
m, "_IncludeDispatchKeyGuard");
py_context_manager<c10::impl::ExcludeDispatchKeyGuard, c10::DispatchKeySet>(
m, "_ExcludeDispatchKeyGuard");
py_context_manager<SetExcludeDispatchKeyGuard, c10::DispatchKey, bool>(
m, "_SetExcludeDispatchKeyGuard");
py_context_manager_DEPRECATED<at::AutoDispatchBelowAutograd>(
m, "_AutoDispatchBelowAutograd");
py_context_manager<at::AutoDispatchBelowADInplaceOrView>(
m, "_AutoDispatchBelowADInplaceOrView");
// Prints out the name of every operator that has a kernel registered to the
// Dispatcher under [dispatch_key]. If no arguments are specified, it'll print
// out the name of every operator that the Dispatcher knows of. This can be
// useful to answer questions like "list all operators that do not have a CPU
// kernel".
m.def(
"_dispatch_print_registrations_for_dispatch_key",
[](const char* dispatch_key = "") {
auto k = std::string(dispatch_key).empty()
? std::nullopt
: std::make_optional(c10::parseDispatchKey(dispatch_key));
auto op_names =
c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k);
for (auto& op : op_names) {
std::cout << op << '\n';
}
},
py::arg("dispatch_key") = static_cast<const char*>(""));
m.def(
"_parse_dispatch_key",
[](const char* dispatch_key) -> std::optional<c10::DispatchKey> {
try {
return c10::parseDispatchKey(dispatch_key);
} catch (const c10::Error& err) {
return std::nullopt;
}
});
m.def(
"_dispatch_get_registrations_for_dispatch_key",
[](const char* dispatch_key = "") {
auto k = std::string(dispatch_key).empty()
? std::nullopt
: std::make_optional(c10::parseDispatchKey(dispatch_key));
auto op_names =
c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k);
std::vector<std::string> names;
names.reserve(op_names.size());
for (auto& op : op_names) {
names.emplace_back(
op.name +
(op.overload_name.empty() ? "" : "." + op.overload_name));
}
return names;
},
py::arg("dispatch_key") = static_cast<const char*>(""));
m.def(
"_dispatch_set_report_error_callback",
[](c10::OperatorHandle& handle, py::object callback) {
auto obj = callback.release().ptr();
auto callback_obj =
std::make_unique<c10::SafePyObject>(obj, getPyInterpreter());
handle.setReportErrorCallback_(std::move(callback_obj));
});
m.def("_dispatch_pystub", [](const char* name, const char* overload) {
return c10::Dispatcher::singleton().getPyStub(
c10::OperatorName(name, overload));
});
m.def("_replace_", [](const at::Tensor& a, const at::Tensor& b) {
return at::functionalization::impl::replace_(a, b);
});
m.def("_propagate_xla_data", [](const at::Tensor& a, const at::Tensor& b) {
at::functionalization::impl::propagate_xla_data(a, b);
});
m.def("_commit_update", [](const at::Tensor& a) {
return at::functionalization::impl::commit_update(a);
});
m.def("_unsafe_reset_storage", [](const at::Tensor& a) {
return at::functionalization::impl::unsafe_reset_storage(a);
});
m.def("_dispatch_key_for_device", [](const std::string& device_type) {
auto device = c10::Device(device_type);
TORCH_CHECK(
!device.has_index(),
"Expected device_type string to not have a device index; got ",
device_type);
return c10::toString(
c10::computeDispatchKey(std::nullopt, std::nullopt, device));
});
m.def("_are_functorch_transforms_active", []() {
auto include_set = c10::impl::tls_local_dispatch_key_set().included_;
return (
include_set.has(c10::DispatchKey::FuncTorchDynamicLayerFrontMode) ||
include_set.has(c10::DispatchKey::FuncTorchDynamicLayerBackMode));
});
m.def("_autocast_supported_devices", []() {
std::vector<std::string> result;
for (const auto device_type : at::autocast::_AUTOCAST_SUPPORTED_DEVICES) {
result.emplace_back(
c10::DeviceTypeName(device_type, /*lower_case*/ true));
}
return result;
});
m.def("_get_nested_int", [](int64_t data, int64_t coeff) {
return c10::SymInt(c10::SymNode(
c10::make_intrusive<c10::NestedIntSymNodeImpl>(data, coeff)));
});
m.def("_get_constant_bool_symnode", [](int64_t data) {
return c10::SymNode(
c10::make_intrusive<c10::ConstantSymNodeImpl<bool>>(data));
});
m.def("_non_sym_sizes", [](const at::Tensor& a) {
return a.sizes(); // NB: NOT sym_size
});
m.def("_set_throw_on_mutable_data_ptr", [](const at::Tensor& t) {
if (!t.unsafeGetTensorImpl()->has_storage()) {
// If the Tensor doesn't have a storage, then accessing .data_ptr()
// will already raise an error.
return;
}
// Otherwise, set (on the StorageImpl) that accessing (mutable) data_ptr
// will throw.
t.unsafeGetTensorImpl()
->storage()
.unsafeGetStorageImpl()
->set_throw_on_mutable_data_ptr();
});
// Invariant: you must ONLY call this with FakeTensors.
m.def("_set_warn_deprecated_on_mutable_data_ptr", [](const at::Tensor& t) {
if (!t.unsafeGetTensorImpl()->has_storage()) {
// If the Tensor doesn't have a storage, then accessing .data_ptr()
// will already raise an error.
return;
}
t.unsafeGetTensorImpl()
->storage()
.unsafeGetStorageImpl()
->set_warn_deprecated_on_mutable_data_ptr();
});
m.def("_only_lift_cpu_tensors", &torch::utils::only_lift_cpu_tensors);
m.def("_set_only_lift_cpu_tensors", &torch::utils::set_only_lift_cpu_tensors);
using c10::impl::TorchDispatchModeKey;
py::enum_<TorchDispatchModeKey>(m, "_TorchDispatchModeKey")
.value("FUNCTIONAL", TorchDispatchModeKey::FUNCTIONAL)
.value("PROXY", TorchDispatchModeKey::PROXY)
.value("FAKE", TorchDispatchModeKey::FAKE);
}
// TODO: dedupe with the kernel
void python_op_registration_trampoline_impl(
const c10::OperatorHandle& op,
c10::DispatchKey key,
c10::DispatchKeySet keyset,
torch::jit::Stack* stack,
bool with_keyset,
bool with_op) {
auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
py::gil_scoped_acquire g;
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
const auto& func = python_registrations_[op.operator_name()][key];
TORCH_INTERNAL_ASSERT(func != nullptr);
auto* pyobj = func->ptr(getPyInterpreter());
TORCH_INTERNAL_ASSERT(pyobj != nullptr);
auto callable = py::reinterpret_borrow<py::object>(pyobj);
auto obj = with_op ? with_keyset ? callable(
keyset,
torch::detail::getTorchApiFunction(op),
*args_kwargs.first,
**args_kwargs.second)
: callable(
torch::detail::getTorchApiFunction(op),
*args_kwargs.first,
**args_kwargs.second)
: with_keyset ? callable(keyset, *args_kwargs.first, **args_kwargs.second)
: callable(*args_kwargs.first, **args_kwargs.second);
if (!obj) {
throw python_error();
}
pushPyOutToStack(op, stack, obj, "PythonKernelHolder");
}
} // namespace torch::impl::dispatch