[iOS GPU][Kernel] Implement channel split in Metal shaders (#56074)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56074

To run shufflenet we need to support at::chunk on GPU. The current implementation only splits the tensor into two on channel dimension. We'll come back and fully implement it in Metal shaders.
ghstack-source-id: 127522377

Test Plan:
```
2021-03-26 01:37:07.693411-0700 PyTorchPlayground[2279:235793] [MPSImageWrapper] Found a temporary image: [1, 2, 2, 2]
2021-03-26 01:37:07.693499-0700 PyTorchPlayground[2279:235793] [MPSImageWrapper] Found a temporary image: [1, 2, 2, 2]
2021-03-26 01:37:07.693544-0700 PyTorchPlayground[2279:235793] [MPSImageWrapper] Found a temporary image: [1, 4, 2, 2]
2021-03-26 01:37:07.695415-0700 PyTorchPlayground[2279:235793] [bool test_chunk()],[1 4 2 2 ],[SUCCEED]
2021-03-26 01:37:07.695862-0700 PyTorchPlayground[2279:235793] [MPSImageWrapper] Found a temporary image: [1, 4, 2, 2]
2021-03-26 01:37:07.695927-0700 PyTorchPlayground[2279:235793] [MPSImageWrapper] Found a temporary image: [1, 5, 2, 2]
2021-03-26 01:37:07.695971-0700 PyTorchPlayground[2279:235793] [MPSImageWrapper] Found a temporary image: [1, 9, 2, 2]
2021-03-26 01:37:07.698215-0700 PyTorchPlayground[2279:235793] [bool test_chunk2()],[1 9 2 2 ],[SUCCEED]
2021-03-26 01:37:07.699086-0700 PyTorchPlayground[2279:235793] [MPSImageWrapper] Found a temporary image: [1, 8, 2, 2]
2021-03-26 01:37:07.699154-0700 PyTorchPlayground[2279:235793] [MPSImageWrapper] Found a temporary image: [1, 16, 2, 2]
2021-03-26 01:37:07.699197-0700 PyTorchPlayground[2279:235793] [MPSImageWrapper] Found a temporary image: [1, 8, 2, 2]
2021-03-26 01:37:07.700842-0700 PyTorchPlayground[2279:235793] [bool test_chunk3()],[1 16 2 2 ],[SUCCEED]
```
- Sandcastle
- CircleCI

Reviewed By: SS-JIA

Differential Revision: D27357096

fbshipit-source-id: fd3908ad2c26466e4f714d531790be2f1ae24153
This commit is contained in:
Tao Xu
2021-04-28 00:50:47 -07:00
committed by Facebook GitHub Bot
parent 0df574017d
commit 8134806e23
4 changed files with 222 additions and 0 deletions

View File

@ -733,6 +733,101 @@ kernel void transpose(texture2d_array<half, access::read>in_arr[[texture(0),func
}
}
constant bool split_channels_in_is_arr = (ushort_arg_0 > 4);
constant bool split_channels_in_is_tex = !split_channels_in_is_arr;
constant bool split_channels_out1_is_arr = (ushort_arg_1 > 4);
constant bool split_channels_out1_is_tex = !split_channels_out1_is_arr;
constant bool split_channels_out2_is_arr = (ushort_arg_2 > 4);
constant bool split_channels_out2_is_tex = !(split_channels_out2_is_arr);
// A naive implementation to split the input texture into two on channel dimension
kernel void split_channels(texture2d_array<half, access::read> in_arr[[texture(0), function_constant(split_channels_in_is_arr)]],
texture2d<half, access::read> in_tex[[texture(0), function_constant(split_channels_in_is_tex)]],
texture2d_array<half, access::write> out1_arr[[texture(1),function_constant(split_channels_out1_is_arr)]],
texture2d<half, access::write> out1_tex[[texture(1),function_constant(split_channels_out1_is_tex)]],
texture2d_array<half, access::write> out2_arr[[texture(2), function_constant(split_channels_out2_is_arr)]],
texture2d<half, access::write> out2_tex[[texture(2),function_constant(split_channels_out2_is_tex)]],
ushort3 gid[[thread_position_in_grid]]) {
ushort W,H;
if(split_channels_in_is_arr) {
W = in_arr.get_width();
H = in_arr.get_height();
} else {
W = in_tex.get_width();
H = in_tex.get_height();
}
if(gid.x >= W || gid.y >= H){
return;
}
const ushort C1 = ushort_arg_1;
const ushort s1 = divRoundUp(C1, 4);
const ushort c_offset = C1 % 4;
half4 tmp1(0.0, 0.0, 0.0, 0.0);
half4 tmp2(0.0, 0.0, 0.0, 0.0);
half4 in41 = split_channels_in_is_arr ? in_arr.read(gid.xy, gid.z) : in_tex.read(gid.xy);
half4 in42 = split_channels_in_is_arr ? in_arr.read(gid.xy, gid.z+1) : half4(0,0,0,0);
if(gid.z < s1 - 1) {
if(split_channels_out1_is_arr) {
out1_arr.write(in41, gid.xy, gid.z);
}
}
else if(gid.z == s1 - 1) {
if(c_offset == 0){
if(split_channels_out1_is_arr) {
out1_arr.write(in41, gid.xy, gid.z);
} else {
out1_tex.write(in41, gid.xy);
}
return;
} else if(c_offset == 1) {
tmp1.x = in41.x;
tmp2.xyz = in41.yzw;
tmp2.w = in42.x;
} else if (c_offset == 2) {
tmp1.xy = in41.xy;
tmp2.xy = in41.zw;
tmp2.zw = in42.xy;
} else {
tmp1.xyz = in41.xyz;
tmp2.x = in41.w;
tmp2.yzw = in42.xyz;
}
if(split_channels_out1_is_arr) {
out1_arr.write(tmp1, gid.xy, gid.z);
} else {
out1_tex.write(tmp1, gid.xy);
}
if(split_channels_out2_is_arr) {
out2_arr.write(tmp2, gid.xy, 0);
} else {
out2_tex.write(tmp2, gid.xy);
}
}
else {
if (c_offset == 0) {
if(split_channels_out2_is_arr) {
out2_arr.write(in41, gid.xy, gid.z - s1);
} else {
out2_tex.write(in41, gid.xy);
}
return;
}
else if (c_offset == 1 ){
tmp2.xyz = in41.yzw;
tmp2.w = in42.x;
} else if (c_offset == 2){
tmp2.xy = in41.zw;
tmp2.zw = in42.xy;
} else {
tmp2.x = in41.w;
tmp2.yzw = in42.xyz;
}
if(split_channels_out2_is_arr) {
out2_arr.write(tmp2, gid.xy, gid.z - s1 + 1);
} else {
out2_tex.write(tmp2, gid.xy);
}
}
}
)PT_METAL_SHADERS";
#endif /* MPSCNNShaders_h */

View File

@ -49,5 +49,8 @@ bool test_reshape();
bool test_mean_dim();
bool test_mean_dim2();
bool test_mean_dim3();
bool test_chunk();
bool test_chunk2();
bool test_chunk3();
#endif

View File

@ -842,3 +842,55 @@ bool test_mean_dim3() {
return almostEqual(Y1, Y2);
});
}
bool test_chunk() {
__block std::vector<int64_t> size{1, 4, 2, 2};
return TEST(size, __PRETTY_FUNCTION__, ^bool {
auto X1 = at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat));
auto Y1 = at::chunk(X1, 2, 1);
auto X2 = X1.metal();
auto Y2 = at::chunk(X2, 2, 1);
auto A1 = Y1[0].contiguous();
auto A2 = Y1[1].contiguous();
auto Z1 = Y2[0].cpu();
auto Z2 = Y2[1].cpu();
bool b1 = checkRtol(A1 - Z1, {A1, Z1});
bool b2 = checkRtol(A2 - Z2, {A2, Z2});
return b1 && b2;
});
}
bool test_chunk2() {
__block std::vector<int64_t> size{1, 9, 2, 2};
return TEST(size, __PRETTY_FUNCTION__, ^bool {
auto X1 = at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat));
auto Y1 = at::chunk(X1, 2, 1);
auto X2 = X1.metal();
auto Y2 = at::chunk(X2, 2, 1);
auto A1 = Y1[0].contiguous();
auto A2 = Y1[1].contiguous();
auto Z1 = Y2[0].cpu();
auto Z2 = Y2[1].cpu();
bool b1 = checkRtol(A1 - Z1, {A1, Z1});
bool b2 = checkRtol(A2 - Z2, {A2, Z2});
return b1 && b2;
});
}
bool test_chunk3() {
__block std::vector<int64_t> size{1, 16, 2, 2};
return TEST(size, __PRETTY_FUNCTION__, ^bool {
auto X1 = at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat));
auto Y1 = at::chunk(X1, 2, 1);
auto X2 = X1.metal();
auto Y2 = at::chunk(X2, 2, 1);
auto A1 = Y1[0].contiguous();
auto A2 = Y1[1].contiguous();
auto Z1 = Y2[0].cpu();
auto Z2 = Y2[1].cpu();
bool b1 = checkRtol(A1 - Z1, {A1, Z1});
bool b2 = checkRtol(A2 - Z2, {A2, Z2});
return b1 && b2;
});
}

View File

@ -0,0 +1,72 @@
#include <ATen/Tensor.h>
#import <ATen/native/metal/MetalCommandBuffer.h>
#import <ATen/native/metal/MetalTensorImpl.h>
#import <ATen/native/metal/MetalTensorImplStorage.h>
#import <ATen/native/metal/MetalUtils.h>
#import <ATen/native/metal/mpscnn/MPSCNNContext.h>
#import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
#import <ATen/native/metal/mpscnn/MPSImageUtils.h>
#include <torch/library.h>
namespace at {
namespace native {
namespace metal {
// Split the input tensor into two on channel dimension
// TODO: [T87567124] Fully implement chunk in Metal shader
std::vector<Tensor> chunk(const Tensor& input, int64_t chunks, int64_t dim) {
TORCH_CHECK(chunks == 2 && dim == 1);
TORCH_CHECK(input.dim() == 4);
TORCH_CHECK(input.size(0) == 1);
int64_t dim_size = input.size(dim);
int64_t split_size = (dim_size + chunks - 1) / chunks;
int64_t num_splits = 1;
if (split_size != 0) {
num_splits = std::max<int64_t>((dim_size + split_size - 1) / split_size, 1);
}
std::vector<Tensor> splits(num_splits);
int64_t last_split_size = split_size - (split_size * num_splits - dim_size);
MPSImage* X = imageFromTensor(input);
MetalCommandBuffer* commandBuffer = getCommandBufferFromTensor(input);
auto outputSize1 = {input.size(0), split_size, input.size(2), input.size(3)};
auto outputSize2 = {input.size(0), last_split_size, input.size(2), input.size(3)};
MetalTensorImplStorage mt1(outputSize1);
MetalTensorImplStorage mt2(outputSize2);
mt1.texture()->allocateTemporaryTextureStorage(outputSize1, commandBuffer);
mt2.texture()->allocateTemporaryTextureStorage(outputSize2, commandBuffer);
MPSImage* Y1 = mt1.texture()->image();
MPSImage* Y2 = mt2.texture()->image();
NSString* kernelFunc = @"split_channels";
id<MTLComputePipelineState> state = [[MPSCNNContext sharedInstance]
specializedPipelineState:kernelFunc
Constants:@[
@(X.featureChannels),
@(Y1.featureChannels),
@(Y2.featureChannels)]];
id<MTLComputeCommandEncoder> encoder =
[commandBuffer.buffer computeCommandEncoder];
[encoder setComputePipelineState:state];
[encoder setTexture:[X texture] atIndex:0];
[encoder setTexture:[Y1 texture] atIndex:1];
[encoder setTexture:[Y2 texture] atIndex:2];
const auto& launchParams =
mpscnn::spatialPointwiseKernelLaunchParams(state, X);
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
[encoder endEncoding];
[X markRead];
[Y1 markRead];
[Y2 markRead];
auto output1 = makeTensor(std::move(mt1), input.options());
auto output2 = makeTensor(std::move(mt2), input.options());
return {output1, output2};
}
TORCH_LIBRARY_IMPL(aten, Metal, m) {
m.impl("chunk", TORCH_FN(chunk));
};
}
}
}