Fix numerical instability for norm (#129352)

Fixes #123645
When the reduce size is large, reducing directly may exceed the range that FP32 can represent, resulting in incorrect results.
Reducing in group and using double as the intermediate cumulative type can avoid exceeding the representation range.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129352
Approved by: https://github.com/jgong5, https://github.com/peterbell10
This commit is contained in:
CaoE
2024-09-27 00:51:30 +00:00
committed by PyTorch MergeBot
parent adc77a9b7f
commit 66340e6751
2 changed files with 41 additions and 14 deletions

View File

@ -239,22 +239,38 @@ static void norm_kernel_tensor_iterator_impl(
using Vec = Vectorized<scalar_t>;
using fVec = Vectorized<acc_t>;
fVec acc_vec{acc_t(0)};
acc_t buffer[fVec::size()];
int64_t d = 0;
for (; d < size - (size % Vec::size()); d += Vec::size()) {
Vec data_vec = Vec::loadu(self_data + d);
norm_two_reduce_step(acc_vec, data_vec);
auto inner_reduction = [&buffer](scalar_t* inner_self_data, int64_t inner_size) -> acc_t {
fVec acc_vec{acc_t(0)};
int64_t d = 0;
for (; d < inner_size - (inner_size % Vec::size()); d += Vec::size()) {
Vec data_vec = Vec::loadu(inner_self_data + d);
norm_two_reduce_step(acc_vec, data_vec);
}
acc_vec.store(buffer);
for (int j = 1; j < fVec::size(); j++) {
buffer[0] = buffer[0] + buffer[j];
}
for (; d < inner_size; d++) {
acc_t data_val = acc_t(inner_self_data[d]);
buffer[0] += data_val * data_val;
}
return buffer[0];
};
// Use group reduction to avoid overflow.
// See https://github.com/pytorch/pytorch/pull/123416
int64_t group_size = 32768L;
int64_t group_n = (size + group_size - 1) / group_size;
scalar_t* inner_self_data = self_data;
int64_t inner_size = group_size;
double result = 0;
for (int64_t g = 0; g < group_n; g++) {
inner_size = (g * inner_size + group_size) > size ? (size - g * inner_size) : group_size;
result += inner_reduction(inner_self_data, inner_size);
inner_self_data += inner_size;
}
acc_vec.store(buffer);
for (int j = 1; j < fVec::size(); j++) {
buffer[0] = buffer[0] + buffer[j];
}
for (; d < size; d++) {
acc_t data_val = acc_t(self_data[d]);
buffer[0] += data_val * data_val;
}
result_data[0] = scalar_t(std::sqrt(buffer[0]));
result_data[0] = scalar_t(std::sqrt(result));
});
});
} else {

View File

@ -486,6 +486,17 @@ class TestReductions(TestCase):
y2 = op(x2, dim=-1)
self.assertEqual(y, y2)
@onlyCPU
@dtypes(torch.float, torch.bfloat16)
def test_reduction_lastdim_overflow(self, device, dtype):
x1 = torch.ones((1, 32, 224, 224, 160), device=device, dtype=torch.double)
x2 = torch.ones((1, 32, 224, 224, 160), device=device, dtype=dtype)
ops = [torch.norm, torch.linalg.vector_norm]
for op in ops:
y1 = op(x1)
y2 = op(x2)
self.assertEqual(y1.to(dtype), y2)
@skipIfNoSciPy
@dtypes(torch.float32, torch.double, torch.complex64, torch.complex128)
def test_logsumexp(self, device, dtype):