mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[JIT] Add support for default argument values to Torchbind (#51253)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51253 **Summary** This commit adds support to Torchbind for specifying default values for arguments of custom class methods. **Test Plan** This commit adds a unit test to `test_torchbind.py` that exercises this feature. Test Plan: Imported from OSS Reviewed By: gmagogsfm Differential Revision: D26131529 Pulled By: SplitInfinity fbshipit-source-id: 68bc86b045dd2f03ba41e1a116081a6eae6ba9ff
This commit is contained in:
committed by
Facebook GitHub Bot
parent
324c6aada1
commit
cbede834d4
@ -18,6 +18,38 @@
|
||||
|
||||
namespace torch {
|
||||
|
||||
/// This struct is used to represent default values for arguments
|
||||
/// when registering methods for custom classes.
|
||||
/// static auto register_foo = torch::class_<Foo>("myclasses", "Foo")
|
||||
/// .def("myMethod", &Foo::myMethod, {torch::arg("name") = name});
|
||||
struct arg {
|
||||
// Static method for representing a default value of None. This is meant to
|
||||
// be used like so:
|
||||
// torch::arg("name") = torch::arg::none
|
||||
// and is identical to:
|
||||
// torch::arg("name") = IValue()
|
||||
static c10::IValue none() {
|
||||
return c10::IValue();
|
||||
}
|
||||
|
||||
// Explicit constructor.
|
||||
explicit arg(std::string name) : name_(std::move(name)), value_(c10::nullopt) {}
|
||||
// Assignment operator. This enables the pybind-like syntax of
|
||||
// torch::arg("name") = value.
|
||||
arg& operator=(const c10::IValue& rhs) {
|
||||
value_ = rhs;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// The name of the argument. This is copied to the schema; argument
|
||||
// names cannot be extracted from the C++ declaration.
|
||||
std::string name_;
|
||||
// IValue's default constructor makes it None, which is not distinguishable from
|
||||
// an actual, user-provided default value that is None. This boolean
|
||||
// helps distinguish between the two cases.
|
||||
c10::optional<c10::IValue> value_;
|
||||
};
|
||||
|
||||
/// This function is used in conjunction with `class_::def()` to register
|
||||
/// a constructor for a given C++ class type. For example,
|
||||
/// `torch::init<int, std::string>()` would register a two-argument constructor
|
||||
@ -98,15 +130,22 @@ class class_ {
|
||||
/// `torch::init<int, std::string>()` would register a two-argument constructor
|
||||
/// taking an `int` and a `std::string` as argument.
|
||||
template <typename... Types>
|
||||
class_& def(detail::types<void, Types...>, std::string doc_string = "") { // Used in combination with
|
||||
// torch::init<...>()
|
||||
class_& def(
|
||||
detail::types<void, Types...>,
|
||||
std::string doc_string = "",
|
||||
std::initializer_list<arg> default_args = {}) { // Used in combination with
|
||||
// torch::init<...>()
|
||||
auto func = [](c10::tagged_capsule<CurClass> self, Types... args) {
|
||||
auto classObj = c10::make_intrusive<CurClass>(args...);
|
||||
auto object = self.ivalue.toObject();
|
||||
object->setSlot(0, c10::IValue::make_capsule(std::move(classObj)));
|
||||
};
|
||||
|
||||
defineMethod("__init__", std::move(func), std::move(doc_string));
|
||||
defineMethod(
|
||||
"__init__",
|
||||
std::move(func),
|
||||
std::move(doc_string),
|
||||
std::move(default_args));
|
||||
return *this;
|
||||
}
|
||||
|
||||
@ -114,7 +153,8 @@ class class_ {
|
||||
template <typename Func, typename... ParameterTypes>
|
||||
class_& def(
|
||||
InitLambda<Func, c10::guts::typelist::typelist<ParameterTypes...>> init,
|
||||
std::string doc_string = "") {
|
||||
std::string doc_string = "",
|
||||
std::initializer_list<arg> default_args = {}) {
|
||||
auto init_lambda_wrapper = [func = std::move(init.f)](
|
||||
c10::tagged_capsule<CurClass> self,
|
||||
ParameterTypes... arg) {
|
||||
@ -123,7 +163,12 @@ class class_ {
|
||||
auto object = self.ivalue.toObject();
|
||||
object->setSlot(0, c10::IValue::make_capsule(classObj));
|
||||
};
|
||||
defineMethod("__init__", std::move(init_lambda_wrapper), std::move(doc_string));
|
||||
|
||||
defineMethod(
|
||||
"__init__",
|
||||
std::move(init_lambda_wrapper),
|
||||
std::move(doc_string),
|
||||
std::move(default_args));
|
||||
|
||||
return *this;
|
||||
}
|
||||
@ -147,9 +192,17 @@ class class_ {
|
||||
/// // do something
|
||||
/// })
|
||||
template <typename Func>
|
||||
class_& def(std::string name, Func f, std::string doc_string = "") {
|
||||
class_& def(
|
||||
std::string name,
|
||||
Func f,
|
||||
std::string doc_string = "",
|
||||
std::initializer_list<arg> default_args = {}) {
|
||||
auto wrapped_f = detail::wrap_func<CurClass, Func>(std::move(f));
|
||||
defineMethod(std::move(name), std::move(wrapped_f), std::move(doc_string));
|
||||
defineMethod(
|
||||
std::move(name),
|
||||
std::move(wrapped_f),
|
||||
std::move(doc_string),
|
||||
std::move(default_args));
|
||||
return *this;
|
||||
}
|
||||
|
||||
@ -287,11 +340,49 @@ class class_ {
|
||||
|
||||
private:
|
||||
template <typename Func>
|
||||
void defineMethod(std::string name, Func func, std::string doc_string = "") {
|
||||
void defineMethod(
|
||||
std::string name,
|
||||
Func func,
|
||||
std::string doc_string = "",
|
||||
std::initializer_list<arg> default_args = {}) {
|
||||
auto qualMethodName = qualClassName + "." + name;
|
||||
auto schema = c10::inferFunctionSchemaSingleReturn<Func>(std::move(name), "");
|
||||
|
||||
auto wrapped_func = [func = std::move(func)](jit::Stack& stack) mutable -> void {
|
||||
// If default values are provided for function arguments, there must be
|
||||
// none (no default values) or default values for all function
|
||||
// arguments, except for self. This is because argument names are not
|
||||
// extracted by inferFunctionSchemaSingleReturn, and so there must be a
|
||||
// torch::arg instance in default_args even for arguments that do not
|
||||
// have an actual default value provided.
|
||||
TORCH_CHECK(
|
||||
default_args.size() == 0 ||
|
||||
default_args.size() == schema.arguments().size() - 1,
|
||||
"Default values must be specified for none or all arguments");
|
||||
|
||||
// If there are default args, copy the argument names and default values to the
|
||||
// function schema.
|
||||
if (default_args.size() > 0) {
|
||||
const auto& old_args = schema.arguments();
|
||||
std::vector<c10::Argument> new_args;
|
||||
new_args.reserve(old_args.size());
|
||||
std::vector<arg> default_args_v(default_args);
|
||||
|
||||
new_args.emplace_back(old_args[0]);
|
||||
for (size_t i = 0; i < default_args_v.size(); ++i) {
|
||||
// Skip self.
|
||||
auto& arg = old_args[i+1];
|
||||
new_args.emplace_back(c10::Argument(
|
||||
std::move(default_args_v[i].name_),
|
||||
arg.type(),
|
||||
arg.N(),
|
||||
default_args_v[i].value_.has_value() ? std::move(*default_args_v[i].value_) : c10::nullopt));
|
||||
}
|
||||
|
||||
schema = schema.cloneWithArguments(new_args);
|
||||
}
|
||||
|
||||
auto wrapped_func =
|
||||
[func = std::move(func)](jit::Stack& stack) mutable -> void {
|
||||
// TODO: we need to figure out how to profile calls to custom functions
|
||||
// like this! Currently can't do it because the profiler stuff is in
|
||||
// libtorch and not ATen
|
||||
|
Reference in New Issue
Block a user