mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
Remove default arguments before calling to __torch_dispatch__ (#61123)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61123 This applies the design pattern of removing explicit arguments when they coincide with the default arguments. This simplifies argument patterns that dispatch kernels receive and make it easier for us to maintain BC (as addition of a new default argument isn't immediately BC-breaking for dispatch implementors). There is an important extra API which I haven't implemented here yet, which is to take an incomplete sequence of arguments and fill out their defaults (in case the user did want normalization). I plan on adding that in a future PR. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: saketh-are Differential Revision: D29853616 Pulled By: ezyang fbshipit-source-id: 71c672cb3a7d4d01f838a1c7fcdb75a8ce7d058e
This commit is contained in:
committed by
Facebook GitHub Bot
parent
32d0c3e8ee
commit
e42360d56f
@ -1501,17 +1501,34 @@ void concrete_dispatch_fn(const c10::impl::PyInterpreter*, const c10::OperatorHa
|
||||
py::gil_scoped_acquire g;
|
||||
|
||||
std::vector<py::handle> overloaded_args;
|
||||
auto args = py::reinterpret_steal<py::object>(PyTuple_New(num_arguments));
|
||||
// TODO: actually populate kwargs sometimes? At the moment, every argument
|
||||
// just gets passed positionally
|
||||
py::dict kwargs;
|
||||
// For now, overloads get coalesced. Might be easier for users if they get
|
||||
// overload resolution but is more complicated (need to expose separate
|
||||
// functions per overload)
|
||||
py::handle torch_api_function = py::module::import("torch").attr("ops").attr(ns).attr(func_name);
|
||||
std::string module_name_str = "torch.ops." + ns_str;
|
||||
|
||||
for (int64_t idx = 0; idx < arguments.size(); idx++) {
|
||||
// Pre-scan for arguments that match defaults
|
||||
int64_t default_suffix_len = 0;
|
||||
for (int64_t idx = arguments.size() - 1; idx >= 0; idx--) {
|
||||
const auto& arg = schema.arguments()[idx];
|
||||
if (!arg.default_value().has_value()) {
|
||||
break;
|
||||
}
|
||||
const auto& default_ivalue = *arg.default_value();
|
||||
const auto& ivalue = arguments[idx];
|
||||
if (default_ivalue != ivalue) {
|
||||
break;
|
||||
}
|
||||
default_suffix_len++;
|
||||
}
|
||||
|
||||
auto args = py::reinterpret_steal<py::object>(PyTuple_New(num_arguments - default_suffix_len));
|
||||
|
||||
// TODO: actually populate kwargs sometimes? At the moment, every argument
|
||||
// just gets passed positionally
|
||||
py::dict kwargs;
|
||||
|
||||
for (int64_t idx = 0; idx < arguments.size() - default_suffix_len; idx++) {
|
||||
auto& ivalue = arguments[idx];
|
||||
// Search for Tensors (as they may have the torch functions we need)
|
||||
if (ivalue.isTensor()) {
|
||||
|
Reference in New Issue
Block a user