Modified assertEqual to handle complex tensors (#33773)

Summary:
- Modified assertEqual to handle complex tensors
- added a test in test_torch.py to test torch.zeros
- added dispatch for complex for index_kernel, index_put_kernel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/33773

Differential Revision: D20135553

Pulled By: anjali411

fbshipit-source-id: f716604535c0447ecffa335b0fc843431397c988
This commit is contained in:
anjali411
2020-02-28 08:40:20 -08:00
committed by Facebook Github Bot
parent 09046713cc
commit dece155335
5 changed files with 20 additions and 6 deletions

View File

@ -93,7 +93,7 @@ void cpu_index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef
}
void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) {
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
iter.dtype(), "index_cpu", [&] {
cpu_index_kernel<scalar_t>(iter, index_size, index_stride, [](char* dst, char* src, int64_t offset) {
*(scalar_t*)dst = *(scalar_t*)(src + offset);
@ -103,7 +103,7 @@ void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef inde
void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) {
// NOTE: duplicate indices are only supported if accumulate is true.
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
iter.dtype(), "index_put", [&] {
if (accumulate) {
// TODO: investigate parallelization of the accumulate kernel. Unlike the non-accumulate case,

View File

@ -88,7 +88,7 @@ void index_put_kernel_impl(TensorIterator& iter, IntArrayRef index_size, IntArra
}
static void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride) {
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, iter.dtype(), "index_cuda", [&] {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, iter.dtype(), "index_cuda", [&] {
using dtype = OpaqueType<sizeof(scalar_t)>;
index_kernel_impl<dtype>(iter, index_size, index_stride);
});
@ -97,7 +97,7 @@ static void index_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayR
static void index_put_kernel(TensorIterator& iter, IntArrayRef index_size, IntArrayRef index_stride, bool accumulate) {
AT_ASSERTM(!accumulate, "index_put does not support accumulate=true");
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, iter.dtype(), "index_put", [&] {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, iter.dtype(), "index_put", [&] {
using dtype = OpaqueType<sizeof(scalar_t)>;
index_put_kernel_impl<dtype>(iter, index_size, index_stride);
});

View File

@ -333,10 +333,13 @@ static inline ScalarType toUnderlying(ScalarType t) {
static inline bool isSignedType(ScalarType t) {
TORCH_CHECK(!isQIntType(t), "isSignedType not supported for quantized types");
#define CASE_SIGNED(ctype, name) \
case ScalarType::name: \
case ScalarType::name: \
return std::numeric_limits<ctype>::is_signed;
switch (t) {
case ScalarType::ComplexFloat: \
case ScalarType::ComplexDouble: \
return true; \
AT_FORALL_SCALAR_TYPES_AND3(Half, Bool, BFloat16, CASE_SIGNED)
default:
AT_ERROR("Unknown ScalarType");

View File

@ -988,6 +988,10 @@ class _TestTorchMixin(object):
expected = torch.tensor([[0.]], dtype=torch.bfloat16)
self.assertEqual(bfloat16Tensor, expected)
complexTensor = torch.zeros(2, 2, dtype=torch.complex64)
expected = torch.tensor([[0., 0.], [0., 0.]], dtype=torch.complex64)
self.assertEqual(complexTensor, expected)
def test_zeros_out(self):
shape = (3, 4)
out = torch.zeros(shape)

View File

@ -817,7 +817,7 @@ class TestCase(expecttest.TestCase):
b = b.to(torch.int)
diff = a - b
if a.is_floating_point():
if a.dtype.is_complex or a.dtype.is_floating_point:
# check that NaNs are in the same locations
nan_mask = torch.isnan(a)
self.assertTrue(torch.equal(nan_mask, torch.isnan(b)), message)
@ -829,8 +829,15 @@ class TestCase(expecttest.TestCase):
self.assertTrue(torch.equal(inf_sign, torch.isinf(b).sign()), message)
diff[inf_mask] = 0
# TODO: implement abs on CharTensor (int8)
# TODO: modify abs to return float/double for ComplexFloat/ComplexDouble
if diff.is_signed() and diff.dtype != torch.int8:
diff = diff.abs()
# if diff is complex, the imaginary component for diff will be 0
# from the previous step, hence converting it to float and double is fine.
if diff.dtype == torch.complex64:
diff = diff.to(torch.float)
elif diff.dtype == torch.complex128:
diff = diff.to(torch.double)
max_err = diff.max()
self.assertLessEqual(max_err, prec, message)
super(TestCase, self).assertEqual(x.is_sparse, y.is_sparse, message)