mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[TorchBind] Support using lambda function as TorchBind constructor (#47819)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47819 Reviewed By: wanchaol Differential Revision: D24910065 Pulled By: gmagogsfm fbshipit-source-id: ad5b4f67b0367e44fe486d31a060d9ad1e0cf568
This commit is contained in:
committed by
Facebook GitHub Bot
parent
b6cb2caa68
commit
00a3add425
@ -33,6 +33,14 @@ struct Foo : torch::CustomClassHolder {
|
||||
}
|
||||
};
|
||||
|
||||
struct LambdaInit : torch::CustomClassHolder {
|
||||
int x, y;
|
||||
LambdaInit(int x_, int y_) : x(x_), y(y_) {}
|
||||
int64_t diff() {
|
||||
return this->x - this->y;
|
||||
}
|
||||
};
|
||||
|
||||
struct NoInit : torch::CustomClassHolder {
|
||||
int64_t x;
|
||||
};
|
||||
@ -202,6 +210,16 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
|
||||
.def("add", &Foo::add)
|
||||
.def("combine", &Foo::combine);
|
||||
|
||||
m.class_<LambdaInit>("_LambdaInit")
|
||||
.def(torch::init([](int64_t x, int64_t y, bool swap) {
|
||||
if (swap) {
|
||||
return c10::make_intrusive<LambdaInit>(y, x);
|
||||
} else {
|
||||
return c10::make_intrusive<LambdaInit>(x, y);
|
||||
}
|
||||
}))
|
||||
.def("diff", &LambdaInit::diff);
|
||||
|
||||
m.class_<NoInit>("_NoInit").def(
|
||||
"get_x", [](const c10::intrusive_ptr<NoInit>& self) { return self->x; });
|
||||
|
||||
|
@ -338,3 +338,10 @@ class TestTorchbind(JitTestCase):
|
||||
foo = torch.classes._TorchScriptTesting._StackString(["test"])
|
||||
with self.assertRaisesRegex(AttributeError, 'does not have a field'):
|
||||
foo.bar
|
||||
|
||||
def test_lambda_as_constructor(self):
|
||||
obj_no_swap = torch.classes._TorchScriptTesting._LambdaInit(4, 3, False)
|
||||
self.assertEqual(obj_no_swap.diff(), 1)
|
||||
|
||||
obj_swap = torch.classes._TorchScriptTesting._LambdaInit(4, 3, True)
|
||||
self.assertEqual(obj_swap.diff(), -1)
|
||||
|
@ -27,6 +27,21 @@ detail::types<void, Types...> init() {
|
||||
return detail::types<void, Types...>{};
|
||||
}
|
||||
|
||||
template <typename Func, typename... ParameterTypeList>
|
||||
struct InitLambda {
|
||||
Func f;
|
||||
};
|
||||
|
||||
template <typename Func>
|
||||
decltype(auto) init(Func&& f) {
|
||||
using InitTraits =
|
||||
c10::guts::infer_function_traits_t<std::decay_t<Func>>;
|
||||
using ParameterTypeList = typename InitTraits::parameter_types;
|
||||
|
||||
InitLambda<Func, ParameterTypeList> init{std::forward<Func>(f)};
|
||||
return init;
|
||||
}
|
||||
|
||||
/// Entry point for custom C++ class registration. To register a C++ class
|
||||
/// in PyTorch, instantiate `torch::class_` with the desired class as the
|
||||
/// template parameter. Typically, this instantiation should be done in
|
||||
@ -95,6 +110,24 @@ class class_ {
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Used in combination with torch::init([]lambda(){......})
|
||||
template <typename Func, typename... ParameterTypes>
|
||||
class_& def(
|
||||
InitLambda<Func, c10::guts::typelist::typelist<ParameterTypes...>> init,
|
||||
std::string doc_string = "") {
|
||||
auto init_lambda_wrapper = [func = std::move(init.f)](
|
||||
c10::tagged_capsule<CurClass> self,
|
||||
ParameterTypes... arg) {
|
||||
c10::intrusive_ptr<CurClass> classObj =
|
||||
at::guts::invoke(func, std::forward<ParameterTypes>(arg)...);
|
||||
auto object = self.ivalue.toObject();
|
||||
object->setSlot(0, c10::IValue::make_capsule(classObj));
|
||||
};
|
||||
defineMethod("__init__", std::move(init_lambda_wrapper), std::move(doc_string));
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// This is the normal method registration API. `name` is the name that
|
||||
/// the method will be made accessible by in Python and TorchScript.
|
||||
/// `f` is a callable object that defines the method. Typically `f`
|
||||
|
Reference in New Issue
Block a user