Allow int vals to go down the fastpath for _foreach_max (#127303)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127303
Approved by: https://github.com/albanD
ghstack dependencies: #127187
This commit is contained in:
Jane Xu
2024-05-29 08:18:19 -07:00
committed by PyTorch MergeBot
parent 601c5e085d
commit 05e99154ee
4 changed files with 10 additions and 19 deletions

View File

@ -68,8 +68,8 @@ struct LpMaxFunctor {
T vals[kILP];
T r_x[kILP];
for (int64_t i = 0; i < kILP; i++) {
vals[i] = T(-INFINITY);
r_x[i] = T(-INFINITY);
vals[i] = T(std::numeric_limits<T>::lowest());
r_x[i] = T(std::numeric_limits<T>::lowest());
}
if (n % kILP == 0 && (chunk_size & kILP) == 0 && is_aligned(x)) {
@ -96,7 +96,7 @@ struct LpMaxFunctor {
}
}
auto val = T(-INFINITY);
auto val = T(std::numeric_limits<T>::lowest());
for (int i = 0; i < kILP; i++) {
val = max_propagate_nan(val, vals[i]);
}
@ -118,7 +118,7 @@ __global__ void lpmax_cleanup(
__shared__ T vals[512];
const T* output_this_tensor =
output_per_tensor + blockIdx.x * max_chunks_per_tensor;
T val = -INFINITY;
T val = std::numeric_limits<T>::lowest();
for (size_t i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) {
val = max_propagate_nan(val, output_this_tensor[i]);
}
@ -130,21 +130,11 @@ __global__ void lpmax_cleanup(
std::vector<Tensor> foreach_tensor_max_cuda(TensorList tensors) {
check_foreach_api_restrictions(tensors);
// we currently use -INF as the identity value to compare against, which
// does not work for int8, int16, nor bool. Fall back to slow path here.
const bool has_small_int_or_bool =
std::any_of(tensors.begin(), tensors.end(), [](const auto& t) {
const auto scalar_type = t.scalar_type();
return scalar_type == at::ScalarType::Short ||
scalar_type == at::ScalarType::Char ||
scalar_type == at::ScalarType::Bool;
});
if (!can_use_fast_route(tensors) || has_small_int_or_bool) {
if (!can_use_fast_route(tensors)) {
return foreach_tensor_max_slow(tensors);
}
// for parity with max in ReduceAllOps.cpp, though I think max(empty) should
// eventually be allowed.
// for parity with max in ReduceAllOps.cpp, as max(empty) is ???
TORCH_CHECK(
std::all_of(
tensors.begin(),

View File

@ -103,7 +103,7 @@ __inline__ __device__ T BlockReduceMax(T val, T* shared) {
shared[wid] = val;
}
__syncthreads();
val = (tid < B::Warps()) ? shared[lid] : T(-INFINITY);
val = (tid < B::Warps()) ? shared[lid] : T(std::numeric_limits<T>::lowest());
if (wid == 0) {
val = WarpReduceMax(val);
}

View File

@ -1015,7 +1015,7 @@ class TestForeach(TestCase):
def test_foreach_reduce_large_input(self, device, dtype, op):
# test inputs larger than kChunkSize = 65536
N = 65536 * 2
disable_fastpath = dtype in (torch.int8, torch.int16, torch.bool)
disable_fastpath = False
kwargs = {}
if op.name == "_foreach_norm":
ord = 2

View File

@ -9380,7 +9380,7 @@ class foreach_max_sample_func(foreach_inputs_sample_func):
return []
def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype):
return dtype in (torch.int8, torch.int16, torch.bool)
return False
class foreach_norm_sample_func(foreach_inputs_sample_func):
@ -11125,6 +11125,7 @@ foreach_reduce_op_db: List[ForeachFuncInfo] = [
supports_inplace_autograd=True,
supports_forward_ad=True,
decorators=(
# no complex support for ordering ops like max
DecorateInfo(
unittest.expectedFailure,
"TestForeach",