mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Added numpy conversion (#18505)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18505 ghimport-source-id: f3c9b9251e5793f9e192f587194ddfebb45facc1 Stack from [ghstack](https://github.com/ezyang/ghstack): * **#18505 [WIP]Added numpy conversion** * #18166 Bool Tensor for CUDA Differential Revision: D14646403 fbshipit-source-id: 79d39d692c778ce1981c1d35b1c33e3d93111041
This commit is contained in:
committed by
Facebook Github Bot
parent
7349dbb7ce
commit
48f70ea0a2
@ -25,7 +25,7 @@ _(double,Double,d) /* 7 */ \
|
||||
_(at::ComplexHalf,ComplexHalf,z) /* 8 */ \
|
||||
_(std::complex<float>,ComplexFloat,z) /* 9 */ \
|
||||
_(std::complex<double>,ComplexDouble,z) /* 10 */ \
|
||||
_(bool,Bool,i) /* 11 */
|
||||
_(bool,Bool,i) /* 11 */
|
||||
|
||||
// If you want to support ComplexHalf for real, replace occurrences
|
||||
// of this macro with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX. But
|
||||
@ -193,19 +193,25 @@ static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
|
||||
if (isComplexType(a) || isComplexType(b)) {
|
||||
AT_ERROR("promoteTypes with complex numbers is not handled yet; figure out what the correct rules should be");
|
||||
}
|
||||
|
||||
// this matrix has to be consistent with AT_FORALL_SCALAR_TYPES_WITH_COMPLEX so that's why we have to add
|
||||
// undefined as we are not sure what is the corrent values for the type promotions in complex type cases.
|
||||
static constexpr ScalarType _promoteTypesLookup
|
||||
[static_cast<int>(ScalarType::NumOptions)]
|
||||
[static_cast<int>(ScalarType::NumOptions)] = {
|
||||
/* u1 i1 i2 i4 i8 f2 f4 f8 b1 */
|
||||
/* u1 */ { u1, i2, i2, i4, i8, f2, f4, f8, u1 },
|
||||
/* i1 */ { i2, i1, i2, i4, i8, f2, f4, f8, i1 },
|
||||
/* i2 */ { i2, i2, i2, i4, i8, f2, f4, f8, i2 },
|
||||
/* i4 */ { i4, i4, i4, i4, i8, f2, f4, f8, i4 },
|
||||
/* i8 */ { i8, i8, i8, i8, i8, f2, f4, f8, i8 },
|
||||
/* f2 */ { f2, f2, f2, f2, f2, f2, f4, f8, f2 },
|
||||
/* f4 */ { f4, f4, f4, f4, f4, f4, f4, f8, f4 },
|
||||
/* f8 */ { f8, f8, f8, f8, f8, f8, f8, f8, f8 },
|
||||
/* b1 */ { u1, i1, i2, i4, i8, f2, f4, f8, b1 },
|
||||
/* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 */
|
||||
/* u1 */ { u1, i2, i2, i4, i8, f2, f4, f8, ud, ud, ud, u1 },
|
||||
/* i1 */ { i2, i1, i2, i4, i8, f2, f4, f8, ud, ud, ud, i1 },
|
||||
/* i2 */ { i2, i2, i2, i4, i8, f2, f4, f8, ud, ud, ud, i2 },
|
||||
/* i4 */ { i4, i4, i4, i4, i8, f2, f4, f8, ud, ud, ud, i4 },
|
||||
/* i8 */ { i8, i8, i8, i8, i8, f2, f4, f8, ud, ud, ud, i8 },
|
||||
/* f2 */ { f2, f2, f2, f2, f2, f2, f4, f8, ud, ud, ud, f2 },
|
||||
/* f4 */ { f4, f4, f4, f4, f4, f4, f4, f8, ud, ud, ud, f4 },
|
||||
/* f8 */ { f8, f8, f8, f8, f8, f8, f8, f8, ud, ud, ud, f8 },
|
||||
/* c2 */ { ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud },
|
||||
/* c4 */ { ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud },
|
||||
/* c8 */ { ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud },
|
||||
/* b1 */ { u1, i1, i2, i4, i8, f2, f4, f8, ud, ud, ud, b1 },
|
||||
};
|
||||
return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];
|
||||
}
|
||||
|
@ -414,6 +414,10 @@ class TestCase(expecttest.TestCase):
|
||||
self.assertEqual(x.item(), y, prec, message, allow_inf)
|
||||
elif isinstance(y, torch.Tensor) and isinstance(x, Number):
|
||||
self.assertEqual(x, y.item(), prec, message, allow_inf)
|
||||
elif isinstance(x, torch.Tensor) and isinstance(y, numpy.bool_):
|
||||
self.assertEqual(x.item(), y, prec, message, allow_inf)
|
||||
elif isinstance(y, torch.Tensor) and isinstance(x, numpy.bool_):
|
||||
self.assertEqual(x, y.item(), prec, message, allow_inf)
|
||||
elif isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
|
||||
def assertTensorsEqual(a, b):
|
||||
super(TestCase, self).assertEqual(a.size(), b.size(), message)
|
||||
|
@ -10001,6 +10001,23 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||
y[0][1] = 3
|
||||
self.assertTrue(x[0][1] == 3)
|
||||
|
||||
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
||||
def test_to_numpy_bool(self):
|
||||
x = torch.tensor([True, False], dtype=torch.bool)
|
||||
self.assertEqual(x.dtype, torch.bool)
|
||||
|
||||
y = x.numpy()
|
||||
self.assertEqual(y.dtype, np.bool)
|
||||
for i in range(len(x)):
|
||||
self.assertEqual(x[i], y[i])
|
||||
|
||||
x = torch.tensor([True], dtype=torch.bool)
|
||||
self.assertEqual(x.dtype, torch.bool)
|
||||
|
||||
y = x.numpy()
|
||||
self.assertEqual(y.dtype, np.bool)
|
||||
self.assertEqual(x[0], y[0])
|
||||
|
||||
def test_dlpack_conversion(self):
|
||||
x = torch.randn(1, 2, 3, 4).type('torch.FloatTensor')
|
||||
z = from_dlpack(to_dlpack(x))
|
||||
@ -10024,6 +10041,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||
np.int8,
|
||||
np.uint8,
|
||||
np.longlong,
|
||||
np.bool,
|
||||
]
|
||||
for dtype in dtypes:
|
||||
array = np.array([1, 2, 3, 4], dtype=dtype)
|
||||
@ -10075,6 +10093,7 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||
np.int16,
|
||||
np.int8,
|
||||
np.uint8,
|
||||
np.bool,
|
||||
]
|
||||
|
||||
incorrect_byteorder = '>' if sys.byteorder == 'little' else '<'
|
||||
@ -10120,7 +10139,8 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
||||
np.int64,
|
||||
np.int32,
|
||||
np.int16,
|
||||
np.uint8
|
||||
np.uint8,
|
||||
np.bool,
|
||||
]
|
||||
for dtype in dtypes:
|
||||
self.assertEqual(dtype(42), torch.tensor(dtype(42)).item())
|
||||
|
@ -219,6 +219,15 @@ static int64_t dispatch_to_CLong(const Tensor & self) {
|
||||
return self.item<int64_t>();
|
||||
}
|
||||
|
||||
static bool dispatch_to_Bool(const Tensor & self) {
|
||||
AutoNoGIL no_gil;
|
||||
OptionalDeviceGuard device_guard(device_of(self));
|
||||
if (self.numel() != 1) {
|
||||
throw ValueError("only one element tensors can be converted to Python scalars");
|
||||
}
|
||||
return self.item<bool>();
|
||||
}
|
||||
|
||||
static PyObject * THPVariable_float_scalar(PyObject* self, PyObject* args) {
|
||||
HANDLE_TH_ERRORS
|
||||
jit::tracer::warn("Converting a tensor to a Python float", jit::tracer::WARN_PYTHON_DATAFLOW);
|
||||
@ -439,6 +448,8 @@ static PyObject * THPVariable_item(PyObject* self, PyObject* args)
|
||||
return wrap(dispatch_to_CDouble(self_));
|
||||
} else if (self_.is_complex()) {
|
||||
return wrap(dispatch_to_CComplexDouble(self_));
|
||||
} else if (self_.scalar_type() == ScalarType::Bool) {
|
||||
return wrap(dispatch_to_Bool(self_));
|
||||
} else {
|
||||
return wrap(dispatch_to_CLong(self_));
|
||||
}
|
||||
|
@ -156,6 +156,7 @@ static int aten_to_dtype(const ScalarType scalar_type) {
|
||||
case kShort: return NPY_INT16;
|
||||
case kChar: return NPY_INT8;
|
||||
case kByte: return NPY_UINT8;
|
||||
case kBool: return NPY_BOOL;
|
||||
default:
|
||||
throw ValueError("Got unsupported ScalarType ", toString(scalar_type));
|
||||
}
|
||||
@ -170,6 +171,7 @@ ScalarType numpy_dtype_to_aten(int dtype) {
|
||||
case NPY_INT16: return kShort;
|
||||
case NPY_INT8: return kChar;
|
||||
case NPY_UINT8: return kByte;
|
||||
case NPY_BOOL: return kBool;
|
||||
default:
|
||||
// Workaround: MSVC does not support two switch cases that have the same value
|
||||
if (dtype == NPY_LONGLONG || dtype == NPY_INT64) {
|
||||
|
Reference in New Issue
Block a user