mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Allow int vals to go down the fastpath for _foreach_max (#127303)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127303 Approved by: https://github.com/albanD ghstack dependencies: #127187
This commit is contained in:
committed by
PyTorch MergeBot
parent
601c5e085d
commit
05e99154ee
@ -68,8 +68,8 @@ struct LpMaxFunctor {
|
||||
T vals[kILP];
|
||||
T r_x[kILP];
|
||||
for (int64_t i = 0; i < kILP; i++) {
|
||||
vals[i] = T(-INFINITY);
|
||||
r_x[i] = T(-INFINITY);
|
||||
vals[i] = T(std::numeric_limits<T>::lowest());
|
||||
r_x[i] = T(std::numeric_limits<T>::lowest());
|
||||
}
|
||||
|
||||
if (n % kILP == 0 && (chunk_size & kILP) == 0 && is_aligned(x)) {
|
||||
@ -96,7 +96,7 @@ struct LpMaxFunctor {
|
||||
}
|
||||
}
|
||||
|
||||
auto val = T(-INFINITY);
|
||||
auto val = T(std::numeric_limits<T>::lowest());
|
||||
for (int i = 0; i < kILP; i++) {
|
||||
val = max_propagate_nan(val, vals[i]);
|
||||
}
|
||||
@ -118,7 +118,7 @@ __global__ void lpmax_cleanup(
|
||||
__shared__ T vals[512];
|
||||
const T* output_this_tensor =
|
||||
output_per_tensor + blockIdx.x * max_chunks_per_tensor;
|
||||
T val = -INFINITY;
|
||||
T val = std::numeric_limits<T>::lowest();
|
||||
for (size_t i = threadIdx.x; i < max_chunks_per_tensor; i += blockDim.x) {
|
||||
val = max_propagate_nan(val, output_this_tensor[i]);
|
||||
}
|
||||
@ -130,21 +130,11 @@ __global__ void lpmax_cleanup(
|
||||
|
||||
std::vector<Tensor> foreach_tensor_max_cuda(TensorList tensors) {
|
||||
check_foreach_api_restrictions(tensors);
|
||||
// we currently use -INF as the identity value to compare against, which
|
||||
// does not work for int8, int16, nor bool. Fall back to slow path here.
|
||||
const bool has_small_int_or_bool =
|
||||
std::any_of(tensors.begin(), tensors.end(), [](const auto& t) {
|
||||
const auto scalar_type = t.scalar_type();
|
||||
return scalar_type == at::ScalarType::Short ||
|
||||
scalar_type == at::ScalarType::Char ||
|
||||
scalar_type == at::ScalarType::Bool;
|
||||
});
|
||||
if (!can_use_fast_route(tensors) || has_small_int_or_bool) {
|
||||
if (!can_use_fast_route(tensors)) {
|
||||
return foreach_tensor_max_slow(tensors);
|
||||
}
|
||||
|
||||
// for parity with max in ReduceAllOps.cpp, though I think max(empty) should
|
||||
// eventually be allowed.
|
||||
// for parity with max in ReduceAllOps.cpp, as max(empty) is ???
|
||||
TORCH_CHECK(
|
||||
std::all_of(
|
||||
tensors.begin(),
|
||||
|
@ -103,7 +103,7 @@ __inline__ __device__ T BlockReduceMax(T val, T* shared) {
|
||||
shared[wid] = val;
|
||||
}
|
||||
__syncthreads();
|
||||
val = (tid < B::Warps()) ? shared[lid] : T(-INFINITY);
|
||||
val = (tid < B::Warps()) ? shared[lid] : T(std::numeric_limits<T>::lowest());
|
||||
if (wid == 0) {
|
||||
val = WarpReduceMax(val);
|
||||
}
|
||||
|
@ -1015,7 +1015,7 @@ class TestForeach(TestCase):
|
||||
def test_foreach_reduce_large_input(self, device, dtype, op):
|
||||
# test inputs larger than kChunkSize = 65536
|
||||
N = 65536 * 2
|
||||
disable_fastpath = dtype in (torch.int8, torch.int16, torch.bool)
|
||||
disable_fastpath = False
|
||||
kwargs = {}
|
||||
if op.name == "_foreach_norm":
|
||||
ord = 2
|
||||
|
@ -9380,7 +9380,7 @@ class foreach_max_sample_func(foreach_inputs_sample_func):
|
||||
return []
|
||||
|
||||
def _should_disable_fastpath(self, opinfo, rightmost_arg, rightmost_arg_type, dtype):
|
||||
return dtype in (torch.int8, torch.int16, torch.bool)
|
||||
return False
|
||||
|
||||
|
||||
class foreach_norm_sample_func(foreach_inputs_sample_func):
|
||||
@ -11125,6 +11125,7 @@ foreach_reduce_op_db: List[ForeachFuncInfo] = [
|
||||
supports_inplace_autograd=True,
|
||||
supports_forward_ad=True,
|
||||
decorators=(
|
||||
# no complex support for ordering ops like max
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestForeach",
|
||||
|
Reference in New Issue
Block a user