mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
[iOS GPU][Kernel] Implement channel split in Metal shaders (#56074)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56074 To run shufflenet we need to support at::chunk on GPU. The current implementation only splits the tensor into two on channel dimension. We'll come back and fully implement it in Metal shaders. ghstack-source-id: 127522377 Test Plan: ``` 2021-03-26 01:37:07.693411-0700 PyTorchPlayground[2279:235793] [MPSImageWrapper] Found a temporary image: [1, 2, 2, 2] 2021-03-26 01:37:07.693499-0700 PyTorchPlayground[2279:235793] [MPSImageWrapper] Found a temporary image: [1, 2, 2, 2] 2021-03-26 01:37:07.693544-0700 PyTorchPlayground[2279:235793] [MPSImageWrapper] Found a temporary image: [1, 4, 2, 2] 2021-03-26 01:37:07.695415-0700 PyTorchPlayground[2279:235793] [bool test_chunk()],[1 4 2 2 ],[SUCCEED] 2021-03-26 01:37:07.695862-0700 PyTorchPlayground[2279:235793] [MPSImageWrapper] Found a temporary image: [1, 4, 2, 2] 2021-03-26 01:37:07.695927-0700 PyTorchPlayground[2279:235793] [MPSImageWrapper] Found a temporary image: [1, 5, 2, 2] 2021-03-26 01:37:07.695971-0700 PyTorchPlayground[2279:235793] [MPSImageWrapper] Found a temporary image: [1, 9, 2, 2] 2021-03-26 01:37:07.698215-0700 PyTorchPlayground[2279:235793] [bool test_chunk2()],[1 9 2 2 ],[SUCCEED] 2021-03-26 01:37:07.699086-0700 PyTorchPlayground[2279:235793] [MPSImageWrapper] Found a temporary image: [1, 8, 2, 2] 2021-03-26 01:37:07.699154-0700 PyTorchPlayground[2279:235793] [MPSImageWrapper] Found a temporary image: [1, 16, 2, 2] 2021-03-26 01:37:07.699197-0700 PyTorchPlayground[2279:235793] [MPSImageWrapper] Found a temporary image: [1, 8, 2, 2] 2021-03-26 01:37:07.700842-0700 PyTorchPlayground[2279:235793] [bool test_chunk3()],[1 16 2 2 ],[SUCCEED] ``` - Sandcastle - CircleCI Reviewed By: SS-JIA Differential Revision: D27357096 fbshipit-source-id: fd3908ad2c26466e4f714d531790be2f1ae24153
This commit is contained in:
committed by
Facebook GitHub Bot
parent
0df574017d
commit
8134806e23
@ -733,6 +733,101 @@ kernel void transpose(texture2d_array<half, access::read>in_arr[[texture(0),func
|
||||
}
|
||||
}
|
||||
|
||||
constant bool split_channels_in_is_arr = (ushort_arg_0 > 4);
|
||||
constant bool split_channels_in_is_tex = !split_channels_in_is_arr;
|
||||
constant bool split_channels_out1_is_arr = (ushort_arg_1 > 4);
|
||||
constant bool split_channels_out1_is_tex = !split_channels_out1_is_arr;
|
||||
constant bool split_channels_out2_is_arr = (ushort_arg_2 > 4);
|
||||
constant bool split_channels_out2_is_tex = !(split_channels_out2_is_arr);
|
||||
// A naive implementation to split the input texture into two on channel dimension
|
||||
kernel void split_channels(texture2d_array<half, access::read> in_arr[[texture(0), function_constant(split_channels_in_is_arr)]],
|
||||
texture2d<half, access::read> in_tex[[texture(0), function_constant(split_channels_in_is_tex)]],
|
||||
texture2d_array<half, access::write> out1_arr[[texture(1),function_constant(split_channels_out1_is_arr)]],
|
||||
texture2d<half, access::write> out1_tex[[texture(1),function_constant(split_channels_out1_is_tex)]],
|
||||
texture2d_array<half, access::write> out2_arr[[texture(2), function_constant(split_channels_out2_is_arr)]],
|
||||
texture2d<half, access::write> out2_tex[[texture(2),function_constant(split_channels_out2_is_tex)]],
|
||||
ushort3 gid[[thread_position_in_grid]]) {
|
||||
ushort W,H;
|
||||
if(split_channels_in_is_arr) {
|
||||
W = in_arr.get_width();
|
||||
H = in_arr.get_height();
|
||||
} else {
|
||||
W = in_tex.get_width();
|
||||
H = in_tex.get_height();
|
||||
}
|
||||
if(gid.x >= W || gid.y >= H){
|
||||
return;
|
||||
}
|
||||
const ushort C1 = ushort_arg_1;
|
||||
const ushort s1 = divRoundUp(C1, 4);
|
||||
const ushort c_offset = C1 % 4;
|
||||
half4 tmp1(0.0, 0.0, 0.0, 0.0);
|
||||
half4 tmp2(0.0, 0.0, 0.0, 0.0);
|
||||
half4 in41 = split_channels_in_is_arr ? in_arr.read(gid.xy, gid.z) : in_tex.read(gid.xy);
|
||||
half4 in42 = split_channels_in_is_arr ? in_arr.read(gid.xy, gid.z+1) : half4(0,0,0,0);
|
||||
if(gid.z < s1 - 1) {
|
||||
if(split_channels_out1_is_arr) {
|
||||
out1_arr.write(in41, gid.xy, gid.z);
|
||||
}
|
||||
}
|
||||
else if(gid.z == s1 - 1) {
|
||||
if(c_offset == 0){
|
||||
if(split_channels_out1_is_arr) {
|
||||
out1_arr.write(in41, gid.xy, gid.z);
|
||||
} else {
|
||||
out1_tex.write(in41, gid.xy);
|
||||
}
|
||||
return;
|
||||
} else if(c_offset == 1) {
|
||||
tmp1.x = in41.x;
|
||||
tmp2.xyz = in41.yzw;
|
||||
tmp2.w = in42.x;
|
||||
} else if (c_offset == 2) {
|
||||
tmp1.xy = in41.xy;
|
||||
tmp2.xy = in41.zw;
|
||||
tmp2.zw = in42.xy;
|
||||
} else {
|
||||
tmp1.xyz = in41.xyz;
|
||||
tmp2.x = in41.w;
|
||||
tmp2.yzw = in42.xyz;
|
||||
}
|
||||
if(split_channels_out1_is_arr) {
|
||||
out1_arr.write(tmp1, gid.xy, gid.z);
|
||||
} else {
|
||||
out1_tex.write(tmp1, gid.xy);
|
||||
}
|
||||
if(split_channels_out2_is_arr) {
|
||||
out2_arr.write(tmp2, gid.xy, 0);
|
||||
} else {
|
||||
out2_tex.write(tmp2, gid.xy);
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (c_offset == 0) {
|
||||
if(split_channels_out2_is_arr) {
|
||||
out2_arr.write(in41, gid.xy, gid.z - s1);
|
||||
} else {
|
||||
out2_tex.write(in41, gid.xy);
|
||||
}
|
||||
return;
|
||||
}
|
||||
else if (c_offset == 1 ){
|
||||
tmp2.xyz = in41.yzw;
|
||||
tmp2.w = in42.x;
|
||||
} else if (c_offset == 2){
|
||||
tmp2.xy = in41.zw;
|
||||
tmp2.zw = in42.xy;
|
||||
} else {
|
||||
tmp2.x = in41.w;
|
||||
tmp2.yzw = in42.xyz;
|
||||
}
|
||||
if(split_channels_out2_is_arr) {
|
||||
out2_arr.write(tmp2, gid.xy, gid.z - s1 + 1);
|
||||
} else {
|
||||
out2_tex.write(tmp2, gid.xy);
|
||||
}
|
||||
}
|
||||
}
|
||||
)PT_METAL_SHADERS";
|
||||
|
||||
#endif /* MPSCNNShaders_h */
|
||||
|
@ -49,5 +49,8 @@ bool test_reshape();
|
||||
bool test_mean_dim();
|
||||
bool test_mean_dim2();
|
||||
bool test_mean_dim3();
|
||||
bool test_chunk();
|
||||
bool test_chunk2();
|
||||
bool test_chunk3();
|
||||
|
||||
#endif
|
||||
|
@ -842,3 +842,55 @@ bool test_mean_dim3() {
|
||||
return almostEqual(Y1, Y2);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
bool test_chunk() {
|
||||
__block std::vector<int64_t> size{1, 4, 2, 2};
|
||||
return TEST(size, __PRETTY_FUNCTION__, ^bool {
|
||||
auto X1 = at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto Y1 = at::chunk(X1, 2, 1);
|
||||
auto X2 = X1.metal();
|
||||
auto Y2 = at::chunk(X2, 2, 1);
|
||||
auto A1 = Y1[0].contiguous();
|
||||
auto A2 = Y1[1].contiguous();
|
||||
auto Z1 = Y2[0].cpu();
|
||||
auto Z2 = Y2[1].cpu();
|
||||
bool b1 = checkRtol(A1 - Z1, {A1, Z1});
|
||||
bool b2 = checkRtol(A2 - Z2, {A2, Z2});
|
||||
return b1 && b2;
|
||||
});
|
||||
}
|
||||
|
||||
bool test_chunk2() {
|
||||
__block std::vector<int64_t> size{1, 9, 2, 2};
|
||||
return TEST(size, __PRETTY_FUNCTION__, ^bool {
|
||||
auto X1 = at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto Y1 = at::chunk(X1, 2, 1);
|
||||
auto X2 = X1.metal();
|
||||
auto Y2 = at::chunk(X2, 2, 1);
|
||||
auto A1 = Y1[0].contiguous();
|
||||
auto A2 = Y1[1].contiguous();
|
||||
auto Z1 = Y2[0].cpu();
|
||||
auto Z2 = Y2[1].cpu();
|
||||
bool b1 = checkRtol(A1 - Z1, {A1, Z1});
|
||||
bool b2 = checkRtol(A2 - Z2, {A2, Z2});
|
||||
return b1 && b2;
|
||||
});
|
||||
}
|
||||
|
||||
bool test_chunk3() {
|
||||
__block std::vector<int64_t> size{1, 16, 2, 2};
|
||||
return TEST(size, __PRETTY_FUNCTION__, ^bool {
|
||||
auto X1 = at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto Y1 = at::chunk(X1, 2, 1);
|
||||
auto X2 = X1.metal();
|
||||
auto Y2 = at::chunk(X2, 2, 1);
|
||||
auto A1 = Y1[0].contiguous();
|
||||
auto A2 = Y1[1].contiguous();
|
||||
auto Z1 = Y2[0].cpu();
|
||||
auto Z2 = Y2[1].cpu();
|
||||
bool b1 = checkRtol(A1 - Z1, {A1, Z1});
|
||||
bool b2 = checkRtol(A2 - Z2, {A2, Z2});
|
||||
return b1 && b2;
|
||||
});
|
||||
}
|
||||
|
72
aten/src/ATen/native/metal/ops/MetalChunk.mm
Normal file
72
aten/src/ATen/native/metal/ops/MetalChunk.mm
Normal file
@ -0,0 +1,72 @@
|
||||
#include <ATen/Tensor.h>
|
||||
#import <ATen/native/metal/MetalCommandBuffer.h>
|
||||
#import <ATen/native/metal/MetalTensorImpl.h>
|
||||
#import <ATen/native/metal/MetalTensorImplStorage.h>
|
||||
#import <ATen/native/metal/MetalUtils.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSCNNContext.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSImageUtils.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace metal {
|
||||
|
||||
// Split the input tensor into two on channel dimension
|
||||
// TODO: [T87567124] Fully implement chunk in Metal shader
|
||||
std::vector<Tensor> chunk(const Tensor& input, int64_t chunks, int64_t dim) {
|
||||
TORCH_CHECK(chunks == 2 && dim == 1);
|
||||
TORCH_CHECK(input.dim() == 4);
|
||||
TORCH_CHECK(input.size(0) == 1);
|
||||
int64_t dim_size = input.size(dim);
|
||||
int64_t split_size = (dim_size + chunks - 1) / chunks;
|
||||
int64_t num_splits = 1;
|
||||
if (split_size != 0) {
|
||||
num_splits = std::max<int64_t>((dim_size + split_size - 1) / split_size, 1);
|
||||
}
|
||||
std::vector<Tensor> splits(num_splits);
|
||||
int64_t last_split_size = split_size - (split_size * num_splits - dim_size);
|
||||
MPSImage* X = imageFromTensor(input);
|
||||
MetalCommandBuffer* commandBuffer = getCommandBufferFromTensor(input);
|
||||
auto outputSize1 = {input.size(0), split_size, input.size(2), input.size(3)};
|
||||
auto outputSize2 = {input.size(0), last_split_size, input.size(2), input.size(3)};
|
||||
MetalTensorImplStorage mt1(outputSize1);
|
||||
MetalTensorImplStorage mt2(outputSize2);
|
||||
mt1.texture()->allocateTemporaryTextureStorage(outputSize1, commandBuffer);
|
||||
mt2.texture()->allocateTemporaryTextureStorage(outputSize2, commandBuffer);
|
||||
MPSImage* Y1 = mt1.texture()->image();
|
||||
MPSImage* Y2 = mt2.texture()->image();
|
||||
NSString* kernelFunc = @"split_channels";
|
||||
id<MTLComputePipelineState> state = [[MPSCNNContext sharedInstance]
|
||||
specializedPipelineState:kernelFunc
|
||||
Constants:@[
|
||||
@(X.featureChannels),
|
||||
@(Y1.featureChannels),
|
||||
@(Y2.featureChannels)]];
|
||||
id<MTLComputeCommandEncoder> encoder =
|
||||
[commandBuffer.buffer computeCommandEncoder];
|
||||
[encoder setComputePipelineState:state];
|
||||
[encoder setTexture:[X texture] atIndex:0];
|
||||
[encoder setTexture:[Y1 texture] atIndex:1];
|
||||
[encoder setTexture:[Y2 texture] atIndex:2];
|
||||
const auto& launchParams =
|
||||
mpscnn::spatialPointwiseKernelLaunchParams(state, X);
|
||||
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
|
||||
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
|
||||
[encoder endEncoding];
|
||||
[X markRead];
|
||||
[Y1 markRead];
|
||||
[Y2 markRead];
|
||||
auto output1 = makeTensor(std::move(mt1), input.options());
|
||||
auto output2 = makeTensor(std::move(mt2), input.options());
|
||||
return {output1, output2};
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, Metal, m) {
|
||||
m.impl("chunk", TORCH_FN(chunk));
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user