[MPS] Update avg_pool2d to use Metal kernel when ceil_mode=True (#161011)

Fixes #160743

The MPS impl of `avg_pool2d` seems to only give incorrect results when `ceil_mode=True`. I wrote a performance measurement script (0ee6e58643/avg_pool_mps/perf_2d.py) which tests a bunch of different cases and also marks the cases where MPS and CPU results do not match.

I found that if I update `avg_pool2d` to use the new Metal kernel in all cases, that fixes all the mismatches, but it also decreases performance for some of the `ceil_mode=False` cases. So I opted to only run the new Metal kernel when  `ceil_mode=True`, which does not significantly decrease performance in any of the cases tested.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161011
Approved by: https://github.com/malfet
This commit is contained in:
Kurt Mohler
2025-08-20 12:50:26 -05:00
committed by PyTorch MergeBot
parent d228a776e9
commit 121afd6a8f
3 changed files with 126 additions and 23 deletions

View File

@ -1,5 +1,6 @@
#include <ATen/native/mps/kernels/Pooling.h>
#include <c10/metal/atomic.h>
#include <c10/metal/utils.h>
#include <metal_array>
#include <metal_stdlib>
@ -523,6 +524,53 @@ void avg_pool_3d_input_iter(
*output = value_sum / static_cast<T>(divisor);
}
// Iterates through all the input elements that this kernel needs to
// apply max to. Specialized for 2 pooling dimensions.
template <typename T>
void avg_pool_2d_input_iter(
constant T* input,
device T* output,
constant int32_t* input_sizes,
constant int32_t* input_strides,
thread int32_t (&pooling_dim_indices)[3],
constant int32_t* kernel_size,
constant int32_t* stride,
constant int32_t* padding,
bool count_include_pad,
bool has_divisor_override,
int32_t divisor_override) {
auto bounds0 = get_avg_pool_input_iter_bounds<0>(
input_sizes,
pooling_dim_indices,
kernel_size,
stride,
padding,
count_include_pad);
auto bounds1 = get_avg_pool_input_iter_bounds<1>(
input_sizes,
pooling_dim_indices,
kernel_size,
stride,
padding,
count_include_pad);
opmath_t<T> value_sum = 0;
opmath_t<T> divisor = has_divisor_override
? divisor_override
: (bounds0.count) * (bounds1.count);
for (auto i0 = bounds0.start; i0 < bounds0.end; i0++) {
auto offset0 = input_strides[0] * i0;
for (auto i1 = bounds1.start; i1 < bounds1.end; i1++) {
auto offset1 = input_strides[1] * i1;
auto input_value = input[offset0 + offset1];
value_sum += static_cast<opmath_t<T>>(input_value);
}
}
*output = static_cast<T>(value_sum / divisor);
}
template <typename T>
void avg_pool_backward_3d_input_iter(
device AtomicType_t<T>* grad_input,
@ -619,18 +667,33 @@ kernel void avg_pool(
input_sizes += leading_dims;
input_strides += leading_dims;
avg_pool_3d_input_iter<T>(
input,
output,
input_sizes,
input_strides,
pooling_dim_indices,
kernel_size,
stride,
padding,
params.count_include_pad,
params.has_divisor_override,
params.divisor_override);
if (pooling_dims == 3) {
avg_pool_3d_input_iter<T>(
input,
output,
input_sizes,
input_strides,
pooling_dim_indices,
kernel_size,
stride,
padding,
params.count_include_pad,
params.has_divisor_override,
params.divisor_override);
} else if (pooling_dims == 2) {
avg_pool_2d_input_iter<T>(
input,
output,
input_sizes,
input_strides,
pooling_dim_indices,
kernel_size,
stride,
padding,
params.count_include_pad,
params.has_divisor_override,
params.divisor_override);
}
}
template <typename T>

View File

@ -1137,17 +1137,30 @@ TORCH_IMPL_FUNC(avg_pool2d_out_mps)
bool count_include_pad,
std::optional<int64_t> divisor_override,
const Tensor& output) {
mps::avg_pool2d_template(input,
output,
std::nullopt,
{kH, kW},
{dH, dW},
{padH, padW},
{1, 1},
ceil_mode,
count_include_pad,
divisor_override,
"avg_pool2d");
if (ceil_mode) {
mps::avg_pool_out_mps_template(output,
input,
{kH, kW},
{dH, dW},
{padH, padW},
ceil_mode,
count_include_pad,
divisor_override,
/*pooling_dims=*/2,
"avg_pool3d");
} else {
mps::avg_pool2d_template(input,
output,
std::nullopt,
{kH, kW},
{dH, dW},
{padH, padW},
{1, 1},
ceil_mode,
count_include_pad,
divisor_override,
"avg_pool2d");
}
}
TORCH_IMPL_FUNC(avg_pool2d_backward_out_mps)

View File

@ -738,6 +738,33 @@ class TestAvgPool(TestCaseMPS):
padding=(0, 1), stride=2)
self.assertFalse(torch.isnan(y).any())
# Test some cases for avg_pool2d which used to mismatch CPU results.
# Addresses this issue: https://github.com/pytorch/pytorch/issues/160743
def test_avg_pool2d_ceil_mode_mismatch(self):
sizes = [
(4, 2, 3),
(5, 2, 3),
(50, 2, 3),
(4, 1, 2, 3),
(4, 4, 2, 3),
(2, 2, 4, 6),
(5, 40, 60),
(2, 2, 40, 60),
]
kwargs = dict(kernel_size=[1, 3],
stride=[2, 3],
ceil_mode=True,
divisor_override=7)
for input_size in sizes:
model = torch.nn.AvgPool2d(**kwargs)
x = torch.arange(math.prod(input_size), dtype=torch.float).reshape(input_size)
out_cpu = model(x)
out_mps = model(x.to("mps"))
msg = f'{input_size=}, {kwargs=}'
self.assertEqual(out_mps, out_cpu, msg=msg)
class TestMPS(TestCaseMPS):
def test_exp(self, device="mps", dtype=torch.float):