mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Back out "Revert D21089648: Put TORCH_LIBRARY in torch/library.h; add custom class API"
Summary: Original commit changeset: 636e8a11afc6 Test Plan: export to OSS Reviewed By: malfet Differential Revision: D21170502 fbshipit-source-id: e8f35f103c4924aedbcaaf868475008d24bdeeab
This commit is contained in:
committed by
Facebook GitHub Bot
parent
3b832ee2bf
commit
a894fff265
@ -1,7 +1,7 @@
|
||||
#include <torch/csrc/utils/python_dispatch.h>
|
||||
#include <torch/csrc/jit/frontend/function_schema_parser.h>
|
||||
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <pybind11/operators.h>
|
||||
@ -45,12 +45,12 @@ c10::AliasAnalysisKind parseAliasAnalysisKind(const std::string& k) {
|
||||
|
||||
|
||||
template <typename Func>
|
||||
inline c10::CppFunction dispatch_str(const char* key, Func&& raw_f) {
|
||||
inline torch::CppFunction dispatch_str(const char* key, Func&& raw_f) {
|
||||
auto mb_key = parseDispatchKey(key);
|
||||
if (mb_key) {
|
||||
return c10::dispatch(*mb_key, std::move(raw_f));
|
||||
return torch::dispatch(*mb_key, std::forward<Func>(raw_f));
|
||||
} else {
|
||||
c10::CppFunction f(std::forward<Func>(raw_f));
|
||||
torch::CppFunction f(std::forward<Func>(raw_f));
|
||||
return f;
|
||||
}
|
||||
}
|
||||
@ -62,16 +62,16 @@ void initDispatchBindings(PyObject* module) {
|
||||
.def("schema", &c10::OperatorHandle::schema);
|
||||
|
||||
// TODO: figure out how to do chaining
|
||||
py::class_<c10::Library>(m, "_DispatchModule")
|
||||
py::class_<torch::Library>(m, "_DispatchModule")
|
||||
.def("def_", [](py::object self, const char* schema, const char* alias) {
|
||||
self.cast<c10::Library&>().def(torch::schema(schema, parseAliasAnalysisKind(alias)));
|
||||
self.cast<torch::Library&>().def(torch::schema(schema, parseAliasAnalysisKind(alias)));
|
||||
return self;
|
||||
}, "", py::arg("schema"), py::arg("alias") = "")
|
||||
// Simulated "legacy" def where alias analysis kind is not set.
|
||||
// Ordinarily this can only be exercised from RegisterOperators() API
|
||||
// but I am not going to bind that here
|
||||
.def("def_legacy", [](py::object self, const char* schema) {
|
||||
self.cast<c10::Library&>().def(torch::jit::parseSchema(schema));
|
||||
self.cast<torch::Library&>().def(torch::jit::parseSchema(schema));
|
||||
return self;
|
||||
}, "", py::arg("schema"))
|
||||
// We can't conveniently turn Python functions into valid functions
|
||||
@ -83,7 +83,7 @@ void initDispatchBindings(PyObject* module) {
|
||||
// Mangling scheme: args_rets. One character per.
|
||||
// t = Tensor
|
||||
.def("def_name_t_t", [](py::object self, const char* name, const char* dispatch, const char* debug) {
|
||||
self.cast<c10::Library&>().def(
|
||||
self.cast<torch::Library&>().def(
|
||||
name,
|
||||
dispatch_str(dispatch, [](const at::Tensor& a) {
|
||||
return a;
|
||||
@ -94,7 +94,7 @@ void initDispatchBindings(PyObject* module) {
|
||||
py::arg("dispatch") = "",
|
||||
py::arg("debug") = "default_def_name_t_t")
|
||||
.def("def_schema_t_t", [](py::object self, const char* schema, const char* dispatch, const char* alias, const char* debug) {
|
||||
self.cast<c10::Library&>().def(
|
||||
self.cast<torch::Library&>().def(
|
||||
torch::schema(schema, parseAliasAnalysisKind(alias)),
|
||||
dispatch_str(dispatch, [](const at::Tensor& a) {
|
||||
return a;
|
||||
@ -108,7 +108,7 @@ void initDispatchBindings(PyObject* module) {
|
||||
// TODO: maybe consider deduplicating the definitions here, it's getting
|
||||
// pretty long
|
||||
.def("impl_t_t", [](py::object self, const char* name, const char* dispatch, const char* debug) {
|
||||
self.cast<c10::Library&>().impl(
|
||||
self.cast<torch::Library&>().impl(
|
||||
name,
|
||||
dispatch_str(dispatch, [](const at::Tensor& a) {
|
||||
return a;
|
||||
@ -119,7 +119,7 @@ void initDispatchBindings(PyObject* module) {
|
||||
py::arg("dispatch") = "",
|
||||
py::arg("debug") = "impl_t_t")
|
||||
.def("impl_tt_t", [](py::object self, const char* name, const char* dispatch, const char* debug) {
|
||||
self.cast<c10::Library&>().impl(
|
||||
self.cast<torch::Library&>().impl(
|
||||
name,
|
||||
dispatch_str(dispatch, [](const at::Tensor& a, const at::Tensor& b) {
|
||||
return a;
|
||||
@ -133,9 +133,9 @@ void initDispatchBindings(PyObject* module) {
|
||||
// This is a wee bit dodgy right now, but the "underlying" API is much
|
||||
// easier to test than the high level (using TORCH_LIBRARY, e.g.)
|
||||
if (name.empty()) {
|
||||
return std::make_unique<c10::Library>(c10::Library::FRAGMENT, "_", c10::DispatchKey::CatchAll, "/dev/null", 0);
|
||||
return std::make_unique<torch::Library>(torch::Library::FRAGMENT, "_", c10::DispatchKey::CatchAll, "/dev/null", 0);
|
||||
} else {
|
||||
return std::make_unique<c10::Library>(c10::Library::FRAGMENT, name, c10::nullopt, "/dev/null", 0);
|
||||
return std::make_unique<torch::Library>(torch::Library::FRAGMENT, name, c10::nullopt, "/dev/null", 0);
|
||||
}
|
||||
});
|
||||
|
||||
|
Reference in New Issue
Block a user