mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
https://github.com/pytorch/pytorch/pull/164790 modifies aten to perform a different reduction order intra warp. However, this change exposed a large difference in a sum for complex32. Namely the case: ``` import torch a = torch.tensor([[ 4.82031250+7.34765625j, -3.37109375-1.9501953125j], [ 3.7832031250-2.43359375j, -6.07812500+5.32812500j]], dtype=torch.complex32, device='cuda:0') sum_out = torch.sum(a) nansum_out = torch.nansum(a) torch.testing.assert_close( sum_out, nansum_out, rtol=0, atol=0, ) ``` Here, the result of `sum` and `nansum` differed significantly by 1e-2. Further investigation showed that the explicit casting of b back to `arg_t` from `scalar_t` was the root cause. `arg_t` is the dtype of the accumulator, ComplexFloat, and `scalar_t` of the input dtype, ComplexHalf. When we cast in the reduction to the accumulator order, that means the input is still of ComplexHalf, which loses precision as it can store intermediate values. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165494 Approved by: https://github.com/ngimel