mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
fix sigmoid for torch.complex datatypes on CPU (#140391)
Fix https://github.com/pytorch/pytorch/issues/135777. This issue is caused by the lack of special handling of the case where the real number/imag number is 0/Inf/NaN in the vectorized implementation of `reciprocal`. For correctness, I temporarily fallback the implementation of `reciprocal` to scalar implementation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/140391 Approved by: https://github.com/mingfeima, https://github.com/Skylion007 ghstack dependencies: #140358
This commit is contained in:
committed by
PyTorch MergeBot
parent
507bf65c6a
commit
c922ccb7c4
@ -363,12 +363,19 @@ template <> Vectorized<c10::complex<double>> inline operator/(const Vectorized<c
|
||||
|
||||
// reciprocal. Implement this here so we can use multiplication.
|
||||
inline Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::reciprocal() const{
|
||||
//re + im*i = (a + bi) / (c + di)
|
||||
//re = (ac + bd)/abs_2() = c/abs_2()
|
||||
//im = (bc - ad)/abs_2() = d/abs_2()
|
||||
const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0);
|
||||
auto c_d = _mm256_xor_pd(sign_mask, values); //c -d
|
||||
return _mm256_div_pd(c_d, abs_2_());
|
||||
// TODO: The vectorized implementation requires special handling for the case where real number/imag number is 0/Inf/NaN.
|
||||
// //re + im*i = (a + bi) / (c + di)
|
||||
// //re = (ac + bd)/abs_2() = c/abs_2()
|
||||
// //im = (bc - ad)/abs_2() = d/abs_2()
|
||||
// const __m256d sign_mask = _mm256_setr_pd(0.0, -0.0, 0.0, -0.0);
|
||||
// auto c_d = _mm256_xor_pd(sign_mask, values); //c -d
|
||||
// return _mm256_div_pd(c_d, abs_2_());
|
||||
__at_align__ c10::complex<double> tmp[size()];
|
||||
store(tmp);
|
||||
for (const auto i : c10::irange(size())) {
|
||||
tmp[i] = c10::complex<double>(1) / tmp[i];
|
||||
}
|
||||
return loadu(tmp);
|
||||
}
|
||||
|
||||
inline Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::atan() const {
|
||||
|
@ -398,12 +398,19 @@ template <> Vectorized<c10::complex<float>> inline operator/(const Vectorized<c1
|
||||
|
||||
// reciprocal. Implement this here so we can use multiplication.
|
||||
inline Vectorized<c10::complex<float>> Vectorized<c10::complex<float>>::reciprocal() const {
|
||||
//re + im*i = (a + bi) / (c + di)
|
||||
//re = (ac + bd)/abs_2() = c/abs_2()
|
||||
//im = (bc - ad)/abs_2() = d/abs_2()
|
||||
const __m256 sign_mask = _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
|
||||
auto c_d = _mm256_xor_ps(sign_mask, values); //c -d
|
||||
return _mm256_div_ps(c_d, abs_2_());
|
||||
// TODO: The vectorized implementation requires special handling for the case where real number/imag number is 0/Inf/NaN.
|
||||
// //re + im*i = (a + bi) / (c + di)
|
||||
// //re = (ac + bd)/abs_2() = c/abs_2()
|
||||
// //im = (bc - ad)/abs_2() = d/abs_2()
|
||||
// const __m256 sign_mask = _mm256_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
|
||||
// auto c_d = _mm256_xor_ps(sign_mask, values); //c -d
|
||||
// return _mm256_div_ps(c_d, abs_2_());
|
||||
__at_align__ c10::complex<float> tmp[size()];
|
||||
store(tmp);
|
||||
for (const auto i : c10::irange(size())) {
|
||||
tmp[i] = c10::complex<float>(1) / tmp[i];
|
||||
}
|
||||
return loadu(tmp);
|
||||
}
|
||||
|
||||
inline Vectorized<c10::complex<float>> Vectorized<c10::complex<float>>::atan() const {
|
||||
|
@ -433,12 +433,19 @@ template <> Vectorized<c10::complex<double>> inline operator/(const Vectorized<c
|
||||
|
||||
// reciprocal. Implement this here so we can use multiplication.
|
||||
inline Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::reciprocal() const{
|
||||
//re + im*i = (a + bi) / (c + di)
|
||||
//re = (ac + bd)/abs_2() = c/abs_2()
|
||||
//im = (bc - ad)/abs_2() = d/abs_2()
|
||||
const __m512d sign_mask = _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
|
||||
auto c_d = _mm512_xor_pd(sign_mask, values); //c -d
|
||||
return _mm512_div_pd(c_d, abs_2_());
|
||||
// TODO: The vectorized implementation requires special handling for the case where real number/imag number is 0/Inf/NaN.
|
||||
// //re + im*i = (a + bi) / (c + di)
|
||||
// //re = (ac + bd)/abs_2() = c/abs_2()
|
||||
// //im = (bc - ad)/abs_2() = d/abs_2()
|
||||
// const __m512d sign_mask = _mm512_setr_pd(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
|
||||
// auto c_d = _mm512_xor_pd(sign_mask, values); //c -d
|
||||
// return _mm512_div_pd(c_d, abs_2_());
|
||||
__at_align__ c10::complex<double> tmp[size()];
|
||||
store(tmp);
|
||||
for (const auto i : c10::irange(size())) {
|
||||
tmp[i] = c10::complex<double>(1) / tmp[i];
|
||||
}
|
||||
return loadu(tmp);
|
||||
}
|
||||
|
||||
inline Vectorized<c10::complex<double>> Vectorized<c10::complex<double>>::atan() const {
|
||||
|
@ -936,13 +936,20 @@ template <> Vectorized<c10::complex<float>> inline operator/(const Vectorized<c1
|
||||
|
||||
// reciprocal. Implement this here so we can use multiplication.
|
||||
inline Vectorized<c10::complex<float>> Vectorized<c10::complex<float>>::reciprocal() const {
|
||||
//re + im*i = (a + bi) / (c + di)
|
||||
//re = (ac + bd)/abs_2() = c/abs_2()
|
||||
//im = (bc - ad)/abs_2() = d/abs_2()
|
||||
const __m512 sign_mask = _mm512_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0,
|
||||
0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
|
||||
auto c_d = _mm512_xor_ps(sign_mask, values); //c -d
|
||||
return _mm512_div_ps(c_d, abs_2_());
|
||||
// TODO: The vectorized implementation requires special handling for the case where real number/imag number is 0/Inf/NaN.
|
||||
// //re + im*i = (a + bi) / (c + di)
|
||||
// //re = (ac + bd)/abs_2() = c/abs_2()
|
||||
// //im = (bc - ad)/abs_2() = d/abs_2()
|
||||
// const __m512 sign_mask = _mm512_setr_ps(0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0,
|
||||
// 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0);
|
||||
// auto c_d = _mm512_xor_ps(sign_mask, values); //c -d
|
||||
// return _mm512_div_ps(c_d, abs_2_());
|
||||
__at_align__ c10::complex<float> tmp[size()];
|
||||
store(tmp);
|
||||
for (const auto i : c10::irange(size())) {
|
||||
tmp[i] = c10::complex<float>(1) / tmp[i];
|
||||
}
|
||||
return loadu(tmp);
|
||||
}
|
||||
|
||||
inline Vectorized<c10::complex<float>> Vectorized<c10::complex<float>>::atan() const {
|
||||
|
@ -58,10 +58,7 @@ static void sigmoid_kernel(TensorIteratorBase& iter) {
|
||||
return (static_cast<scalar_t>(1) / (static_cast<scalar_t>(1) + std::exp((-a))));
|
||||
},
|
||||
[=](Vectorized<scalar_t> a) {
|
||||
a = Vectorized<scalar_t>(static_cast<scalar_t>(0)) - a;
|
||||
a = a.exp();
|
||||
a = Vectorized<scalar_t>(static_cast<scalar_t>(1)) + a;
|
||||
a = a.reciprocal();
|
||||
a = (Vectorized<scalar_t>(static_cast<scalar_t>(1)) + a.neg().exp()).reciprocal();
|
||||
return a;
|
||||
});
|
||||
});
|
||||
|
@ -20081,9 +20081,9 @@ op_db: List[OpInfo] = [
|
||||
skips=(
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/56012
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_extremal',
|
||||
dtypes=[torch.complex64, torch.cdouble]),
|
||||
dtypes=[torch.complex64, torch.cdouble], device_type='cuda'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs', 'test_reference_numerics_large',
|
||||
dtypes=[torch.chalf, torch.complex64, torch.cdouble])),
|
||||
dtypes=[torch.chalf, torch.complex64, torch.cdouble], device_type='cuda')),
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and_complex_and(torch.complex32, torch.bool, torch.half, torch.bfloat16),
|
||||
supports_forward_ad=True,
|
||||
@ -22579,10 +22579,10 @@ python_ref_db = [
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/56012
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
|
||||
'test_reference_numerics_extremal',
|
||||
dtypes=[torch.complex64, torch.cdouble]),
|
||||
dtypes=[torch.complex64, torch.cdouble], device_type='cuda'),
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestUnaryUfuncs',
|
||||
'test_reference_numerics_large',
|
||||
dtypes=[torch.chalf, torch.complex64, torch.cdouble])
|
||||
dtypes=[torch.chalf, torch.complex64, torch.cdouble], device_type='cuda')
|
||||
),
|
||||
),
|
||||
ElementwiseUnaryPythonRefInfo(
|
||||
|
Reference in New Issue
Block a user