Disable complex inputs to torch.round (#45330)

Summary:
- Related with https://github.com/pytorch/pytorch/issues/44612
- Disable complex inputs to `torch.round`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/45330

Reviewed By: gchanan

Differential Revision: D23970781

Pulled By: anjali411

fbshipit-source-id: b8c9ac315ae0fc872701aa132367c3171fd56185
This commit is contained in:
kiyosora
2020-09-28 19:05:11 -07:00
committed by Facebook GitHub Bot
parent 0c8a6008ac
commit 8c66cd120b
3 changed files with 3 additions and 3 deletions

View File

@ -682,7 +682,7 @@ IMPLEMENT_COMPLEX_KERNEL(FLOATING, log10)
IMPLEMENT_FLOAT_KERNEL(FLOATING, log1p)
IMPLEMENT_COMPLEX_KERNEL(FLOATING, log2)
IMPLEMENT_FLOAT_KERNEL(FLOATING, i0)
IMPLEMENT_COMPLEX_KERNEL(FLOATING, round)
IMPLEMENT_FLOAT_KERNEL(FLOATING, round)
IMPLEMENT_COMPLEX_KERNEL(FLOATING, sin)
IMPLEMENT_COMPLEX_KERNEL(FLOATING, sqrt)
IMPLEMENT_COMPLEX_KERNEL(FLOATING, tan)

View File

@ -114,7 +114,7 @@ __host__ __device__ static inline c10::complex<double> nearbyint_wrapper(c10::co
}
void round_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, iter.dtype(), "round_cuda", [&]() {
AT_DISPATCH_FLOATING_TYPES_AND(ScalarType::Half, iter.dtype(), "round_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
// We do not use std::round because we would like to round midway numbers to the nearest even integer.
return nearbyint_wrapper(a);

View File

@ -4798,7 +4798,7 @@ separate_complex_tests = ['view_as_real', 'real', 'imag', 'asin', 'acos'] # ['l
complex_list = ['t', 'view', 'reshape', 'reshape_as', 'view_as', 'roll', 'clone',
'repeat', 'expand', 'flip', 'fliplr', 'flipud', 'rot90', 'transpose',
'permute', 'squeeze', 'unsqueeze', 'resize', 'resize_as', 'tril', 'triu',
'chunk', 'split', 'split_with_sizes', 'repeat', 'expand', 'zero_', 'round',
'chunk', 'split', 'split_with_sizes', 'repeat', 'expand', 'zero_',
'eq_', 'ne_', 'add', '__radd__', 'sum', 'conj', 'sin', 'cos', 'mul', 'sinh',
'cosh', '__rmul__', 'sgn'] + separate_complex_tests