Remove default arguments before calling to __torch_dispatch__ (#61123)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61123

This applies the design pattern of removing explicit arguments when they
coincide with the default arguments.  This simplifies argument patterns
that dispatch kernels receive and make it easier for us to maintain BC
(as addition of a new default argument isn't immediately BC-breaking
for dispatch implementors).

There is an important extra API which I haven't implemented here yet,
which is to take an incomplete sequence of arguments and fill out their
defaults (in case the user did want normalization).  I plan on adding
that in a future PR.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Reviewed By: saketh-are

Differential Revision: D29853616

Pulled By: ezyang

fbshipit-source-id: 71c672cb3a7d4d01f838a1c7fcdb75a8ce7d058e
This commit is contained in:
Edward Yang
2021-07-23 10:35:50 -07:00
committed by Facebook GitHub Bot
parent 32d0c3e8ee
commit e42360d56f
2 changed files with 27 additions and 10 deletions

View File

@ -1501,17 +1501,34 @@ void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHa
py::gil_scoped_acquire g;
std::vector<py::handle> overloaded_args;
auto args = py::reinterpret_steal<py::object>(PyTuple_New(num_arguments));
// TODO: actually populate kwargs sometimes? At the moment, every argument
// just gets passed positionally
py::dict kwargs;
// For now, overloads get coalesced. Might be easier for users if they get
// overload resolution but is more complicated (need to expose separate
// functions per overload)
py::handle torch_api_function = py::module::import("torch").attr("ops").attr(ns).attr(func_name);
std::string module_name_str = "torch.ops." + ns_str;
for (int64_t idx = 0; idx < arguments.size(); idx++) {
// Pre-scan for arguments that match defaults
int64_t default_suffix_len = 0;
for (int64_t idx = arguments.size() - 1; idx >= 0; idx--) {
const auto& arg = schema.arguments()[idx];
if (!arg.default_value().has_value()) {
break;
}
const auto& default_ivalue = *arg.default_value();
const auto& ivalue = arguments[idx];
if (default_ivalue != ivalue) {
break;
}
default_suffix_len++;
}
auto args = py::reinterpret_steal<py::object>(PyTuple_New(num_arguments - default_suffix_len));
// TODO: actually populate kwargs sometimes? At the moment, every argument
// just gets passed positionally
py::dict kwargs;
for (int64_t idx = 0; idx < arguments.size() - default_suffix_len; idx++) {
auto& ivalue = arguments[idx];
// Search for Tensors (as they may have the torch functions we need)
if (ivalue.isTensor()) {