[MPS] Improve performance of max_pool3d (#157875)

To check how the changes from this PR affect performance, I wrote a script here: 55ef32a127/max_pool_mps/perf.py.

Before this PR, I get this:

```
===================
max_pool3d
===================
0: 0.013105 ms, max_pool3d, (3, 2, 2, 2), {'kernel_size': 2}
1: 0.038003 ms, max_pool3d, (3, 10, 10, 10), {'kernel_size': 5}
2: 0.212963 ms, max_pool3d, (3, 100, 100, 100), {'kernel_size': 5}
3: 1.224645 ms, max_pool3d, (3, 200, 200, 200), {'kernel_size': 5}
4: 7.317867 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 4, 'padding': 1}
5: 34.679233 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 50, 'padding': 20}
6: 34.626383 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 50, 'padding': 20, 'dilation': 1}
7: 44.835892 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 50, 'padding': 20, 'dilation': 1, 'stride': 40}
8: 0.083579 ms, max_pool3d, (10, 10, 10, 10, 10), {'kernel_size': 2}
9: 0.936575 ms, max_pool3d, (10, 10, 30, 30, 30), {'kernel_size': 2}
10: 5.329883 ms, max_pool3d, (10, 10, 50, 50, 50), {'kernel_size': 2}
11: 11.713617 ms, max_pool3d, (10, 10, 70, 70, 70), {'kernel_size': 2}
12: 25.450454 ms, max_pool3d, (10, 10, 90, 90, 90), {'kernel_size': 2}
13: 0.058375 ms, max_pool3d, (10, 10, 10, 10, 10), {'kernel_size': 2, 'dilation': 2}
14: 3.757558 ms, max_pool3d, (10, 10, 50, 50, 50), {'kernel_size': 2, 'dilation': 2}
15: 33.451588 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 2, 'dilation': 2}
```

After this PR, I get this:

```
===================
max_pool3d
===================
0: 0.007202 ms, max_pool3d, (3, 2, 2, 2), {'kernel_size': 2}
1: 0.018596 ms, max_pool3d, (3, 10, 10, 10), {'kernel_size': 5}
2: 0.130717 ms, max_pool3d, (3, 100, 100, 100), {'kernel_size': 5}
3: 0.966795 ms, max_pool3d, (3, 200, 200, 200), {'kernel_size': 5}
4: 4.095804 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 4, 'padding': 1}
5: 12.833446 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 50, 'padding': 20}
6: 12.859346 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 50, 'padding': 20, 'dilation': 1}
7: 14.080529 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 50, 'padding': 20, 'dilation': 1, 'stride': 40}
8: 0.029283 ms, max_pool3d, (10, 10, 10, 10, 10), {'kernel_size': 2}
9: 0.175700 ms, max_pool3d, (10, 10, 30, 30, 30), {'kernel_size': 2}
10: 0.742750 ms, max_pool3d, (10, 10, 50, 50, 50), {'kernel_size': 2}
11: 1.939596 ms, max_pool3d, (10, 10, 70, 70, 70), {'kernel_size': 2}
12: 4.074821 ms, max_pool3d, (10, 10, 90, 90, 90), {'kernel_size': 2}
13: 0.028425 ms, max_pool3d, (10, 10, 10, 10, 10), {'kernel_size': 2, 'dilation': 2}
14: 0.384375 ms, max_pool3d, (10, 10, 50, 50, 50), {'kernel_size': 2, 'dilation': 2}
15: 2.623346 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 2, 'dilation': 2}
```

Every case is improved.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157875
Approved by: https://github.com/malfet
This commit is contained in:
Kurt Mohler
2025-07-15 17:04:00 -05:00
committed by PyTorch MergeBot
parent 66c9bc5062
commit 1b88da1cac
3 changed files with 214 additions and 169 deletions

View File

@ -5,29 +5,30 @@
// maximum allowed pooling dimensions is N-2, because the input may have up to 2
// leading dimensions that are not pooled. To support up to 3-D pooling, N=5 is
// the default.
template <unsigned N = 5>
template <unsigned N = 5, typename idx_type_t = int32_t>
struct PoolingParams {
int32_t dims;
int32_t pooling_dims;
::c10::metal::array<int64_t, N> input_sizes;
::c10::metal::array<int64_t, N> input_strides;
::c10::metal::array<int64_t, N> output_sizes;
::c10::metal::array<int64_t, N> output_strides;
::c10::metal::array<int64_t, N> indices_sizes;
::c10::metal::array<int64_t, N> indices_strides;
::c10::metal::array<int64_t, N - 2> kernel_size;
::c10::metal::array<int64_t, N - 2> stride;
::c10::metal::array<int64_t, N - 2> padding;
::c10::metal::array<int64_t, N - 2> dilation;
::c10::metal::array<idx_type_t, N> input_sizes;
::c10::metal::array<idx_type_t, N> input_strides;
::c10::metal::array<idx_type_t, N> output_sizes;
::c10::metal::array<idx_type_t, N> output_strides;
::c10::metal::array<idx_type_t, N> indices_sizes;
::c10::metal::array<idx_type_t, N> indices_strides;
::c10::metal::array<idx_type_t, N - 2> kernel_size;
::c10::metal::array<idx_type_t, N - 2> stride;
::c10::metal::array<idx_type_t, N - 2> padding;
::c10::metal::array<idx_type_t, N - 2> dilation;
bool return_indices;
};
template <unsigned N = 5>
template <unsigned N = 5, typename idx_type_t = int32_t>
struct PoolingBackwardParams {
int32_t dims;
int32_t pooling_dims;
::c10::metal::array<int64_t, N> grad_input_sizes;
::c10::metal::array<int64_t, N> grad_input_strides;
::c10::metal::array<int64_t, N> grad_output_sizes;
::c10::metal::array<int64_t, N> grad_output_strides;
::c10::metal::array<int64_t, N> indices_strides;
::c10::metal::array<idx_type_t, N> grad_input_sizes;
::c10::metal::array<idx_type_t, N> grad_input_strides;
::c10::metal::array<idx_type_t, N> grad_output_sizes;
::c10::metal::array<idx_type_t, N> grad_output_strides;
::c10::metal::array<idx_type_t, N> indices_strides;
};

View File

@ -6,6 +6,28 @@
using namespace metal;
using namespace c10::metal;
template <typename T>
struct IterBounds {
T start;
T end;
};
template <int32_t dim>
IterBounds<int32_t> get_input_iter_bounds(
constant int32_t* input_sizes,
thread int32_t (&pooling_dim_indices)[3],
constant int32_t* kernel_size,
constant int32_t* stride,
constant int32_t* padding,
constant int32_t* dilation) {
auto d = dilation[dim];
auto start = stride[dim] * pooling_dim_indices[dim] - padding[dim];
auto end = min(start + kernel_size[dim] * d, input_sizes[dim]);
auto start_correction = d * ((-start - 1 + d) / d);
start += start < 0 ? start_correction : 0;
return IterBounds<int32_t>{start, end};
}
// Iterates through all the input elements that this kernel needs to
// apply max to. Specialized for 3 pooling dimensions.
// TODO: Support any number of pooling dims
@ -14,82 +36,62 @@ void max_pool_3d_input_iter(
constant T* input,
device T* output,
device int64_t* indices,
constant int64_t* input_sizes,
constant int64_t* input_strides,
device int64_t* work_pooling_dim_indices,
constant int64_t* kernel_size,
constant int64_t* stride,
constant int64_t* padding,
constant int64_t* dilation) {
int64_t o0 = work_pooling_dim_indices[0];
int64_t o1 = work_pooling_dim_indices[1];
int64_t o2 = work_pooling_dim_indices[2];
constant int32_t* input_sizes,
constant int32_t* input_strides,
thread int32_t (&pooling_dim_indices)[3],
constant int32_t* kernel_size,
constant int32_t* stride,
constant int32_t* padding,
constant int32_t* dilation,
bool return_indices) {
auto bounds0 = get_input_iter_bounds<0>(
input_sizes, pooling_dim_indices, kernel_size, stride, padding, dilation);
auto bounds1 = get_input_iter_bounds<1>(
input_sizes, pooling_dim_indices, kernel_size, stride, padding, dilation);
auto bounds2 = get_input_iter_bounds<2>(
input_sizes, pooling_dim_indices, kernel_size, stride, padding, dilation);
int64_t k0 = kernel_size[0];
int64_t k1 = kernel_size[1];
int64_t k2 = kernel_size[2];
auto d0 = dilation[0];
auto d1 = dilation[1];
auto d2 = dilation[2];
int64_t s0 = stride[0];
int64_t s1 = stride[1];
int64_t s2 = stride[2];
T max_value = input
[input_strides[0] * bounds0.start + input_strides[1] * bounds1.start +
input_strides[2] * bounds2.start];
auto size12 = input_sizes[1] * input_sizes[2];
auto max_index =
bounds0.start * size12 + bounds1.start * input_sizes[2] + bounds2.start;
int64_t d0 = dilation[0];
int64_t d1 = dilation[1];
int64_t d2 = dilation[2];
for (auto i0 = bounds0.start; i0 < bounds0.end; i0 += d0) {
auto offset0 = input_strides[0] * i0;
T max_value = 0;
int64_t max_index = -1;
for (auto i1 = bounds1.start; i1 < bounds1.end; i1 += d1) {
auto offset1 = input_strides[1] * i1;
int64_t size12 = input_sizes[1] * input_sizes[2];
for (auto i2 = bounds2.start; i2 < bounds2.end; i2 += d2) {
auto offset2 = input_strides[2] * i2;
auto input_value = input[offset0 + offset1 + offset2];
bool is_greater = input_value > max_value;
for (int64_t i0 = (s0 * o0) - padding[0];
i0 < (s0 * o0 - padding[0] + k0 * d0) && i0 < input_sizes[0];
i0 += d0) {
if (i0 < 0) {
continue;
}
int64_t offset0 = input_strides[0] * i0;
max_value = is_greater ? input_value : max_value;
for (int64_t i1 = (s1 * o1) - padding[1];
i1 < (s1 * o1 - padding[1] + k1 * d1) && i1 < input_sizes[1];
i1 += d1) {
if (i1 < 0) {
continue;
}
int64_t offset1 = input_strides[1] * i1;
for (int64_t i2 = (s2 * o2) - padding[2];
i2 < (s2 * o2 - padding[2] + k2 * d2) && i2 < input_sizes[2];
i2 += d2) {
if (i2 < 0) {
continue;
}
int64_t offset2 = input_strides[2] * i2;
const T input_value = input[offset0 + offset1 + offset2];
int64_t input_index = i0 * size12 + i1 * input_sizes[2] + i2;
T new_max_value = (max_index == -1 || input_value > max_value)
? input_value
: max_value;
int64_t new_max_index = (max_index == -1 || input_value > max_value)
? input_index
: max_index;
max_value = new_max_value;
max_index = new_max_index;
if (return_indices) {
auto input_index = i0 * size12 + i1 * input_sizes[2] + i2;
max_index = is_greater ? input_index : max_index;
}
}
}
}
*output = max_value;
if (return_indices) {
*indices = max_index;
}
}
struct PoolOffsets {
int64_t output;
int64_t indices;
int64_t input_leading;
int32_t output;
int32_t indices;
int32_t input_leading;
PoolOffsets() : output(0), indices(0), input_leading(0) {}
};
@ -98,30 +100,35 @@ struct PoolOffsets {
// calculate, `output[N, C, d, h, w]`. Also, find the offset of the input for
// the leading dim indices, `input[N, C]`. Optionally, keep track of the output
// pooling dimension indices, `[d, h , w]`.
PoolOffsets find_pool_offsets(
constant int64_t* output_sizes,
constant int64_t* output_strides,
constant int64_t* indices_strides,
constant int64_t* input_strides,
device int64_t* work_pooling_dim_indices,
int32_t dims,
// NOTE: This is templated per number of dimensions so that the compiler can
// unroll the loop, giving better performance.
template <int32_t dims>
PoolOffsets find_pool_offsets_dim_specific(
constant int32_t* output_sizes,
constant int32_t* output_strides,
constant int32_t* indices_strides,
constant int32_t* input_strides,
int32_t pooling_dim_indices[3],
int32_t leading_dims,
bool return_indices,
uint tid) {
int64_t output_idx = static_cast<int64_t>(tid);
auto output_idx = static_cast<int32_t>(tid);
PoolOffsets offsets;
for (int64_t dim = dims - 1; dim >= 0; dim--) {
int64_t dim_idx = output_idx % (output_sizes[dim]);
for (auto dim = dims - 1; dim >= 0; dim--) {
auto dim_idx = output_idx % (output_sizes[dim]);
offsets.output += output_strides[dim] * dim_idx;
if (return_indices) {
offsets.indices += indices_strides[dim] * dim_idx;
}
if (dim < leading_dims) {
offsets.input_leading += input_strides[dim] * dim_idx;
} else {
// Keep track of pooling dimension indices of the output element, so we
// can use them in the input iteration later on.
if (work_pooling_dim_indices != nullptr) {
work_pooling_dim_indices[dim - leading_dims] = dim_idx;
if (pooling_dim_indices != nullptr) {
pooling_dim_indices[dim - leading_dims] = dim_idx;
}
}
output_idx = output_idx / output_sizes[dim];
@ -130,45 +137,76 @@ PoolOffsets find_pool_offsets(
return offsets;
}
PoolOffsets find_pool_offsets(
constant int32_t* output_sizes,
constant int32_t* output_strides,
constant int32_t* indices_strides,
constant int32_t* input_strides,
int32_t pooling_dim_indices[3],
int32_t dims,
int32_t leading_dims,
bool return_indices,
uint tid) {
switch (dims) {
case 5:
return find_pool_offsets_dim_specific<5>(
output_sizes,
output_strides,
indices_strides,
input_strides,
pooling_dim_indices,
leading_dims,
return_indices,
tid);
case 4:
return find_pool_offsets_dim_specific<4>(
output_sizes,
output_strides,
indices_strides,
input_strides,
pooling_dim_indices,
leading_dims,
return_indices,
tid);
}
}
// Kernel computes one element of the output per kernel call.
template <typename T>
kernel void max_pool(
constant void* input_ [[buffer(0)]],
device void* output_ [[buffer(1)]],
device void* indices_ [[buffer(2)]],
device int64_t* work_pooling_dim_indices_ [[buffer(3)]],
constant PoolingParams<5>& params [[buffer(4)]],
constant T* input [[buffer(0)]],
device T* output [[buffer(1)]],
device int64_t* indices [[buffer(2)]],
constant PoolingParams<5>& params [[buffer(3)]],
uint tid [[thread_position_in_grid]]) {
int32_t pooling_dims = params.pooling_dims;
int32_t dims = params.dims;
constant int64_t* input_sizes = params.input_sizes.data();
constant int64_t* input_strides = params.input_strides.data();
constant int64_t* output_sizes = params.output_sizes.data();
constant int64_t* output_strides = params.output_strides.data();
constant int64_t* indices_strides = params.indices_strides.data();
constant int64_t* kernel_size = params.kernel_size.data();
constant int64_t* stride = params.stride.data();
constant int64_t* padding = params.padding.data();
constant int64_t* dilation = params.dilation.data();
bool return_indices = params.return_indices;
auto pooling_dims = params.pooling_dims;
auto dims = params.dims;
auto input_sizes = params.input_sizes.data();
auto input_strides = params.input_strides.data();
auto output_sizes = params.output_sizes.data();
auto output_strides = params.output_strides.data();
auto indices_strides = params.indices_strides.data();
auto kernel_size = params.kernel_size.data();
auto stride = params.stride.data();
auto padding = params.padding.data();
auto dilation = params.dilation.data();
int32_t leading_dims = dims - pooling_dims;
constant T* input = reinterpret_cast<constant T*>(input_);
device T* output = reinterpret_cast<device T*>(output_);
device int64_t* indices = reinterpret_cast<device int64_t*>(indices_);
auto leading_dims = dims - pooling_dims;
// This buffer keeps track of the pooling dimension indices of this thread's
// element of the output. We need to fill it with the proper values below.
device int64_t* work_pooling_dim_indices =
work_pooling_dim_indices_ + tid * pooling_dims;
int32_t pooling_dim_indices[3];
PoolOffsets offsets = find_pool_offsets(
output_sizes,
output_strides,
indices_strides,
input_strides,
work_pooling_dim_indices,
pooling_dim_indices,
dims,
leading_dims,
return_indices,
tid);
output += offsets.output;
@ -181,11 +219,12 @@ kernel void max_pool(
indices,
input_sizes + leading_dims,
input_strides + leading_dims,
work_pooling_dim_indices,
pooling_dim_indices,
kernel_size,
stride,
padding,
dilation);
dilation,
return_indices);
}
// Finds the element in the grad input which corresponds to the index into the
@ -195,15 +234,15 @@ void max_pool_backward_impl(
device AtomicType_t<T>* grad_input,
T grad_output_element,
int32_t input_index,
constant int64_t* grad_input_sizes,
constant int64_t* grad_input_strides,
constant int32_t* grad_input_sizes,
constant int32_t* grad_input_strides,
int32_t grad_input_leading_offset,
int32_t pooling_dims) {
int32_t size_prod = 1;
int32_t pool_offset = 0;
for (int32_t dim = pooling_dims - 1; dim >= 0; dim--) {
int32_t next_size_prod = grad_input_sizes[dim] * size_prod;
for (auto dim = pooling_dims - 1; dim >= 0; dim--) {
auto next_size_prod = grad_input_sizes[dim] * size_prod;
pool_offset +=
grad_input_strides[dim] * ((input_index % next_size_prod) / size_prod);
size_prod *= grad_input_sizes[dim];
@ -221,15 +260,15 @@ kernel void max_pool_backward(
constant int64_t* indices [[buffer(2)]],
constant PoolingBackwardParams<5>& params [[buffer(3)]],
uint tid [[thread_position_in_grid]]) {
int32_t pooling_dims = params.pooling_dims;
int32_t dims = params.dims;
constant int64_t* grad_input_sizes = params.grad_input_sizes.data();
constant int64_t* grad_input_strides = params.grad_input_strides.data();
constant int64_t* grad_output_sizes = params.grad_output_sizes.data();
constant int64_t* grad_output_strides = params.grad_output_strides.data();
constant int64_t* indices_strides = params.indices_strides.data();
auto pooling_dims = params.pooling_dims;
auto dims = params.dims;
auto grad_input_sizes = params.grad_input_sizes.data();
auto grad_input_strides = params.grad_input_strides.data();
auto grad_output_sizes = params.grad_output_sizes.data();
auto grad_output_strides = params.grad_output_strides.data();
auto indices_strides = params.indices_strides.data();
int32_t leading_dims = dims - pooling_dims;
auto leading_dims = dims - pooling_dims;
PoolOffsets offsets = find_pool_offsets(
grad_output_sizes,
@ -239,6 +278,7 @@ kernel void max_pool_backward(
nullptr,
dims,
leading_dims,
/*return_indices=*/true,
tid);
max_pool_backward_impl<T>(
@ -253,11 +293,10 @@ kernel void max_pool_backward(
#define REGISTER_MAX_POOL_OP(DTYPE) \
template [[host_name("max_pool_" #DTYPE)]] kernel void max_pool<DTYPE>( \
constant void* input_ [[buffer(0)]], \
device void* output_ [[buffer(1)]], \
device void* indices_ [[buffer(2)]], \
device int64_t* work_pooling_dim_indices_ [[buffer(3)]], \
constant PoolingParams<5>& params [[buffer(4)]], \
constant DTYPE * input [[buffer(0)]], \
device DTYPE * output [[buffer(1)]], \
device int64_t* indices [[buffer(2)]], \
constant PoolingParams<5>& params [[buffer(3)]], \
uint tid [[thread_position_in_grid]]);
#define REGISTER_MAX_POOL_BACKWARD_OP(DTYPE) \

View File

@ -252,22 +252,20 @@ static void pool2d_template(const Tensor& input,
}
}
static std::vector<int64_t> copy_and_maybe_expand(IntArrayRef a, int32_t pooling_dims) {
std::vector<int64_t> b;
if (a.size() == 1) {
b.assign(pooling_dims, a[0]);
} else {
b.assign(a.data(), a.data() + pooling_dims);
static std::vector<int32_t> copy_and_maybe_expand(IntArrayRef a, int32_t pooling_dims) {
std::vector<int32_t> b(pooling_dims);
for (const auto dim : c10::irange(pooling_dims)) {
b[dim] = safe_downcast<int32_t, int64_t>(a[a.size() == 1 ? 0 : dim]);
}
return b;
}
using PoolSizes = std::tuple<int32_t,
std::vector<int64_t>,
std::vector<int64_t>,
std::vector<int64_t>,
std::vector<int64_t>,
std::vector<int64_t>>;
std::vector<int32_t>,
std::vector<int32_t>,
std::vector<int32_t>,
std::vector<int32_t>>;
static PoolSizes process_pool_sizes(const Tensor& input,
IntArrayRef kernel_size,
@ -368,7 +366,7 @@ static PoolSizes process_pool_sizes(const Tensor& input,
}
static void max_pool_with_indices_out_mps_template(const Tensor& output,
const Tensor& indices,
const std::optional<Tensor>& indices_opt,
const Tensor& input,
IntArrayRef _kernel_size,
IntArrayRef _stride,
@ -379,10 +377,14 @@ static void max_pool_with_indices_out_mps_template(const Tensor& output,
const std::string& op_name) {
auto [dims, output_size, kernel_size, stride, padding, dilation] =
process_pool_sizes(input, _kernel_size, _stride, _padding, _dilation, ceil_mode, pooling_dims, op_name);
const Tensor& indices = *(at::borrow_from_optional_tensor(indices_opt));
const bool return_indices = indices.defined();
const auto memory_format = input.suggest_memory_format();
output.resize_(output_size, memory_format);
if (return_indices) {
indices.resize_(output_size, memory_format);
}
auto iter = TensorIteratorConfig().add_output(output).resize_outputs(false).check_all_same_dtype(false).build();
@ -395,33 +397,33 @@ static void max_pool_with_indices_out_mps_template(const Tensor& output,
params.dims = dims;
params.pooling_dims = pooling_dims;
memcpy(params.input_sizes.data(), input.sizes().data(), dims * sizeof(int64_t));
memcpy(params.input_strides.data(), input.strides().data(), dims * sizeof(int64_t));
memcpy(params.output_strides.data(), output.strides().data(), dims * sizeof(int64_t));
memcpy(params.output_sizes.data(), output.sizes().data(), dims * sizeof(int64_t));
memcpy(params.indices_strides.data(), indices.strides().data(), dims * sizeof(int64_t));
memcpy(params.indices_sizes.data(), indices.sizes().data(), dims * sizeof(int64_t));
memcpy(params.kernel_size.data(), kernel_size.data(), pooling_dims * sizeof(int64_t));
memcpy(params.stride.data(), stride.data(), pooling_dims * sizeof(int64_t));
memcpy(params.padding.data(), padding.data(), pooling_dims * sizeof(int64_t));
memcpy(params.dilation.data(), dilation.data(), pooling_dims * sizeof(int64_t));
params.return_indices = return_indices;
for (const auto dim : c10::irange(dims)) {
params.input_sizes[dim] = safe_downcast<int32_t, int64_t>(input.size(dim));
params.input_strides[dim] = safe_downcast<int32_t, int64_t>(input.stride(dim));
params.output_sizes[dim] = safe_downcast<int32_t, int64_t>(output.size(dim));
params.output_strides[dim] = safe_downcast<int32_t, int64_t>(output.stride(dim));
if (return_indices) {
params.indices_sizes[dim] = safe_downcast<int32_t, int64_t>(indices.size(dim));
params.indices_strides[dim] = safe_downcast<int32_t, int64_t>(indices.stride(dim));
}
}
memcpy(params.kernel_size.data(), kernel_size.data(), pooling_dims * sizeof(int32_t));
memcpy(params.stride.data(), stride.data(), pooling_dims * sizeof(int32_t));
memcpy(params.padding.data(), padding.data(), pooling_dims * sizeof(int32_t));
memcpy(params.dilation.data(), dilation.data(), pooling_dims * sizeof(int32_t));
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
auto maxPoolPSO = lib.getPipelineStateForFunc("max_pool_" + scalarToMetalTypeString(input));
// Each thread needs to keep track of the indices into the pooling
// dimensions for the element of the output that it calculates. In other
// words, if the thread calculates `output[N, C, d, h, w]` for a 3D pool,
// the kernel needs to keep track of the indices `[d, h, w]`. So we create
// a device-side buffer for the threads to store these indices.
id<MTLBuffer> work_pooling_dim_indices = [[device newBufferWithLength:numThreads * pooling_dims * sizeof(int64_t)
options:0] autorelease];
getMPSProfiler().beginProfileKernel(maxPoolPSO, op_name, {input});
[computeEncoder setComputePipelineState:maxPoolPSO];
mtl_setArgs(computeEncoder, input, output, indices, work_pooling_dim_indices, params);
mtl_setArgs(
computeEncoder, input, output, return_indices ? std::optional<Tensor>(indices) : std::nullopt, params);
mtl_dispatch1DJob(computeEncoder, maxPoolPSO, numThreads);
getMPSProfiler().endProfileKernel(maxPoolPSO);
@ -454,11 +456,14 @@ static void max_pool_with_indices_backward_out_mps_template(Tensor& grad_input,
params.dims = dims;
params.pooling_dims = pooling_dims;
memcpy(params.grad_input_sizes.data(), grad_input.sizes().data(), dims * sizeof(int64_t));
memcpy(params.grad_input_strides.data(), grad_input.strides().data(), dims * sizeof(int64_t));
memcpy(params.grad_output_strides.data(), grad_output.strides().data(), dims * sizeof(int64_t));
memcpy(params.grad_output_sizes.data(), grad_output.sizes().data(), dims * sizeof(int64_t));
memcpy(params.indices_strides.data(), indices.strides().data(), dims * sizeof(int64_t));
for (const auto dim : c10::irange(dims)) {
params.grad_input_sizes[dim] = safe_downcast<int32_t, int64_t>(grad_input.size(dim));
params.grad_input_strides[dim] = safe_downcast<int32_t, int64_t>(grad_input.stride(dim));
params.grad_output_sizes[dim] = safe_downcast<int32_t, int64_t>(grad_output.size(dim));
params.grad_output_strides[dim] = safe_downcast<int32_t, int64_t>(grad_output.stride(dim));
params.indices_strides[dim] = safe_downcast<int32_t, int64_t>(indices.stride(dim));
}
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {