torch.from_numpy for complex dtypes (#35531)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/35531

Differential Revision: D20693581

Pulled By: anjali411

fbshipit-source-id: d53e26b4175452fa00b287efbfceea18104c1364
This commit is contained in:
anjali411
2020-03-27 14:38:26 -07:00
committed by Facebook GitHub Bot
parent f101949390
commit 96eec95ece
4 changed files with 28 additions and 14 deletions

View File

@ -113,6 +113,8 @@ int CaffeToNumpyType(const TypeMeta& meta) {
{TypeMeta::Id<bool>(), NPY_BOOL}, {TypeMeta::Id<bool>(), NPY_BOOL},
{TypeMeta::Id<double>(), NPY_DOUBLE}, {TypeMeta::Id<double>(), NPY_DOUBLE},
{TypeMeta::Id<float>(), NPY_FLOAT}, {TypeMeta::Id<float>(), NPY_FLOAT},
{TypeMeta::Id<std::complex<double>>(), NPY_COMPLEX128},
{TypeMeta::Id<std::complex<float>>(), NPY_COMPLEX64},
{TypeMeta::Id<at::Half>(), NPY_FLOAT16}, {TypeMeta::Id<at::Half>(), NPY_FLOAT16},
{TypeMeta::Id<int>(), NPY_INT}, {TypeMeta::Id<int>(), NPY_INT},
{TypeMeta::Id<int8_t>(), NPY_INT8}, {TypeMeta::Id<int8_t>(), NPY_INT8},

View File

@ -4527,6 +4527,8 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
np.double, np.double,
np.float, np.float,
np.float16, np.float16,
np.complex64,
np.complex128,
np.int64, np.int64,
np.int32, np.int32,
np.int16, np.int16,
@ -4535,6 +4537,11 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
np.longlong, np.longlong,
np.bool, np.bool,
] ]
complex_dtypes = [
np.complex64,
np.complex128,
]
for dtype in dtypes: for dtype in dtypes:
array = np.array([1, 2, 3, 4], dtype=dtype) array = np.array([1, 2, 3, 4], dtype=dtype)
tensor_from_array = torch.from_numpy(array) tensor_from_array = torch.from_numpy(array)
@ -4542,6 +4549,8 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
# implements `==` # implements `==`
for i in range(len(array)): for i in range(len(array)):
self.assertEqual(tensor_from_array[i], array[i]) self.assertEqual(tensor_from_array[i], array[i])
# ufunc 'remainder' not supported for complex dtypes
if dtype not in complex_dtypes:
# This is a special test case for Windows # This is a special test case for Windows
# https://github.com/pytorch/pytorch/issues/22615 # https://github.com/pytorch/pytorch/issues/22615
array2 = array % 2 array2 = array % 2
@ -4550,7 +4559,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
self.assertEqual(tensor_from_array2[i], array2[i]) self.assertEqual(tensor_from_array2[i], array2[i])
# Test unsupported type # Test unsupported type
array = np.array([1, 2, 3, 4], dtype=np.complex) array = np.array([1, 2, 3, 4], dtype=np.uint16)
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
tensor_from_array = torch.from_numpy(array) tensor_from_array = torch.from_numpy(array)

View File

@ -2199,8 +2199,9 @@ the tensor will be reflected in the :attr:`ndarray` and vice versa. The returned
tensor is not resizable. tensor is not resizable.
It currently accepts :attr:`ndarray` with dtypes of ``numpy.float64``, It currently accepts :attr:`ndarray` with dtypes of ``numpy.float64``,
``numpy.float32``, ``numpy.float16``, ``numpy.int64``, ``numpy.int32``, ``numpy.float32``, ``numpy.float16``, ``numpy.complex64``, ``numpy.complex128``,
``numpy.int16``, ``numpy.int8``, ``numpy.uint8``, and ``numpy.bool``. ``numpy.int64``, ``numpy.int32``, ``numpy.int16``, ``numpy.int8``, ``numpy.uint8``,
and ``numpy.bool``.
Example:: Example::

View File

@ -190,11 +190,11 @@ at::Tensor tensor_from_numpy(PyObject* obj) {
int aten_to_numpy_dtype(const ScalarType scalar_type) { int aten_to_numpy_dtype(const ScalarType scalar_type) {
switch (scalar_type) { switch (scalar_type) {
case kComplexDouble: return NPY_COMPLEX128;
case kComplexFloat: return NPY_COMPLEX64;
case kDouble: return NPY_DOUBLE; case kDouble: return NPY_DOUBLE;
case kFloat: return NPY_FLOAT; case kFloat: return NPY_FLOAT;
case kHalf: return NPY_HALF; case kHalf: return NPY_HALF;
case kComplexDouble: return NPY_COMPLEX128;
case kComplexFloat: return NPY_COMPLEX64;
case kLong: return NPY_INT64; case kLong: return NPY_INT64;
case kInt: return NPY_INT32; case kInt: return NPY_INT32;
case kShort: return NPY_INT16; case kShort: return NPY_INT16;
@ -211,6 +211,8 @@ ScalarType numpy_dtype_to_aten(int dtype) {
case NPY_DOUBLE: return kDouble; case NPY_DOUBLE: return kDouble;
case NPY_FLOAT: return kFloat; case NPY_FLOAT: return kFloat;
case NPY_HALF: return kHalf; case NPY_HALF: return kHalf;
case NPY_COMPLEX64: return kComplexFloat;
case NPY_COMPLEX128: return kComplexDouble;
case NPY_INT16: return kShort; case NPY_INT16: return kShort;
case NPY_INT8: return kChar; case NPY_INT8: return kChar;
case NPY_UINT8: return kByte; case NPY_UINT8: return kByte;
@ -236,7 +238,7 @@ ScalarType numpy_dtype_to_aten(int dtype) {
if (!pytype) throw python_error(); if (!pytype) throw python_error();
throw TypeError( throw TypeError(
"can't convert np.ndarray of type %s. The only supported types are: " "can't convert np.ndarray of type %s. The only supported types are: "
"float64, float32, float16, int64, int32, int16, int8, uint8, and bool.", "float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.",
((PyTypeObject*)pytype.get())->tp_name); ((PyTypeObject*)pytype.get())->tp_name);
} }