torch.from_numpy for complex dtypes (#35531)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/35531

Differential Revision: D20693581

Pulled By: anjali411

fbshipit-source-id: d53e26b4175452fa00b287efbfceea18104c1364
This commit is contained in:
anjali411
2020-03-27 14:38:26 -07:00
committed by Facebook GitHub Bot
parent f101949390
commit 96eec95ece
4 changed files with 28 additions and 14 deletions

View File

@ -190,11 +190,11 @@ at::Tensor tensor_from_numpy(PyObject* obj) {
int aten_to_numpy_dtype(const ScalarType scalar_type) {
switch (scalar_type) {
case kComplexDouble: return NPY_COMPLEX128;
case kComplexFloat: return NPY_COMPLEX64;
case kDouble: return NPY_DOUBLE;
case kFloat: return NPY_FLOAT;
case kHalf: return NPY_HALF;
case kComplexDouble: return NPY_COMPLEX128;
case kComplexFloat: return NPY_COMPLEX64;
case kLong: return NPY_INT64;
case kInt: return NPY_INT32;
case kShort: return NPY_INT16;
@ -211,6 +211,8 @@ ScalarType numpy_dtype_to_aten(int dtype) {
case NPY_DOUBLE: return kDouble;
case NPY_FLOAT: return kFloat;
case NPY_HALF: return kHalf;
case NPY_COMPLEX64: return kComplexFloat;
case NPY_COMPLEX128: return kComplexDouble;
case NPY_INT16: return kShort;
case NPY_INT8: return kChar;
case NPY_UINT8: return kByte;
@ -236,7 +238,7 @@ ScalarType numpy_dtype_to_aten(int dtype) {
if (!pytype) throw python_error();
throw TypeError(
"can't convert np.ndarray of type %s. The only supported types are: "
"float64, float32, float16, int64, int32, int16, int8, uint8, and bool.",
"float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.",
((PyTypeObject*)pytype.get())->tp_name);
}