mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 00:54:56 +08:00
Generate autograd functions for NN / more refactors (#3136)
Generate autograd functions for NN and implement more derivatives in derivatives.yaml A big refactor of gen_variable_type.py
This commit is contained in:
@ -27,6 +27,7 @@ static std::unordered_map<std::string, ParameterType> type_map = {
|
||||
FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
|
||||
: optional(false)
|
||||
, keyword_only(keyword_only)
|
||||
, size(0)
|
||||
, default_scalar(0)
|
||||
{
|
||||
auto space = fmt.find(' ');
|
||||
@ -35,6 +36,13 @@ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
|
||||
}
|
||||
|
||||
auto type_str = fmt.substr(0, space);
|
||||
auto bracket = type_str.find('[');
|
||||
if (bracket != std::string::npos) {
|
||||
auto size_str = type_str.substr(bracket + 1, type_str.length() - bracket - 2);
|
||||
size = atoi(size_str.c_str());
|
||||
type_str = type_str.substr(0, bracket);
|
||||
}
|
||||
|
||||
auto name_str = fmt.substr(space + 1);
|
||||
type_ = type_map[type_str];
|
||||
|
||||
@ -55,12 +63,20 @@ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
|
||||
|
||||
bool FunctionParameter::check(PyObject* obj) {
|
||||
switch (type_) {
|
||||
case ParameterType::TENSOR: return THPVariable_Check(obj);
|
||||
case ParameterType::TENSOR: {
|
||||
return THPVariable_Check(obj) || (optional && obj == Py_None);
|
||||
}
|
||||
case ParameterType::SCALAR: return THPUtils_checkDouble(obj);
|
||||
case ParameterType::INT64: return THPUtils_checkLong(obj);
|
||||
case ParameterType::DOUBLE: return THPUtils_checkDouble(obj);
|
||||
case ParameterType::TENSOR_LIST: return PyTuple_Check(obj) || PyList_Check(obj);
|
||||
case ParameterType::INT_LIST: return PyTuple_Check(obj) || PyList_Check(obj);
|
||||
case ParameterType::INT_LIST: {
|
||||
if (PyTuple_Check(obj) || PyList_Check(obj)) {
|
||||
return true;
|
||||
}
|
||||
// if a size is specified (e.g. IntList[2]) we also allow passing a single int
|
||||
return size > 0 && THPUtils_checkLong(obj);
|
||||
}
|
||||
case ParameterType::GENERATOR: return false;
|
||||
case ParameterType::BOOL: return PyBool_Check(obj);
|
||||
case ParameterType::STORAGE: return false;
|
||||
@ -97,6 +113,10 @@ void FunctionParameter::set_default_str(const std::string& str) {
|
||||
default_double = atof(str.c_str());
|
||||
} else if (type_ == ParameterType::SCALAR) {
|
||||
default_scalar = Scalar(atof(str.c_str()));
|
||||
} else if (type_ == ParameterType::INT_LIST) {
|
||||
if (str != "None") {
|
||||
default_intlist.assign(size, std::stoi(str));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user