mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Follows #132753 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132963 Approved by: https://github.com/Skylion007
1024 lines
36 KiB
C++
1024 lines
36 KiB
C++
#include <torch/csrc/jit/frontend/function_schema_parser.h>
|
|
#include <torch/csrc/utils/python_dispatch.h>
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <ATen/FuncTorchTLS.h>
|
|
#include <ATen/FunctionalTensorWrapper.h>
|
|
#include <ATen/TensorSubclassLikeUtils.h>
|
|
#include <ATen/core/NestedIntSymNodeImpl.h>
|
|
#include <ATen/core/PythonOpRegistrationTrampoline.h>
|
|
#include <ATen/core/dispatch/Dispatcher.h>
|
|
|
|
#include <ATen/functorch/BatchedTensorImpl.h>
|
|
#include <torch/library.h>
|
|
|
|
#include <c10/core/SafePyObject.h>
|
|
#include <torch/csrc/PyInterpreter.h>
|
|
#include <torch/csrc/autograd/python_variable.h>
|
|
#include <torch/csrc/jit/python/pybind_utils.h>
|
|
#include <torch/csrc/utils/tensor_new.h>
|
|
|
|
#include <c10/util/flat_hash_map.h>
|
|
#include <pybind11/operators.h>
|
|
#include <pybind11/stl.h>
|
|
#include <torch/csrc/inductor/aoti_eager/kernel_holder.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
#include <torch/csrc/utils/python_raii.h>
|
|
|
|
#include <iostream>
|
|
#include <utility>
|
|
|
|
namespace py = pybind11;
|
|
|
|
namespace torch::impl::dispatch {
|
|
|
|
// NB: I'd like to index this on OperatorHandle, but I can't, as I can't
|
|
// guarantee that the main interpreter has finish doing all registrations before
|
|
// the other interpreters start banging on it
|
|
static ska::flat_hash_map<
|
|
c10::OperatorName,
|
|
ska::flat_hash_map<c10::DispatchKey, std::shared_ptr<c10::SafePyObject>>>
|
|
python_registrations_;
|
|
|
|
static torch::Library::Kind parseKind(const std::string& k) {
|
|
static std::unordered_map<std::string, torch::Library::Kind> kind_map = {
|
|
{"DEF", torch::Library::DEF},
|
|
{"IMPL", torch::Library::IMPL},
|
|
{"FRAGMENT", torch::Library::FRAGMENT},
|
|
};
|
|
auto it = kind_map.find(k);
|
|
TORCH_CHECK(it != kind_map.end(), "could not parse ", k);
|
|
return it->second;
|
|
}
|
|
static c10::AliasAnalysisKind parseAliasAnalysisKind(const std::string& k) {
|
|
static std::unordered_map<std::string, c10::AliasAnalysisKind> key_map = {
|
|
{"CONSERVATIVE", c10::AliasAnalysisKind::CONSERVATIVE},
|
|
{"FROM_SCHEMA", c10::AliasAnalysisKind::FROM_SCHEMA},
|
|
{"PURE_FUNCTION", c10::AliasAnalysisKind::PURE_FUNCTION},
|
|
{"", c10::AliasAnalysisKind::FROM_SCHEMA}, // default
|
|
};
|
|
auto it = key_map.find(k);
|
|
TORCH_CHECK(it != key_map.end(), "could not parse ", k);
|
|
return it->second;
|
|
}
|
|
|
|
template <typename Func>
|
|
inline torch::CppFunction dispatch_str(const char* key, Func&& raw_f) {
|
|
if (key[0] != '\0') {
|
|
return torch::dispatch(
|
|
c10::parseDispatchKey(key), std::forward<Func>(raw_f));
|
|
} else {
|
|
torch::CppFunction f(std::forward<Func>(raw_f));
|
|
return f;
|
|
}
|
|
}
|
|
|
|
struct EnableHermeticPyObject {
|
|
EnableHermeticPyObject()
|
|
: old_(c10::impl::HermeticPyObjectTLS::get_state()),
|
|
old_excluded_python_(
|
|
c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Python)),
|
|
old_python_(
|
|
c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Python)),
|
|
old_python_snapshot_(c10::impl::tls_is_dispatch_key_included(
|
|
at::DispatchKey::PythonTLSSnapshot)) {
|
|
c10::impl::HermeticPyObjectTLS::set_state(true);
|
|
c10::impl::tls_set_dispatch_key_excluded(at::DispatchKey::Python, true);
|
|
c10::impl::tls_set_dispatch_key_included(at::DispatchKey::Python, false);
|
|
c10::impl::tls_set_dispatch_key_included(
|
|
at::DispatchKey::PythonTLSSnapshot, false);
|
|
}
|
|
~EnableHermeticPyObject() {
|
|
c10::impl::HermeticPyObjectTLS::set_state(old_);
|
|
c10::impl::tls_set_dispatch_key_excluded(
|
|
at::DispatchKey::Python, old_excluded_python_);
|
|
c10::impl::tls_set_dispatch_key_included(
|
|
at::DispatchKey::Python, old_python_);
|
|
c10::impl::tls_set_dispatch_key_included(
|
|
at::DispatchKey::PythonTLSSnapshot, old_python_snapshot_);
|
|
}
|
|
bool old_;
|
|
bool old_excluded_python_;
|
|
bool old_python_;
|
|
bool old_python_snapshot_;
|
|
};
|
|
|
|
class PythonKernelHolder : public c10::OperatorKernel {
|
|
c10::SafePyObject func_;
|
|
c10::DispatchKey dispatch_key_;
|
|
// If "with_keyset", then we expect a keyset as the first arg.
|
|
bool with_keyset_;
|
|
// If "with_op", then we expect the op as first arg (or second if keyset)
|
|
bool with_op_;
|
|
|
|
public:
|
|
PythonKernelHolder(
|
|
py::object func,
|
|
c10::DispatchKey dispatch_key,
|
|
bool with_keyset = false,
|
|
bool with_op = false)
|
|
: func_(func.release().ptr(), getPyInterpreter()),
|
|
dispatch_key_(dispatch_key),
|
|
with_keyset_(with_keyset),
|
|
with_op_(with_op) {}
|
|
|
|
void operator()(
|
|
const c10::OperatorHandle& op,
|
|
c10::DispatchKeySet keyset,
|
|
torch::jit::Stack* stack) {
|
|
// Figure out if we can handle it hermetically, or if we have
|
|
// to double dispatch
|
|
|
|
// If Torch Dispatch Mode is active, use its PyInterpreter for dispatch
|
|
const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len();
|
|
if (mode_stack_len > 0) {
|
|
const auto& cur_torch_dispatch_mode_state =
|
|
c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
|
|
cur_torch_dispatch_mode_state->pyinterpreter()
|
|
->python_op_registration_trampoline(
|
|
op, dispatch_key_, keyset, stack, with_keyset_, with_op_);
|
|
return;
|
|
}
|
|
|
|
const auto& schema = op.schema();
|
|
const auto num_arguments = schema.arguments().size();
|
|
|
|
// Otherwise, find a PyInterpreter on a Tensor IF if has Python key (which
|
|
// means it's a nontrivial tensor subclass)
|
|
for (const auto& ivalue : torch::jit::last(*stack, num_arguments)) {
|
|
if (ivalue.isTensor()) {
|
|
auto* interpreter =
|
|
ivalue.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter();
|
|
if (interpreter &&
|
|
ivalue.unsafeToTensorImpl()->key_set().has(
|
|
at::DispatchKey::Python)) {
|
|
(*interpreter)
|
|
->python_op_registration_trampoline(
|
|
op, dispatch_key_, keyset, stack, with_keyset_, with_op_);
|
|
return;
|
|
}
|
|
} else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) {
|
|
// NB: use toListRef as it doesn't induce refcount bumps
|
|
// (toTensorListRef is not a thing)
|
|
for (const auto& nv : ivalue.toListRef()) {
|
|
if (nv.isNone()) {
|
|
continue;
|
|
}
|
|
auto* interpreter =
|
|
nv.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter();
|
|
if (interpreter &&
|
|
nv.unsafeToTensorImpl()->key_set().has(at::DispatchKey::Python)) {
|
|
(*interpreter)
|
|
->python_op_registration_trampoline(
|
|
op, dispatch_key_, keyset, stack, with_keyset_, with_op_);
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Nothing requires the operator to be homed to a specific interpreter, so
|
|
// run it on the current interpreter
|
|
|
|
auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
|
|
py::gil_scoped_acquire g;
|
|
// Jan 2024: We're slated to get rid of multipy, so stop forcing hermetic
|
|
// mode unconditionally in all situations when you're using multipy.
|
|
// Eventually just delete this entirely. (Note that you may break multipy
|
|
// anyway this way with dispatcher registered functions that require
|
|
// hermetic to be off.)
|
|
#if defined(USE_DEPLOY)
|
|
EnableHermeticPyObject g2;
|
|
#endif
|
|
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
|
|
auto func =
|
|
py::reinterpret_borrow<py::object>(func_.ptr(getPyInterpreter()));
|
|
auto obj = with_op_ ? with_keyset_
|
|
? func(
|
|
keyset,
|
|
torch::detail::getTorchApiFunction(op),
|
|
*args_kwargs.first,
|
|
**args_kwargs.second)
|
|
: func(
|
|
torch::detail::getTorchApiFunction(op),
|
|
*args_kwargs.first,
|
|
**args_kwargs.second)
|
|
: with_keyset_ ? func(keyset, *args_kwargs.first, **args_kwargs.second)
|
|
: func(*args_kwargs.first, **args_kwargs.second);
|
|
if (!obj) {
|
|
throw python_error();
|
|
}
|
|
pushPyOutToStack(op, stack, obj, "PythonKernelHolder");
|
|
}
|
|
};
|
|
|
|
static torch::_RegisterOrVerify register_or_verify() {
|
|
if (isMainPyInterpreter()) {
|
|
return torch::_RegisterOrVerify::REGISTER;
|
|
} else {
|
|
return torch::_RegisterOrVerify::VERIFY;
|
|
}
|
|
}
|
|
|
|
static py::object ophandle_call_boxed(
|
|
const c10::OperatorHandle& handle,
|
|
const py::args& args,
|
|
const py::kwargs& kwargs) {
|
|
auto stack = torch::jit::createStackForSchema(
|
|
handle.schema(),
|
|
args,
|
|
kwargs,
|
|
/*self=*/std::nullopt);
|
|
{
|
|
pybind11::gil_scoped_release no_gil_guard;
|
|
handle.callBoxed(stack);
|
|
}
|
|
return torch::jit::createPyObjectForStack(std::move(stack));
|
|
}
|
|
|
|
// A small RAII guard that lets you explicitly *remove* a key from the TLS
|
|
// exclude set.
|
|
class SetExcludeDispatchKeyGuard {
|
|
public:
|
|
SetExcludeDispatchKeyGuard(at::DispatchKey k, bool set_excluded)
|
|
: k(k), old(c10::impl::tls_is_dispatch_key_excluded(k)) {
|
|
c10::impl::tls_set_dispatch_key_excluded(k, set_excluded);
|
|
}
|
|
~SetExcludeDispatchKeyGuard() {
|
|
c10::impl::tls_set_dispatch_key_excluded(k, old);
|
|
}
|
|
SetExcludeDispatchKeyGuard(const SetExcludeDispatchKeyGuard&) = delete;
|
|
SetExcludeDispatchKeyGuard operator=(const SetExcludeDispatchKeyGuard&) =
|
|
delete;
|
|
SetExcludeDispatchKeyGuard(SetExcludeDispatchKeyGuard&&) = delete;
|
|
SetExcludeDispatchKeyGuard operator=(SetExcludeDispatchKeyGuard&&) = delete;
|
|
|
|
private:
|
|
at::DispatchKey k;
|
|
bool old;
|
|
};
|
|
|
|
void initDispatchBindings(PyObject* module) {
|
|
auto m = py::handle(module).cast<py::module>();
|
|
|
|
py::class_<c10::OperatorHandle>(m, "_DispatchOperatorHandle")
|
|
.def("schema", &c10::OperatorHandle::schema)
|
|
.def("debug", &c10::OperatorHandle::debug)
|
|
.def(
|
|
"redispatch_boxed",
|
|
[](const py::object& self,
|
|
c10::DispatchKeySet keyset,
|
|
py::args args,
|
|
const py::kwargs& kwargs) {
|
|
auto& handle = self.cast<c10::OperatorHandle&>();
|
|
auto stack = torch::jit::createStackForSchema(
|
|
handle.schema(),
|
|
std::move(args),
|
|
kwargs,
|
|
/*self=*/std::nullopt);
|
|
{
|
|
pybind11::gil_scoped_release no_gil_guard;
|
|
handle.redispatchBoxed(keyset, &stack);
|
|
}
|
|
return torch::jit::createPyObjectForStack(std::move(stack));
|
|
});
|
|
|
|
m.def("_dispatch_call_boxed", &ophandle_call_boxed);
|
|
|
|
// TODO: figure out how to do chaining
|
|
py::class_<torch::Library>(m, "_DispatchModule")
|
|
.def(
|
|
"reset",
|
|
[](const py::object& self) {
|
|
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
|
|
self.cast<torch::Library&>().reset();
|
|
return;
|
|
},
|
|
"")
|
|
// Some of these APIs are only for testing and do not work in multipy
|
|
// environment
|
|
.def(
|
|
"def_",
|
|
[](py::object self, const char* schema, const char* alias) {
|
|
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
|
|
self.cast<torch::Library&>().def(
|
|
torch::schema(schema, parseAliasAnalysisKind(alias)));
|
|
return self;
|
|
},
|
|
"",
|
|
py::arg("schema"),
|
|
py::arg("alias") = "")
|
|
// Simulated "legacy" def where alias analysis kind is not set.
|
|
// Ordinarily this can only be exercised from RegisterOperators() API
|
|
// but I am not going to bind that here
|
|
.def(
|
|
"def_legacy",
|
|
[](py::object self, const char* schema) {
|
|
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
|
|
self.cast<torch::Library&>().def(torch::jit::parseSchema(schema));
|
|
return self;
|
|
},
|
|
"",
|
|
py::arg("schema"))
|
|
// We can't conveniently turn Python functions into valid functions
|
|
// in the dispatcher. So instead we provide a bunch of precanned
|
|
// functions for testing purposes. You're NOT intended to actually
|
|
// call these functions; they're just here so we can actually register
|
|
// something
|
|
//
|
|
// Mangling scheme: args_rets. One character per.
|
|
// t = Tensor
|
|
.def(
|
|
"def_name_t_t",
|
|
[](py::object self,
|
|
const char* name,
|
|
const char* dispatch,
|
|
const char* debug) {
|
|
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
|
|
self.cast<torch::Library&>().def(
|
|
name, dispatch_str(dispatch, [](const at::Tensor& a) {
|
|
return a;
|
|
}).debug(debug));
|
|
return self;
|
|
},
|
|
"",
|
|
py::arg("name"),
|
|
py::arg("dispatch") = "",
|
|
py::arg("debug") = "default_def_name_t_t")
|
|
.def(
|
|
"def_schema_t_t",
|
|
[](py::object self,
|
|
const char* schema,
|
|
const char* dispatch,
|
|
const char* alias,
|
|
const char* debug) {
|
|
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
|
|
self.cast<torch::Library&>().def(
|
|
torch::schema(schema, parseAliasAnalysisKind(alias)),
|
|
dispatch_str(dispatch, [](const at::Tensor& a) {
|
|
return a;
|
|
}).debug(debug));
|
|
return self;
|
|
},
|
|
"",
|
|
py::arg("name"),
|
|
py::arg("dispatch") = "",
|
|
py::arg("alias") = "",
|
|
py::arg("debug") = "default_def_schema_t_t")
|
|
// TODO: maybe consider deduplicating the definitions here, it's getting
|
|
// pretty long
|
|
.def(
|
|
"impl_t_t",
|
|
[](py::object self,
|
|
const char* name,
|
|
const char* dispatch,
|
|
const char* debug) {
|
|
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
|
|
self.cast<torch::Library&>().impl(
|
|
name, dispatch_str(dispatch, [](const at::Tensor& a) {
|
|
return a;
|
|
}).debug(debug));
|
|
return self;
|
|
},
|
|
"",
|
|
py::arg("name"),
|
|
py::arg("dispatch") = "",
|
|
py::arg("debug") = "impl_t_t")
|
|
.def(
|
|
"impl_with_aoti_compile",
|
|
[](const py::object& self,
|
|
const char* ns,
|
|
const char* op_name_with_overload,
|
|
c10::DispatchKey dispatch) {
|
|
HANDLE_TH_ERRORS
|
|
std::string reg_op_name =
|
|
std::string(ns).append("::").append(op_name_with_overload);
|
|
|
|
auto& lib = self.cast<torch::Library&>();
|
|
lib.impl(
|
|
reg_op_name.c_str(),
|
|
torch::dispatch(
|
|
dispatch,
|
|
CppFunction::makeFromBoxedFunctor(
|
|
std::make_unique<
|
|
torch::inductor::AOTIPythonKernelHolder>(
|
|
dispatch, ns, op_name_with_overload))),
|
|
register_or_verify());
|
|
END_HANDLE_TH_ERRORS_PYBIND
|
|
},
|
|
"",
|
|
py::arg("ns"),
|
|
py::arg("op_name_with_overload"),
|
|
py::arg("dispatch"))
|
|
.def(
|
|
"impl",
|
|
[](const py::object& self,
|
|
const char* name,
|
|
// TODO: empty string no longer works
|
|
c10::DispatchKey dispatch,
|
|
py::object func,
|
|
bool with_keyset) {
|
|
HANDLE_TH_ERRORS
|
|
auto& lib = self.cast<torch::Library&>();
|
|
if (func.is(py::module::import("torch.library")
|
|
.attr("fallthrough_kernel"))) {
|
|
lib.impl(
|
|
name,
|
|
torch::dispatch(dispatch, CppFunction::makeFallthrough()),
|
|
register_or_verify());
|
|
} else {
|
|
lib.impl(
|
|
name,
|
|
torch::dispatch(
|
|
dispatch,
|
|
CppFunction::makeFromBoxedFunctor(
|
|
std::make_unique<PythonKernelHolder>(
|
|
func, dispatch, with_keyset))),
|
|
register_or_verify());
|
|
python_registrations_[lib._resolve(name)].insert_or_assign(
|
|
dispatch,
|
|
std::make_shared<c10::SafePyObject>(
|
|
func.release().ptr(), getPyInterpreter()));
|
|
}
|
|
END_HANDLE_TH_ERRORS_PYBIND
|
|
},
|
|
"",
|
|
py::arg("name"),
|
|
py::arg("dispatch"),
|
|
py::arg("func"),
|
|
py::arg("with_keyset") = false)
|
|
.def(
|
|
"define",
|
|
[](const py::object& self,
|
|
const char* schema,
|
|
const char* alias_analysis,
|
|
const std::vector<at::Tag>& tags) {
|
|
auto parsed_schema =
|
|
torch::schema(schema, parseAliasAnalysisKind(alias_analysis));
|
|
self.cast<torch::Library&>().def(
|
|
std::move(parsed_schema), tags, register_or_verify());
|
|
// TODO: this is dumb, had to make a second copy
|
|
return torch::schema(schema, parseAliasAnalysisKind(alias_analysis))
|
|
.name();
|
|
},
|
|
"",
|
|
py::arg("schema"),
|
|
py::arg("alias_analysis") = "",
|
|
py::arg("tags") = std::vector<at::Tag>())
|
|
.def(
|
|
"fallback_fallthrough",
|
|
[](py::object self, const char* dispatch) {
|
|
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
|
|
self.cast<torch::Library&>().fallback(
|
|
dispatch_str(dispatch, CppFunction::makeFallthrough()));
|
|
return self;
|
|
},
|
|
"",
|
|
py::arg("dispatch") = "")
|
|
.def(
|
|
"fallback",
|
|
[](const py::object& self,
|
|
c10::DispatchKey dispatch,
|
|
const py::object& func,
|
|
bool with_keyset) {
|
|
HANDLE_TH_ERRORS
|
|
auto& lib = self.cast<torch::Library&>();
|
|
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
|
|
if (func.is(py::module::import("torch.library")
|
|
.attr("fallthrough_kernel"))) {
|
|
lib.fallback(
|
|
torch::dispatch(dispatch, CppFunction::makeFallthrough()));
|
|
} else {
|
|
lib.fallback(torch::dispatch(
|
|
dispatch,
|
|
CppFunction::makeFromBoxedFunctor(
|
|
std::make_unique<PythonKernelHolder>(
|
|
func, dispatch, with_keyset, /*with_op*/ true))));
|
|
}
|
|
END_HANDLE_TH_ERRORS_PYBIND
|
|
},
|
|
"",
|
|
py::arg("dispatch"),
|
|
py::arg("func"),
|
|
py::arg("with_keyset") = false);
|
|
|
|
m.def(
|
|
"_dispatch_library",
|
|
[](const char* kind,
|
|
std::string name,
|
|
const char* dispatch,
|
|
const char* file,
|
|
uint32_t linenum) {
|
|
HANDLE_TH_ERRORS
|
|
return std::make_unique<torch::Library>(
|
|
parseKind(kind),
|
|
std::move(name),
|
|
std::string(dispatch).empty()
|
|
? std::nullopt
|
|
: std::make_optional(c10::parseDispatchKey(dispatch)),
|
|
"/dev/null", // temporary workaround
|
|
linenum);
|
|
END_HANDLE_TH_ERRORS_PYBIND
|
|
},
|
|
"",
|
|
py::arg("kind"),
|
|
py::arg("name"),
|
|
py::arg("dispatch"),
|
|
py::arg("file") = "/dev/null",
|
|
py::arg("linenum") = 0);
|
|
|
|
m.def(
|
|
"_dispatch_find_schema_or_throw",
|
|
[](const char* name, const char* overload_name) -> c10::OperatorHandle {
|
|
return c10::Dispatcher::singleton().findSchemaOrThrow(
|
|
name, overload_name);
|
|
});
|
|
|
|
m.def("_dispatch_dump", [](const char* name) -> std::string {
|
|
auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
|
|
if (!op) {
|
|
return "";
|
|
} else {
|
|
return op->dumpState();
|
|
}
|
|
});
|
|
|
|
m.def("_dispatch_dump_table", [](const char* name) -> std::string {
|
|
auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
|
|
if (!op) {
|
|
return "";
|
|
} else {
|
|
return op->dumpComputedTable();
|
|
}
|
|
});
|
|
|
|
m.def("_dispatch_check_invariants", [](const char* name) {
|
|
auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
|
|
if (!op) {
|
|
} else {
|
|
return op->checkInvariants();
|
|
}
|
|
});
|
|
|
|
m.def("_dispatch_check_all_invariants", []() {
|
|
c10::Dispatcher::singleton().checkInvariants();
|
|
});
|
|
|
|
m.def("_dispatch_has_kernel", [](const char* name) -> bool {
|
|
auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
|
|
return static_cast<bool>(op);
|
|
});
|
|
|
|
m.def(
|
|
// Returns whether or not a direct kernel registration exists
|
|
// for this <op_name, dispatch_key> pair.
|
|
"_dispatch_has_kernel_for_dispatch_key",
|
|
[](const char* name, c10::DispatchKey dispatch) -> bool {
|
|
auto op =
|
|
c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
|
|
TORCH_CHECK(op, "operator ", name, " does not exist");
|
|
return op->hasKernelForDispatchKey(dispatch);
|
|
});
|
|
|
|
m.def(
|
|
// Returns whether or not the kernel for this dispatach key is a
|
|
// fallthrough kernel
|
|
"_dispatch_kernel_for_dispatch_key_is_fallthrough",
|
|
[](const char* name, c10::DispatchKey dispatch) -> bool {
|
|
auto op =
|
|
c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
|
|
return op->isKernelFallthroughKernel(dispatch);
|
|
});
|
|
|
|
m.def(
|
|
"_dispatch_has_kernel_for_any_dispatch_key",
|
|
[](const char* name, c10::DispatchKeySet ks) -> bool {
|
|
auto op =
|
|
c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
|
|
TORCH_CHECK(op, "operator ", name, " does not exist");
|
|
return op->hasKernelForAnyDispatchKey(ks);
|
|
});
|
|
|
|
m.def(
|
|
// Returns whether or not there is an entry in the runtime computed
|
|
// dispatch table, for this <op_name, dispatch_key> pair. For example, if
|
|
// "op" has a `CompositeImplicitAutograd` kernel, Then
|
|
// _dispatch_has_computed_kernel_for_dispatch_key(op, backend) will return
|
|
// true for all backends that are part of the alias set for
|
|
// CompositeImplicitAutograd.
|
|
"_dispatch_has_computed_kernel_for_dispatch_key",
|
|
[](const char* name, const char* dispatch) -> bool {
|
|
auto op =
|
|
c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
|
|
TORCH_CHECK(op, "operator ", name, " does not exist");
|
|
return op->hasComputedKernelForDispatchKey(
|
|
c10::parseDispatchKey(dispatch));
|
|
});
|
|
|
|
m.def("_dispatch_find_dangling_impls", []() -> std::vector<std::string> {
|
|
auto danglingImpls = c10::Dispatcher::singleton().findDanglingImpls();
|
|
|
|
std::vector<std::string> states;
|
|
states.reserve(danglingImpls.size());
|
|
for (auto& danglingImpl : danglingImpls) {
|
|
states.emplace_back(danglingImpl.dumpState());
|
|
}
|
|
|
|
return states;
|
|
});
|
|
|
|
m.def("_dispatch_get_all_op_names", []() -> std::vector<std::string> {
|
|
auto op_names = c10::Dispatcher::singleton().getAllOpNames();
|
|
|
|
std::vector<std::string> names;
|
|
names.reserve(op_names.size());
|
|
for (auto& op : op_names) {
|
|
std::stringstream ss;
|
|
ss << op.name;
|
|
if (!op.overload_name.empty()) {
|
|
ss << "." << op.overload_name;
|
|
}
|
|
names.emplace_back(ss.str());
|
|
}
|
|
|
|
return names;
|
|
});
|
|
|
|
m.def(
|
|
"_dispatch_tls_set_dispatch_key_excluded",
|
|
[](c10::DispatchKey dispatch_key, bool desired_state) {
|
|
c10::impl::tls_set_dispatch_key_excluded(dispatch_key, desired_state);
|
|
});
|
|
m.def(
|
|
"_dispatch_tls_is_dispatch_key_excluded",
|
|
[](c10::DispatchKey dispatch_key) {
|
|
return c10::impl::tls_is_dispatch_key_excluded(dispatch_key);
|
|
});
|
|
m.def(
|
|
"_dispatch_tls_set_dispatch_key_included",
|
|
[](c10::DispatchKey dispatch_key, bool desired_state) {
|
|
c10::impl::tls_set_dispatch_key_included(dispatch_key, desired_state);
|
|
});
|
|
m.def(
|
|
"_dispatch_tls_is_dispatch_key_included",
|
|
[](c10::DispatchKey dispatch_key) {
|
|
return c10::impl::tls_is_dispatch_key_included(dispatch_key);
|
|
});
|
|
|
|
m.def("_dispatch_isTensorSubclassLike", [](const at::Tensor& tensor) {
|
|
return at::isTensorSubclassLike(tensor);
|
|
});
|
|
|
|
m.def("_dispatch_key_name", [](c10::DispatchKey k) {
|
|
return c10::toString(k);
|
|
});
|
|
m.def("_dispatch_key_parse", [](c10::DispatchKey k) { return k; });
|
|
m.def("_to_functionality_key", [](c10::DispatchKey k) {
|
|
return c10::toFunctionalityKey(k);
|
|
});
|
|
// E.g. given `DispatchKey::AutogradFunctionality`, returns a keyset of:
|
|
// AutogradCPU
|
|
// AutogradCUDA
|
|
// ...
|
|
// AutogradPrivateUse3
|
|
m.def("_functionality_to_backend_keys", [](c10::DispatchKey key) {
|
|
std::vector<c10::DispatchKey> keys;
|
|
if (c10::isPerBackendFunctionalityKey(key)) {
|
|
auto ks = c10::DispatchKeySet(key) |
|
|
c10::DispatchKeySet(c10::DispatchKeySet::RAW, c10::full_backend_mask);
|
|
for (auto k : ks) {
|
|
keys.push_back(k);
|
|
}
|
|
} else {
|
|
keys.push_back(key);
|
|
}
|
|
return keys;
|
|
});
|
|
m.def("_dispatch_num_backends", []() { return c10::num_backends; });
|
|
|
|
#define DEF_ONE(n) .value(#n, c10::DispatchKey::n)
|
|
|
|
py::enum_<c10::DispatchKey>(m, "DispatchKey")
|
|
// clang-format off
|
|
DEF_ONE(Undefined)
|
|
DEF_ONE(CompositeExplicitAutogradNonFunctional)
|
|
DEF_ONE(CompositeExplicitAutograd)
|
|
DEF_ONE(CompositeImplicitAutogradNestedTensor)
|
|
DEF_ONE(CompositeImplicitAutograd)
|
|
// NestedTensor is not a backend key
|
|
DEF_ONE(AutogradNestedTensor)
|
|
DEF_ONE(AutogradOther)
|
|
DEF_ONE(Autograd)
|
|
DEF_ONE(Conjugate)
|
|
DEF_ONE(ZeroTensor)
|
|
DEF_ONE(Negative)
|
|
DEF_ONE(BackendSelect)
|
|
DEF_ONE(ADInplaceOrView)
|
|
DEF_ONE(PythonTLSSnapshot)
|
|
DEF_ONE(Python)
|
|
DEF_ONE(FuncTorchDynamicLayerFrontMode)
|
|
DEF_ONE(FuncTorchDynamicLayerBackMode)
|
|
DEF_ONE(FuncTorchBatchedDecomposition)
|
|
DEF_ONE(FuncTorchBatched)
|
|
DEF_ONE(FuncTorchVmapMode)
|
|
DEF_ONE(FuncTorchGradWrapper)
|
|
DEF_ONE(PythonDispatcher)
|
|
DEF_ONE(PreDispatch)
|
|
DEF_ONE(Functionalize)
|
|
DEF_ONE(AutocastCPU)
|
|
DEF_ONE(AutocastXPU)
|
|
DEF_ONE(AutocastHPU)
|
|
DEF_ONE(AutocastIPU)
|
|
DEF_ONE(AutocastCUDA)
|
|
DEF_ONE(AutocastPrivateUse1)
|
|
// clang-format on
|
|
|
|
#define DEF_SINGLE(n, prefix) .value(#prefix #n, c10::DispatchKey::prefix##n)
|
|
#define DEF_MULTIPLE(fullname, prefix) \
|
|
DEF_SINGLE(, fullname) \
|
|
DEF_SINGLE(, StartOf##fullname##Backends) \
|
|
C10_FORALL_BACKEND_COMPONENTS(DEF_SINGLE, prefix) \
|
|
DEF_SINGLE(, EndOf##fullname##Backends)
|
|
|
|
// clang-format off
|
|
C10_FORALL_FUNCTIONALITY_KEYS(DEF_MULTIPLE)
|
|
// clang-format on
|
|
|
|
#undef DEF_MULTIPLE
|
|
#undef DEF_SINGLE
|
|
;
|
|
|
|
py::class_<c10::DispatchKeySet>(m, "DispatchKeySet")
|
|
.def(py::init<c10::DispatchKey>())
|
|
.def("__or__", &c10::DispatchKeySet::operator|)
|
|
.def("__sub__", &c10::DispatchKeySet::operator-)
|
|
.def("__and__", &c10::DispatchKeySet::operator&)
|
|
.def("raw_repr", &c10::DispatchKeySet::raw_repr)
|
|
.def("highestPriorityTypeId", &c10::DispatchKeySet::highestPriorityTypeId)
|
|
.def(
|
|
"remove",
|
|
[](c10::DispatchKeySet self, c10::DispatchKey k) {
|
|
return self.remove(k);
|
|
})
|
|
.def(
|
|
"add",
|
|
[](c10::DispatchKeySet self, c10::DispatchKey k) {
|
|
return self.add(k);
|
|
})
|
|
.def("has", &c10::DispatchKeySet::has)
|
|
.def("__repr__", [](c10::DispatchKeySet d) { return c10::toString(d); });
|
|
|
|
m.attr("_dispatch_autogradother_backends") =
|
|
py::cast(c10::autogradother_backends);
|
|
|
|
m.attr("_additional_keys_to_prop_for_wrapper_tensors") =
|
|
py::cast(at::functorch::kKeysToPropagateToWrapper);
|
|
|
|
m.attr("_after_autograd_keyset") = py::cast(c10::after_autograd_keyset);
|
|
m.attr("_after_ADInplaceOrView_keyset") =
|
|
py::cast(c10::after_ADInplaceOrView_keyset);
|
|
|
|
m.def("_dispatch_has_backend_fallback", [](c10::DispatchKey t) {
|
|
return c10::Dispatcher::singleton().hasBackendFallbackForDispatchKey(t);
|
|
});
|
|
|
|
m.def("_dispatch_keyset_full_after", [](c10::DispatchKey t) {
|
|
return c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, t);
|
|
});
|
|
|
|
m.def("_dispatch_keyset_full", []() {
|
|
return c10::DispatchKeySet(c10::DispatchKeySet::FULL);
|
|
});
|
|
|
|
m.def("_dispatch_is_alias_key", c10::isAliasDispatchKey);
|
|
|
|
m.def("_dispatch_keyset_to_string", [](c10::DispatchKeySet keyset) {
|
|
return c10::toString(keyset);
|
|
});
|
|
|
|
m.def("_dispatch_get_backend_keyset_from_autograd", [](c10::DispatchKey k) {
|
|
return c10::getBackendKeySetFromAutograd(k);
|
|
});
|
|
|
|
m.def("_dispatch_keys", [](const at::Tensor& tensor) {
|
|
auto* impl = tensor.unsafeGetTensorImpl();
|
|
return impl->key_set();
|
|
});
|
|
m.def("_dispatch_tls_local_include_set", []() {
|
|
return c10::impl::tls_local_dispatch_key_set().included_;
|
|
});
|
|
m.def("_dispatch_tls_local_exclude_set", []() {
|
|
return c10::impl::tls_local_dispatch_key_set().excluded_;
|
|
});
|
|
m.def("_functionalization_reapply_views_tls", []() {
|
|
return at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
|
|
});
|
|
m.def(
|
|
"_dispatch_is_included_in_alias",
|
|
[](c10::DispatchKey a, c10::DispatchKey b) {
|
|
return c10::isIncludedInAlias(a, b);
|
|
});
|
|
|
|
// DEPRECATED, please don't use this. Instead use
|
|
// torch._C._ExcludeDispatchKeyGuard
|
|
py_context_manager_DEPRECATED<
|
|
c10::impl::ExcludeDispatchKeyGuard,
|
|
c10::DispatchKeySet>(m, "ExcludeDispatchKeyGuard");
|
|
|
|
py_context_manager<
|
|
c10::impl::ForceDispatchKeyGuard,
|
|
c10::DispatchKeySet,
|
|
c10::DispatchKeySet>(m, "_ForceDispatchKeyGuard");
|
|
py_context_manager<c10::impl::ForceDispatchKeyGuard>(
|
|
m, "_PreserveDispatchKeyGuard");
|
|
py_context_manager<c10::impl::IncludeDispatchKeyGuard, c10::DispatchKey>(
|
|
m, "_IncludeDispatchKeyGuard");
|
|
py_context_manager<c10::impl::ExcludeDispatchKeyGuard, c10::DispatchKeySet>(
|
|
m, "_ExcludeDispatchKeyGuard");
|
|
py_context_manager<SetExcludeDispatchKeyGuard, c10::DispatchKey, bool>(
|
|
m, "_SetExcludeDispatchKeyGuard");
|
|
|
|
py_context_manager_DEPRECATED<at::AutoDispatchBelowAutograd>(
|
|
m, "_AutoDispatchBelowAutograd");
|
|
py_context_manager<at::AutoDispatchBelowADInplaceOrView>(
|
|
m, "_AutoDispatchBelowADInplaceOrView");
|
|
|
|
// Prints out the name of every operator that has a kernel registered to the
|
|
// Dispatcher under [dispatch_key]. If no arguments are specified, it'll print
|
|
// out the name of every operator that the Dispatcher knows of. This can be
|
|
// useful to answer questions like "list all operators that do not have a CPU
|
|
// kernel".
|
|
m.def(
|
|
"_dispatch_print_registrations_for_dispatch_key",
|
|
[](const char* dispatch_key = "") {
|
|
auto k = std::string(dispatch_key).empty()
|
|
? std::nullopt
|
|
: std::make_optional(c10::parseDispatchKey(dispatch_key));
|
|
auto op_names =
|
|
c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k);
|
|
for (auto& op : op_names) {
|
|
std::cout << op << '\n';
|
|
}
|
|
},
|
|
py::arg("dispatch_key") = static_cast<const char*>(""));
|
|
|
|
m.def(
|
|
"_parse_dispatch_key",
|
|
[](const char* dispatch_key) -> std::optional<c10::DispatchKey> {
|
|
try {
|
|
return c10::parseDispatchKey(dispatch_key);
|
|
} catch (const c10::Error& err) {
|
|
return std::nullopt;
|
|
}
|
|
});
|
|
|
|
m.def(
|
|
"_dispatch_get_registrations_for_dispatch_key",
|
|
[](const char* dispatch_key = "") {
|
|
auto k = std::string(dispatch_key).empty()
|
|
? std::nullopt
|
|
: std::make_optional(c10::parseDispatchKey(dispatch_key));
|
|
auto op_names =
|
|
c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k);
|
|
std::vector<std::string> names;
|
|
names.reserve(op_names.size());
|
|
for (auto& op : op_names) {
|
|
names.emplace_back(
|
|
op.name +
|
|
(op.overload_name.empty() ? "" : "." + op.overload_name));
|
|
}
|
|
return names;
|
|
},
|
|
py::arg("dispatch_key") = static_cast<const char*>(""));
|
|
m.def(
|
|
"_dispatch_set_report_error_callback",
|
|
[](c10::OperatorHandle& handle, py::object callback) {
|
|
auto obj = callback.release().ptr();
|
|
auto callback_obj =
|
|
std::make_unique<c10::SafePyObject>(obj, getPyInterpreter());
|
|
handle.setReportErrorCallback_(std::move(callback_obj));
|
|
});
|
|
|
|
m.def(
|
|
"_dispatch_is_main_interpreter", []() { return isMainPyInterpreter(); });
|
|
m.def("_dispatch_pystub", [](const char* name, const char* overload) {
|
|
return c10::Dispatcher::singleton().getPyStub(
|
|
c10::OperatorName(name, overload));
|
|
});
|
|
|
|
m.def("_replace_", [](const at::Tensor& a, const at::Tensor& b) {
|
|
return at::functionalization::impl::replace_(a, b);
|
|
});
|
|
m.def("_propagate_xla_data", [](const at::Tensor& a, const at::Tensor& b) {
|
|
at::functionalization::impl::propagate_xla_data(a, b);
|
|
});
|
|
m.def("_commit_update", [](const at::Tensor& a) {
|
|
return at::functionalization::impl::commit_update(a);
|
|
});
|
|
m.def("_unsafe_reset_storage", [](const at::Tensor& a) {
|
|
return at::functionalization::impl::unsafe_reset_storage(a);
|
|
});
|
|
|
|
m.def("_dispatch_key_for_device", [](const std::string& device_type) {
|
|
auto device = c10::Device(device_type);
|
|
TORCH_CHECK(
|
|
!device.has_index(),
|
|
"Expected device_type string to not have a device index; got ",
|
|
device_type);
|
|
return c10::toString(
|
|
c10::computeDispatchKey(std::nullopt, std::nullopt, device));
|
|
});
|
|
|
|
m.def("_are_functorch_transforms_active", []() {
|
|
auto include_set = c10::impl::tls_local_dispatch_key_set().included_;
|
|
return (
|
|
include_set.has(c10::DispatchKey::FuncTorchDynamicLayerFrontMode) ||
|
|
include_set.has(c10::DispatchKey::FuncTorchDynamicLayerBackMode));
|
|
});
|
|
|
|
m.def("_get_nested_int", [](int64_t data, int64_t coeff) {
|
|
return c10::SymInt(c10::SymNode(
|
|
c10::make_intrusive<c10::NestedIntSymNodeImpl>(data, coeff)));
|
|
});
|
|
|
|
m.def("_get_constant_bool_symnode", [](int64_t data) {
|
|
return c10::SymNode(
|
|
c10::make_intrusive<c10::ConstantSymNodeImpl<bool>>(data));
|
|
});
|
|
|
|
m.def("_non_sym_sizes", [](const at::Tensor& a) {
|
|
return a.sizes(); // NB: NOT sym_size
|
|
});
|
|
|
|
m.def("_set_throw_on_mutable_data_ptr", [](const at::Tensor& t) {
|
|
if (!t.unsafeGetTensorImpl()->has_storage()) {
|
|
// If the Tensor doesn't have a storage, then accessing .data_ptr()
|
|
// will already raise an error.
|
|
return;
|
|
}
|
|
// Otherwise, set (on the StorageImpl) that accessing (mutable) data_ptr
|
|
// will throw.
|
|
t.unsafeGetTensorImpl()
|
|
->storage()
|
|
.unsafeGetStorageImpl()
|
|
->set_throw_on_mutable_data_ptr();
|
|
});
|
|
|
|
// Invariant: you must ONLY call this with FakeTensors.
|
|
m.def("_set_warn_deprecated_on_mutable_data_ptr", [](const at::Tensor& t) {
|
|
if (!t.unsafeGetTensorImpl()->has_storage()) {
|
|
// If the Tensor doesn't have a storage, then accessing .data_ptr()
|
|
// will already raise an error.
|
|
return;
|
|
}
|
|
t.unsafeGetTensorImpl()
|
|
->storage()
|
|
.unsafeGetStorageImpl()
|
|
->set_warn_deprecated_on_mutable_data_ptr();
|
|
});
|
|
|
|
m.def("_only_lift_cpu_tensors", &torch::utils::only_lift_cpu_tensors);
|
|
m.def("_set_only_lift_cpu_tensors", &torch::utils::set_only_lift_cpu_tensors);
|
|
|
|
using c10::impl::TorchDispatchModeKey;
|
|
py::enum_<TorchDispatchModeKey>(m, "_TorchDispatchModeKey")
|
|
.value("FUNCTIONAL", TorchDispatchModeKey::FUNCTIONAL)
|
|
.value("PROXY", TorchDispatchModeKey::PROXY)
|
|
.value("FAKE", TorchDispatchModeKey::FAKE);
|
|
}
|
|
|
|
// TODO: dedupe with the kernel
|
|
void python_op_registration_trampoline_impl(
|
|
const c10::OperatorHandle& op,
|
|
c10::DispatchKey key,
|
|
c10::DispatchKeySet keyset,
|
|
torch::jit::Stack* stack,
|
|
bool with_keyset,
|
|
bool with_op) {
|
|
auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
|
|
py::gil_scoped_acquire g;
|
|
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
|
|
const auto& func = python_registrations_[op.operator_name()][key];
|
|
TORCH_INTERNAL_ASSERT(func != nullptr);
|
|
auto* pyobj = func->ptr(getPyInterpreter());
|
|
TORCH_INTERNAL_ASSERT(pyobj != nullptr);
|
|
auto callable = py::reinterpret_borrow<py::object>(pyobj);
|
|
auto obj = with_op ? with_keyset ? callable(
|
|
keyset,
|
|
torch::detail::getTorchApiFunction(op),
|
|
*args_kwargs.first,
|
|
**args_kwargs.second)
|
|
: callable(
|
|
torch::detail::getTorchApiFunction(op),
|
|
*args_kwargs.first,
|
|
**args_kwargs.second)
|
|
: with_keyset ? callable(keyset, *args_kwargs.first, **args_kwargs.second)
|
|
: callable(*args_kwargs.first, **args_kwargs.second);
|
|
if (!obj) {
|
|
throw python_error();
|
|
}
|
|
pushPyOutToStack(op, stack, obj, "PythonKernelHolder");
|
|
}
|
|
|
|
} // namespace torch::impl::dispatch
|