Fix optional argument resolution in cwrap

This commit is contained in:
Adam Paszke
2016-07-19 10:37:40 -04:00
parent c574295012
commit 1d763810ba
3 changed files with 22 additions and 11 deletions

View File

@ -63,24 +63,24 @@ static PyObject * THPTensor_(${name})(THPTensor *self, PyObject *args)
calls add some overhead).
"""
impl = ''
prev_arg_count = -1
prev_option = None
for option in sorted(self.options, key=lambda o: o.num_required_args()):
num_args = option.num_required_args()
if num_args > prev_arg_count:
prev_num_args = prev_option.num_required_args() if prev_option else -1
if num_args > prev_num_args:
# Nothing to close if it's the first option
if prev_arg_count != -1 and prev_arg_count < math.inf:
if prev_num_args != -1 and prev_option.check_argcount():
impl += ' }\n'
if num_args < math.inf:
if option.check_argcount():
impl += Template(' if (_argcount == $numargs) {') \
.substitute({'numargs': num_args})
prev_arg_count = num_args
else:
impl += ' PyErr_Clear();'
impl += '\n {'
impl += option.generate()
impl += ' }\n'
impl += ' PyErr_Clear();'
prev_option = option
# Close last argcount block
if prev_arg_count < math.inf:
if prev_option.check_argcount():
impl += ' }\n'
return impl
@ -107,6 +107,7 @@ static PyObject * THPTensor_(${name})(THPTensor *self, PyObject *args)
for option, optional_args in zip(self.options, self.optional_args):
if not optional_args:
resolved_options.append(option)
continue
# Generate options with all possible configurations of optional args
for enabled_bits in product((True, False), repeat=len(optional_args)):
new_option = option.copy()

View File

@ -128,6 +128,9 @@ class Option(object):
"""
return argcount(self)
def check_argcount(self):
return True
def _library_state_macro(self, argstr):
return 'LIBRARY_STATE' if argstr else 'LIBRARY_STATE_NOARGS'
@ -206,5 +209,11 @@ class LongArgsTHOption(THOption):
arg_idx += 1
return init
def check_argcount(self):
return False
def num_required_args(self):
return math.inf
# TODO: this is an ugly hack
# LONG_ARG options have to be sorted decreasingly w.r.t. number of arguments
# (ones with larger counts are more specific)
return 100000 - argcount(self)

View File

@ -34,11 +34,12 @@ int THPUtils_getCallable(PyObject *arg, PyObject **result) {
}
THLongStorage * THPUtils_getLongStorage(PyObject *args, int ignore_first) {
// TODO: error messages
long value;
Py_ssize_t length = PyTuple_Size(args);
if (length < ignore_first+1)
return NULL;
throw std::logic_error("Provided too few arguments");
// Maybe there's a LongStorage
PyObject *first_arg = PyTuple_GET_ITEM(args, ignore_first);
@ -53,7 +54,7 @@ THLongStorage * THPUtils_getLongStorage(PyObject *args, int ignore_first) {
for (Py_ssize_t i = ignore_first; i < length; ++i) {
PyObject *arg = PyTuple_GET_ITEM(args, i);
if (!THPUtils_getLong(arg, &value))
return NULL;
throw std::invalid_argument("Expected a numeric argument");
result->data[i-ignore_first] = value;
}
return result.release();