Fix numerical instability for norm (#129352)

Fixes #123645
When the reduce size is large, reducing directly may exceed the range that FP32 can represent, resulting in incorrect results.
Reducing in group and using double as the intermediate cumulative type can avoid exceeding the representation range.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129352
Approved by: https://github.com/jgong5, https://github.com/peterbell10
This commit is contained in:
CaoE
2024-09-27 00:51:30 +00:00
committed by PyTorch MergeBot
parent adc77a9b7f
commit 66340e6751
2 changed files with 41 additions and 14 deletions

View File

@ -486,6 +486,17 @@ class TestReductions(TestCase):
y2 = op(x2, dim=-1)
self.assertEqual(y, y2)
@onlyCPU
@dtypes(torch.float, torch.bfloat16)
def test_reduction_lastdim_overflow(self, device, dtype):
x1 = torch.ones((1, 32, 224, 224, 160), device=device, dtype=torch.double)
x2 = torch.ones((1, 32, 224, 224, 160), device=device, dtype=dtype)
ops = [torch.norm, torch.linalg.vector_norm]
for op in ops:
y1 = op(x1)
y2 = op(x2)
self.assertEqual(y1.to(dtype), y2)
@skipIfNoSciPy
@dtypes(torch.float32, torch.double, torch.complex64, torch.complex128)
def test_logsumexp(self, device, dtype):