diff --git a/aten/src/ATen/native/cuda/ForeachReduceOp.cu b/aten/src/ATen/native/cuda/ForeachReduceOp.cu index 04b7c12e9a1a..7c2a389351a2 100644 --- a/aten/src/ATen/native/cuda/ForeachReduceOp.cu +++ b/aten/src/ATen/native/cuda/ForeachReduceOp.cu @@ -68,8 +68,8 @@ struct LpMaxFunctor { T vals[kILP]; T r_x[kILP]; for (int64_t i = 0; i < kILP; i++) { - vals[i] = T(-INFINITY); - r_x[i] = T(-INFINITY); + vals[i] = T(std::numeric_limits::lowest()); + r_x[i] = T(std::numeric_limits::lowest()); } if (n % kILP == 0 && (chunk_size & kILP) == 0 && is_aligned(x)) { @@ -96,7 +96,7 @@ struct LpMaxFunctor { } } - auto val = T(-INFINITY); + auto val = T(std::numeric_limits::lowest()); for (int i = 0; i < kILP; i++) { val = max_propagate_nan(val, vals[i]); } @@ -118,7 +118,7 @@ __global__ void lpmax_cleanup( __shared__ T vals[512]; const T* output_this_tensor = output_per_tensor + blockIdx.x * max_chunks_per_tensor; - T val = -INFINITY; + T val = std::numeric_limits::lowest(); for (size_t i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) { val = max_propagate_nan(val, output_this_tensor[i]); } @@ -130,21 +130,11 @@ __global__ void lpmax_cleanup( std::vector foreach_tensor_max_cuda(TensorList tensors) { check_foreach_api_restrictions(tensors); - // we currently use -INF as the identity value to compare against, which - // does not work for int8, int16, nor bool. Fall back to slow path here. - const bool has_small_int_or_bool = - std::any_of(tensors.begin(), tensors.end(), [](const auto& t) { - const auto scalar_type = t.scalar_type(); - return scalar_type == at::ScalarType::Short || - scalar_type == at::ScalarType::Char || - scalar_type == at::ScalarType::Bool; - }); - if (!can_use_fast_route(tensors) || has_small_int_or_bool) { + if (!can_use_fast_route(tensors)) { return foreach_tensor_max_slow(tensors); } - // for parity with max in ReduceAllOps.cpp, though I think max(empty) should - // eventually be allowed. + // for parity with max in ReduceAllOps.cpp, as max(empty) is ??? TORCH_CHECK( std::all_of( tensors.begin(), diff --git a/aten/src/ATen/native/cuda/block_reduce.cuh b/aten/src/ATen/native/cuda/block_reduce.cuh index c1e003ca8e53..df757a11761b 100644 --- a/aten/src/ATen/native/cuda/block_reduce.cuh +++ b/aten/src/ATen/native/cuda/block_reduce.cuh @@ -103,7 +103,7 @@ __inline__ __device__ T BlockReduceMax(T val, T* shared) { shared[wid] = val; } __syncthreads(); - val = (tid < B::Warps()) ? shared[lid] : T(-INFINITY); + val = (tid < B::Warps()) ? shared[lid] : T(std::numeric_limits::lowest()); if (wid == 0) { val = WarpReduceMax(val); } diff --git a/test/test_foreach.py b/test/test_foreach.py index 58595b628dc7..8465d538187c 100644 --- a/test/test_foreach.py +++ b/test/test_foreach.py @@ -1015,7 +1015,7 @@ class TestForeach(TestCase): def test_foreach_reduce_large_input(self, device, dtype, op): # test inputs larger than kChunkSize = 65536 N = 65536 * 2 - disable_fastpath = dtype in (torch.int8, torch.int16, torch.bool) + disable_fastpath = False kwargs = {} if op.name == "_foreach_norm": ord = 2 diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 4e17fcd5d277..9a83bf5b0038 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -9380,7 +9380,7 @@ class foreach_max_sample_func(foreach_inputs_sample_func): return [] def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype): - return dtype in (torch.int8, torch.int16, torch.bool) + return False class foreach_norm_sample_func(foreach_inputs_sample_func): @@ -11125,6 +11125,7 @@ foreach_reduce_op_db: List[ForeachFuncInfo] = [ supports_inplace_autograd=True, supports_forward_ad=True, decorators=( + # no complex support for ordering ops like max DecorateInfo( unittest.expectedFailure, "TestForeach",