mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ATen] Remove explicit casting of complex nansum during accumulation (#165494)
https://github.com/pytorch/pytorch/pull/164790 modifies aten to perform a different reduction order intra warp. However, this change exposed a large difference in a sum for complex32. Namely the case: ``` import torch a = torch.tensor([[ 4.82031250+7.34765625j, -3.37109375-1.9501953125j], [ 3.7832031250-2.43359375j, -6.07812500+5.32812500j]], dtype=torch.complex32, device='cuda:0') sum_out = torch.sum(a) nansum_out = torch.nansum(a) torch.testing.assert_close( sum_out, nansum_out, rtol=0, atol=0, ) ``` Here, the result of `sum` and `nansum` differed significantly by 1e-2. Further investigation showed that the explicit casting of b back to `arg_t` from `scalar_t` was the root cause. `arg_t` is the dtype of the accumulator, ComplexFloat, and `scalar_t` of the input dtype, ComplexHalf. When we cast in the reduction to the accumulator order, that means the input is still of ComplexHalf, which loses precision as it can store intermediate values. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165494 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
f58f301313
commit
712f54d453
@ -77,8 +77,8 @@ struct nansum_functor_complex {
|
||||
#if AT_USE_JITERATOR()
|
||||
void operator()(TensorIterator& iter) {
|
||||
std::string func = jiterator_stringify(
|
||||
arg_t combine(arg_t a, scalar_t b) {
|
||||
return a + (std::isnan(b) ? arg_t{0.} : arg_t{b});
|
||||
arg_t combine(arg_t a, arg_t b) {
|
||||
return a + (std::isnan(b) ? arg_t{0.} : b);
|
||||
}
|
||||
);
|
||||
jitted_gpu_reduce_kernel<nansum_name, scalar_t, scalar_t>(
|
||||
|
Reference in New Issue
Block a user