Revert "Fix numerical instability for norm (#129352)"

This reverts commit 66340e67515cd3592bda6bdd9bfe2ffa22fe7413.

Reverted https://github.com/pytorch/pytorch/pull/129352 on behalf of https://github.com/atalman due to Breaks Internal CI ([comment](https://github.com/pytorch/pytorch/pull/129352#issuecomment-2379989485))
This commit is contained in:
PyTorch MergeBot
2024-09-27 20:18:47 +00:00
parent d55eef5c59
commit f21b471978
2 changed files with 14 additions and 41 deletions

View File

@ -239,38 +239,22 @@ static void norm_kernel_tensor_iterator_impl(
using Vec = Vectorized<scalar_t>;
using fVec = Vectorized<acc_t>;
fVec acc_vec{acc_t(0)};
acc_t buffer[fVec::size()];
auto inner_reduction = [&buffer](scalar_t* inner_self_data, int64_t inner_size) -> acc_t {
fVec acc_vec{acc_t(0)};
int64_t d = 0;
for (; d < inner_size - (inner_size % Vec::size()); d += Vec::size()) {
Vec data_vec = Vec::loadu(inner_self_data + d);
norm_two_reduce_step(acc_vec, data_vec);
}
acc_vec.store(buffer);
for (int j = 1; j < fVec::size(); j++) {
buffer[0] = buffer[0] + buffer[j];
}
for (; d < inner_size; d++) {
acc_t data_val = acc_t(inner_self_data[d]);
buffer[0] += data_val * data_val;
}
return buffer[0];
};
// Use group reduction to avoid overflow.
// See https://github.com/pytorch/pytorch/pull/123416
int64_t group_size = 32768L;
int64_t group_n = (size + group_size - 1) / group_size;
scalar_t* inner_self_data = self_data;
int64_t inner_size = group_size;
double result = 0;
for (int64_t g = 0; g < group_n; g++) {
inner_size = (g * inner_size + group_size) > size ? (size - g * inner_size) : group_size;
result += inner_reduction(inner_self_data, inner_size);
inner_self_data += inner_size;
int64_t d = 0;
for (; d < size - (size % Vec::size()); d += Vec::size()) {
Vec data_vec = Vec::loadu(self_data + d);
norm_two_reduce_step(acc_vec, data_vec);
}
result_data[0] = scalar_t(std::sqrt(result));
acc_vec.store(buffer);
for (int j = 1; j < fVec::size(); j++) {
buffer[0] = buffer[0] + buffer[j];
}
for (; d < size; d++) {
acc_t data_val = acc_t(self_data[d]);
buffer[0] += data_val * data_val;
}
result_data[0] = scalar_t(std::sqrt(buffer[0]));
});
});
} else {

View File

@ -486,17 +486,6 @@ class TestReductions(TestCase):
y2 = op(x2, dim=-1)
self.assertEqual(y, y2)
@onlyCPU
@dtypes(torch.float, torch.bfloat16)
def test_reduction_lastdim_overflow(self, device, dtype):
x1 = torch.ones((1, 32, 224, 224, 160), device=device, dtype=torch.double)
x2 = torch.ones((1, 32, 224, 224, 160), device=device, dtype=dtype)
ops = [torch.norm, torch.linalg.vector_norm]
for op in ops:
y1 = op(x1)
y2 = op(x2)
self.assertEqual(y1.to(dtype), y2)
@skipIfNoSciPy
@dtypes(torch.float32, torch.double, torch.complex64, torch.complex128)
def test_logsumexp(self, device, dtype):