mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Implement Variable.new (#4080)
This commit is contained in:
@ -22,7 +22,7 @@ static std::unordered_map<std::string, ParameterType> type_map = {
|
||||
{"Generator", ParameterType::GENERATOR},
|
||||
{"bool", ParameterType::BOOL},
|
||||
{"Storage", ParameterType::STORAGE},
|
||||
{"PyObject*", ParameterType::PYOBJECT}
|
||||
{"PyObject*", ParameterType::PYOBJECT},
|
||||
};
|
||||
|
||||
FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
|
||||
@ -85,7 +85,7 @@ bool FunctionParameter::check(PyObject* obj) {
|
||||
}
|
||||
case ParameterType::GENERATOR: return THPGenerator_Check(obj);
|
||||
case ParameterType::BOOL: return PyBool_Check(obj);
|
||||
case ParameterType::STORAGE: return false;
|
||||
case ParameterType::STORAGE: return isStorage(obj);
|
||||
case ParameterType::PYOBJECT: return true;
|
||||
default: throw std::runtime_error("unknown parameter type");
|
||||
}
|
||||
@ -140,6 +140,7 @@ FunctionSignature::FunctionSignature(const std::string& fmt)
|
||||
: min_args(0)
|
||||
, max_args(0)
|
||||
, max_pos_args(0)
|
||||
, hidden(false)
|
||||
, deprecated(false)
|
||||
{
|
||||
auto open_paren = fmt.find('(');
|
||||
@ -178,7 +179,11 @@ FunctionSignature::FunctionSignature(const std::string& fmt)
|
||||
}
|
||||
|
||||
if (fmt.substr(last_offset) == "|deprecated") {
|
||||
hidden = true;
|
||||
// TODO: raise warning when parsing deprecated signatures
|
||||
deprecated = true;
|
||||
} else if (fmt.substr(last_offset) == "|hidden") {
|
||||
hidden = true;
|
||||
}
|
||||
|
||||
max_args = params.size();
|
||||
@ -341,7 +346,8 @@ bool FunctionSignature::parse(PyObject* args, PyObject* kwargs, PyObject* dst[],
|
||||
return false;
|
||||
} else if (param.check(obj)) {
|
||||
dst[i++] = obj;
|
||||
} else if (allow_varargs_intlist && arg_pos == 0 && !is_kwd) {
|
||||
} else if (allow_varargs_intlist && arg_pos == 0 && !is_kwd &&
|
||||
THPUtils_checkLong(obj)) {
|
||||
// take all positional arguments as this parameter
|
||||
// e.g. permute(1, 2, 3) -> permute((1, 2, 3))
|
||||
dst[i++] = args;
|
||||
@ -365,7 +371,7 @@ bool FunctionSignature::parse(PyObject* args, PyObject* kwargs, PyObject* dst[],
|
||||
|
||||
if (!is_kwd) {
|
||||
arg_pos++;
|
||||
} else {
|
||||
} else if (obj) {
|
||||
remaining_kwargs--;
|
||||
}
|
||||
}
|
||||
@ -420,7 +426,7 @@ void PythonArgParser::print_error(PyObject* args, PyObject* kwargs, PyObject* pa
|
||||
std::vector<int> plausible_idxs;
|
||||
ssize_t i = 0;
|
||||
for (auto& signature : signatures_) {
|
||||
if (num_args >= signature.min_args && num_args <= signature.max_args && !signature.deprecated) {
|
||||
if (num_args >= signature.min_args && num_args <= signature.max_args && !signature.hidden) {
|
||||
plausible_idxs.push_back(i);
|
||||
}
|
||||
i++;
|
||||
@ -433,7 +439,7 @@ void PythonArgParser::print_error(PyObject* args, PyObject* kwargs, PyObject* pa
|
||||
|
||||
std::vector<std::string> options;
|
||||
for (auto& signature : signatures_) {
|
||||
if (!signature.deprecated) {
|
||||
if (!signature.hidden) {
|
||||
options.push_back(signature.toString());
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user