mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix optional argument resolution in cwrap
This commit is contained in:
@ -63,24 +63,24 @@ static PyObject * THPTensor_(${name})(THPTensor *self, PyObject *args)
|
||||
calls add some overhead).
|
||||
"""
|
||||
impl = ''
|
||||
prev_arg_count = -1
|
||||
prev_option = None
|
||||
for option in sorted(self.options, key=lambda o: o.num_required_args()):
|
||||
num_args = option.num_required_args()
|
||||
if num_args > prev_arg_count:
|
||||
prev_num_args = prev_option.num_required_args() if prev_option else -1
|
||||
if num_args > prev_num_args:
|
||||
# Nothing to close if it's the first option
|
||||
if prev_arg_count != -1 and prev_arg_count < math.inf:
|
||||
if prev_num_args != -1 and prev_option.check_argcount():
|
||||
impl += ' }\n'
|
||||
if num_args < math.inf:
|
||||
if option.check_argcount():
|
||||
impl += Template(' if (_argcount == $numargs) {') \
|
||||
.substitute({'numargs': num_args})
|
||||
prev_arg_count = num_args
|
||||
else:
|
||||
impl += ' PyErr_Clear();'
|
||||
impl += '\n {'
|
||||
impl += option.generate()
|
||||
impl += ' }\n'
|
||||
impl += ' PyErr_Clear();'
|
||||
prev_option = option
|
||||
# Close last argcount block
|
||||
if prev_arg_count < math.inf:
|
||||
if prev_option.check_argcount():
|
||||
impl += ' }\n'
|
||||
return impl
|
||||
|
||||
@ -107,6 +107,7 @@ static PyObject * THPTensor_(${name})(THPTensor *self, PyObject *args)
|
||||
for option, optional_args in zip(self.options, self.optional_args):
|
||||
if not optional_args:
|
||||
resolved_options.append(option)
|
||||
continue
|
||||
# Generate options with all possible configurations of optional args
|
||||
for enabled_bits in product((True, False), repeat=len(optional_args)):
|
||||
new_option = option.copy()
|
||||
|
@ -128,6 +128,9 @@ class Option(object):
|
||||
"""
|
||||
return argcount(self)
|
||||
|
||||
def check_argcount(self):
|
||||
return True
|
||||
|
||||
def _library_state_macro(self, argstr):
|
||||
return 'LIBRARY_STATE' if argstr else 'LIBRARY_STATE_NOARGS'
|
||||
|
||||
@ -206,5 +209,11 @@ class LongArgsTHOption(THOption):
|
||||
arg_idx += 1
|
||||
return init
|
||||
|
||||
def check_argcount(self):
|
||||
return False
|
||||
|
||||
def num_required_args(self):
|
||||
return math.inf
|
||||
# TODO: this is an ugly hack
|
||||
# LONG_ARG options have to be sorted decreasingly w.r.t. number of arguments
|
||||
# (ones with larger counts are more specific)
|
||||
return 100000 - argcount(self)
|
||||
|
@ -34,11 +34,12 @@ int THPUtils_getCallable(PyObject *arg, PyObject **result) {
|
||||
}
|
||||
|
||||
THLongStorage * THPUtils_getLongStorage(PyObject *args, int ignore_first) {
|
||||
// TODO: error messages
|
||||
long value;
|
||||
|
||||
Py_ssize_t length = PyTuple_Size(args);
|
||||
if (length < ignore_first+1)
|
||||
return NULL;
|
||||
throw std::logic_error("Provided too few arguments");
|
||||
|
||||
// Maybe there's a LongStorage
|
||||
PyObject *first_arg = PyTuple_GET_ITEM(args, ignore_first);
|
||||
@ -53,7 +54,7 @@ THLongStorage * THPUtils_getLongStorage(PyObject *args, int ignore_first) {
|
||||
for (Py_ssize_t i = ignore_first; i < length; ++i) {
|
||||
PyObject *arg = PyTuple_GET_ITEM(args, i);
|
||||
if (!THPUtils_getLong(arg, &value))
|
||||
return NULL;
|
||||
throw std::invalid_argument("Expected a numeric argument");
|
||||
result->data[i-ignore_first] = value;
|
||||
}
|
||||
return result.release();
|
||||
|
Reference in New Issue
Block a user