mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
torch.from_numpy for complex dtypes (#35531)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/35531 Differential Revision: D20693581 Pulled By: anjali411 fbshipit-source-id: d53e26b4175452fa00b287efbfceea18104c1364
This commit is contained in:
committed by
Facebook GitHub Bot
parent
f101949390
commit
96eec95ece
@ -190,11 +190,11 @@ at::Tensor tensor_from_numpy(PyObject* obj) {
|
||||
|
||||
int aten_to_numpy_dtype(const ScalarType scalar_type) {
|
||||
switch (scalar_type) {
|
||||
case kComplexDouble: return NPY_COMPLEX128;
|
||||
case kComplexFloat: return NPY_COMPLEX64;
|
||||
case kDouble: return NPY_DOUBLE;
|
||||
case kFloat: return NPY_FLOAT;
|
||||
case kHalf: return NPY_HALF;
|
||||
case kComplexDouble: return NPY_COMPLEX128;
|
||||
case kComplexFloat: return NPY_COMPLEX64;
|
||||
case kLong: return NPY_INT64;
|
||||
case kInt: return NPY_INT32;
|
||||
case kShort: return NPY_INT16;
|
||||
@ -211,6 +211,8 @@ ScalarType numpy_dtype_to_aten(int dtype) {
|
||||
case NPY_DOUBLE: return kDouble;
|
||||
case NPY_FLOAT: return kFloat;
|
||||
case NPY_HALF: return kHalf;
|
||||
case NPY_COMPLEX64: return kComplexFloat;
|
||||
case NPY_COMPLEX128: return kComplexDouble;
|
||||
case NPY_INT16: return kShort;
|
||||
case NPY_INT8: return kChar;
|
||||
case NPY_UINT8: return kByte;
|
||||
@ -236,7 +238,7 @@ ScalarType numpy_dtype_to_aten(int dtype) {
|
||||
if (!pytype) throw python_error();
|
||||
throw TypeError(
|
||||
"can't convert np.ndarray of type %s. The only supported types are: "
|
||||
"float64, float32, float16, int64, int32, int16, int8, uint8, and bool.",
|
||||
"float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.",
|
||||
((PyTypeObject*)pytype.get())->tp_name);
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user