mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[JIT] Fix stateful lambda stuff and simplify code in custom C++ binding API
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32658 Test Plan: Imported from OSS Differential Revision: D19584701 Pulled By: jamesr66a fbshipit-source-id: d556c7db2f32900eb1122348402789b59516a7d7
This commit is contained in:
committed by
Facebook Github Bot
parent
465ebd58ba
commit
0ea65d63cf
@ -81,31 +81,13 @@ class class_ {
|
||||
object->setSlot(0, capsule);
|
||||
};
|
||||
|
||||
defineMethod<void>("__init__", std::move(func));
|
||||
defineMethod("__init__", std::move(func));
|
||||
return *this;
|
||||
}
|
||||
template <
|
||||
typename Method,
|
||||
std::enable_if_t<
|
||||
std::is_member_function_pointer<std::decay_t<Method>>::value,
|
||||
bool> = false>
|
||||
class_& def(std::string name, Method&& m) {
|
||||
auto res = def_(
|
||||
std::move(name),
|
||||
std::forward<Method>(m),
|
||||
detail::args_t<std::remove_reference_t<decltype(m)>>{});
|
||||
return *this;
|
||||
}
|
||||
template <
|
||||
typename Func,
|
||||
std::enable_if_t<
|
||||
!std::is_member_function_pointer<std::decay_t<Func>>::value,
|
||||
bool> = false>
|
||||
class_& def(std::string name, Func&& f) {
|
||||
auto res = def_(
|
||||
std::move(name),
|
||||
std::forward<Func>(f),
|
||||
detail::args_t<std::remove_reference_t<decltype(&Func::operator())>>{});
|
||||
template <typename Func>
|
||||
class_& def(std::string name, Func f) {
|
||||
auto wrapped_f = detail::wrap_func<CurClass, Func>(std::move(f));
|
||||
defineMethod(std::move(name), std::move(wrapped_f));
|
||||
return *this;
|
||||
}
|
||||
|
||||
@ -139,7 +121,10 @@ class class_ {
|
||||
auto object = self.ivalue.toObject();
|
||||
object->setSlot(0, capsule);
|
||||
};
|
||||
defineMethod<void>("__setstate__", std::move(setstate_wrapper));
|
||||
defineMethod(
|
||||
"__setstate__",
|
||||
detail::wrap_func<CurClass, decltype(setstate_wrapper)>(
|
||||
std::move(setstate_wrapper)));
|
||||
|
||||
// type validation
|
||||
auto getstate_schema = classTypePtr->getMethod("__getstate__")->getSchema();
|
||||
@ -176,7 +161,7 @@ class class_ {
|
||||
}
|
||||
|
||||
private:
|
||||
template<typename R, typename Func>
|
||||
template <typename Func>
|
||||
void defineMethod(std::string name, Func func) {
|
||||
auto graph = std::make_shared<Graph>();
|
||||
auto qualFuncName = className + "::" + name;
|
||||
@ -213,45 +198,6 @@ class class_ {
|
||||
auto method = classCU()->create_function(qualClassName + "." + name, graph);
|
||||
classTypePtr->addMethod(method);
|
||||
}
|
||||
|
||||
template <
|
||||
typename Func,
|
||||
typename R,
|
||||
typename... Types,
|
||||
std::enable_if_t<
|
||||
std::is_member_function_pointer<std::decay_t<Func>>::value,
|
||||
bool> = false>
|
||||
class_& def_(std::string name, Func f, detail::types<R, Types...> funcInfo) {
|
||||
auto func = [f = std::move(f)](
|
||||
c10::intrusive_ptr<CurClass> cur, Types... args) {
|
||||
return at::guts::invoke(f, *cur, args...);
|
||||
};
|
||||
defineMethod<R>(std::move(name), std::move(func));
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <typename R, typename Head, typename... Tail>
|
||||
void assert_self_type(detail::types<R, Head, Tail...> funcInfo) {
|
||||
static_assert(
|
||||
std::is_same<std::decay_t<Head>, c10::intrusive_ptr<CurClass>>::value,
|
||||
"First argument of a registered lambda method must be an intrusive_ptr<> of the corresponding class.");
|
||||
}
|
||||
|
||||
template <
|
||||
typename Func,
|
||||
typename R,
|
||||
typename... Types,
|
||||
std::enable_if_t<
|
||||
!std::is_member_function_pointer<std::decay_t<Func>>::value,
|
||||
bool> = false>
|
||||
class_& def_(
|
||||
std::string name,
|
||||
Func&& f,
|
||||
detail::types<R, Types...> funcInfo) {
|
||||
assert_self_type(funcInfo);
|
||||
defineMethod<R>(std::move(name), std::forward<Func>(f));
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace jit
|
||||
|
@ -14,19 +14,51 @@ struct types {
|
||||
using type = types;
|
||||
};
|
||||
|
||||
template <class Sig>
|
||||
struct args;
|
||||
template <typename Method>
|
||||
struct WrapMethod;
|
||||
|
||||
// Method
|
||||
template <class R, class CurClass, class... Args>
|
||||
struct args<R (CurClass::*)(Args...)> : types<R, Args...> {};
|
||||
template <typename R, typename CurrClass, typename... Args>
|
||||
struct WrapMethod<R (CurrClass::*)(Args...)> {
|
||||
WrapMethod(R (CurrClass::*m)(Args...)) : m(std::move(m)) {}
|
||||
|
||||
// Const method
|
||||
template <class R, class CurClass, class... Args>
|
||||
struct args<R (CurClass::*)(Args...) const> : types<R, Args...> {};
|
||||
R operator()(c10::intrusive_ptr<CurrClass> cur, Args... args) {
|
||||
return c10::guts::invoke(m, *cur, args...);
|
||||
}
|
||||
|
||||
template <class Sig>
|
||||
using args_t = typename args<Sig>::type;
|
||||
R (CurrClass::*m)(Args...);
|
||||
};
|
||||
|
||||
template <typename R, typename CurrClass, typename... Args>
|
||||
struct WrapMethod<R (CurrClass::*)(Args...) const> {
|
||||
WrapMethod(R (CurrClass::*m)(Args...) const) : m(std::move(m)) {}
|
||||
|
||||
R operator()(c10::intrusive_ptr<CurrClass> cur, Args... args) {
|
||||
return c10::guts::invoke(m, *cur, args...);
|
||||
}
|
||||
|
||||
R (CurrClass::*m)(Args...) const;
|
||||
};
|
||||
|
||||
// Adapter for different callable types
|
||||
template <
|
||||
typename CurClass,
|
||||
typename Func,
|
||||
std::enable_if_t<
|
||||
std::is_member_function_pointer<std::decay_t<Func>>::value,
|
||||
bool> = false>
|
||||
WrapMethod<Func> wrap_func(Func f) {
|
||||
return WrapMethod<Func>(std::move(f));
|
||||
}
|
||||
|
||||
template <
|
||||
typename CurClass,
|
||||
typename Func,
|
||||
std::enable_if_t<
|
||||
!std::is_member_function_pointer<std::decay_t<Func>>::value,
|
||||
bool> = false>
|
||||
Func wrap_func(Func f) {
|
||||
return f;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
|
Reference in New Issue
Block a user