mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 14:15:01 +08:00
Bind 0-dim variables without requires grad to int64/double similar to how we do with Scalar. (#6637)
Note: - Only integral scalar types bind to int64 - Both integral and floating point scalar types bind to double (same rules as python numbers).
This commit is contained in:
committed by
Soumith Chintala
parent
639dd0e324
commit
30849eb668
@ -85,7 +85,8 @@ bool FunctionParameter::check(PyObject* obj) {
|
||||
case ParameterType::TENSOR: {
|
||||
return THPVariable_Check(obj);
|
||||
}
|
||||
case ParameterType::SCALAR: {
|
||||
case ParameterType::SCALAR:
|
||||
case ParameterType::DOUBLE: {
|
||||
// NOTE: we don't currently accept most NumPy types as Scalars. np.float64
|
||||
// is okay because it's a subclass of PyFloat. We may want to change this
|
||||
// in the future.
|
||||
@ -98,8 +99,16 @@ bool FunctionParameter::check(PyObject* obj) {
|
||||
}
|
||||
return false;
|
||||
}
|
||||
case ParameterType::INT64: return THPUtils_checkLong(obj);
|
||||
case ParameterType::DOUBLE: return THPUtils_checkDouble(obj);
|
||||
case ParameterType::INT64: {
|
||||
if (THPUtils_checkLong(obj)) {
|
||||
return true;
|
||||
}
|
||||
if (THPVariable_Check(obj)) {
|
||||
auto& var = ((THPVariable*)obj)->cdata;
|
||||
return at::isIntegralType(var.type().scalarType()) && !var.requires_grad() && var.dim() == 0;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
case ParameterType::TENSOR_LIST: return PyTuple_Check(obj) || PyList_Check(obj);
|
||||
case ParameterType::INT_LIST: {
|
||||
if (PyTuple_Check(obj) || PyList_Check(obj)) {
|
||||
|
Reference in New Issue
Block a user