Revert "[ATen] Fix CUDA reduction warp shuffle order (#164790)"

This reverts commit 36371b8ec7a1baed255c18451b2c716386a54c95.

Reverted https://github.com/pytorch/pytorch/pull/164790 on behalf of https://github.com/clee2000 due to was reverted due to failing internal tests after merge D84992607 ([comment](https://github.com/pytorch/pytorch/pull/164790#issuecomment-3420373755))
This commit is contained in:
PyTorch MergeBot
2025-10-20 03:06:52 +00:00
parent 47804ce467
commit 602ace5eb4
3 changed files with 1 additions and 14 deletions

View File

@ -653,14 +653,8 @@ struct ReduceOp {
}
__syncthreads();
// Intra-warp reduction, fix CUDA to have offset decreasing for better numerics
// matching Triton, etc.
// todo for AMD
#ifdef USE_ROCM
for (int offset = 1; offset < dim_x; offset <<= 1) {
#else
for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
#endif
#pragma unroll
for (int i = 0; i < output_vec_size; i++) {
arg_t other = ops.warp_shfl_down(value[i], offset);

View File

@ -466,11 +466,7 @@ struct ReduceJitOp {
__syncthreads();
#ifdef USE_ROCM
for (int offset = 1; offset < dim_x; offset <<= 1) {
#else
for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
#endif
#pragma unroll
for (int i = 0; i < output_vec_size; i++) {
arg_t other = reducer::warp_shfl_down(value[i], offset);

View File

@ -220,8 +220,6 @@ def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs)
(torch.bfloat16, torch.ops.aten.reflection_pad2d_backward.default): 5e-3,
(torch.float16, torch.ops.aten.reflection_pad3d_backward.default): 5e-3,
(torch.bfloat16, torch.ops.aten.reflection_pad3d_backward.default): 5e-2,
(torch.float16, torch.ops.aten._batch_norm_with_update.default): 2e-7,
(torch.bfloat16, torch.ops.aten._batch_norm_with_update.default): 2e-7,
# see https://github.com/pytorch/pytorch/pull/96264
(torch.float16, torch.ops.aten.mv.default): 1e-5,
(torch.bfloat16, torch.ops.aten.mv.default): 1e-5,
@ -297,7 +295,6 @@ def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs):
rtol, atol = tol_table[(decomp.dtype, op)]
else:
rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype)
test_case.assertEqual(
orig,
decomp,