fix sigmoid for torch.complex datatypes on CPU (#140391)

Fix https://github.com/pytorch/pytorch/issues/135777.
This issue is caused by the lack of special handling of the case where the real number/imag number is 0/Inf/NaN in the vectorized implementation of `reciprocal`. For correctness, I temporarily fallback the implementation of `reciprocal` to scalar implementation.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140391
Approved by: https://github.com/mingfeima, https://github.com/Skylion007
ghstack dependencies: #140358
This commit is contained in:
Sun, Jiayi
2025-01-19 18:53:42 -08:00
committed by PyTorch MergeBot
parent 507bf65c6a
commit c922ccb7c4
6 changed files with 58 additions and 33 deletions

View File

@ -363,12 +363,19 @@ template <> Vectorized<c10::complex<double>> inline operator/(const Vectorized<c
// reciprocal. Implement this here so we can use multiplication.
inline Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::reciprocal() const{
//re + im*i = (a + bi) / (c + di)
//re = (ac + bd)/abs_2() = c/abs_2()
//im = (bc - ad)/abs_2() = d/abs_2()
const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0);
auto c_d = _mm256_xor_pd(sign_mask, values); //c -d
return _mm256_div_pd(c_d, abs_2_());
// TODO: The vectorized implementation requires special handling for the case where real number/imag number is 0/Inf/NaN.
// //re + im*i = (a + bi) / (c + di)
// //re = (ac + bd)/abs_2() = c/abs_2()
// //im = (bc - ad)/abs_2() = d/abs_2()
// const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0);
// auto c_d = _mm256_xor_pd(sign_mask, values); //c -d
// return _mm256_div_pd(c_d, abs_2_());
__at_align__ c10::complex<double> tmp[size()];
store(tmp);
for (const auto i : c10::irange(size())) {
tmp[i] = c10::complex<double>(1) / tmp[i];
}
return loadu(tmp);
}
inline Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::atan() const {

View File

@ -398,12 +398,19 @@ template <> Vectorized<c10::complex<float>> inline operator/(const Vectorized<c1
// reciprocal. Implement this here so we can use multiplication.
inline Vectorized<c10::complex<float>> Vectorized<c10::complex<float>>::reciprocal() const {
//re + im*i = (a + bi) / (c + di)
//re = (ac + bd)/abs_2() = c/abs_2()
//im = (bc - ad)/abs_2() = d/abs_2()
const __m256 sign_mask = _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
auto c_d = _mm256_xor_ps(sign_mask, values); //c -d
return _mm256_div_ps(c_d, abs_2_());
// TODO: The vectorized implementation requires special handling for the case where real number/imag number is 0/Inf/NaN.
// //re + im*i = (a + bi) / (c + di)
// //re = (ac + bd)/abs_2() = c/abs_2()
// //im = (bc - ad)/abs_2() = d/abs_2()
// const __m256 sign_mask = _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
// auto c_d = _mm256_xor_ps(sign_mask, values); //c -d
// return _mm256_div_ps(c_d, abs_2_());
__at_align__ c10::complex<float> tmp[size()];
store(tmp);
for (const auto i : c10::irange(size())) {
tmp[i] = c10::complex<float>(1) / tmp[i];
}
return loadu(tmp);
}
inline Vectorized<c10::complex<float>> Vectorized<c10::complex<float>>::atan() const {

View File

@ -433,12 +433,19 @@ template <> Vectorized<c10::complex<double>> inline operator/(const Vectorized<c
// reciprocal. Implement this here so we can use multiplication.
inline Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::reciprocal() const{
//re + im*i = (a + bi) / (c + di)
//re = (ac + bd)/abs_2() = c/abs_2()
//im = (bc - ad)/abs_2() = d/abs_2()
const __m512d sign_mask = _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
auto c_d = _mm512_xor_pd(sign_mask, values); //c -d
return _mm512_div_pd(c_d, abs_2_());
// TODO: The vectorized implementation requires special handling for the case where real number/imag number is 0/Inf/NaN.
// //re + im*i = (a + bi) / (c + di)
// //re = (ac + bd)/abs_2() = c/abs_2()
// //im = (bc - ad)/abs_2() = d/abs_2()
// const __m512d sign_mask = _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
// auto c_d = _mm512_xor_pd(sign_mask, values); //c -d
// return _mm512_div_pd(c_d, abs_2_());
__at_align__ c10::complex<double> tmp[size()];
store(tmp);
for (const auto i : c10::irange(size())) {
tmp[i] = c10::complex<double>(1) / tmp[i];
}
return loadu(tmp);
}
inline Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::atan() const {

View File

@ -936,13 +936,20 @@ template <> Vectorized<c10::complex<float>> inline operator/(const Vectorized<c1
// reciprocal. Implement this here so we can use multiplication.
inline Vectorized<c10::complex<float>> Vectorized<c10::complex<float>>::reciprocal() const {
//re + im*i = (a + bi) / (c + di)
//re = (ac + bd)/abs_2() = c/abs_2()
//im = (bc - ad)/abs_2() = d/abs_2()
const __m512 sign_mask = _mm512_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0,
0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
auto c_d = _mm512_xor_ps(sign_mask, values); //c -d
return _mm512_div_ps(c_d, abs_2_());
// TODO: The vectorized implementation requires special handling for the case where real number/imag number is 0/Inf/NaN.
// //re + im*i = (a + bi) / (c + di)
// //re = (ac + bd)/abs_2() = c/abs_2()
// //im = (bc - ad)/abs_2() = d/abs_2()
// const __m512 sign_mask = _mm512_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0,
// 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
// auto c_d = _mm512_xor_ps(sign_mask, values); //c -d
// return _mm512_div_ps(c_d, abs_2_());
__at_align__ c10::complex<float> tmp[size()];
store(tmp);
for (const auto i : c10::irange(size())) {
tmp[i] = c10::complex<float>(1) / tmp[i];
}
return loadu(tmp);
}
inline Vectorized<c10::complex<float>> Vectorized<c10::complex<float>>::atan() const {

View File

@ -58,10 +58,7 @@ static void sigmoid_kernel(TensorIteratorBase& iter) {
return (static_cast<scalar_t>(1) / (static_cast<scalar_t>(1) + std::exp((-a))));
},
[=](Vectorized<scalar_t> a) {
a = Vectorized<scalar_t>(static_cast<scalar_t>(0)) - a;
a = a.exp();
a = Vectorized<scalar_t>(static_cast<scalar_t>(1)) + a;
a = a.reciprocal();
a = (Vectorized<scalar_t>(static_cast<scalar_t>(1)) + a.neg().exp()).reciprocal();
return a;
});
});

View File

@ -20081,9 +20081,9 @@ op_db: List[OpInfo] = [
skips=(
# Reference: https://github.com/pytorch/pytorch/issues/56012
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
dtypes=[torch.complex64, torch.cdouble]),
dtypes=[torch.complex64, torch.cdouble], device_type='cuda'),
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
dtypes=[torch.chalf, torch.complex64, torch.cdouble])),
dtypes=[torch.chalf, torch.complex64, torch.cdouble], device_type='cuda')),
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.complex32, torch.bool, torch.half, torch.bfloat16),
supports_forward_ad=True,
@ -22579,10 +22579,10 @@ python_ref_db = [
# Reference: https://github.com/pytorch/pytorch/issues/56012
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
'test_reference_numerics_extremal',
dtypes=[torch.complex64, torch.cdouble]),
dtypes=[torch.complex64, torch.cdouble], device_type='cuda'),
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
'test_reference_numerics_large',
dtypes=[torch.chalf, torch.complex64, torch.cdouble])
dtypes=[torch.chalf, torch.complex64, torch.cdouble], device_type='cuda')
),
),
ElementwiseUnaryPythonRefInfo(