[torch.fx] Add support for custom op (#43248)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43248

We add the support of __torch_function__ override for C++ custom op. The logic is the same as the other components, like torch.nn.Module.
Refactored some code a little bit to make it reusable.

Test Plan: buck test //caffe2/test:fx -- test_torch_custom_ops

Reviewed By: bradleyhd

Differential Revision: D23203204

fbshipit-source-id: c462a86e407e46c777171da32d7a40860acf061e
This commit is contained in:
Lu Fang
2020-09-02 16:06:42 -07:00
committed by Facebook GitHub Bot
parent 7a77d1c5c2
commit f15e27265f
4 changed files with 165 additions and 38 deletions

View File

@ -1,4 +1,5 @@
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_arg_parser.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/backends/backend_init.h>
@ -899,13 +900,67 @@ void initJITBindings(PyObject* module) {
docstring << " " << op->schema() << "\n";
}
return py::cpp_function(
[operations](py::args args, py::kwargs kwargs) {
auto func = py::cpp_function(
[operations, symbol](py::args args, py::kwargs kwargs) {
std::vector<py::handle> overloaded_args;
size_t total_arg_num = args.size() + kwargs.size();
for (size_t i = 0; i < args.size(); ++i) {
is_tensor_and_append_overloaded(
args[i].ptr(), &overloaded_args);
is_tensor_list_and_append_overloaded(
args[i].ptr(),
&overloaded_args,
static_cast<int>(total_arg_num),
false /* throw_error */);
}
// NB: for kwargs, we cannot guarantee the order of appending
// is the same as the argument order in operator's schema.
// This is suboptimal, but should be fine. Later when we have
// better schema matching and argument parsing, we could
// match the operator in `operations` first, then the order will
// be guaranteed.
for (auto item : kwargs) {
is_tensor_and_append_overloaded(
item.second.ptr(), &overloaded_args);
is_tensor_list_and_append_overloaded(
item.second.ptr(),
&overloaded_args,
total_arg_num,
false /* throw_error */);
}
if (overloaded_args.size() > 0) {
std::vector<py::object> overloaded_types;
overloaded_types.reserve(overloaded_args.size());
for (auto& oarg : overloaded_args) {
overloaded_types.push_back(
py::reinterpret_borrow<py::object>(
(PyObject*)Py_TYPE(oarg.ptr())));
}
py::tuple py_types = py::cast(overloaded_types);
py::object ret;
std::string ns = symbol.ns().toUnqualString();
std::string method_name = symbol.toUnqualString();
auto self_func = py::module::import("torch")
.attr("ops")
.attr(ns.c_str())
.attr(method_name.c_str());
std::string module_name("torch.ops");
module_name.append(ns);
return pybind11::reinterpret_steal<py::object>(
handle_torch_function_no_python_arg_parser(
overloaded_args,
args.ptr(),
kwargs.ptr(),
method_name.c_str(),
self_func.ptr(),
module_name.c_str()));
}
return invokeOperatorFromPython(
operations, std::move(args), std::move(kwargs));
},
py::name(symbol.toUnqualString()),
py::doc(docstring.str().c_str()));
return func;
} catch (const c10::Error& error) {
throw std::runtime_error(error.what_without_backtrace());
}