mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[TorchBind] Support using lambda function as TorchBind constructor (#47819)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47819 Reviewed By: wanchaol Differential Revision: D24910065 Pulled By: gmagogsfm fbshipit-source-id: ad5b4f67b0367e44fe486d31a060d9ad1e0cf568
This commit is contained in:
committed by
Facebook GitHub Bot
parent
b6cb2caa68
commit
00a3add425
@ -33,6 +33,14 @@ struct Foo : torch::CustomClassHolder {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct LambdaInit : torch::CustomClassHolder {
|
||||||
|
int x, y;
|
||||||
|
LambdaInit(int x_, int y_) : x(x_), y(y_) {}
|
||||||
|
int64_t diff() {
|
||||||
|
return this->x - this->y;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct NoInit : torch::CustomClassHolder {
|
struct NoInit : torch::CustomClassHolder {
|
||||||
int64_t x;
|
int64_t x;
|
||||||
};
|
};
|
||||||
@ -202,6 +210,16 @@ TORCH_LIBRARY(_TorchScriptTesting, m) {
|
|||||||
.def("add", &Foo::add)
|
.def("add", &Foo::add)
|
||||||
.def("combine", &Foo::combine);
|
.def("combine", &Foo::combine);
|
||||||
|
|
||||||
|
m.class_<LambdaInit>("_LambdaInit")
|
||||||
|
.def(torch::init([](int64_t x, int64_t y, bool swap) {
|
||||||
|
if (swap) {
|
||||||
|
return c10::make_intrusive<LambdaInit>(y, x);
|
||||||
|
} else {
|
||||||
|
return c10::make_intrusive<LambdaInit>(x, y);
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.def("diff", &LambdaInit::diff);
|
||||||
|
|
||||||
m.class_<NoInit>("_NoInit").def(
|
m.class_<NoInit>("_NoInit").def(
|
||||||
"get_x", [](const c10::intrusive_ptr<NoInit>& self) { return self->x; });
|
"get_x", [](const c10::intrusive_ptr<NoInit>& self) { return self->x; });
|
||||||
|
|
||||||
|
@ -338,3 +338,10 @@ class TestTorchbind(JitTestCase):
|
|||||||
foo = torch.classes._TorchScriptTesting._StackString(["test"])
|
foo = torch.classes._TorchScriptTesting._StackString(["test"])
|
||||||
with self.assertRaisesRegex(AttributeError, 'does not have a field'):
|
with self.assertRaisesRegex(AttributeError, 'does not have a field'):
|
||||||
foo.bar
|
foo.bar
|
||||||
|
|
||||||
|
def test_lambda_as_constructor(self):
|
||||||
|
obj_no_swap = torch.classes._TorchScriptTesting._LambdaInit(4, 3, False)
|
||||||
|
self.assertEqual(obj_no_swap.diff(), 1)
|
||||||
|
|
||||||
|
obj_swap = torch.classes._TorchScriptTesting._LambdaInit(4, 3, True)
|
||||||
|
self.assertEqual(obj_swap.diff(), -1)
|
||||||
|
@ -27,6 +27,21 @@ detail::types<void, Types...> init() {
|
|||||||
return detail::types<void, Types...>{};
|
return detail::types<void, Types...>{};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Func, typename... ParameterTypeList>
|
||||||
|
struct InitLambda {
|
||||||
|
Func f;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Func>
|
||||||
|
decltype(auto) init(Func&& f) {
|
||||||
|
using InitTraits =
|
||||||
|
c10::guts::infer_function_traits_t<std::decay_t<Func>>;
|
||||||
|
using ParameterTypeList = typename InitTraits::parameter_types;
|
||||||
|
|
||||||
|
InitLambda<Func, ParameterTypeList> init{std::forward<Func>(f)};
|
||||||
|
return init;
|
||||||
|
}
|
||||||
|
|
||||||
/// Entry point for custom C++ class registration. To register a C++ class
|
/// Entry point for custom C++ class registration. To register a C++ class
|
||||||
/// in PyTorch, instantiate `torch::class_` with the desired class as the
|
/// in PyTorch, instantiate `torch::class_` with the desired class as the
|
||||||
/// template parameter. Typically, this instantiation should be done in
|
/// template parameter. Typically, this instantiation should be done in
|
||||||
@ -95,6 +110,24 @@ class class_ {
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Used in combination with torch::init([]lambda(){......})
|
||||||
|
template <typename Func, typename... ParameterTypes>
|
||||||
|
class_& def(
|
||||||
|
InitLambda<Func, c10::guts::typelist::typelist<ParameterTypes...>> init,
|
||||||
|
std::string doc_string = "") {
|
||||||
|
auto init_lambda_wrapper = [func = std::move(init.f)](
|
||||||
|
c10::tagged_capsule<CurClass> self,
|
||||||
|
ParameterTypes... arg) {
|
||||||
|
c10::intrusive_ptr<CurClass> classObj =
|
||||||
|
at::guts::invoke(func, std::forward<ParameterTypes>(arg)...);
|
||||||
|
auto object = self.ivalue.toObject();
|
||||||
|
object->setSlot(0, c10::IValue::make_capsule(classObj));
|
||||||
|
};
|
||||||
|
defineMethod("__init__", std::move(init_lambda_wrapper), std::move(doc_string));
|
||||||
|
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
/// This is the normal method registration API. `name` is the name that
|
/// This is the normal method registration API. `name` is the name that
|
||||||
/// the method will be made accessible by in Python and TorchScript.
|
/// the method will be made accessible by in Python and TorchScript.
|
||||||
/// `f` is a callable object that defines the method. Typically `f`
|
/// `f` is a callable object that defines the method. Typically `f`
|
||||||
|
Reference in New Issue
Block a user