mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix numerical instability for norm (#129352)
Fixes #123645 When the reduce size is large, reducing directly may exceed the range that FP32 can represent, resulting in incorrect results. Reducing in group and using double as the intermediate cumulative type can avoid exceeding the representation range. Pull Request resolved: https://github.com/pytorch/pytorch/pull/129352 Approved by: https://github.com/jgong5, https://github.com/peterbell10
This commit is contained in:
@ -239,22 +239,38 @@ static void norm_kernel_tensor_iterator_impl(
|
||||
|
||||
using Vec = Vectorized<scalar_t>;
|
||||
using fVec = Vectorized<acc_t>;
|
||||
fVec acc_vec{acc_t(0)};
|
||||
acc_t buffer[fVec::size()];
|
||||
int64_t d = 0;
|
||||
for (; d < size - (size % Vec::size()); d += Vec::size()) {
|
||||
Vec data_vec = Vec::loadu(self_data + d);
|
||||
norm_two_reduce_step(acc_vec, data_vec);
|
||||
auto inner_reduction = [&buffer](scalar_t* inner_self_data, int64_t inner_size) -> acc_t {
|
||||
fVec acc_vec{acc_t(0)};
|
||||
int64_t d = 0;
|
||||
for (; d < inner_size - (inner_size % Vec::size()); d += Vec::size()) {
|
||||
Vec data_vec = Vec::loadu(inner_self_data + d);
|
||||
norm_two_reduce_step(acc_vec, data_vec);
|
||||
}
|
||||
acc_vec.store(buffer);
|
||||
for (int j = 1; j < fVec::size(); j++) {
|
||||
buffer[0] = buffer[0] + buffer[j];
|
||||
}
|
||||
for (; d < inner_size; d++) {
|
||||
acc_t data_val = acc_t(inner_self_data[d]);
|
||||
buffer[0] += data_val * data_val;
|
||||
}
|
||||
return buffer[0];
|
||||
};
|
||||
|
||||
// Use group reduction to avoid overflow.
|
||||
// See https://github.com/pytorch/pytorch/pull/123416
|
||||
int64_t group_size = 32768L;
|
||||
int64_t group_n = (size + group_size - 1) / group_size;
|
||||
scalar_t* inner_self_data = self_data;
|
||||
int64_t inner_size = group_size;
|
||||
double result = 0;
|
||||
for (int64_t g = 0; g < group_n; g++) {
|
||||
inner_size = (g * inner_size + group_size) > size ? (size - g * inner_size) : group_size;
|
||||
result += inner_reduction(inner_self_data, inner_size);
|
||||
inner_self_data += inner_size;
|
||||
}
|
||||
acc_vec.store(buffer);
|
||||
for (int j = 1; j < fVec::size(); j++) {
|
||||
buffer[0] = buffer[0] + buffer[j];
|
||||
}
|
||||
for (; d < size; d++) {
|
||||
acc_t data_val = acc_t(self_data[d]);
|
||||
buffer[0] += data_val * data_val;
|
||||
}
|
||||
result_data[0] = scalar_t(std::sqrt(buffer[0]));
|
||||
result_data[0] = scalar_t(std::sqrt(result));
|
||||
});
|
||||
});
|
||||
} else {
|
||||
|
@ -486,6 +486,17 @@ class TestReductions(TestCase):
|
||||
y2 = op(x2, dim=-1)
|
||||
self.assertEqual(y, y2)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(torch.float, torch.bfloat16)
|
||||
def test_reduction_lastdim_overflow(self, device, dtype):
|
||||
x1 = torch.ones((1, 32, 224, 224, 160), device=device, dtype=torch.double)
|
||||
x2 = torch.ones((1, 32, 224, 224, 160), device=device, dtype=dtype)
|
||||
ops = [torch.norm, torch.linalg.vector_norm]
|
||||
for op in ops:
|
||||
y1 = op(x1)
|
||||
y2 = op(x2)
|
||||
self.assertEqual(y1.to(dtype), y2)
|
||||
|
||||
@skipIfNoSciPy
|
||||
@dtypes(torch.float32, torch.double, torch.complex64, torch.complex128)
|
||||
def test_logsumexp(self, device, dtype):
|
||||
|
Reference in New Issue
Block a user