mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add Half support for aminmax on CPU (#106853)
Add Half support for aminmax on CPU. Pull Request resolved: https://github.com/pytorch/pytorch/pull/106853 Approved by: https://github.com/cpuhrsch
This commit is contained in:
@ -199,7 +199,7 @@ static void aminmax_allreduce_kernel(
|
||||
}
|
||||
);
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES_AND(kBFloat16, input.scalar_type(), "aminmax_cpu", [&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "aminmax_cpu", [&] {
|
||||
using Vec = Vectorized<opmath_type<scalar_t>>;
|
||||
using scalar_t_pair = std::pair<scalar_t, scalar_t>;
|
||||
reduce_all_impl_vec_two_outputs<scalar_t>(
|
||||
|
@ -185,7 +185,7 @@ static void aminmax_kernel(
|
||||
return;
|
||||
}
|
||||
|
||||
AT_DISPATCH_ALL_TYPES_AND(ScalarType::Bool, self.scalar_type(), "aminmax_cpu", [&] {
|
||||
AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "aminmax_cpu", [&] {
|
||||
compare_base_kernel<scalar_t, scalar_t>(min_result, max_result, self, wrap_dim, keepdim, [&] (
|
||||
scalar_t* min_result_data, scalar_t* max_result_data,
|
||||
const scalar_t* self_data, auto self_dim_stride) {
|
||||
|
@ -79,7 +79,7 @@ def mps_ops_grad_modifier(ops):
|
||||
'cdist': [torch.float32],
|
||||
'masked.scatter': [torch.float16, torch.float32],
|
||||
'index_fill': [torch.float16, torch.float32], # missing `aten::_unique`.
|
||||
'aminmax': [torch.float32],
|
||||
'aminmax': [torch.float32, torch.float16],
|
||||
'polar': [torch.float32],
|
||||
|
||||
# Correctness issues
|
||||
|
@ -1207,7 +1207,7 @@ class TestReductions(TestCase):
|
||||
self._test_minmax_helper(torch.amax, np.amax, device, dtype)
|
||||
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(torch.float, torch.double)
|
||||
@dtypes(torch.float, torch.double, torch.bfloat16, torch.half)
|
||||
@dtypesIfCUDA(torch.half, torch.float, torch.bfloat16)
|
||||
def test_aminmax(self, device, dtype):
|
||||
|
||||
|
@ -12268,8 +12268,7 @@ op_db: List[OpInfo] = [
|
||||
supports_fwgrad_bwgrad=True),
|
||||
OpInfo('aminmax',
|
||||
ref=lambda x, dim=None, keepdim=False: (np.amin(x, axis=dim, keepdims=keepdim), np.amax(x, axis=dim, keepdims=keepdim)),
|
||||
dtypes=all_types_and(torch.bool),
|
||||
dtypesIfCUDA=all_types_and(torch.bool, torch.float16, torch.bfloat16),
|
||||
dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
|
||||
decorators=(onlyNativeDeviceTypes,),
|
||||
supports_autograd=False,
|
||||
sample_inputs_func=sample_inputs_aminmax,
|
||||
|
Reference in New Issue
Block a user