Add Half support for aminmax on CPU (#106853)

Add Half support for aminmax on CPU.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106853
Approved by: https://github.com/cpuhrsch
This commit is contained in:
CaoE
2023-10-23 17:43:47 +00:00
committed by PyTorch MergeBot
parent ad4ccf9689
commit 4b324a8717
5 changed files with 5 additions and 6 deletions

View File

@ -199,7 +199,7 @@ static void aminmax_allreduce_kernel(
}
);
} else {
AT_DISPATCH_ALL_TYPES_AND(kBFloat16, input.scalar_type(), "aminmax_cpu", [&] {
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "aminmax_cpu", [&] {
using Vec = Vectorized<opmath_type<scalar_t>>;
using scalar_t_pair = std::pair<scalar_t, scalar_t>;
reduce_all_impl_vec_two_outputs<scalar_t>(

View File

@ -185,7 +185,7 @@ static void aminmax_kernel(
return;
}
AT_DISPATCH_ALL_TYPES_AND(ScalarType::Bool, self.scalar_type(), "aminmax_cpu", [&] {
AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "aminmax_cpu", [&] {
compare_base_kernel<scalar_t, scalar_t>(min_result, max_result, self, wrap_dim, keepdim, [&] (
scalar_t* min_result_data, scalar_t* max_result_data,
const scalar_t* self_data, auto self_dim_stride) {

View File

@ -79,7 +79,7 @@ def mps_ops_grad_modifier(ops):
'cdist': [torch.float32],
'masked.scatter': [torch.float16, torch.float32],
'index_fill': [torch.float16, torch.float32], # missing `aten::_unique`.
'aminmax': [torch.float32],
'aminmax': [torch.float32, torch.float16],
'polar': [torch.float32],
# Correctness issues

View File

@ -1207,7 +1207,7 @@ class TestReductions(TestCase):
self._test_minmax_helper(torch.amax, np.amax, device, dtype)
@onlyNativeDeviceTypes
@dtypes(torch.float, torch.double)
@dtypes(torch.float, torch.double, torch.bfloat16, torch.half)
@dtypesIfCUDA(torch.half, torch.float, torch.bfloat16)
def test_aminmax(self, device, dtype):

View File

@ -12268,8 +12268,7 @@ op_db: List[OpInfo] = [
supports_fwgrad_bwgrad=True),
OpInfo('aminmax',
ref=lambda x, dim=None, keepdim=False: (np.amin(x, axis=dim, keepdims=keepdim), np.amax(x, axis=dim, keepdims=keepdim)),
dtypes=all_types_and(torch.bool),
dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16),
dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
decorators=(onlyNativeDeviceTypes,),
supports_autograd=False,
sample_inputs_func=sample_inputs_aminmax,