diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h b/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h index 8521fb26b4ce..ba46dbef9db9 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h @@ -363,12 +363,19 @@ template <> Vectorized> inline operator/(const Vectorized> Vectorized>::reciprocal() const{ - //re + im*i = (a + bi) / (c + di) - //re = (ac + bd)/abs_2() = c/abs_2() - //im = (bc - ad)/abs_2() = d/abs_2() - const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0); - auto c_d = _mm256_xor_pd(sign_mask, values); //c -d - return _mm256_div_pd(c_d, abs_2_()); + // TODO: The vectorized implementation requires special handling for the case where real number/imag number is 0/Inf/NaN. + // //re + im*i = (a + bi) / (c + di) + // //re = (ac + bd)/abs_2() = c/abs_2() + // //im = (bc - ad)/abs_2() = d/abs_2() + // const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0); + // auto c_d = _mm256_xor_pd(sign_mask, values); //c -d + // return _mm256_div_pd(c_d, abs_2_()); + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = c10::complex(1) / tmp[i]; + } + return loadu(tmp); } inline Vectorized> Vectorized>::atan() const { diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h b/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h index 806972d8be00..c715b3d1dd23 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h @@ -398,12 +398,19 @@ template <> Vectorized> inline operator/(const Vectorized> Vectorized>::reciprocal() const { - //re + im*i = (a + bi) / (c + di) - //re = (ac + bd)/abs_2() = c/abs_2() - //im = (bc - ad)/abs_2() = d/abs_2() - const __m256 sign_mask = _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); - auto c_d = _mm256_xor_ps(sign_mask, values); //c -d - return _mm256_div_ps(c_d, abs_2_()); + // TODO: The vectorized implementation requires special handling for the case where real number/imag number is 0/Inf/NaN. + // //re + im*i = (a + bi) / (c + di) + // //re = (ac + bd)/abs_2() = c/abs_2() + // //im = (bc - ad)/abs_2() = d/abs_2() + // const __m256 sign_mask = _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); + // auto c_d = _mm256_xor_ps(sign_mask, values); //c -d + // return _mm256_div_ps(c_d, abs_2_()); + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = c10::complex(1) / tmp[i]; + } + return loadu(tmp); } inline Vectorized> Vectorized>::atan() const { diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h b/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h index e93e4e6420ad..5ac00b45b1fb 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h @@ -433,12 +433,19 @@ template <> Vectorized> inline operator/(const Vectorized> Vectorized>::reciprocal() const{ - //re + im*i = (a + bi) / (c + di) - //re = (ac + bd)/abs_2() = c/abs_2() - //im = (bc - ad)/abs_2() = d/abs_2() - const __m512d sign_mask = _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); - auto c_d = _mm512_xor_pd(sign_mask, values); //c -d - return _mm512_div_pd(c_d, abs_2_()); + // TODO: The vectorized implementation requires special handling for the case where real number/imag number is 0/Inf/NaN. + // //re + im*i = (a + bi) / (c + di) + // //re = (ac + bd)/abs_2() = c/abs_2() + // //im = (bc - ad)/abs_2() = d/abs_2() + // const __m512d sign_mask = _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); + // auto c_d = _mm512_xor_pd(sign_mask, values); //c -d + // return _mm512_div_pd(c_d, abs_2_()); + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = c10::complex(1) / tmp[i]; + } + return loadu(tmp); } inline Vectorized> Vectorized>::atan() const { diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h b/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h index 5c3a4333a513..74fdfbef8d0f 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h @@ -936,13 +936,20 @@ template <> Vectorized> inline operator/(const Vectorized> Vectorized>::reciprocal() const { - //re + im*i = (a + bi) / (c + di) - //re = (ac + bd)/abs_2() = c/abs_2() - //im = (bc - ad)/abs_2() = d/abs_2() - const __m512 sign_mask = _mm512_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, - 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); - auto c_d = _mm512_xor_ps(sign_mask, values); //c -d - return _mm512_div_ps(c_d, abs_2_()); + // TODO: The vectorized implementation requires special handling for the case where real number/imag number is 0/Inf/NaN. + // //re + im*i = (a + bi) / (c + di) + // //re = (ac + bd)/abs_2() = c/abs_2() + // //im = (bc - ad)/abs_2() = d/abs_2() + // const __m512 sign_mask = _mm512_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, + // 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0); + // auto c_d = _mm512_xor_ps(sign_mask, values); //c -d + // return _mm512_div_ps(c_d, abs_2_()); + __at_align__ c10::complex tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = c10::complex(1) / tmp[i]; + } + return loadu(tmp); } inline Vectorized> Vectorized>::atan() const { diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp index a90406836cf4..23154b636add 100644 --- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp @@ -58,10 +58,7 @@ static void sigmoid_kernel(TensorIteratorBase& iter) { return (static_cast(1) / (static_cast(1) + std::exp((-a)))); }, [=](Vectorized a) { - a = Vectorized(static_cast(0)) - a; - a = a.exp(); - a = Vectorized(static_cast(1)) + a; - a = a.reciprocal(); + a = (Vectorized(static_cast(1)) + a.neg().exp()).reciprocal(); return a; }); }); diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 7045fe7ce9d2..fb1f8e23f612 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -20081,9 +20081,9 @@ op_db: List[OpInfo] = [ skips=( # Reference: https://github.com/pytorch/pytorch/issues/56012 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', - dtypes=[torch.complex64, torch.cdouble]), + dtypes=[torch.complex64, torch.cdouble], device_type='cuda'), DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', - dtypes=[torch.chalf, torch.complex64, torch.cdouble])), + dtypes=[torch.chalf, torch.complex64, torch.cdouble], device_type='cuda')), dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.complex32, torch.bool, torch.half, torch.bfloat16), supports_forward_ad=True, @@ -22579,10 +22579,10 @@ python_ref_db = [ # Reference: https://github.com/pytorch/pytorch/issues/56012 DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal', - dtypes=[torch.complex64, torch.cdouble]), + dtypes=[torch.complex64, torch.cdouble], device_type='cuda'), DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large', - dtypes=[torch.chalf, torch.complex64, torch.cdouble]) + dtypes=[torch.chalf, torch.complex64, torch.cdouble], device_type='cuda') ), ), ElementwiseUnaryPythonRefInfo(