mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Revert "[ATen] Fix CUDA reduction warp shuffle order (#164790)"
This reverts commit 36371b8ec7a1baed255c18451b2c716386a54c95. Reverted https://github.com/pytorch/pytorch/pull/164790 on behalf of https://github.com/clee2000 due to was reverted due to failing internal tests after merge D84992607 ([comment](https://github.com/pytorch/pytorch/pull/164790#issuecomment-3420373755))
This commit is contained in:
		| @ -653,14 +653,8 @@ struct ReduceOp { | ||||
|     } | ||||
|  | ||||
|     __syncthreads(); | ||||
|     // Intra-warp reduction, fix CUDA to have offset decreasing for better numerics | ||||
|     // matching Triton, etc. | ||||
|     // todo for AMD | ||||
|     #ifdef USE_ROCM | ||||
|  | ||||
|     for (int offset = 1; offset < dim_x; offset <<= 1) { | ||||
|     #else | ||||
|     for (int offset = dim_x >> 1; offset > 0; offset >>= 1) { | ||||
|     #endif | ||||
|       #pragma unroll | ||||
|       for (int i = 0; i < output_vec_size; i++) { | ||||
|         arg_t other = ops.warp_shfl_down(value[i], offset); | ||||
|  | ||||
| @ -466,11 +466,7 @@ struct ReduceJitOp { | ||||
|  | ||||
|     __syncthreads(); | ||||
|  | ||||
|     #ifdef USE_ROCM | ||||
|     for (int offset = 1; offset < dim_x; offset <<= 1) { | ||||
|     #else | ||||
|     for (int offset = dim_x >> 1; offset > 0; offset >>= 1) { | ||||
|     #endif | ||||
|       #pragma unroll | ||||
|       for (int i = 0; i < output_vec_size; i++) { | ||||
|         arg_t other = reducer::warp_shfl_down(value[i], offset); | ||||
|  | ||||
| @ -220,8 +220,6 @@ def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs) | ||||
|         (torch.bfloat16, torch.ops.aten.reflection_pad2d_backward.default): 5e-3, | ||||
|         (torch.float16, torch.ops.aten.reflection_pad3d_backward.default): 5e-3, | ||||
|         (torch.bfloat16, torch.ops.aten.reflection_pad3d_backward.default): 5e-2, | ||||
|         (torch.float16, torch.ops.aten._batch_norm_with_update.default): 2e-7, | ||||
|         (torch.bfloat16, torch.ops.aten._batch_norm_with_update.default): 2e-7, | ||||
|         # see https://github.com/pytorch/pytorch/pull/96264 | ||||
|         (torch.float16, torch.ops.aten.mv.default): 1e-5, | ||||
|         (torch.bfloat16, torch.ops.aten.mv.default): 1e-5, | ||||
| @ -297,7 +295,6 @@ def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs): | ||||
|         rtol, atol = tol_table[(decomp.dtype, op)] | ||||
|     else: | ||||
|         rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype) | ||||
|  | ||||
|     test_case.assertEqual( | ||||
|         orig, | ||||
|         decomp, | ||||
|  | ||||
		Reference in New Issue
	
	Block a user