mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[TorchBind] Support using lambda function as TorchBind constructor (#47819)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47819 Reviewed By: wanchaol Differential Revision: D24910065 Pulled By: gmagogsfm fbshipit-source-id: ad5b4f67b0367e44fe486d31a060d9ad1e0cf568
This commit is contained in:
committed by
Facebook GitHub Bot
parent
b6cb2caa68
commit
00a3add425
@ -27,6 +27,21 @@ detail::types<void, Types...> init() {
|
||||
return detail::types<void, Types...>{};
|
||||
}
|
||||
|
||||
template <typename Func, typename... ParameterTypeList>
|
||||
struct InitLambda {
|
||||
Func f;
|
||||
};
|
||||
|
||||
template <typename Func>
|
||||
decltype(auto) init(Func&& f) {
|
||||
using InitTraits =
|
||||
c10::guts::infer_function_traits_t<std::decay_t<Func>>;
|
||||
using ParameterTypeList = typename InitTraits::parameter_types;
|
||||
|
||||
InitLambda<Func, ParameterTypeList> init{std::forward<Func>(f)};
|
||||
return init;
|
||||
}
|
||||
|
||||
/// Entry point for custom C++ class registration. To register a C++ class
|
||||
/// in PyTorch, instantiate `torch::class_` with the desired class as the
|
||||
/// template parameter. Typically, this instantiation should be done in
|
||||
@ -95,6 +110,24 @@ class class_ {
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Used in combination with torch::init([]lambda(){......})
|
||||
template <typename Func, typename... ParameterTypes>
|
||||
class_& def(
|
||||
InitLambda<Func, c10::guts::typelist::typelist<ParameterTypes...>> init,
|
||||
std::string doc_string = "") {
|
||||
auto init_lambda_wrapper = [func = std::move(init.f)](
|
||||
c10::tagged_capsule<CurClass> self,
|
||||
ParameterTypes... arg) {
|
||||
c10::intrusive_ptr<CurClass> classObj =
|
||||
at::guts::invoke(func, std::forward<ParameterTypes>(arg)...);
|
||||
auto object = self.ivalue.toObject();
|
||||
object->setSlot(0, c10::IValue::make_capsule(classObj));
|
||||
};
|
||||
defineMethod("__init__", std::move(init_lambda_wrapper), std::move(doc_string));
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// This is the normal method registration API. `name` is the name that
|
||||
/// the method will be made accessible by in Python and TorchScript.
|
||||
/// `f` is a callable object that defines the method. Typically `f`
|
||||
|
Reference in New Issue
Block a user