mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Improve performance of max_pool3d (#157875)
To check how the changes from this PR affect performance, I wrote a script here: 55ef32a127/max_pool_mps/perf.py
.
Before this PR, I get this:
```
===================
max_pool3d
===================
0: 0.013105 ms, max_pool3d, (3, 2, 2, 2), {'kernel_size': 2}
1: 0.038003 ms, max_pool3d, (3, 10, 10, 10), {'kernel_size': 5}
2: 0.212963 ms, max_pool3d, (3, 100, 100, 100), {'kernel_size': 5}
3: 1.224645 ms, max_pool3d, (3, 200, 200, 200), {'kernel_size': 5}
4: 7.317867 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 4, 'padding': 1}
5: 34.679233 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 50, 'padding': 20}
6: 34.626383 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 50, 'padding': 20, 'dilation': 1}
7: 44.835892 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 50, 'padding': 20, 'dilation': 1, 'stride': 40}
8: 0.083579 ms, max_pool3d, (10, 10, 10, 10, 10), {'kernel_size': 2}
9: 0.936575 ms, max_pool3d, (10, 10, 30, 30, 30), {'kernel_size': 2}
10: 5.329883 ms, max_pool3d, (10, 10, 50, 50, 50), {'kernel_size': 2}
11: 11.713617 ms, max_pool3d, (10, 10, 70, 70, 70), {'kernel_size': 2}
12: 25.450454 ms, max_pool3d, (10, 10, 90, 90, 90), {'kernel_size': 2}
13: 0.058375 ms, max_pool3d, (10, 10, 10, 10, 10), {'kernel_size': 2, 'dilation': 2}
14: 3.757558 ms, max_pool3d, (10, 10, 50, 50, 50), {'kernel_size': 2, 'dilation': 2}
15: 33.451588 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 2, 'dilation': 2}
```
After this PR, I get this:
```
===================
max_pool3d
===================
0: 0.007202 ms, max_pool3d, (3, 2, 2, 2), {'kernel_size': 2}
1: 0.018596 ms, max_pool3d, (3, 10, 10, 10), {'kernel_size': 5}
2: 0.130717 ms, max_pool3d, (3, 100, 100, 100), {'kernel_size': 5}
3: 0.966795 ms, max_pool3d, (3, 200, 200, 200), {'kernel_size': 5}
4: 4.095804 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 4, 'padding': 1}
5: 12.833446 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 50, 'padding': 20}
6: 12.859346 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 50, 'padding': 20, 'dilation': 1}
7: 14.080529 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 50, 'padding': 20, 'dilation': 1, 'stride': 40}
8: 0.029283 ms, max_pool3d, (10, 10, 10, 10, 10), {'kernel_size': 2}
9: 0.175700 ms, max_pool3d, (10, 10, 30, 30, 30), {'kernel_size': 2}
10: 0.742750 ms, max_pool3d, (10, 10, 50, 50, 50), {'kernel_size': 2}
11: 1.939596 ms, max_pool3d, (10, 10, 70, 70, 70), {'kernel_size': 2}
12: 4.074821 ms, max_pool3d, (10, 10, 90, 90, 90), {'kernel_size': 2}
13: 0.028425 ms, max_pool3d, (10, 10, 10, 10, 10), {'kernel_size': 2, 'dilation': 2}
14: 0.384375 ms, max_pool3d, (10, 10, 50, 50, 50), {'kernel_size': 2, 'dilation': 2}
15: 2.623346 ms, max_pool3d, (10, 10, 100, 100, 100), {'kernel_size': 2, 'dilation': 2}
```
Every case is improved.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157875
Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
66c9bc5062
commit
1b88da1cac
@ -5,29 +5,30 @@
|
||||
// maximum allowed pooling dimensions is N-2, because the input may have up to 2
|
||||
// leading dimensions that are not pooled. To support up to 3-D pooling, N=5 is
|
||||
// the default.
|
||||
template <unsigned N = 5>
|
||||
template <unsigned N = 5, typename idx_type_t = int32_t>
|
||||
struct PoolingParams {
|
||||
int32_t dims;
|
||||
int32_t pooling_dims;
|
||||
::c10::metal::array<int64_t, N> input_sizes;
|
||||
::c10::metal::array<int64_t, N> input_strides;
|
||||
::c10::metal::array<int64_t, N> output_sizes;
|
||||
::c10::metal::array<int64_t, N> output_strides;
|
||||
::c10::metal::array<int64_t, N> indices_sizes;
|
||||
::c10::metal::array<int64_t, N> indices_strides;
|
||||
::c10::metal::array<int64_t, N - 2> kernel_size;
|
||||
::c10::metal::array<int64_t, N - 2> stride;
|
||||
::c10::metal::array<int64_t, N - 2> padding;
|
||||
::c10::metal::array<int64_t, N - 2> dilation;
|
||||
::c10::metal::array<idx_type_t, N> input_sizes;
|
||||
::c10::metal::array<idx_type_t, N> input_strides;
|
||||
::c10::metal::array<idx_type_t, N> output_sizes;
|
||||
::c10::metal::array<idx_type_t, N> output_strides;
|
||||
::c10::metal::array<idx_type_t, N> indices_sizes;
|
||||
::c10::metal::array<idx_type_t, N> indices_strides;
|
||||
::c10::metal::array<idx_type_t, N - 2> kernel_size;
|
||||
::c10::metal::array<idx_type_t, N - 2> stride;
|
||||
::c10::metal::array<idx_type_t, N - 2> padding;
|
||||
::c10::metal::array<idx_type_t, N - 2> dilation;
|
||||
bool return_indices;
|
||||
};
|
||||
|
||||
template <unsigned N = 5>
|
||||
template <unsigned N = 5, typename idx_type_t = int32_t>
|
||||
struct PoolingBackwardParams {
|
||||
int32_t dims;
|
||||
int32_t pooling_dims;
|
||||
::c10::metal::array<int64_t, N> grad_input_sizes;
|
||||
::c10::metal::array<int64_t, N> grad_input_strides;
|
||||
::c10::metal::array<int64_t, N> grad_output_sizes;
|
||||
::c10::metal::array<int64_t, N> grad_output_strides;
|
||||
::c10::metal::array<int64_t, N> indices_strides;
|
||||
::c10::metal::array<idx_type_t, N> grad_input_sizes;
|
||||
::c10::metal::array<idx_type_t, N> grad_input_strides;
|
||||
::c10::metal::array<idx_type_t, N> grad_output_sizes;
|
||||
::c10::metal::array<idx_type_t, N> grad_output_strides;
|
||||
::c10::metal::array<idx_type_t, N> indices_strides;
|
||||
};
|
||||
|
@ -6,6 +6,28 @@
|
||||
using namespace metal;
|
||||
using namespace c10::metal;
|
||||
|
||||
template <typename T>
|
||||
struct IterBounds {
|
||||
T start;
|
||||
T end;
|
||||
};
|
||||
|
||||
template <int32_t dim>
|
||||
IterBounds<int32_t> get_input_iter_bounds(
|
||||
constant int32_t* input_sizes,
|
||||
thread int32_t (&pooling_dim_indices)[3],
|
||||
constant int32_t* kernel_size,
|
||||
constant int32_t* stride,
|
||||
constant int32_t* padding,
|
||||
constant int32_t* dilation) {
|
||||
auto d = dilation[dim];
|
||||
auto start = stride[dim] * pooling_dim_indices[dim] - padding[dim];
|
||||
auto end = min(start + kernel_size[dim] * d, input_sizes[dim]);
|
||||
auto start_correction = d * ((-start - 1 + d) / d);
|
||||
start += start < 0 ? start_correction : 0;
|
||||
return IterBounds<int32_t>{start, end};
|
||||
}
|
||||
|
||||
// Iterates through all the input elements that this kernel needs to
|
||||
// apply max to. Specialized for 3 pooling dimensions.
|
||||
// TODO: Support any number of pooling dims
|
||||
@ -14,82 +36,62 @@ void max_pool_3d_input_iter(
|
||||
constant T* input,
|
||||
device T* output,
|
||||
device int64_t* indices,
|
||||
constant int64_t* input_sizes,
|
||||
constant int64_t* input_strides,
|
||||
device int64_t* work_pooling_dim_indices,
|
||||
constant int64_t* kernel_size,
|
||||
constant int64_t* stride,
|
||||
constant int64_t* padding,
|
||||
constant int64_t* dilation) {
|
||||
int64_t o0 = work_pooling_dim_indices[0];
|
||||
int64_t o1 = work_pooling_dim_indices[1];
|
||||
int64_t o2 = work_pooling_dim_indices[2];
|
||||
constant int32_t* input_sizes,
|
||||
constant int32_t* input_strides,
|
||||
thread int32_t (&pooling_dim_indices)[3],
|
||||
constant int32_t* kernel_size,
|
||||
constant int32_t* stride,
|
||||
constant int32_t* padding,
|
||||
constant int32_t* dilation,
|
||||
bool return_indices) {
|
||||
auto bounds0 = get_input_iter_bounds<0>(
|
||||
input_sizes, pooling_dim_indices, kernel_size, stride, padding, dilation);
|
||||
auto bounds1 = get_input_iter_bounds<1>(
|
||||
input_sizes, pooling_dim_indices, kernel_size, stride, padding, dilation);
|
||||
auto bounds2 = get_input_iter_bounds<2>(
|
||||
input_sizes, pooling_dim_indices, kernel_size, stride, padding, dilation);
|
||||
|
||||
int64_t k0 = kernel_size[0];
|
||||
int64_t k1 = kernel_size[1];
|
||||
int64_t k2 = kernel_size[2];
|
||||
auto d0 = dilation[0];
|
||||
auto d1 = dilation[1];
|
||||
auto d2 = dilation[2];
|
||||
|
||||
int64_t s0 = stride[0];
|
||||
int64_t s1 = stride[1];
|
||||
int64_t s2 = stride[2];
|
||||
T max_value = input
|
||||
[input_strides[0] * bounds0.start + input_strides[1] * bounds1.start +
|
||||
input_strides[2] * bounds2.start];
|
||||
auto size12 = input_sizes[1] * input_sizes[2];
|
||||
auto max_index =
|
||||
bounds0.start * size12 + bounds1.start * input_sizes[2] + bounds2.start;
|
||||
|
||||
int64_t d0 = dilation[0];
|
||||
int64_t d1 = dilation[1];
|
||||
int64_t d2 = dilation[2];
|
||||
for (auto i0 = bounds0.start; i0 < bounds0.end; i0 += d0) {
|
||||
auto offset0 = input_strides[0] * i0;
|
||||
|
||||
T max_value = 0;
|
||||
int64_t max_index = -1;
|
||||
for (auto i1 = bounds1.start; i1 < bounds1.end; i1 += d1) {
|
||||
auto offset1 = input_strides[1] * i1;
|
||||
|
||||
int64_t size12 = input_sizes[1] * input_sizes[2];
|
||||
for (auto i2 = bounds2.start; i2 < bounds2.end; i2 += d2) {
|
||||
auto offset2 = input_strides[2] * i2;
|
||||
auto input_value = input[offset0 + offset1 + offset2];
|
||||
bool is_greater = input_value > max_value;
|
||||
|
||||
for (int64_t i0 = (s0 * o0) - padding[0];
|
||||
i0 < (s0 * o0 - padding[0] + k0 * d0) && i0 < input_sizes[0];
|
||||
i0 += d0) {
|
||||
if (i0 < 0) {
|
||||
continue;
|
||||
}
|
||||
int64_t offset0 = input_strides[0] * i0;
|
||||
max_value = is_greater ? input_value : max_value;
|
||||
|
||||
for (int64_t i1 = (s1 * o1) - padding[1];
|
||||
i1 < (s1 * o1 - padding[1] + k1 * d1) && i1 < input_sizes[1];
|
||||
i1 += d1) {
|
||||
if (i1 < 0) {
|
||||
continue;
|
||||
}
|
||||
int64_t offset1 = input_strides[1] * i1;
|
||||
|
||||
for (int64_t i2 = (s2 * o2) - padding[2];
|
||||
i2 < (s2 * o2 - padding[2] + k2 * d2) && i2 < input_sizes[2];
|
||||
i2 += d2) {
|
||||
if (i2 < 0) {
|
||||
continue;
|
||||
}
|
||||
int64_t offset2 = input_strides[2] * i2;
|
||||
|
||||
const T input_value = input[offset0 + offset1 + offset2];
|
||||
int64_t input_index = i0 * size12 + i1 * input_sizes[2] + i2;
|
||||
|
||||
T new_max_value = (max_index == -1 || input_value > max_value)
|
||||
? input_value
|
||||
: max_value;
|
||||
int64_t new_max_index = (max_index == -1 || input_value > max_value)
|
||||
? input_index
|
||||
: max_index;
|
||||
|
||||
max_value = new_max_value;
|
||||
max_index = new_max_index;
|
||||
if (return_indices) {
|
||||
auto input_index = i0 * size12 + i1 * input_sizes[2] + i2;
|
||||
max_index = is_greater ? input_index : max_index;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
*output = max_value;
|
||||
if (return_indices) {
|
||||
*indices = max_index;
|
||||
}
|
||||
}
|
||||
|
||||
struct PoolOffsets {
|
||||
int64_t output;
|
||||
int64_t indices;
|
||||
int64_t input_leading;
|
||||
int32_t output;
|
||||
int32_t indices;
|
||||
int32_t input_leading;
|
||||
|
||||
PoolOffsets() : output(0), indices(0), input_leading(0) {}
|
||||
};
|
||||
@ -98,30 +100,35 @@ struct PoolOffsets {
|
||||
// calculate, `output[N, C, d, h, w]`. Also, find the offset of the input for
|
||||
// the leading dim indices, `input[N, C]`. Optionally, keep track of the output
|
||||
// pooling dimension indices, `[d, h , w]`.
|
||||
PoolOffsets find_pool_offsets(
|
||||
constant int64_t* output_sizes,
|
||||
constant int64_t* output_strides,
|
||||
constant int64_t* indices_strides,
|
||||
constant int64_t* input_strides,
|
||||
device int64_t* work_pooling_dim_indices,
|
||||
int32_t dims,
|
||||
// NOTE: This is templated per number of dimensions so that the compiler can
|
||||
// unroll the loop, giving better performance.
|
||||
template <int32_t dims>
|
||||
PoolOffsets find_pool_offsets_dim_specific(
|
||||
constant int32_t* output_sizes,
|
||||
constant int32_t* output_strides,
|
||||
constant int32_t* indices_strides,
|
||||
constant int32_t* input_strides,
|
||||
int32_t pooling_dim_indices[3],
|
||||
int32_t leading_dims,
|
||||
bool return_indices,
|
||||
uint tid) {
|
||||
int64_t output_idx = static_cast<int64_t>(tid);
|
||||
auto output_idx = static_cast<int32_t>(tid);
|
||||
PoolOffsets offsets;
|
||||
|
||||
for (int64_t dim = dims - 1; dim >= 0; dim--) {
|
||||
int64_t dim_idx = output_idx % (output_sizes[dim]);
|
||||
for (auto dim = dims - 1; dim >= 0; dim--) {
|
||||
auto dim_idx = output_idx % (output_sizes[dim]);
|
||||
offsets.output += output_strides[dim] * dim_idx;
|
||||
if (return_indices) {
|
||||
offsets.indices += indices_strides[dim] * dim_idx;
|
||||
}
|
||||
|
||||
if (dim < leading_dims) {
|
||||
offsets.input_leading += input_strides[dim] * dim_idx;
|
||||
} else {
|
||||
// Keep track of pooling dimension indices of the output element, so we
|
||||
// can use them in the input iteration later on.
|
||||
if (work_pooling_dim_indices != nullptr) {
|
||||
work_pooling_dim_indices[dim - leading_dims] = dim_idx;
|
||||
if (pooling_dim_indices != nullptr) {
|
||||
pooling_dim_indices[dim - leading_dims] = dim_idx;
|
||||
}
|
||||
}
|
||||
output_idx = output_idx / output_sizes[dim];
|
||||
@ -130,45 +137,76 @@ PoolOffsets find_pool_offsets(
|
||||
return offsets;
|
||||
}
|
||||
|
||||
PoolOffsets find_pool_offsets(
|
||||
constant int32_t* output_sizes,
|
||||
constant int32_t* output_strides,
|
||||
constant int32_t* indices_strides,
|
||||
constant int32_t* input_strides,
|
||||
int32_t pooling_dim_indices[3],
|
||||
int32_t dims,
|
||||
int32_t leading_dims,
|
||||
bool return_indices,
|
||||
uint tid) {
|
||||
switch (dims) {
|
||||
case 5:
|
||||
return find_pool_offsets_dim_specific<5>(
|
||||
output_sizes,
|
||||
output_strides,
|
||||
indices_strides,
|
||||
input_strides,
|
||||
pooling_dim_indices,
|
||||
leading_dims,
|
||||
return_indices,
|
||||
tid);
|
||||
case 4:
|
||||
return find_pool_offsets_dim_specific<4>(
|
||||
output_sizes,
|
||||
output_strides,
|
||||
indices_strides,
|
||||
input_strides,
|
||||
pooling_dim_indices,
|
||||
leading_dims,
|
||||
return_indices,
|
||||
tid);
|
||||
}
|
||||
}
|
||||
|
||||
// Kernel computes one element of the output per kernel call.
|
||||
template <typename T>
|
||||
kernel void max_pool(
|
||||
constant void* input_ [[buffer(0)]],
|
||||
device void* output_ [[buffer(1)]],
|
||||
device void* indices_ [[buffer(2)]],
|
||||
device int64_t* work_pooling_dim_indices_ [[buffer(3)]],
|
||||
constant PoolingParams<5>& params [[buffer(4)]],
|
||||
constant T* input [[buffer(0)]],
|
||||
device T* output [[buffer(1)]],
|
||||
device int64_t* indices [[buffer(2)]],
|
||||
constant PoolingParams<5>& params [[buffer(3)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
int32_t pooling_dims = params.pooling_dims;
|
||||
int32_t dims = params.dims;
|
||||
constant int64_t* input_sizes = params.input_sizes.data();
|
||||
constant int64_t* input_strides = params.input_strides.data();
|
||||
constant int64_t* output_sizes = params.output_sizes.data();
|
||||
constant int64_t* output_strides = params.output_strides.data();
|
||||
constant int64_t* indices_strides = params.indices_strides.data();
|
||||
constant int64_t* kernel_size = params.kernel_size.data();
|
||||
constant int64_t* stride = params.stride.data();
|
||||
constant int64_t* padding = params.padding.data();
|
||||
constant int64_t* dilation = params.dilation.data();
|
||||
bool return_indices = params.return_indices;
|
||||
auto pooling_dims = params.pooling_dims;
|
||||
auto dims = params.dims;
|
||||
auto input_sizes = params.input_sizes.data();
|
||||
auto input_strides = params.input_strides.data();
|
||||
auto output_sizes = params.output_sizes.data();
|
||||
auto output_strides = params.output_strides.data();
|
||||
auto indices_strides = params.indices_strides.data();
|
||||
auto kernel_size = params.kernel_size.data();
|
||||
auto stride = params.stride.data();
|
||||
auto padding = params.padding.data();
|
||||
auto dilation = params.dilation.data();
|
||||
|
||||
int32_t leading_dims = dims - pooling_dims;
|
||||
constant T* input = reinterpret_cast<constant T*>(input_);
|
||||
device T* output = reinterpret_cast<device T*>(output_);
|
||||
device int64_t* indices = reinterpret_cast<device int64_t*>(indices_);
|
||||
auto leading_dims = dims - pooling_dims;
|
||||
|
||||
// This buffer keeps track of the pooling dimension indices of this thread's
|
||||
// element of the output. We need to fill it with the proper values below.
|
||||
device int64_t* work_pooling_dim_indices =
|
||||
work_pooling_dim_indices_ + tid * pooling_dims;
|
||||
int32_t pooling_dim_indices[3];
|
||||
|
||||
PoolOffsets offsets = find_pool_offsets(
|
||||
output_sizes,
|
||||
output_strides,
|
||||
indices_strides,
|
||||
input_strides,
|
||||
work_pooling_dim_indices,
|
||||
pooling_dim_indices,
|
||||
dims,
|
||||
leading_dims,
|
||||
return_indices,
|
||||
tid);
|
||||
|
||||
output += offsets.output;
|
||||
@ -181,11 +219,12 @@ kernel void max_pool(
|
||||
indices,
|
||||
input_sizes + leading_dims,
|
||||
input_strides + leading_dims,
|
||||
work_pooling_dim_indices,
|
||||
pooling_dim_indices,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation);
|
||||
dilation,
|
||||
return_indices);
|
||||
}
|
||||
|
||||
// Finds the element in the grad input which corresponds to the index into the
|
||||
@ -195,15 +234,15 @@ void max_pool_backward_impl(
|
||||
device AtomicType_t<T>* grad_input,
|
||||
T grad_output_element,
|
||||
int32_t input_index,
|
||||
constant int64_t* grad_input_sizes,
|
||||
constant int64_t* grad_input_strides,
|
||||
constant int32_t* grad_input_sizes,
|
||||
constant int32_t* grad_input_strides,
|
||||
int32_t grad_input_leading_offset,
|
||||
int32_t pooling_dims) {
|
||||
int32_t size_prod = 1;
|
||||
int32_t pool_offset = 0;
|
||||
|
||||
for (int32_t dim = pooling_dims - 1; dim >= 0; dim--) {
|
||||
int32_t next_size_prod = grad_input_sizes[dim] * size_prod;
|
||||
for (auto dim = pooling_dims - 1; dim >= 0; dim--) {
|
||||
auto next_size_prod = grad_input_sizes[dim] * size_prod;
|
||||
pool_offset +=
|
||||
grad_input_strides[dim] * ((input_index % next_size_prod) / size_prod);
|
||||
size_prod *= grad_input_sizes[dim];
|
||||
@ -221,15 +260,15 @@ kernel void max_pool_backward(
|
||||
constant int64_t* indices [[buffer(2)]],
|
||||
constant PoolingBackwardParams<5>& params [[buffer(3)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
int32_t pooling_dims = params.pooling_dims;
|
||||
int32_t dims = params.dims;
|
||||
constant int64_t* grad_input_sizes = params.grad_input_sizes.data();
|
||||
constant int64_t* grad_input_strides = params.grad_input_strides.data();
|
||||
constant int64_t* grad_output_sizes = params.grad_output_sizes.data();
|
||||
constant int64_t* grad_output_strides = params.grad_output_strides.data();
|
||||
constant int64_t* indices_strides = params.indices_strides.data();
|
||||
auto pooling_dims = params.pooling_dims;
|
||||
auto dims = params.dims;
|
||||
auto grad_input_sizes = params.grad_input_sizes.data();
|
||||
auto grad_input_strides = params.grad_input_strides.data();
|
||||
auto grad_output_sizes = params.grad_output_sizes.data();
|
||||
auto grad_output_strides = params.grad_output_strides.data();
|
||||
auto indices_strides = params.indices_strides.data();
|
||||
|
||||
int32_t leading_dims = dims - pooling_dims;
|
||||
auto leading_dims = dims - pooling_dims;
|
||||
|
||||
PoolOffsets offsets = find_pool_offsets(
|
||||
grad_output_sizes,
|
||||
@ -239,6 +278,7 @@ kernel void max_pool_backward(
|
||||
nullptr,
|
||||
dims,
|
||||
leading_dims,
|
||||
/*return_indices=*/true,
|
||||
tid);
|
||||
|
||||
max_pool_backward_impl<T>(
|
||||
@ -253,11 +293,10 @@ kernel void max_pool_backward(
|
||||
|
||||
#define REGISTER_MAX_POOL_OP(DTYPE) \
|
||||
template [[host_name("max_pool_" #DTYPE)]] kernel void max_pool<DTYPE>( \
|
||||
constant void* input_ [[buffer(0)]], \
|
||||
device void* output_ [[buffer(1)]], \
|
||||
device void* indices_ [[buffer(2)]], \
|
||||
device int64_t* work_pooling_dim_indices_ [[buffer(3)]], \
|
||||
constant PoolingParams<5>& params [[buffer(4)]], \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * output [[buffer(1)]], \
|
||||
device int64_t* indices [[buffer(2)]], \
|
||||
constant PoolingParams<5>& params [[buffer(3)]], \
|
||||
uint tid [[thread_position_in_grid]]);
|
||||
|
||||
#define REGISTER_MAX_POOL_BACKWARD_OP(DTYPE) \
|
||||
|
@ -252,22 +252,20 @@ static void pool2d_template(const Tensor& input,
|
||||
}
|
||||
}
|
||||
|
||||
static std::vector<int64_t> copy_and_maybe_expand(IntArrayRef a, int32_t pooling_dims) {
|
||||
std::vector<int64_t> b;
|
||||
if (a.size() == 1) {
|
||||
b.assign(pooling_dims, a[0]);
|
||||
} else {
|
||||
b.assign(a.data(), a.data() + pooling_dims);
|
||||
static std::vector<int32_t> copy_and_maybe_expand(IntArrayRef a, int32_t pooling_dims) {
|
||||
std::vector<int32_t> b(pooling_dims);
|
||||
for (const auto dim : c10::irange(pooling_dims)) {
|
||||
b[dim] = safe_downcast<int32_t, int64_t>(a[a.size() == 1 ? 0 : dim]);
|
||||
}
|
||||
return b;
|
||||
}
|
||||
|
||||
using PoolSizes = std::tuple<int32_t,
|
||||
std::vector<int64_t>,
|
||||
std::vector<int64_t>,
|
||||
std::vector<int64_t>,
|
||||
std::vector<int64_t>,
|
||||
std::vector<int64_t>>;
|
||||
std::vector<int32_t>,
|
||||
std::vector<int32_t>,
|
||||
std::vector<int32_t>,
|
||||
std::vector<int32_t>>;
|
||||
|
||||
static PoolSizes process_pool_sizes(const Tensor& input,
|
||||
IntArrayRef kernel_size,
|
||||
@ -368,7 +366,7 @@ static PoolSizes process_pool_sizes(const Tensor& input,
|
||||
}
|
||||
|
||||
static void max_pool_with_indices_out_mps_template(const Tensor& output,
|
||||
const Tensor& indices,
|
||||
const std::optional<Tensor>& indices_opt,
|
||||
const Tensor& input,
|
||||
IntArrayRef _kernel_size,
|
||||
IntArrayRef _stride,
|
||||
@ -379,10 +377,14 @@ static void max_pool_with_indices_out_mps_template(const Tensor& output,
|
||||
const std::string& op_name) {
|
||||
auto [dims, output_size, kernel_size, stride, padding, dilation] =
|
||||
process_pool_sizes(input, _kernel_size, _stride, _padding, _dilation, ceil_mode, pooling_dims, op_name);
|
||||
const Tensor& indices = *(at::borrow_from_optional_tensor(indices_opt));
|
||||
const bool return_indices = indices.defined();
|
||||
|
||||
const auto memory_format = input.suggest_memory_format();
|
||||
output.resize_(output_size, memory_format);
|
||||
if (return_indices) {
|
||||
indices.resize_(output_size, memory_format);
|
||||
}
|
||||
|
||||
auto iter = TensorIteratorConfig().add_output(output).resize_outputs(false).check_all_same_dtype(false).build();
|
||||
|
||||
@ -395,33 +397,33 @@ static void max_pool_with_indices_out_mps_template(const Tensor& output,
|
||||
|
||||
params.dims = dims;
|
||||
params.pooling_dims = pooling_dims;
|
||||
memcpy(params.input_sizes.data(), input.sizes().data(), dims * sizeof(int64_t));
|
||||
memcpy(params.input_strides.data(), input.strides().data(), dims * sizeof(int64_t));
|
||||
memcpy(params.output_strides.data(), output.strides().data(), dims * sizeof(int64_t));
|
||||
memcpy(params.output_sizes.data(), output.sizes().data(), dims * sizeof(int64_t));
|
||||
memcpy(params.indices_strides.data(), indices.strides().data(), dims * sizeof(int64_t));
|
||||
memcpy(params.indices_sizes.data(), indices.sizes().data(), dims * sizeof(int64_t));
|
||||
memcpy(params.kernel_size.data(), kernel_size.data(), pooling_dims * sizeof(int64_t));
|
||||
memcpy(params.stride.data(), stride.data(), pooling_dims * sizeof(int64_t));
|
||||
memcpy(params.padding.data(), padding.data(), pooling_dims * sizeof(int64_t));
|
||||
memcpy(params.dilation.data(), dilation.data(), pooling_dims * sizeof(int64_t));
|
||||
params.return_indices = return_indices;
|
||||
|
||||
for (const auto dim : c10::irange(dims)) {
|
||||
params.input_sizes[dim] = safe_downcast<int32_t, int64_t>(input.size(dim));
|
||||
params.input_strides[dim] = safe_downcast<int32_t, int64_t>(input.stride(dim));
|
||||
params.output_sizes[dim] = safe_downcast<int32_t, int64_t>(output.size(dim));
|
||||
params.output_strides[dim] = safe_downcast<int32_t, int64_t>(output.stride(dim));
|
||||
if (return_indices) {
|
||||
params.indices_sizes[dim] = safe_downcast<int32_t, int64_t>(indices.size(dim));
|
||||
params.indices_strides[dim] = safe_downcast<int32_t, int64_t>(indices.stride(dim));
|
||||
}
|
||||
}
|
||||
|
||||
memcpy(params.kernel_size.data(), kernel_size.data(), pooling_dims * sizeof(int32_t));
|
||||
memcpy(params.stride.data(), stride.data(), pooling_dims * sizeof(int32_t));
|
||||
memcpy(params.padding.data(), padding.data(), pooling_dims * sizeof(int32_t));
|
||||
memcpy(params.dilation.data(), dilation.data(), pooling_dims * sizeof(int32_t));
|
||||
|
||||
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
|
||||
auto maxPoolPSO = lib.getPipelineStateForFunc("max_pool_" + scalarToMetalTypeString(input));
|
||||
|
||||
// Each thread needs to keep track of the indices into the pooling
|
||||
// dimensions for the element of the output that it calculates. In other
|
||||
// words, if the thread calculates `output[N, C, d, h, w]` for a 3D pool,
|
||||
// the kernel needs to keep track of the indices `[d, h, w]`. So we create
|
||||
// a device-side buffer for the threads to store these indices.
|
||||
id<MTLBuffer> work_pooling_dim_indices = [[device newBufferWithLength:numThreads * pooling_dims * sizeof(int64_t)
|
||||
options:0] autorelease];
|
||||
|
||||
getMPSProfiler().beginProfileKernel(maxPoolPSO, op_name, {input});
|
||||
[computeEncoder setComputePipelineState:maxPoolPSO];
|
||||
mtl_setArgs(computeEncoder, input, output, indices, work_pooling_dim_indices, params);
|
||||
mtl_setArgs(
|
||||
computeEncoder, input, output, return_indices ? std::optional<Tensor>(indices) : std::nullopt, params);
|
||||
|
||||
mtl_dispatch1DJob(computeEncoder, maxPoolPSO, numThreads);
|
||||
getMPSProfiler().endProfileKernel(maxPoolPSO);
|
||||
@ -454,11 +456,14 @@ static void max_pool_with_indices_backward_out_mps_template(Tensor& grad_input,
|
||||
|
||||
params.dims = dims;
|
||||
params.pooling_dims = pooling_dims;
|
||||
memcpy(params.grad_input_sizes.data(), grad_input.sizes().data(), dims * sizeof(int64_t));
|
||||
memcpy(params.grad_input_strides.data(), grad_input.strides().data(), dims * sizeof(int64_t));
|
||||
memcpy(params.grad_output_strides.data(), grad_output.strides().data(), dims * sizeof(int64_t));
|
||||
memcpy(params.grad_output_sizes.data(), grad_output.sizes().data(), dims * sizeof(int64_t));
|
||||
memcpy(params.indices_strides.data(), indices.strides().data(), dims * sizeof(int64_t));
|
||||
|
||||
for (const auto dim : c10::irange(dims)) {
|
||||
params.grad_input_sizes[dim] = safe_downcast<int32_t, int64_t>(grad_input.size(dim));
|
||||
params.grad_input_strides[dim] = safe_downcast<int32_t, int64_t>(grad_input.stride(dim));
|
||||
params.grad_output_sizes[dim] = safe_downcast<int32_t, int64_t>(grad_output.size(dim));
|
||||
params.grad_output_strides[dim] = safe_downcast<int32_t, int64_t>(grad_output.stride(dim));
|
||||
params.indices_strides[dim] = safe_downcast<int32_t, int64_t>(indices.stride(dim));
|
||||
}
|
||||
|
||||
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
|
Reference in New Issue
Block a user