[TorchBind] Support using lambda function as TorchBind constructor (#47819)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47819

Reviewed By: wanchaol

Differential Revision: D24910065

Pulled By: gmagogsfm

fbshipit-source-id: ad5b4f67b0367e44fe486d31a060d9ad1e0cf568
This commit is contained in:
Yanan Cao
2020-11-12 09:27:21 -08:00
committed by Facebook GitHub Bot
parent b6cb2caa68
commit 00a3add425
3 changed files with 58 additions and 0 deletions

View File

@ -33,6 +33,14 @@ struct Foo : torch::CustomClassHolder {
}
};
struct LambdaInit : torch::CustomClassHolder {
int x, y;
LambdaInit(int x_, int y_) : x(x_), y(y_) {}
int64_t diff() {
return this->x - this->y;
}
};
struct NoInit : torch::CustomClassHolder {
int64_t x;
};
@ -202,6 +210,16 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
.def("add", &Foo::add)
.def("combine", &Foo::combine);
m.class_<LambdaInit>("_LambdaInit")
.def(torch::init([](int64_t x, int64_t y, bool swap) {
if (swap) {
return c10::make_intrusive<LambdaInit>(y, x);
} else {
return c10::make_intrusive<LambdaInit>(x, y);
}
}))
.def("diff", &LambdaInit::diff);
m.class_<NoInit>("_NoInit").def(
"get_x", [](const c10::intrusive_ptr<NoInit>& self) { return self->x; });

View File

@ -338,3 +338,10 @@ class TestTorchbind(JitTestCase):
foo = torch.classes._TorchScriptTesting._StackString(["test"])
with self.assertRaisesRegex(AttributeError, 'does not have a field'):
foo.bar
def test_lambda_as_constructor(self):
obj_no_swap = torch.classes._TorchScriptTesting._LambdaInit(4, 3, False)
self.assertEqual(obj_no_swap.diff(), 1)
obj_swap = torch.classes._TorchScriptTesting._LambdaInit(4, 3, True)
self.assertEqual(obj_swap.diff(), -1)

View File

@ -27,6 +27,21 @@ detail::types<void, Types...> init() {
return detail::types<void, Types...>{};
}
template <typename Func, typename... ParameterTypeList>
struct InitLambda {
Func f;
};
template <typename Func>
decltype(auto) init(Func&& f) {
using InitTraits =
c10::guts::infer_function_traits_t<std::decay_t<Func>>;
using ParameterTypeList = typename InitTraits::parameter_types;
InitLambda<Func, ParameterTypeList> init{std::forward<Func>(f)};
return init;
}
/// Entry point for custom C++ class registration. To register a C++ class
/// in PyTorch, instantiate `torch::class_` with the desired class as the
/// template parameter. Typically, this instantiation should be done in
@ -95,6 +110,24 @@ class class_ {
return *this;
}
// Used in combination with torch::init([]lambda(){......})
template <typename Func, typename... ParameterTypes>
class_& def(
InitLambda<Func, c10::guts::typelist::typelist<ParameterTypes...>> init,
std::string doc_string = "") {
auto init_lambda_wrapper = [func = std::move(init.f)](
c10::tagged_capsule<CurClass> self,
ParameterTypes... arg) {
c10::intrusive_ptr<CurClass> classObj =
at::guts::invoke(func, std::forward<ParameterTypes>(arg)...);
auto object = self.ivalue.toObject();
object->setSlot(0, c10::IValue::make_capsule(classObj));
};
defineMethod("__init__", std::move(init_lambda_wrapper), std::move(doc_string));
return *this;
}
/// This is the normal method registration API. `name` is the name that
/// the method will be made accessible by in Python and TorchScript.
/// `f` is a callable object that defines the method. Typically `f`