mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Fix numerical instability for norm (#129352)"
This reverts commit 66340e67515cd3592bda6bdd9bfe2ffa22fe7413. Reverted https://github.com/pytorch/pytorch/pull/129352 on behalf of https://github.com/atalman due to Breaks Internal CI ([comment](https://github.com/pytorch/pytorch/pull/129352#issuecomment-2379989485))
This commit is contained in:
@ -239,38 +239,22 @@ static void norm_kernel_tensor_iterator_impl(
|
||||
|
||||
using Vec = Vectorized<scalar_t>;
|
||||
using fVec = Vectorized<acc_t>;
|
||||
fVec acc_vec{acc_t(0)};
|
||||
acc_t buffer[fVec::size()];
|
||||
auto inner_reduction = [&buffer](scalar_t* inner_self_data, int64_t inner_size) -> acc_t {
|
||||
fVec acc_vec{acc_t(0)};
|
||||
int64_t d = 0;
|
||||
for (; d < inner_size - (inner_size % Vec::size()); d += Vec::size()) {
|
||||
Vec data_vec = Vec::loadu(inner_self_data + d);
|
||||
norm_two_reduce_step(acc_vec, data_vec);
|
||||
}
|
||||
acc_vec.store(buffer);
|
||||
for (int j = 1; j < fVec::size(); j++) {
|
||||
buffer[0] = buffer[0] + buffer[j];
|
||||
}
|
||||
for (; d < inner_size; d++) {
|
||||
acc_t data_val = acc_t(inner_self_data[d]);
|
||||
buffer[0] += data_val * data_val;
|
||||
}
|
||||
return buffer[0];
|
||||
};
|
||||
|
||||
// Use group reduction to avoid overflow.
|
||||
// See https://github.com/pytorch/pytorch/pull/123416
|
||||
int64_t group_size = 32768L;
|
||||
int64_t group_n = (size + group_size - 1) / group_size;
|
||||
scalar_t* inner_self_data = self_data;
|
||||
int64_t inner_size = group_size;
|
||||
double result = 0;
|
||||
for (int64_t g = 0; g < group_n; g++) {
|
||||
inner_size = (g * inner_size + group_size) > size ? (size - g * inner_size) : group_size;
|
||||
result += inner_reduction(inner_self_data, inner_size);
|
||||
inner_self_data += inner_size;
|
||||
int64_t d = 0;
|
||||
for (; d < size - (size % Vec::size()); d += Vec::size()) {
|
||||
Vec data_vec = Vec::loadu(self_data + d);
|
||||
norm_two_reduce_step(acc_vec, data_vec);
|
||||
}
|
||||
result_data[0] = scalar_t(std::sqrt(result));
|
||||
acc_vec.store(buffer);
|
||||
for (int j = 1; j < fVec::size(); j++) {
|
||||
buffer[0] = buffer[0] + buffer[j];
|
||||
}
|
||||
for (; d < size; d++) {
|
||||
acc_t data_val = acc_t(self_data[d]);
|
||||
buffer[0] += data_val * data_val;
|
||||
}
|
||||
result_data[0] = scalar_t(std::sqrt(buffer[0]));
|
||||
});
|
||||
});
|
||||
} else {
|
||||
|
@ -486,17 +486,6 @@ class TestReductions(TestCase):
|
||||
y2 = op(x2, dim=-1)
|
||||
self.assertEqual(y, y2)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(torch.float, torch.bfloat16)
|
||||
def test_reduction_lastdim_overflow(self, device, dtype):
|
||||
x1 = torch.ones((1, 32, 224, 224, 160), device=device, dtype=torch.double)
|
||||
x2 = torch.ones((1, 32, 224, 224, 160), device=device, dtype=dtype)
|
||||
ops = [torch.norm, torch.linalg.vector_norm]
|
||||
for op in ops:
|
||||
y1 = op(x1)
|
||||
y2 = op(x2)
|
||||
self.assertEqual(y1.to(dtype), y2)
|
||||
|
||||
@skipIfNoSciPy
|
||||
@dtypes(torch.float32, torch.double, torch.complex64, torch.complex128)
|
||||
def test_logsumexp(self, device, dtype):
|
||||
|
Reference in New Issue
Block a user