Added numpy conversion (#18505)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18505
ghimport-source-id: f3c9b9251e5793f9e192f587194ddfebb45facc1

Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#18505 [WIP]Added numpy conversion**
* #18166 Bool Tensor for CUDA

Differential Revision: D14646403

fbshipit-source-id: 79d39d692c778ce1981c1d35b1c33e3d93111041
This commit is contained in:
Iurii Zdebskyi
2019-04-03 07:22:38 -07:00
committed by Facebook Github Bot
parent 7349dbb7ce
commit 48f70ea0a2
5 changed files with 55 additions and 12 deletions

View File

@ -25,7 +25,7 @@ _(double,Double,d) /* 7 */ \
_(at::ComplexHalf,ComplexHalf,z) /* 8 */ \
_(std::complex<float>,ComplexFloat,z) /* 9 */ \
_(std::complex<double>,ComplexDouble,z) /* 10 */ \
_(bool,Bool,i) /* 11 */
_(bool,Bool,i) /* 11 */
// If you want to support ComplexHalf for real, replace occurrences
// of this macro with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX. But
@ -193,19 +193,25 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
if (isComplexType(a) || isComplexType(b)) {
AT_ERROR("promoteTypes with complex numbers is not handled yet; figure out what the correct rules should be");
}
// this matrix has to be consistent with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX so that's why we have to add
// undefined as we are not sure what is the corrent values for the type promotions in complex type cases.
static constexpr ScalarType _promoteTypesLookup
[static_cast<int>(ScalarType::NumOptions)]
[static_cast<int>(ScalarType::NumOptions)] = {
/* u1 i1 i2 i4 i8 f2 f4 f8 b1 */
/* u1 */ { u1, i2, i2, i4, i8, f2, f4, f8, u1 },
/* i1 */ { i2, i1, i2, i4, i8, f2, f4, f8, i1 },
/* i2 */ { i2, i2, i2, i4, i8, f2, f4, f8, i2 },
/* i4 */ { i4, i4, i4, i4, i8, f2, f4, f8, i4 },
/* i8 */ { i8, i8, i8, i8, i8, f2, f4, f8, i8 },
/* f2 */ { f2, f2, f2, f2, f2, f2, f4, f8, f2 },
/* f4 */ { f4, f4, f4, f4, f4, f4, f4, f8, f4 },
/* f8 */ { f8, f8, f8, f8, f8, f8, f8, f8, f8 },
/* b1 */ { u1, i1, i2, i4, i8, f2, f4, f8, b1 },
/* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 */
/* u1 */ { u1, i2, i2, i4, i8, f2, f4, f8, ud, ud, ud, u1 },
/* i1 */ { i2, i1, i2, i4, i8, f2, f4, f8, ud, ud, ud, i1 },
/* i2 */ { i2, i2, i2, i4, i8, f2, f4, f8, ud, ud, ud, i2 },
/* i4 */ { i4, i4, i4, i4, i8, f2, f4, f8, ud, ud, ud, i4 },
/* i8 */ { i8, i8, i8, i8, i8, f2, f4, f8, ud, ud, ud, i8 },
/* f2 */ { f2, f2, f2, f2, f2, f2, f4, f8, ud, ud, ud, f2 },
/* f4 */ { f4, f4, f4, f4, f4, f4, f4, f8, ud, ud, ud, f4 },
/* f8 */ { f8, f8, f8, f8, f8, f8, f8, f8, ud, ud, ud, f8 },
/* c2 */ { ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud },
/* c4 */ { ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud },
/* c8 */ { ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud },
/* b1 */ { u1, i1, i2, i4, i8, f2, f4, f8, ud, ud, ud, b1 },
};
return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];
}

View File

@ -414,6 +414,10 @@ class TestCase(expecttest.TestCase):
self.assertEqual(x.item(), y, prec, message, allow_inf)
elif isinstance(y, torch.Tensor) and isinstance(x, Number):
self.assertEqual(x, y.item(), prec, message, allow_inf)
elif isinstance(x, torch.Tensor) and isinstance(y, numpy.bool_):
self.assertEqual(x.item(), y, prec, message, allow_inf)
elif isinstance(y, torch.Tensor) and isinstance(x, numpy.bool_):
self.assertEqual(x, y.item(), prec, message, allow_inf)
elif isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
def assertTensorsEqual(a, b):
super(TestCase, self).assertEqual(a.size(), b.size(), message)

View File

@ -10001,6 +10001,23 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
y[0][1] = 3
self.assertTrue(x[0][1] == 3)
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_to_numpy_bool(self):
x = torch.tensor([True, False], dtype=torch.bool)
self.assertEqual(x.dtype, torch.bool)
y = x.numpy()
self.assertEqual(y.dtype, np.bool)
for i in range(len(x)):
self.assertEqual(x[i], y[i])
x = torch.tensor([True], dtype=torch.bool)
self.assertEqual(x.dtype, torch.bool)
y = x.numpy()
self.assertEqual(y.dtype, np.bool)
self.assertEqual(x[0], y[0])
def test_dlpack_conversion(self):
x = torch.randn(1, 2, 3, 4).type('torch.FloatTensor')
z = from_dlpack(to_dlpack(x))
@ -10024,6 +10041,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
np.int8,
np.uint8,
np.longlong,
np.bool,
]
for dtype in dtypes:
array = np.array([1, 2, 3, 4], dtype=dtype)
@ -10075,6 +10093,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
np.int16,
np.int8,
np.uint8,
np.bool,
]
incorrect_byteorder = '>' if sys.byteorder == 'little' else '<'
@ -10120,7 +10139,8 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
np.int64,
np.int32,
np.int16,
np.uint8
np.uint8,
np.bool,
]
for dtype in dtypes:
self.assertEqual(dtype(42), torch.tensor(dtype(42)).item())

View File

@ -219,6 +219,15 @@ static int64_t dispatch_to_CLong(const Tensor & self) {
return self.item<int64_t>();
}
static bool dispatch_to_Bool(const Tensor & self) {
AutoNoGIL no_gil;
OptionalDeviceGuard device_guard(device_of(self));
if (self.numel() != 1) {
throw ValueError("only one element tensors can be converted to Python scalars");
}
return self.item<bool>();
}
static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) {
HANDLE_TH_ERRORS
jit::tracer::warn("Converting a tensor to a Python float", jit::tracer::WARN_PYTHON_DATAFLOW);
@ -439,6 +448,8 @@ static PyObject * THPVariable_item(PyObject* self, PyObject* args)
return wrap(dispatch_to_CDouble(self_));
} else if (self_.is_complex()) {
return wrap(dispatch_to_CComplexDouble(self_));
} else if (self_.scalar_type() == ScalarType::Bool) {
return wrap(dispatch_to_Bool(self_));
} else {
return wrap(dispatch_to_CLong(self_));
}

View File

@ -156,6 +156,7 @@ static int aten_to_dtype(const ScalarType scalar_type) {
case kShort: return NPY_INT16;
case kChar: return NPY_INT8;
case kByte: return NPY_UINT8;
case kBool: return NPY_BOOL;
default:
throw ValueError("Got unsupported ScalarType ", toString(scalar_type));
}
@ -170,6 +171,7 @@ ScalarType numpy_dtype_to_aten(int dtype) {
case NPY_INT16: return kShort;
case NPY_INT8: return kChar;
case NPY_UINT8: return kByte;
case NPY_BOOL: return kBool;
default:
// Workaround: MSVC does not support two switch cases that have the same value
if (dtype == NPY_LONGLONG || dtype == NPY_INT64) {