[JIT] Add support for default argument values to Torchbind (#51253)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51253

**Summary**
This commit adds support to Torchbind for specifying default values for
arguments of custom class methods.

**Test Plan**
This commit adds a unit test to `test_torchbind.py` that exercises this
feature.

Test Plan: Imported from OSS

Reviewed By: gmagogsfm

Differential Revision: D26131529

Pulled By: SplitInfinity

fbshipit-source-id: 68bc86b045dd2f03ba41e1a116081a6eae6ba9ff
This commit is contained in:
Meghan Lele
2021-02-17 11:18:49 -08:00
committed by Facebook GitHub Bot
parent 324c6aada1
commit cbede834d4
3 changed files with 167 additions and 11 deletions

View File

@ -18,6 +18,38 @@
namespace torch {
/// This struct is used to represent default values for arguments
/// when registering methods for custom classes.
/// static auto register_foo = torch::class_<Foo>("myclasses", "Foo")
/// .def("myMethod", &Foo::myMethod, {torch::arg("name") = name});
struct arg {
// Static method for representing a default value of None. This is meant to
// be used like so:
// torch::arg("name") = torch::arg::none
// and is identical to:
// torch::arg("name") = IValue()
static c10::IValue none() {
return c10::IValue();
}
// Explicit constructor.
explicit arg(std::string name) : name_(std::move(name)), value_(c10::nullopt) {}
// Assignment operator. This enables the pybind-like syntax of
// torch::arg("name") = value.
arg& operator=(const c10::IValue& rhs) {
value_ = rhs;
return *this;
}
// The name of the argument. This is copied to the schema; argument
// names cannot be extracted from the C++ declaration.
std::string name_;
// IValue's default constructor makes it None, which is not distinguishable from
// an actual, user-provided default value that is None. This boolean
// helps distinguish between the two cases.
c10::optional<c10::IValue> value_;
};
/// This function is used in conjunction with `class_::def()` to register
/// a constructor for a given C++ class type. For example,
/// `torch::init<int, std::string>()` would register a two-argument constructor
@ -98,15 +130,22 @@ class class_ {
/// `torch::init<int, std::string>()` would register a two-argument constructor
/// taking an `int` and a `std::string` as argument.
template <typename... Types>
class_& def(detail::types<void, Types...>, std::string doc_string = "") { // Used in combination with
// torch::init<...>()
class_& def(
detail::types<void, Types...>,
std::string doc_string = "",
std::initializer_list<arg> default_args = {}) { // Used in combination with
// torch::init<...>()
auto func = [](c10::tagged_capsule<CurClass> self, Types... args) {
auto classObj = c10::make_intrusive<CurClass>(args...);
auto object = self.ivalue.toObject();
object->setSlot(0, c10::IValue::make_capsule(std::move(classObj)));
};
defineMethod("__init__", std::move(func), std::move(doc_string));
defineMethod(
"__init__",
std::move(func),
std::move(doc_string),
std::move(default_args));
return *this;
}
@ -114,7 +153,8 @@ class class_ {
template <typename Func, typename... ParameterTypes>
class_& def(
InitLambda<Func, c10::guts::typelist::typelist<ParameterTypes...>> init,
std::string doc_string = "") {
std::string doc_string = "",
std::initializer_list<arg> default_args = {}) {
auto init_lambda_wrapper = [func = std::move(init.f)](
c10::tagged_capsule<CurClass> self,
ParameterTypes... arg) {
@ -123,7 +163,12 @@ class class_ {
auto object = self.ivalue.toObject();
object->setSlot(0, c10::IValue::make_capsule(classObj));
};
defineMethod("__init__", std::move(init_lambda_wrapper), std::move(doc_string));
defineMethod(
"__init__",
std::move(init_lambda_wrapper),
std::move(doc_string),
std::move(default_args));
return *this;
}
@ -147,9 +192,17 @@ class class_ {
/// // do something
/// })
template <typename Func>
class_& def(std::string name, Func f, std::string doc_string = "") {
class_& def(
std::string name,
Func f,
std::string doc_string = "",
std::initializer_list<arg> default_args = {}) {
auto wrapped_f = detail::wrap_func<CurClass, Func>(std::move(f));
defineMethod(std::move(name), std::move(wrapped_f), std::move(doc_string));
defineMethod(
std::move(name),
std::move(wrapped_f),
std::move(doc_string),
std::move(default_args));
return *this;
}
@ -287,11 +340,49 @@ class class_ {
private:
template <typename Func>
void defineMethod(std::string name, Func func, std::string doc_string = "") {
void defineMethod(
std::string name,
Func func,
std::string doc_string = "",
std::initializer_list<arg> default_args = {}) {
auto qualMethodName = qualClassName + "." + name;
auto schema = c10::inferFunctionSchemaSingleReturn<Func>(std::move(name), "");
auto wrapped_func = [func = std::move(func)](jit::Stack& stack) mutable -> void {
// If default values are provided for function arguments, there must be
// none (no default values) or default values for all function
// arguments, except for self. This is because argument names are not
// extracted by inferFunctionSchemaSingleReturn, and so there must be a
// torch::arg instance in default_args even for arguments that do not
// have an actual default value provided.
TORCH_CHECK(
default_args.size() == 0 ||
default_args.size() == schema.arguments().size() - 1,
"Default values must be specified for none or all arguments");
// If there are default args, copy the argument names and default values to the
// function schema.
if (default_args.size() > 0) {
const auto& old_args = schema.arguments();
std::vector<c10::Argument> new_args;
new_args.reserve(old_args.size());
std::vector<arg> default_args_v(default_args);
new_args.emplace_back(old_args[0]);
for (size_t i = 0; i < default_args_v.size(); ++i) {
// Skip self.
auto& arg = old_args[i+1];
new_args.emplace_back(c10::Argument(
std::move(default_args_v[i].name_),
arg.type(),
arg.N(),
default_args_v[i].value_.has_value() ? std::move(*default_args_v[i].value_) : c10::nullopt));
}
schema = schema.cloneWithArguments(new_args);
}
auto wrapped_func =
[func = std::move(func)](jit::Stack& stack) mutable -> void {
// TODO: we need to figure out how to profile calls to custom functions
// like this! Currently can't do it because the profiler stuff is in
// libtorch and not ATen