mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Update avg_pool2d
to use Metal kernel when ceil_mode=True
(#161011)
Fixes #160743
The MPS impl of `avg_pool2d` seems to only give incorrect results when `ceil_mode=True`. I wrote a performance measurement script (0ee6e58643/avg_pool_mps/perf_2d.py
) which tests a bunch of different cases and also marks the cases where MPS and CPU results do not match.
I found that if I update `avg_pool2d` to use the new Metal kernel in all cases, that fixes all the mismatches, but it also decreases performance for some of the `ceil_mode=False` cases. So I opted to only run the new Metal kernel when `ceil_mode=True`, which does not significantly decrease performance in any of the cases tested.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161011
Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
d228a776e9
commit
121afd6a8f
@ -1,5 +1,6 @@
|
||||
#include <ATen/native/mps/kernels/Pooling.h>
|
||||
#include <c10/metal/atomic.h>
|
||||
#include <c10/metal/utils.h>
|
||||
#include <metal_array>
|
||||
#include <metal_stdlib>
|
||||
|
||||
@ -523,6 +524,53 @@ void avg_pool_3d_input_iter(
|
||||
*output = value_sum / static_cast<T>(divisor);
|
||||
}
|
||||
|
||||
// Iterates through all the input elements that this kernel needs to
|
||||
// apply max to. Specialized for 2 pooling dimensions.
|
||||
template <typename T>
|
||||
void avg_pool_2d_input_iter(
|
||||
constant T* input,
|
||||
device T* output,
|
||||
constant int32_t* input_sizes,
|
||||
constant int32_t* input_strides,
|
||||
thread int32_t (&pooling_dim_indices)[3],
|
||||
constant int32_t* kernel_size,
|
||||
constant int32_t* stride,
|
||||
constant int32_t* padding,
|
||||
bool count_include_pad,
|
||||
bool has_divisor_override,
|
||||
int32_t divisor_override) {
|
||||
auto bounds0 = get_avg_pool_input_iter_bounds<0>(
|
||||
input_sizes,
|
||||
pooling_dim_indices,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
count_include_pad);
|
||||
auto bounds1 = get_avg_pool_input_iter_bounds<1>(
|
||||
input_sizes,
|
||||
pooling_dim_indices,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
count_include_pad);
|
||||
|
||||
opmath_t<T> value_sum = 0;
|
||||
opmath_t<T> divisor = has_divisor_override
|
||||
? divisor_override
|
||||
: (bounds0.count) * (bounds1.count);
|
||||
|
||||
for (auto i0 = bounds0.start; i0 < bounds0.end; i0++) {
|
||||
auto offset0 = input_strides[0] * i0;
|
||||
|
||||
for (auto i1 = bounds1.start; i1 < bounds1.end; i1++) {
|
||||
auto offset1 = input_strides[1] * i1;
|
||||
auto input_value = input[offset0 + offset1];
|
||||
value_sum += static_cast<opmath_t<T>>(input_value);
|
||||
}
|
||||
}
|
||||
*output = static_cast<T>(value_sum / divisor);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void avg_pool_backward_3d_input_iter(
|
||||
device AtomicType_t<T>* grad_input,
|
||||
@ -619,18 +667,33 @@ kernel void avg_pool(
|
||||
input_sizes += leading_dims;
|
||||
input_strides += leading_dims;
|
||||
|
||||
avg_pool_3d_input_iter<T>(
|
||||
input,
|
||||
output,
|
||||
input_sizes,
|
||||
input_strides,
|
||||
pooling_dim_indices,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
params.count_include_pad,
|
||||
params.has_divisor_override,
|
||||
params.divisor_override);
|
||||
if (pooling_dims == 3) {
|
||||
avg_pool_3d_input_iter<T>(
|
||||
input,
|
||||
output,
|
||||
input_sizes,
|
||||
input_strides,
|
||||
pooling_dim_indices,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
params.count_include_pad,
|
||||
params.has_divisor_override,
|
||||
params.divisor_override);
|
||||
} else if (pooling_dims == 2) {
|
||||
avg_pool_2d_input_iter<T>(
|
||||
input,
|
||||
output,
|
||||
input_sizes,
|
||||
input_strides,
|
||||
pooling_dim_indices,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
params.count_include_pad,
|
||||
params.has_divisor_override,
|
||||
params.divisor_override);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -1137,17 +1137,30 @@ TORCH_IMPL_FUNC(avg_pool2d_out_mps)
|
||||
bool count_include_pad,
|
||||
std::optional<int64_t> divisor_override,
|
||||
const Tensor& output) {
|
||||
mps::avg_pool2d_template(input,
|
||||
output,
|
||||
std::nullopt,
|
||||
{kH, kW},
|
||||
{dH, dW},
|
||||
{padH, padW},
|
||||
{1, 1},
|
||||
ceil_mode,
|
||||
count_include_pad,
|
||||
divisor_override,
|
||||
"avg_pool2d");
|
||||
if (ceil_mode) {
|
||||
mps::avg_pool_out_mps_template(output,
|
||||
input,
|
||||
{kH, kW},
|
||||
{dH, dW},
|
||||
{padH, padW},
|
||||
ceil_mode,
|
||||
count_include_pad,
|
||||
divisor_override,
|
||||
/*pooling_dims=*/2,
|
||||
"avg_pool3d");
|
||||
} else {
|
||||
mps::avg_pool2d_template(input,
|
||||
output,
|
||||
std::nullopt,
|
||||
{kH, kW},
|
||||
{dH, dW},
|
||||
{padH, padW},
|
||||
{1, 1},
|
||||
ceil_mode,
|
||||
count_include_pad,
|
||||
divisor_override,
|
||||
"avg_pool2d");
|
||||
}
|
||||
}
|
||||
|
||||
TORCH_IMPL_FUNC(avg_pool2d_backward_out_mps)
|
||||
|
@ -738,6 +738,33 @@ class TestAvgPool(TestCaseMPS):
|
||||
padding=(0, 1), stride=2)
|
||||
self.assertFalse(torch.isnan(y).any())
|
||||
|
||||
# Test some cases for avg_pool2d which used to mismatch CPU results.
|
||||
# Addresses this issue: https://github.com/pytorch/pytorch/issues/160743
|
||||
def test_avg_pool2d_ceil_mode_mismatch(self):
|
||||
sizes = [
|
||||
(4, 2, 3),
|
||||
(5, 2, 3),
|
||||
(50, 2, 3),
|
||||
(4, 1, 2, 3),
|
||||
(4, 4, 2, 3),
|
||||
(2, 2, 4, 6),
|
||||
(5, 40, 60),
|
||||
(2, 2, 40, 60),
|
||||
]
|
||||
|
||||
kwargs = dict(kernel_size=[1, 3],
|
||||
stride=[2, 3],
|
||||
ceil_mode=True,
|
||||
divisor_override=7)
|
||||
|
||||
for input_size in sizes:
|
||||
model = torch.nn.AvgPool2d(**kwargs)
|
||||
x = torch.arange(math.prod(input_size), dtype=torch.float).reshape(input_size)
|
||||
out_cpu = model(x)
|
||||
out_mps = model(x.to("mps"))
|
||||
msg = f'{input_size=}, {kwargs=}'
|
||||
self.assertEqual(out_mps, out_cpu, msg=msg)
|
||||
|
||||
|
||||
class TestMPS(TestCaseMPS):
|
||||
def test_exp(self, device="mps", dtype=torch.float):
|
||||
|
Reference in New Issue
Block a user