[JIT] Fix stateful lambda stuff and simplify code in custom C++ binding API

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32658

Test Plan: Imported from OSS

Differential Revision: D19584701

Pulled By: jamesr66a

fbshipit-source-id: d556c7db2f32900eb1122348402789b59516a7d7
This commit is contained in:
James Reed
2020-01-28 10:58:28 -08:00
committed by Facebook Github Bot
parent 465ebd58ba
commit 0ea65d63cf
2 changed files with 52 additions and 74 deletions

View File

@ -81,31 +81,13 @@ class class_ {
object->setSlot(0, capsule);
};
defineMethod<void>("__init__", std::move(func));
defineMethod("__init__", std::move(func));
return *this;
}
template <
typename Method,
std::enable_if_t<
std::is_member_function_pointer<std::decay_t<Method>>::value,
bool> = false>
class_& def(std::string name, Method&& m) {
auto res = def_(
std::move(name),
std::forward<Method>(m),
detail::args_t<std::remove_reference_t<decltype(m)>>{});
return *this;
}
template <
typename Func,
std::enable_if_t<
!std::is_member_function_pointer<std::decay_t<Func>>::value,
bool> = false>
class_& def(std::string name, Func&& f) {
auto res = def_(
std::move(name),
std::forward<Func>(f),
detail::args_t<std::remove_reference_t<decltype(&Func::operator())>>{});
template <typename Func>
class_& def(std::string name, Func f) {
auto wrapped_f = detail::wrap_func<CurClass, Func>(std::move(f));
defineMethod(std::move(name), std::move(wrapped_f));
return *this;
}
@ -139,7 +121,10 @@ class class_ {
auto object = self.ivalue.toObject();
object->setSlot(0, capsule);
};
defineMethod<void>("__setstate__", std::move(setstate_wrapper));
defineMethod(
"__setstate__",
detail::wrap_func<CurClass, decltype(setstate_wrapper)>(
std::move(setstate_wrapper)));
// type validation
auto getstate_schema = classTypePtr->getMethod("__getstate__")->getSchema();
@ -176,7 +161,7 @@ class class_ {
}
private:
template<typename R, typename Func>
template <typename Func>
void defineMethod(std::string name, Func func) {
auto graph = std::make_shared<Graph>();
auto qualFuncName = className + "::" + name;
@ -213,45 +198,6 @@ class class_ {
auto method = classCU()->create_function(qualClassName + "." + name, graph);
classTypePtr->addMethod(method);
}
template <
typename Func,
typename R,
typename... Types,
std::enable_if_t<
std::is_member_function_pointer<std::decay_t<Func>>::value,
bool> = false>
class_& def_(std::string name, Func f, detail::types<R, Types...> funcInfo) {
auto func = [f = std::move(f)](
c10::intrusive_ptr<CurClass> cur, Types... args) {
return at::guts::invoke(f, *cur, args...);
};
defineMethod<R>(std::move(name), std::move(func));
return *this;
}
template <typename R, typename Head, typename... Tail>
void assert_self_type(detail::types<R, Head, Tail...> funcInfo) {
static_assert(
std::is_same<std::decay_t<Head>, c10::intrusive_ptr<CurClass>>::value,
"First argument of a registered lambda method must be an intrusive_ptr<> of the corresponding class.");
}
template <
typename Func,
typename R,
typename... Types,
std::enable_if_t<
!std::is_member_function_pointer<std::decay_t<Func>>::value,
bool> = false>
class_& def_(
std::string name,
Func&& f,
detail::types<R, Types...> funcInfo) {
assert_self_type(funcInfo);
defineMethod<R>(std::move(name), std::forward<Func>(f));
return *this;
}
};
} // namespace jit

View File

@ -14,19 +14,51 @@ struct types {
using type = types;
};
template <class Sig>
struct args;
template <typename Method>
struct WrapMethod;
// Method
template <class R, class CurClass, class... Args>
struct args<R (CurClass::*)(Args...)> : types<R, Args...> {};
template <typename R, typename CurrClass, typename... Args>
struct WrapMethod<R (CurrClass::*)(Args...)> {
WrapMethod(R (CurrClass::*m)(Args...)) : m(std::move(m)) {}
// Const method
template <class R, class CurClass, class... Args>
struct args<R (CurClass::*)(Args...) const> : types<R, Args...> {};
R operator()(c10::intrusive_ptr<CurrClass> cur, Args... args) {
return c10::guts::invoke(m, *cur, args...);
}
template <class Sig>
using args_t = typename args<Sig>::type;
R (CurrClass::*m)(Args...);
};
template <typename R, typename CurrClass, typename... Args>
struct WrapMethod<R (CurrClass::*)(Args...) const> {
WrapMethod(R (CurrClass::*m)(Args...) const) : m(std::move(m)) {}
R operator()(c10::intrusive_ptr<CurrClass> cur, Args... args) {
return c10::guts::invoke(m, *cur, args...);
}
R (CurrClass::*m)(Args...) const;
};
// Adapter for different callable types
template <
typename CurClass,
typename Func,
std::enable_if_t<
std::is_member_function_pointer<std::decay_t<Func>>::value,
bool> = false>
WrapMethod<Func> wrap_func(Func f) {
return WrapMethod<Func>(std::move(f));
}
template <
typename CurClass,
typename Func,
std::enable_if_t<
!std::is_member_function_pointer<std::decay_t<Func>>::value,
bool> = false>
Func wrap_func(Func f) {
return f;
}
} // namespace detail