mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Accept more numpy scalars as doubles (#9659)
Summary: Allows mulitplication of e.g. numpy.float32 with tensors. This came up with #9468 If you want this and after the other patch is done, I'll add tests (but that would be conflicting, so I prefer to wait). Pull Request resolved: https://github.com/pytorch/pytorch/pull/9659 Differential Revision: D8948078 Pulled By: weiyangfb fbshipit-source-id: c7dcc57b63e2f100df837f70e1299395692f1a1b
This commit is contained in:
committed by
Facebook Github Bot
parent
8bd80a6b74
commit
267e1ec112
@ -10,6 +10,9 @@ PyObject* tensor_to_numpy(const at::Tensor& tensor) {
|
||||
at::Tensor tensor_from_numpy(PyObject* obj) {
|
||||
throw std::runtime_error("PyTorch was compiled without NumPy support");
|
||||
}
|
||||
bool is_numpy_scalar(PyObject* obj) {
|
||||
throw std::runtime_error("PyTorch was compiled without NumPy support");
|
||||
}
|
||||
}}
|
||||
#else
|
||||
|
||||
@ -180,6 +183,11 @@ ScalarType numpy_dtype_to_aten(int dtype) {
|
||||
((PyTypeObject*)pytype.get())->tp_name);
|
||||
}
|
||||
|
||||
bool is_numpy_scalar(PyObject* obj) {
|
||||
return (PyArray_IsIntegerScalar(obj) ||
|
||||
PyArray_IsScalar(obj, Floating));
|
||||
}
|
||||
|
||||
}} // namespace torch::utils
|
||||
|
||||
#endif // USE_NUMPY
|
||||
|
||||
Reference in New Issue
Block a user