mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[torch.fx] Add support for custom op (#43248)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/43248 We add the support of __torch_function__ override for C++ custom op. The logic is the same as the other components, like torch.nn.Module. Refactored some code a little bit to make it reusable. Test Plan: buck test //caffe2/test:fx -- test_torch_custom_ops Reviewed By: bradleyhd Differential Revision: D23203204 fbshipit-source-id: c462a86e407e46c777171da32d7a40860acf061e
This commit is contained in:
committed by
Facebook GitHub Bot
parent
7a77d1c5c2
commit
f15e27265f
@ -1,4 +1,5 @@
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/csrc/utils/python_arg_parser.h>
|
||||
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/backends/backend_init.h>
|
||||
@ -899,13 +900,67 @@ void initJITBindings(PyObject* module) {
|
||||
docstring << " " << op->schema() << "\n";
|
||||
}
|
||||
|
||||
return py::cpp_function(
|
||||
[operations](py::args args, py::kwargs kwargs) {
|
||||
auto func = py::cpp_function(
|
||||
[operations, symbol](py::args args, py::kwargs kwargs) {
|
||||
std::vector<py::handle> overloaded_args;
|
||||
size_t total_arg_num = args.size() + kwargs.size();
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
is_tensor_and_append_overloaded(
|
||||
args[i].ptr(), &overloaded_args);
|
||||
is_tensor_list_and_append_overloaded(
|
||||
args[i].ptr(),
|
||||
&overloaded_args,
|
||||
static_cast<int>(total_arg_num),
|
||||
false /* throw_error */);
|
||||
}
|
||||
// NB: for kwargs, we cannot guarantee the order of appending
|
||||
// is the same as the argument order in operator's schema.
|
||||
// This is suboptimal, but should be fine. Later when we have
|
||||
// better schema matching and argument parsing, we could
|
||||
// match the operator in `operations` first, then the order will
|
||||
// be guaranteed.
|
||||
for (auto item : kwargs) {
|
||||
is_tensor_and_append_overloaded(
|
||||
item.second.ptr(), &overloaded_args);
|
||||
is_tensor_list_and_append_overloaded(
|
||||
item.second.ptr(),
|
||||
&overloaded_args,
|
||||
total_arg_num,
|
||||
false /* throw_error */);
|
||||
}
|
||||
if (overloaded_args.size() > 0) {
|
||||
std::vector<py::object> overloaded_types;
|
||||
overloaded_types.reserve(overloaded_args.size());
|
||||
for (auto& oarg : overloaded_args) {
|
||||
overloaded_types.push_back(
|
||||
py::reinterpret_borrow<py::object>(
|
||||
(PyObject*)Py_TYPE(oarg.ptr())));
|
||||
}
|
||||
py::tuple py_types = py::cast(overloaded_types);
|
||||
py::object ret;
|
||||
std::string ns = symbol.ns().toUnqualString();
|
||||
std::string method_name = symbol.toUnqualString();
|
||||
auto self_func = py::module::import("torch")
|
||||
.attr("ops")
|
||||
.attr(ns.c_str())
|
||||
.attr(method_name.c_str());
|
||||
std::string module_name("torch.ops");
|
||||
module_name.append(ns);
|
||||
return pybind11::reinterpret_steal<py::object>(
|
||||
handle_torch_function_no_python_arg_parser(
|
||||
overloaded_args,
|
||||
args.ptr(),
|
||||
kwargs.ptr(),
|
||||
method_name.c_str(),
|
||||
self_func.ptr(),
|
||||
module_name.c_str()));
|
||||
}
|
||||
return invokeOperatorFromPython(
|
||||
operations, std::move(args), std::move(kwargs));
|
||||
},
|
||||
py::name(symbol.toUnqualString()),
|
||||
py::doc(docstring.str().c_str()));
|
||||
return func;
|
||||
} catch (const c10::Error& error) {
|
||||
throw std::runtime_error(error.what_without_backtrace());
|
||||
}
|
||||
|
Reference in New Issue
Block a user