Generate autograd functions for NN / more refactors (#3136)

Generate autograd functions for NN and implement more derivatives in derivatives.yaml

A big refactor of gen_variable_type.py
This commit is contained in:
Sam Gross
2017-10-19 15:03:26 -04:00
committed by GitHub
parent 98e67448fa
commit f1f64c8d07
23 changed files with 1516 additions and 517 deletions

View File

@ -27,6 +27,7 @@ static std::unordered_map<std::string, ParameterType> type_map = {
FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
: optional(false)
, keyword_only(keyword_only)
, size(0)
, default_scalar(0)
{
auto space = fmt.find(' ');
@ -35,6 +36,13 @@ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
}
auto type_str = fmt.substr(0, space);
auto bracket = type_str.find('[');
if (bracket != std::string::npos) {
auto size_str = type_str.substr(bracket + 1, type_str.length() - bracket - 2);
size = atoi(size_str.c_str());
type_str = type_str.substr(0, bracket);
}
auto name_str = fmt.substr(space + 1);
type_ = type_map[type_str];
@ -55,12 +63,20 @@ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
bool FunctionParameter::check(PyObject* obj) {
switch (type_) {
case ParameterType::TENSOR: return THPVariable_Check(obj);
case ParameterType::TENSOR: {
return THPVariable_Check(obj) || (optional && obj == Py_None);
}
case ParameterType::SCALAR: return THPUtils_checkDouble(obj);
case ParameterType::INT64: return THPUtils_checkLong(obj);
case ParameterType::DOUBLE: return THPUtils_checkDouble(obj);
case ParameterType::TENSOR_LIST: return PyTuple_Check(obj) || PyList_Check(obj);
case ParameterType::INT_LIST: return PyTuple_Check(obj) || PyList_Check(obj);
case ParameterType::INT_LIST: {
if (PyTuple_Check(obj) || PyList_Check(obj)) {
return true;
}
// if a size is specified (e.g. IntList[2]) we also allow passing a single int
return size > 0 && THPUtils_checkLong(obj);
}
case ParameterType::GENERATOR: return false;
case ParameterType::BOOL: return PyBool_Check(obj);
case ParameterType::STORAGE: return false;
@ -97,6 +113,10 @@ void FunctionParameter::set_default_str(const std::string& str) {
default_double = atof(str.c_str());
} else if (type_ == ParameterType::SCALAR) {
default_scalar = Scalar(atof(str.c_str()));
} else if (type_ == ParameterType::INT_LIST) {
if (str != "None") {
default_intlist.assign(size, std::stoi(str));
}
}
}