Add support for float[]? arguments in native_functions.yaml (#37175)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37175

ghstack-source-id: 106938114

Test Plan: Upcoming diffs use this for upsampling.

Differential Revision: D21209994

fbshipit-source-id: 1a71c07e45e28772a2bbe450b68280dcc0fe2def
This commit is contained in:
David Reiss
2020-07-13 11:46:52 -07:00
committed by Facebook GitHub Bot
parent d04a2e4dae
commit fb9e44f8dd
15 changed files with 163 additions and 3 deletions

View File

@ -25,6 +25,7 @@ static std::unordered_map<std::string, ParameterType> type_map = {
{"complex", ParameterType::COMPLEX},
{"TensorList", ParameterType::TENSOR_LIST},
{"IntArrayRef", ParameterType::INT_LIST},
{"ArrayRef<double>", ParameterType::FLOAT_LIST},
{"Generator", ParameterType::GENERATOR},
{"bool", ParameterType::BOOL},
{"Storage", ParameterType::STORAGE},
@ -310,6 +311,7 @@ auto FunctionParameter::check(PyObject* obj, std::vector<py::handle> &overloaded
// if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single int
return size > 0 && THPUtils_checkLong(obj);
}
case ParameterType::FLOAT_LIST: return (PyTuple_Check(obj) || PyList_Check(obj));
case ParameterType::GENERATOR: return THPGenerator_Check(obj);
case ParameterType::BOOL: return PyBool_Check(obj);
case ParameterType::STORAGE: return isStorage(obj);
@ -334,6 +336,7 @@ std::string FunctionParameter::type_name() const {
case ParameterType::COMPLEX: return "complex";
case ParameterType::TENSOR_LIST: return "tuple of Tensors";
case ParameterType::INT_LIST: return "tuple of ints";
case ParameterType::FLOAT_LIST: return "tuple of floats";
case ParameterType::GENERATOR: return "torch.Generator";
case ParameterType::BOOL: return "bool";
case ParameterType::STORAGE: return "torch.Storage";
@ -419,6 +422,10 @@ void FunctionParameter::set_default_str(const std::string& str) {
if (str != "None") {
default_intlist = parse_intlist_args(str, size);
}
} else if (type_ == ParameterType::FLOAT_LIST) {
if (str != "None") {
throw std::runtime_error("Defaults not supported for float[]");
}
} else if (type_ == ParameterType::SCALARTYPE) {
if (str == "None") {
default_scalartype = at::ScalarType::Undefined;