mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-29 19:24:55 +08:00
Add support for float[]? arguments in native_functions.yaml (#37175)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/37175 ghstack-source-id: 106938114 Test Plan: Upcoming diffs use this for upsampling. Differential Revision: D21209994 fbshipit-source-id: 1a71c07e45e28772a2bbe450b68280dcc0fe2def
This commit is contained in:
committed by
Facebook GitHub Bot
parent
d04a2e4dae
commit
fb9e44f8dd
@ -25,6 +25,7 @@ static std::unordered_map<std::string, ParameterType> type_map = {
|
||||
{"complex", ParameterType::COMPLEX},
|
||||
{"TensorList", ParameterType::TENSOR_LIST},
|
||||
{"IntArrayRef", ParameterType::INT_LIST},
|
||||
{"ArrayRef<double>", ParameterType::FLOAT_LIST},
|
||||
{"Generator", ParameterType::GENERATOR},
|
||||
{"bool", ParameterType::BOOL},
|
||||
{"Storage", ParameterType::STORAGE},
|
||||
@ -310,6 +311,7 @@ auto FunctionParameter::check(PyObject* obj, std::vector<py::handle> &overloaded
|
||||
// if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single int
|
||||
return size > 0 && THPUtils_checkLong(obj);
|
||||
}
|
||||
case ParameterType::FLOAT_LIST: return (PyTuple_Check(obj) || PyList_Check(obj));
|
||||
case ParameterType::GENERATOR: return THPGenerator_Check(obj);
|
||||
case ParameterType::BOOL: return PyBool_Check(obj);
|
||||
case ParameterType::STORAGE: return isStorage(obj);
|
||||
@ -334,6 +336,7 @@ std::string FunctionParameter::type_name() const {
|
||||
case ParameterType::COMPLEX: return "complex";
|
||||
case ParameterType::TENSOR_LIST: return "tuple of Tensors";
|
||||
case ParameterType::INT_LIST: return "tuple of ints";
|
||||
case ParameterType::FLOAT_LIST: return "tuple of floats";
|
||||
case ParameterType::GENERATOR: return "torch.Generator";
|
||||
case ParameterType::BOOL: return "bool";
|
||||
case ParameterType::STORAGE: return "torch.Storage";
|
||||
@ -419,6 +422,10 @@ void FunctionParameter::set_default_str(const std::string& str) {
|
||||
if (str != "None") {
|
||||
default_intlist = parse_intlist_args(str, size);
|
||||
}
|
||||
} else if (type_ == ParameterType::FLOAT_LIST) {
|
||||
if (str != "None") {
|
||||
throw std::runtime_error("Defaults not supported for float[]");
|
||||
}
|
||||
} else if (type_ == ParameterType::SCALARTYPE) {
|
||||
if (str == "None") {
|
||||
default_scalartype = at::ScalarType::Undefined;
|
||||
|
||||
Reference in New Issue
Block a user