mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[ATen] Fix CUDA reduction warp shuffle order (#164790)"
This reverts commit 36371b8ec7a1baed255c18451b2c716386a54c95. Reverted https://github.com/pytorch/pytorch/pull/164790 on behalf of https://github.com/clee2000 due to was reverted due to failing internal tests after merge D84992607 ([comment](https://github.com/pytorch/pytorch/pull/164790#issuecomment-3420373755))
This commit is contained in:
@ -653,14 +653,8 @@ struct ReduceOp {
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
// Intra-warp reduction, fix CUDA to have offset decreasing for better numerics
|
||||
// matching Triton, etc.
|
||||
// todo for AMD
|
||||
#ifdef USE_ROCM
|
||||
|
||||
for (int offset = 1; offset < dim_x; offset <<= 1) {
|
||||
#else
|
||||
for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
|
||||
#endif
|
||||
#pragma unroll
|
||||
for (int i = 0; i < output_vec_size; i++) {
|
||||
arg_t other = ops.warp_shfl_down(value[i], offset);
|
||||
|
@ -466,11 +466,7 @@ struct ReduceJitOp {
|
||||
|
||||
__syncthreads();
|
||||
|
||||
#ifdef USE_ROCM
|
||||
for (int offset = 1; offset < dim_x; offset <<= 1) {
|
||||
#else
|
||||
for (int offset = dim_x >> 1; offset > 0; offset >>= 1) {
|
||||
#endif
|
||||
#pragma unroll
|
||||
for (int i = 0; i < output_vec_size; i++) {
|
||||
arg_t other = reducer::warp_shfl_down(value[i], offset);
|
||||
|
@ -220,8 +220,6 @@ def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs)
|
||||
(torch.bfloat16, torch.ops.aten.reflection_pad2d_backward.default): 5e-3,
|
||||
(torch.float16, torch.ops.aten.reflection_pad3d_backward.default): 5e-3,
|
||||
(torch.bfloat16, torch.ops.aten.reflection_pad3d_backward.default): 5e-2,
|
||||
(torch.float16, torch.ops.aten._batch_norm_with_update.default): 2e-7,
|
||||
(torch.bfloat16, torch.ops.aten._batch_norm_with_update.default): 2e-7,
|
||||
# see https://github.com/pytorch/pytorch/pull/96264
|
||||
(torch.float16, torch.ops.aten.mv.default): 1e-5,
|
||||
(torch.bfloat16, torch.ops.aten.mv.default): 1e-5,
|
||||
@ -297,7 +295,6 @@ def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs):
|
||||
rtol, atol = tol_table[(decomp.dtype, op)]
|
||||
else:
|
||||
rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype)
|
||||
|
||||
test_case.assertEqual(
|
||||
orig,
|
||||
decomp,
|
||||
|
Reference in New Issue
Block a user