mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
torch.from_numpy for complex dtypes (#35531)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/35531 Differential Revision: D20693581 Pulled By: anjali411 fbshipit-source-id: d53e26b4175452fa00b287efbfceea18104c1364
This commit is contained in:
committed by
Facebook GitHub Bot
parent
f101949390
commit
96eec95ece
@ -113,6 +113,8 @@ int CaffeToNumpyType(const TypeMeta& meta) {
|
||||
{TypeMeta::Id<bool>(), NPY_BOOL},
|
||||
{TypeMeta::Id<double>(), NPY_DOUBLE},
|
||||
{TypeMeta::Id<float>(), NPY_FLOAT},
|
||||
{TypeMeta::Id<std::complex<double>>(), NPY_COMPLEX128},
|
||||
{TypeMeta::Id<std::complex<float>>(), NPY_COMPLEX64},
|
||||
{TypeMeta::Id<at::Half>(), NPY_FLOAT16},
|
||||
{TypeMeta::Id<int>(), NPY_INT},
|
||||
{TypeMeta::Id<int8_t>(), NPY_INT8},
|
||||
|
@ -4527,6 +4527,8 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||
np.double,
|
||||
np.float,
|
||||
np.float16,
|
||||
np.complex64,
|
||||
np.complex128,
|
||||
np.int64,
|
||||
np.int32,
|
||||
np.int16,
|
||||
@ -4535,6 +4537,11 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||
np.longlong,
|
||||
np.bool,
|
||||
]
|
||||
complex_dtypes = [
|
||||
np.complex64,
|
||||
np.complex128,
|
||||
]
|
||||
|
||||
for dtype in dtypes:
|
||||
array = np.array([1, 2, 3, 4], dtype=dtype)
|
||||
tensor_from_array = torch.from_numpy(array)
|
||||
@ -4542,6 +4549,8 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||
# implements `==`
|
||||
for i in range(len(array)):
|
||||
self.assertEqual(tensor_from_array[i], array[i])
|
||||
# ufunc 'remainder' not supported for complex dtypes
|
||||
if dtype not in complex_dtypes:
|
||||
# This is a special test case for Windows
|
||||
# https://github.com/pytorch/pytorch/issues/22615
|
||||
array2 = array % 2
|
||||
@ -4550,7 +4559,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||
self.assertEqual(tensor_from_array2[i], array2[i])
|
||||
|
||||
# Test unsupported type
|
||||
array = np.array([1, 2, 3, 4], dtype=np.complex)
|
||||
array = np.array([1, 2, 3, 4], dtype=np.uint16)
|
||||
with self.assertRaises(TypeError):
|
||||
tensor_from_array = torch.from_numpy(array)
|
||||
|
||||
|
@ -2199,8 +2199,9 @@ the tensor will be reflected in the :attr:`ndarray` and vice versa. The returned
|
||||
tensor is not resizable.
|
||||
|
||||
It currently accepts :attr:`ndarray` with dtypes of ``numpy.float64``,
|
||||
``numpy.float32``, ``numpy.float16``, ``numpy.int64``, ``numpy.int32``,
|
||||
``numpy.int16``, ``numpy.int8``, ``numpy.uint8``, and ``numpy.bool``.
|
||||
``numpy.float32``, ``numpy.float16``, ``numpy.complex64``, ``numpy.complex128``,
|
||||
``numpy.int64``, ``numpy.int32``, ``numpy.int16``, ``numpy.int8``, ``numpy.uint8``,
|
||||
and ``numpy.bool``.
|
||||
|
||||
Example::
|
||||
|
||||
|
@ -190,11 +190,11 @@ at::Tensor tensor_from_numpy(PyObject* obj) {
|
||||
|
||||
int aten_to_numpy_dtype(const ScalarType scalar_type) {
|
||||
switch (scalar_type) {
|
||||
case kComplexDouble: return NPY_COMPLEX128;
|
||||
case kComplexFloat: return NPY_COMPLEX64;
|
||||
case kDouble: return NPY_DOUBLE;
|
||||
case kFloat: return NPY_FLOAT;
|
||||
case kHalf: return NPY_HALF;
|
||||
case kComplexDouble: return NPY_COMPLEX128;
|
||||
case kComplexFloat: return NPY_COMPLEX64;
|
||||
case kLong: return NPY_INT64;
|
||||
case kInt: return NPY_INT32;
|
||||
case kShort: return NPY_INT16;
|
||||
@ -211,6 +211,8 @@ ScalarType numpy_dtype_to_aten(int dtype) {
|
||||
case NPY_DOUBLE: return kDouble;
|
||||
case NPY_FLOAT: return kFloat;
|
||||
case NPY_HALF: return kHalf;
|
||||
case NPY_COMPLEX64: return kComplexFloat;
|
||||
case NPY_COMPLEX128: return kComplexDouble;
|
||||
case NPY_INT16: return kShort;
|
||||
case NPY_INT8: return kChar;
|
||||
case NPY_UINT8: return kByte;
|
||||
@ -236,7 +238,7 @@ ScalarType numpy_dtype_to_aten(int dtype) {
|
||||
if (!pytype) throw python_error();
|
||||
throw TypeError(
|
||||
"can't convert np.ndarray of type %s. The only supported types are: "
|
||||
"float64, float32, float16, int64, int32, int16, int8, uint8, and bool.",
|
||||
"float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.",
|
||||
((PyTypeObject*)pytype.get())->tp_name);
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user