Bind 0-dim variables without requires grad to int64/double similar to how we do with Scalar. (#6637)

Note:
- Only integral scalar types bind to int64
- Both integral and floating point scalar types bind to double (same rules as python numbers).
This commit is contained in:
gchanan
2018-04-17 09:54:49 -04:00
committed by Soumith Chintala
parent 639dd0e324
commit 30849eb668
2 changed files with 36 additions and 3 deletions

View File

@ -85,7 +85,8 @@ bool FunctionParameter::check(PyObject* obj) {
case ParameterType::TENSOR: {
return THPVariable_Check(obj);
}
case ParameterType::SCALAR: {
case ParameterType::SCALAR:
case ParameterType::DOUBLE: {
// NOTE: we don't currently accept most NumPy types as Scalars. np.float64
// is okay because it's a subclass of PyFloat. We may want to change this
// in the future.
@ -98,8 +99,16 @@ bool FunctionParameter::check(PyObject* obj) {
}
return false;
}
case ParameterType::INT64: return THPUtils_checkLong(obj);
case ParameterType::DOUBLE: return THPUtils_checkDouble(obj);
case ParameterType::INT64: {
if (THPUtils_checkLong(obj)) {
return true;
}
if (THPVariable_Check(obj)) {
auto& var = ((THPVariable*)obj)->cdata;
return at::isIntegralType(var.type().scalarType()) && !var.requires_grad() && var.dim() == 0;
}
return false;
}
case ParameterType::TENSOR_LIST: return PyTuple_Check(obj) || PyList_Check(obj);
case ParameterType::INT_LIST: {
if (PyTuple_Check(obj) || PyList_Check(obj)) {