mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
We discussed in a composability meeting a few weeks ago that `pre_autograd` should probably be renamed to `pre_dispatch`. One question in this PR was: should I re-use a dispatch key? Or should I create a new dispatch key (that yet again corresponds to "top of the dispatcher")? ~~For now, I ended up sticking our proxy mode on the mode stack corresponding to `PythonTLSSnapshot`, because it was simple and it works. It looks like one of the functorch dispatch keys has higher priority though, so it's possible that functorch will end up running first. Open to options, but we can consider adding a new dispatch key later if that becomes a problem~~ Update: I added a dedicated dispatch key, `PreDispatch`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/101818 Approved by: https://github.com/ezyang, https://github.com/Neilblaze, https://github.com/albanD, https://github.com/zou3519
753 lines
26 KiB
C++
753 lines
26 KiB
C++
#include <torch/csrc/jit/frontend/function_schema_parser.h>
|
|
#include <torch/csrc/utils/python_dispatch.h>
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <ATen/FuncTorchTLS.h>
|
|
#include <ATen/FunctionalTensorWrapper.h>
|
|
#include <ATen/TensorSubclassLikeUtils.h>
|
|
#include <ATen/core/PythonOpRegistrationTrampoline.h>
|
|
#include <ATen/core/dispatch/Dispatcher.h>
|
|
#include <torch/library.h>
|
|
|
|
#include <c10/core/SafePyObject.h>
|
|
#include <torch/csrc/PyInterpreter.h>
|
|
#include <torch/csrc/autograd/python_variable.h>
|
|
#include <torch/csrc/jit/python/pybind_utils.h>
|
|
|
|
#include <c10/util/flat_hash_map.h>
|
|
#include <pybind11/operators.h>
|
|
#include <pybind11/stl.h>
|
|
#include <torch/csrc/jit/python/pybind_utils.h>
|
|
#include <torch/csrc/utils/pybind.h>
|
|
#include <torch/csrc/utils/python_raii.h>
|
|
|
|
#include <iostream>
|
|
|
|
namespace py = pybind11;
|
|
|
|
namespace torch {
|
|
namespace impl {
|
|
namespace dispatch {
|
|
|
|
// NB: I'd like to index this on OperatorHandle, but I can't, as I can't
|
|
// guarantee that the main interpreter has finish doing all registrations before
|
|
// the other interpreters start banging on it
|
|
static ska::flat_hash_map<
|
|
c10::OperatorName,
|
|
ska::flat_hash_map<c10::DispatchKey, std::shared_ptr<c10::SafePyObject>>>
|
|
python_registrations_;
|
|
|
|
static torch::Library::Kind parseKind(const std::string& k) {
|
|
static std::unordered_map<std::string, torch::Library::Kind> kind_map = {
|
|
{"DEF", torch::Library::DEF},
|
|
{"IMPL", torch::Library::IMPL},
|
|
{"FRAGMENT", torch::Library::FRAGMENT},
|
|
};
|
|
auto it = kind_map.find(k);
|
|
TORCH_CHECK(it != kind_map.end(), "could not parse ", k);
|
|
return it->second;
|
|
}
|
|
static c10::AliasAnalysisKind parseAliasAnalysisKind(const std::string& k) {
|
|
static std::unordered_map<std::string, c10::AliasAnalysisKind> key_map = {
|
|
{"CONSERVATIVE", c10::AliasAnalysisKind::CONSERVATIVE},
|
|
{"FROM_SCHEMA", c10::AliasAnalysisKind::FROM_SCHEMA},
|
|
{"PURE_FUNCTION", c10::AliasAnalysisKind::PURE_FUNCTION},
|
|
{"", c10::AliasAnalysisKind::FROM_SCHEMA}, // default
|
|
};
|
|
auto it = key_map.find(k);
|
|
TORCH_CHECK(it != key_map.end(), "could not parse ", k);
|
|
return it->second;
|
|
}
|
|
|
|
template <typename Func>
|
|
inline torch::CppFunction dispatch_str(const char* key, Func&& raw_f) {
|
|
auto mb_key = std::string(key).empty()
|
|
? c10::nullopt
|
|
: c10::make_optional(c10::parseDispatchKey(key));
|
|
if (mb_key) {
|
|
return torch::dispatch(*mb_key, std::forward<Func>(raw_f));
|
|
} else {
|
|
torch::CppFunction f(std::forward<Func>(raw_f));
|
|
return f;
|
|
}
|
|
}
|
|
|
|
struct EnableHermeticPyObject {
|
|
EnableHermeticPyObject()
|
|
: old_(c10::impl::HermeticPyObjectTLS::get_state()),
|
|
old_excluded_python_(
|
|
c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Python)),
|
|
old_python_(
|
|
c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Python)),
|
|
old_python_snapshot_(c10::impl::tls_is_dispatch_key_included(
|
|
at::DispatchKey::PythonTLSSnapshot)) {
|
|
c10::impl::HermeticPyObjectTLS::set_state(true);
|
|
c10::impl::tls_set_dispatch_key_excluded(at::DispatchKey::Python, true);
|
|
c10::impl::tls_set_dispatch_key_included(at::DispatchKey::Python, false);
|
|
c10::impl::tls_set_dispatch_key_included(
|
|
at::DispatchKey::PythonTLSSnapshot, false);
|
|
}
|
|
~EnableHermeticPyObject() {
|
|
c10::impl::HermeticPyObjectTLS::set_state(old_);
|
|
c10::impl::tls_set_dispatch_key_excluded(
|
|
at::DispatchKey::Python, old_excluded_python_);
|
|
c10::impl::tls_set_dispatch_key_included(
|
|
at::DispatchKey::Python, old_python_);
|
|
c10::impl::tls_set_dispatch_key_included(
|
|
at::DispatchKey::PythonTLSSnapshot, old_python_snapshot_);
|
|
}
|
|
bool old_;
|
|
bool old_excluded_python_;
|
|
bool old_python_;
|
|
bool old_python_snapshot_;
|
|
};
|
|
|
|
class PythonKernelHolder : public c10::OperatorKernel {
|
|
c10::SafePyObject func_;
|
|
c10::DispatchKey dispatch_key_;
|
|
|
|
public:
|
|
PythonKernelHolder(py::object func, c10::DispatchKey dispatch_key)
|
|
: func_(func.release().ptr(), getPyInterpreter()),
|
|
dispatch_key_(dispatch_key) {}
|
|
|
|
void operator()(
|
|
const c10::OperatorHandle& op,
|
|
c10::DispatchKeySet keyset,
|
|
torch::jit::Stack* stack) {
|
|
// Figure out if we can handle it hermetically, or if we have
|
|
// to double dispatch
|
|
|
|
// If Torch Dispatch Mode is active, use its PyInterpreter for dispatch
|
|
const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len();
|
|
if (mode_stack_len > 0) {
|
|
const auto& cur_torch_dispatch_mode_state =
|
|
c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
|
|
cur_torch_dispatch_mode_state->pyinterpreter()
|
|
->python_op_registration_trampoline(op, dispatch_key_, stack);
|
|
return;
|
|
}
|
|
|
|
const auto& schema = op.schema();
|
|
const auto num_arguments = schema.arguments().size();
|
|
|
|
// Otherwise, find a PyInterpreter on a Tensor IF if has Python key (which
|
|
// means it's a nontrivial tensor subclass)
|
|
for (const auto& ivalue : torch::jit::last(*stack, num_arguments)) {
|
|
if (ivalue.isTensor()) {
|
|
auto* interpreter =
|
|
ivalue.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter();
|
|
if (interpreter &&
|
|
ivalue.unsafeToTensorImpl()->key_set().has(
|
|
at::DispatchKey::Python)) {
|
|
(*interpreter)
|
|
->python_op_registration_trampoline(op, dispatch_key_, stack);
|
|
return;
|
|
}
|
|
} else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) {
|
|
// NB: use toListRef as it doesn't induce refcount bumps
|
|
// (toTensorListRef is not a thing)
|
|
for (const auto& nv : ivalue.toListRef()) {
|
|
if (nv.isNone()) {
|
|
continue;
|
|
}
|
|
auto* interpreter =
|
|
nv.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter();
|
|
if (interpreter &&
|
|
nv.unsafeToTensorImpl()->key_set().has(at::DispatchKey::Python)) {
|
|
(*interpreter)
|
|
->python_op_registration_trampoline(op, dispatch_key_, stack);
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Nothing requires the operator to be homed to a specific interpreter, so
|
|
// run it on the current interpreter
|
|
|
|
auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
|
|
py::gil_scoped_acquire g;
|
|
EnableHermeticPyObject g2;
|
|
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
|
|
auto obj = py::reinterpret_steal<py::object>(PyObject_Call(
|
|
func_.ptr(getPyInterpreter()),
|
|
args_kwargs.first.ptr(),
|
|
args_kwargs.second.ptr()));
|
|
if (!obj) {
|
|
throw python_error();
|
|
}
|
|
pushPyOutToStack(op, stack, obj, "PythonKernelHolder");
|
|
}
|
|
};
|
|
|
|
static torch::_RegisterOrVerify register_or_verify() {
|
|
if (isMainPyInterpreter()) {
|
|
return torch::_RegisterOrVerify::REGISTER;
|
|
} else {
|
|
return torch::_RegisterOrVerify::VERIFY;
|
|
}
|
|
}
|
|
|
|
static py::object ophandle_call_boxed(
|
|
const c10::OperatorHandle& handle,
|
|
py::args args,
|
|
py::kwargs kwargs) {
|
|
auto stack = torch::jit::createStackForSchema(
|
|
handle.schema(),
|
|
args,
|
|
kwargs,
|
|
/*self=*/c10::nullopt);
|
|
{
|
|
pybind11::gil_scoped_release no_gil_guard;
|
|
handle.callBoxed(stack);
|
|
}
|
|
return torch::jit::createPyObjectForStack(std::move(stack));
|
|
}
|
|
|
|
void initDispatchBindings(PyObject* module) {
|
|
auto m = py::handle(module).cast<py::module>();
|
|
|
|
py::class_<c10::OperatorHandle>(m, "_DispatchOperatorHandle")
|
|
.def("schema", &c10::OperatorHandle::schema);
|
|
|
|
m.def("_dispatch_call_boxed", &ophandle_call_boxed);
|
|
|
|
// TODO: figure out how to do chaining
|
|
py::class_<torch::Library>(m, "_DispatchModule")
|
|
// Some of these APIs are only for testing and do not work in multipy
|
|
// environment
|
|
.def(
|
|
"def_",
|
|
[](py::object self, const char* schema, const char* alias) {
|
|
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
|
|
self.cast<torch::Library&>().def(
|
|
torch::schema(schema, parseAliasAnalysisKind(alias)));
|
|
return self;
|
|
},
|
|
"",
|
|
py::arg("schema"),
|
|
py::arg("alias") = "")
|
|
// Simulated "legacy" def where alias analysis kind is not set.
|
|
// Ordinarily this can only be exercised from RegisterOperators() API
|
|
// but I am not going to bind that here
|
|
.def(
|
|
"def_legacy",
|
|
[](py::object self, const char* schema) {
|
|
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
|
|
self.cast<torch::Library&>().def(torch::jit::parseSchema(schema));
|
|
return self;
|
|
},
|
|
"",
|
|
py::arg("schema"))
|
|
// We can't conveniently turn Python functions into valid functions
|
|
// in the dispatcher. So instead we provide a bunch of precanned
|
|
// functions for testing purposes. You're NOT intended to actually
|
|
// call these functions; they're just here so we can actually register
|
|
// something
|
|
//
|
|
// Mangling scheme: args_rets. One character per.
|
|
// t = Tensor
|
|
.def(
|
|
"def_name_t_t",
|
|
[](py::object self,
|
|
const char* name,
|
|
const char* dispatch,
|
|
const char* debug) {
|
|
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
|
|
self.cast<torch::Library&>().def(
|
|
name, dispatch_str(dispatch, [](const at::Tensor& a) {
|
|
return a;
|
|
}).debug(debug));
|
|
return self;
|
|
},
|
|
"",
|
|
py::arg("name"),
|
|
py::arg("dispatch") = "",
|
|
py::arg("debug") = "default_def_name_t_t")
|
|
.def(
|
|
"def_schema_t_t",
|
|
[](py::object self,
|
|
const char* schema,
|
|
const char* dispatch,
|
|
const char* alias,
|
|
const char* debug) {
|
|
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
|
|
self.cast<torch::Library&>().def(
|
|
torch::schema(schema, parseAliasAnalysisKind(alias)),
|
|
dispatch_str(dispatch, [](const at::Tensor& a) {
|
|
return a;
|
|
}).debug(debug));
|
|
return self;
|
|
},
|
|
"",
|
|
py::arg("name"),
|
|
py::arg("dispatch") = "",
|
|
py::arg("alias") = "",
|
|
py::arg("debug") = "default_def_schema_t_t")
|
|
// TODO: maybe consider deduplicating the definitions here, it's getting
|
|
// pretty long
|
|
.def(
|
|
"impl_t_t",
|
|
[](py::object self,
|
|
const char* name,
|
|
const char* dispatch,
|
|
const char* debug) {
|
|
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
|
|
self.cast<torch::Library&>().impl(
|
|
name, dispatch_str(dispatch, [](const at::Tensor& a) {
|
|
return a;
|
|
}).debug(debug));
|
|
return self;
|
|
},
|
|
"",
|
|
py::arg("name"),
|
|
py::arg("dispatch") = "",
|
|
py::arg("debug") = "impl_t_t")
|
|
.def(
|
|
"impl",
|
|
[](py::object self,
|
|
const char* name,
|
|
// TODO: empty string no longer works
|
|
c10::DispatchKey dispatch,
|
|
py::object func) {
|
|
HANDLE_TH_ERRORS
|
|
auto& lib = self.cast<torch::Library&>();
|
|
lib.impl(
|
|
name,
|
|
torch::dispatch(
|
|
dispatch,
|
|
CppFunction::makeFromBoxedFunctor(
|
|
std::make_unique<PythonKernelHolder>(func, dispatch))),
|
|
register_or_verify());
|
|
python_registrations_[lib._resolve(name)].insert_or_assign(
|
|
dispatch,
|
|
std::make_shared<c10::SafePyObject>(
|
|
func.release().ptr(), getPyInterpreter()));
|
|
END_HANDLE_TH_ERRORS_PYBIND
|
|
},
|
|
"",
|
|
py::arg("name"),
|
|
py::arg("dispatch"),
|
|
py::arg("func"))
|
|
.def(
|
|
"define",
|
|
[](py::object self, const char* schema, const char* alias_analysis) {
|
|
auto parsed_schema =
|
|
torch::schema(schema, parseAliasAnalysisKind(alias_analysis));
|
|
self.cast<torch::Library&>().def(
|
|
std::move(parsed_schema), {}, register_or_verify());
|
|
// TODO: this is dumb, had to make a second copy
|
|
return torch::schema(schema, parseAliasAnalysisKind(alias_analysis))
|
|
.name();
|
|
},
|
|
"",
|
|
py::arg("schema"),
|
|
py::arg("alias_analysis") = "")
|
|
.def(
|
|
"fallback_fallthrough",
|
|
[](py::object self, const char* dispatch) {
|
|
TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
|
|
self.cast<torch::Library&>().fallback(
|
|
dispatch_str(dispatch, CppFunction::makeFallthrough()));
|
|
return self;
|
|
},
|
|
"",
|
|
py::arg("dispatch") = "");
|
|
|
|
m.def(
|
|
"_dispatch_library",
|
|
[](const char* kind,
|
|
std::string name,
|
|
const char* dispatch,
|
|
const char* file,
|
|
uint32_t linenum) {
|
|
HANDLE_TH_ERRORS
|
|
return std::make_unique<torch::Library>(
|
|
parseKind(kind),
|
|
std::move(name),
|
|
std::string(dispatch).empty()
|
|
? c10::nullopt
|
|
: c10::make_optional(c10::parseDispatchKey(dispatch)),
|
|
"/dev/null", // temporary workaround
|
|
linenum);
|
|
END_HANDLE_TH_ERRORS_PYBIND
|
|
},
|
|
"",
|
|
py::arg("kind"),
|
|
py::arg("name"),
|
|
py::arg("dispatch"),
|
|
py::arg("file") = "/dev/null",
|
|
py::arg("linenum") = 0);
|
|
|
|
m.def(
|
|
"_dispatch_find_schema_or_throw",
|
|
[](const char* name, const char* overload_name) -> c10::OperatorHandle {
|
|
return c10::Dispatcher::singleton().findSchemaOrThrow(
|
|
name, overload_name);
|
|
});
|
|
|
|
m.def("_dispatch_dump", [](const char* name) -> std::string {
|
|
auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
|
|
if (!op) {
|
|
return "";
|
|
} else {
|
|
return op->dumpState();
|
|
}
|
|
});
|
|
|
|
m.def("_dispatch_dump_table", [](const char* name) -> std::string {
|
|
auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
|
|
if (!op) {
|
|
return "";
|
|
} else {
|
|
return op->dumpComputedTable();
|
|
}
|
|
});
|
|
|
|
m.def("_dispatch_check_invariants", [](const char* name) {
|
|
auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
|
|
if (!op) {
|
|
} else {
|
|
return op->checkInvariants();
|
|
}
|
|
});
|
|
|
|
m.def("_dispatch_check_all_invariants", []() {
|
|
c10::Dispatcher::singleton().checkInvariants();
|
|
});
|
|
|
|
m.def("_dispatch_has_kernel", [](const char* name) -> bool {
|
|
auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
|
|
return static_cast<bool>(op);
|
|
});
|
|
|
|
m.def(
|
|
// Returns whether or not a direct kernel registration exists
|
|
// for this <op_name, dispatch_key> pair.
|
|
"_dispatch_has_kernel_for_dispatch_key",
|
|
[](const char* name, c10::DispatchKey dispatch) -> bool {
|
|
auto op =
|
|
c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
|
|
TORCH_CHECK(op, "operator ", name, " does not exist");
|
|
return op->hasKernelForDispatchKey(dispatch);
|
|
});
|
|
|
|
m.def(
|
|
"_dispatch_has_kernel_for_any_dispatch_key",
|
|
[](const char* name, c10::DispatchKeySet ks) -> bool {
|
|
auto op =
|
|
c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
|
|
TORCH_CHECK(op, "operator ", name, " does not exist");
|
|
return op->hasKernelForAnyDispatchKey(ks);
|
|
});
|
|
|
|
m.def(
|
|
// Returns whether or not there is an entry in the runtime computed
|
|
// dispatch table, for this <op_name, dispatch_key> pair. For example, if
|
|
// "op" has a `CompositeImplicitAutograd` kernel, Then
|
|
// _dispatch_has_computed_kernel_for_dispatch_key(op, backend) will return
|
|
// true for all backends that are part of the alias set for
|
|
// CompositeImplicitAutograd.
|
|
"_dispatch_has_computed_kernel_for_dispatch_key",
|
|
[](const char* name, const char* dispatch) -> bool {
|
|
auto op =
|
|
c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
|
|
TORCH_CHECK(op, "operator ", name, " does not exist");
|
|
return op->hasComputedKernelForDispatchKey(
|
|
c10::parseDispatchKey(dispatch));
|
|
});
|
|
|
|
m.def("_dispatch_find_dangling_impls", []() -> std::vector<std::string> {
|
|
auto danglingImpls = c10::Dispatcher::singleton().findDanglingImpls();
|
|
|
|
std::vector<std::string> states;
|
|
states.reserve(danglingImpls.size());
|
|
for (auto& danglingImpl : danglingImpls) {
|
|
states.emplace_back(danglingImpl.dumpState());
|
|
}
|
|
|
|
return states;
|
|
});
|
|
|
|
m.def("_dispatch_get_all_op_names", []() -> std::vector<std::string> {
|
|
auto op_names = c10::Dispatcher::singleton().getAllOpNames();
|
|
|
|
std::vector<std::string> names;
|
|
names.reserve(op_names.size());
|
|
for (auto& op : op_names) {
|
|
std::stringstream ss;
|
|
ss << op.name;
|
|
if (!op.overload_name.empty()) {
|
|
ss << "." << op.overload_name;
|
|
}
|
|
names.emplace_back(ss.str());
|
|
}
|
|
|
|
return names;
|
|
});
|
|
|
|
m.def(
|
|
"_dispatch_tls_set_dispatch_key_excluded",
|
|
[](c10::DispatchKey dispatch_key, bool desired_state) {
|
|
c10::impl::tls_set_dispatch_key_excluded(dispatch_key, desired_state);
|
|
});
|
|
m.def(
|
|
"_dispatch_tls_is_dispatch_key_excluded",
|
|
[](c10::DispatchKey dispatch_key) {
|
|
return c10::impl::tls_is_dispatch_key_excluded(dispatch_key);
|
|
});
|
|
m.def(
|
|
"_dispatch_tls_set_dispatch_key_included",
|
|
[](c10::DispatchKey dispatch_key, bool desired_state) {
|
|
c10::impl::tls_set_dispatch_key_included(dispatch_key, desired_state);
|
|
});
|
|
m.def(
|
|
"_dispatch_tls_is_dispatch_key_included",
|
|
[](c10::DispatchKey dispatch_key) {
|
|
return c10::impl::tls_is_dispatch_key_included(dispatch_key);
|
|
});
|
|
|
|
m.def("_dispatch_isTensorSubclassLike", [](const at::Tensor& tensor) {
|
|
return at::isTensorSubclassLike(tensor);
|
|
});
|
|
|
|
m.def("_dispatch_key_name", [](c10::DispatchKey k) {
|
|
return c10::toString(k);
|
|
});
|
|
m.def("_dispatch_key_parse", [](c10::DispatchKey k) { return k; });
|
|
m.def("_to_functionality_key", [](c10::DispatchKey k) {
|
|
return c10::toFunctionalityKey(k);
|
|
});
|
|
// E.g. given `DispatchKey::AutogradFunctionality`, returns a keyset of:
|
|
// AutogradCPU
|
|
// AutogradCUDA
|
|
// ...
|
|
// AutogradPrivateUse3
|
|
m.def("_functionality_to_backend_keys", [](c10::DispatchKey key) {
|
|
std::vector<c10::DispatchKey> keys;
|
|
if (c10::isPerBackendFunctionalityKey(key)) {
|
|
auto ks = c10::DispatchKeySet(key) |
|
|
c10::DispatchKeySet(c10::DispatchKeySet::RAW, c10::full_backend_mask);
|
|
for (auto k : ks) {
|
|
keys.push_back(k);
|
|
}
|
|
} else {
|
|
keys.push_back(key);
|
|
}
|
|
return keys;
|
|
});
|
|
m.def("_dispatch_num_backends", []() { return c10::num_backends; });
|
|
|
|
#define DEF_ONE(n) .value(#n, c10::DispatchKey::n)
|
|
|
|
py::enum_<c10::DispatchKey>(m, "DispatchKey")
|
|
// clang-format off
|
|
DEF_ONE(Undefined)
|
|
DEF_ONE(CompositeExplicitAutogradNonFunctional)
|
|
DEF_ONE(CompositeExplicitAutograd)
|
|
DEF_ONE(CompositeImplicitAutogradNestedTensor)
|
|
DEF_ONE(CompositeImplicitAutograd)
|
|
DEF_ONE(AutogradOther)
|
|
DEF_ONE(Autograd)
|
|
DEF_ONE(BackendSelect)
|
|
DEF_ONE(ADInplaceOrView)
|
|
DEF_ONE(PythonTLSSnapshot)
|
|
DEF_ONE(Python)
|
|
DEF_ONE(FuncTorchDynamicLayerFrontMode)
|
|
DEF_ONE(FuncTorchDynamicLayerBackMode)
|
|
DEF_ONE(FuncTorchBatchedDecomposition)
|
|
DEF_ONE(FuncTorchBatched)
|
|
DEF_ONE(FuncTorchVmapMode)
|
|
DEF_ONE(FuncTorchGradWrapper)
|
|
DEF_ONE(PythonDispatcher)
|
|
DEF_ONE(PreDispatch)
|
|
DEF_ONE(Functionalize)
|
|
DEF_ONE(AutocastCPU)
|
|
DEF_ONE(AutocastXPU)
|
|
DEF_ONE(AutocastHPU)
|
|
DEF_ONE(AutocastCUDA)
|
|
DEF_ONE(AutocastPrivateUse1)
|
|
// clang-format on
|
|
|
|
#define DEF_SINGLE(n, prefix) .value(#prefix #n, c10::DispatchKey::prefix##n)
|
|
#define DEF_MULTIPLE(fullname, prefix) \
|
|
DEF_SINGLE(, fullname) \
|
|
DEF_SINGLE(, StartOf##fullname##Backends) \
|
|
C10_FORALL_BACKEND_COMPONENTS(DEF_SINGLE, prefix) \
|
|
DEF_SINGLE(, EndOf##fullname##Backends)
|
|
|
|
// clang-format off
|
|
C10_FORALL_FUNCTIONALITY_KEYS(DEF_MULTIPLE)
|
|
// clang-format on
|
|
|
|
#undef DEF_MULTIPLE
|
|
#undef DEF_SINGLE
|
|
;
|
|
|
|
py::class_<c10::DispatchKeySet>(m, "DispatchKeySet")
|
|
.def(py::init<c10::DispatchKey>())
|
|
.def("__or__", &c10::DispatchKeySet::operator|)
|
|
.def("__sub__", &c10::DispatchKeySet::operator-)
|
|
.def("__and__", &c10::DispatchKeySet::operator&)
|
|
.def("highestPriorityTypeId", &c10::DispatchKeySet::highestPriorityTypeId)
|
|
.def(
|
|
"remove",
|
|
[](c10::DispatchKeySet self, c10::DispatchKey k) {
|
|
return self.remove(k);
|
|
})
|
|
.def(
|
|
"add",
|
|
[](c10::DispatchKeySet self, c10::DispatchKey k) {
|
|
return self.add(k);
|
|
})
|
|
.def("has", &c10::DispatchKeySet::has)
|
|
.def("__repr__", [](c10::DispatchKeySet d) { return c10::toString(d); });
|
|
|
|
m.attr("_dispatch_autogradother_backends") =
|
|
py::cast(c10::autogradother_backends);
|
|
|
|
m.def("_dispatch_has_backend_fallback", [](c10::DispatchKey t) {
|
|
return c10::Dispatcher::singleton().hasBackendFallbackForDispatchKey(t);
|
|
});
|
|
|
|
m.def("_dispatch_keyset_full_after", [](c10::DispatchKey t) {
|
|
return c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, t);
|
|
});
|
|
|
|
m.def("_dispatch_keyset_full", []() {
|
|
return c10::DispatchKeySet(c10::DispatchKeySet::FULL);
|
|
});
|
|
|
|
m.def("_dispatch_is_alias_key", c10::isAliasDispatchKey);
|
|
|
|
m.def("_dispatch_keyset_to_string", [](c10::DispatchKeySet keyset) {
|
|
return c10::toString(keyset);
|
|
});
|
|
|
|
m.def("_dispatch_get_backend_keyset_from_autograd", [](c10::DispatchKey k) {
|
|
return c10::getBackendKeySetFromAutograd(k);
|
|
});
|
|
|
|
m.def("_dispatch_keys", [](const at::Tensor& tensor) {
|
|
auto* impl = tensor.unsafeGetTensorImpl();
|
|
return impl->key_set();
|
|
});
|
|
m.def("_dispatch_tls_local_include_set", []() {
|
|
return c10::impl::tls_local_dispatch_key_set().included_;
|
|
});
|
|
m.def("_dispatch_tls_local_exclude_set", []() {
|
|
return c10::impl::tls_local_dispatch_key_set().excluded_;
|
|
});
|
|
m.def("_functionalization_reapply_views_tls", []() {
|
|
return at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
|
|
});
|
|
m.def(
|
|
"_dispatch_is_included_in_alias",
|
|
[](c10::DispatchKey a, c10::DispatchKey b) {
|
|
return c10::isIncludedInAlias(a, b);
|
|
});
|
|
|
|
// DEPRECATED, please don't use this. Instead use
|
|
// torch._C._ExcludeDispatchKeyGuard
|
|
py_context_manager_DEPRECATED<
|
|
c10::impl::ExcludeDispatchKeyGuard,
|
|
c10::DispatchKeySet>(m, "ExcludeDispatchKeyGuard");
|
|
|
|
py_context_manager<c10::impl::ExcludeDispatchKeyGuard, c10::DispatchKeySet>(
|
|
m, "_ExcludeDispatchKeyGuard");
|
|
|
|
py_context_manager_DEPRECATED<at::AutoDispatchBelowAutograd>(
|
|
m, "_AutoDispatchBelowAutograd");
|
|
|
|
// Prints out the name of every operator that has a kernel registered to the
|
|
// Dispatcher under [dispatch_key]. If no arguments are specified, it'll print
|
|
// out the name of every operator that the Dispatcher knows of. This can be
|
|
// useful to answer questions like "list all operators that do not have a CPU
|
|
// kernel".
|
|
m.def(
|
|
"_dispatch_print_registrations_for_dispatch_key",
|
|
[](const char* dispatch_key = "") {
|
|
auto k = std::string(dispatch_key).empty()
|
|
? c10::nullopt
|
|
: c10::make_optional(c10::parseDispatchKey(dispatch_key));
|
|
auto op_names =
|
|
c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k);
|
|
for (auto& op : op_names) {
|
|
std::cout << op << std::endl;
|
|
}
|
|
},
|
|
py::arg("dispatch_key") = static_cast<const char*>(""));
|
|
|
|
m.def(
|
|
"_dispatch_get_registrations_for_dispatch_key",
|
|
[](const char* dispatch_key = "") {
|
|
auto k = std::string(dispatch_key).empty()
|
|
? c10::nullopt
|
|
: c10::make_optional(c10::parseDispatchKey(dispatch_key));
|
|
auto op_names =
|
|
c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k);
|
|
std::vector<std::string> names;
|
|
names.reserve(op_names.size());
|
|
for (auto& op : op_names) {
|
|
names.emplace_back(
|
|
op.name +
|
|
(op.overload_name.empty() ? "" : "." + op.overload_name));
|
|
}
|
|
return names;
|
|
},
|
|
py::arg("dispatch_key") = static_cast<const char*>(""));
|
|
m.def(
|
|
"_dispatch_set_report_error_callback",
|
|
[](c10::OperatorHandle& handle, py::object callback) {
|
|
auto obj = callback.release().ptr();
|
|
auto callback_obj =
|
|
std::make_unique<c10::SafePyObject>(obj, getPyInterpreter());
|
|
handle.setReportErrorCallback_(std::move(callback_obj));
|
|
});
|
|
|
|
m.def(
|
|
"_dispatch_is_main_interpreter", []() { return isMainPyInterpreter(); });
|
|
|
|
m.def("_replace_", [](const at::Tensor& a, const at::Tensor& b) {
|
|
return at::functionalization::impl::replace_(a, b);
|
|
});
|
|
m.def("_propagate_xla_data", [](const at::Tensor& a, const at::Tensor& b) {
|
|
at::functionalization::impl::propagate_xla_data(a, b);
|
|
});
|
|
m.def("_commit_update", [](const at::Tensor& a) {
|
|
return at::functionalization::impl::commit_update(a);
|
|
});
|
|
|
|
m.def("_are_functorch_transforms_active", []() {
|
|
auto include_set = c10::impl::tls_local_dispatch_key_set().included_;
|
|
return (
|
|
include_set.has(c10::DispatchKey::FuncTorchDynamicLayerFrontMode) ||
|
|
include_set.has(c10::DispatchKey::FuncTorchDynamicLayerBackMode));
|
|
});
|
|
}
|
|
|
|
// TODO: dedupe with the kernel
|
|
void python_op_registration_trampoline_impl(
|
|
const c10::OperatorHandle& op,
|
|
c10::DispatchKey key,
|
|
torch::jit::Stack* stack) {
|
|
auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
|
|
py::gil_scoped_acquire g;
|
|
auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
|
|
const auto& func = python_registrations_[op.operator_name()][key];
|
|
TORCH_INTERNAL_ASSERT(func != nullptr);
|
|
auto* pyobj = func->ptr(getPyInterpreter());
|
|
TORCH_INTERNAL_ASSERT(pyobj != nullptr);
|
|
auto obj = py::reinterpret_steal<py::object>(
|
|
PyObject_Call(pyobj, args_kwargs.first.ptr(), args_kwargs.second.ptr()));
|
|
if (!obj) {
|
|
throw python_error();
|
|
}
|
|
pushPyOutToStack(op, stack, obj, "PythonKernelHolder");
|
|
}
|
|
|
|
} // namespace dispatch
|
|
} // namespace impl
|
|
} // namespace torch
|