Leaky relu in metal shader (#78544)

Summary:
Heavily referenced how Hardswish was implemented.

This is a great intro task to get a taste of how a torch method is implemented in shader and tested.

Test Plan:
Compared in metal shader metal version and cpu version result in tests.

https://pxl.cl/251kT

Reviewed By: SS-JIA

Differential Revision: D36732187

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78544
Approved by: https://github.com/SS-JIA
This commit is contained in:
Matt Guo
2022-06-02 18:13:51 +00:00
committed by PyTorch MergeBot
parent 849b08f14b
commit 9c8eb2cf1b
5 changed files with 148 additions and 0 deletions

View File

@ -421,6 +421,33 @@ kernel void hardswish(texture2d_array<half, access::read> in_arr[[texture(0), fu
}
}
constant bool leaky_relu_is_arr = (ushort_arg_0 > 1 || ushort_arg_1 > 4);
constant bool leaky_relu_is_tex = !leaky_relu_is_arr;
kernel void leaky_relu(texture2d_array<half, access::read> in_arr[[texture(0), function_constant(leaky_relu_is_arr)]],
texture2d<half, access::read> in_tex[[texture(0), function_constant(leaky_relu_is_tex)]],
texture2d_array<half, access::write> out_arr[[texture(1), function_constant(leaky_relu_is_arr)]],
texture2d<half, access::write> out_tex[[texture(1), function_constant(leaky_relu_is_tex)]],
ushort3 gid[[thread_position_in_grid]]) {
const ushort oH = ushort_arg_2;
const ushort oW = ushort_arg_3;
const half negative_slope = (half)float_arg_0;
if (gid.x >= oW || gid.y >= oH) {
return;
}
ushort2 gid_ = gid.xy;
if (leaky_relu_is_arr) {
half4 value = in_arr.read(gid_, gid.z);
half4 is_negative = half4(value < 0.0);
half4 outval = is_negative*value*negative_slope + (1-is_negative)*value;
out_arr.write(outval, gid_, gid.z);
} else {
half4 value = in_tex.read(gid_);
half4 is_negative = half4(value < 0.0);
half4 outval = is_negative*value*negative_slope + (1-is_negative)*value;
out_tex.write(outval, gid_);
}
}
constant bool out_is_arr = (ushort_arg_3 > 1 || ushort_arg_2 > 4);
constant bool out_is_tex = !out_is_arr;
constant bool in_is_arr = (ushort_arg_7 > 1 || ushort_arg_6 > 4);

View File

@ -42,6 +42,8 @@ bool test_sigmoid();
bool test_hardsigmoid();
bool test_hardswish_();
bool test_hardswish();
bool test_leaky_relu_();
bool test_leaky_relu();
bool test_upsampling_nearest2d_vec();
bool test_upsampling_nearest2d_vec2();
bool test_adaptive_avg_pool2d();

View File

@ -274,6 +274,30 @@ bool test_hardswish() {
});
}
bool test_leaky_relu_() {
__block std::vector<int64_t> size{3, 3, 44, 44};
return TEST(size, __PRETTY_FUNCTION__, ^bool {
auto X =
at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat)) * 12 - 6;
auto X2 = X.metal();
auto Y1 = at::leaky_relu_(X, -0.0125);
auto Y2 = at::leaky_relu_(X2, -0.0125).cpu();
return almostEqual(Y1, Y2);
});
}
bool test_leaky_relu() {
__block std::vector<int64_t> size{1, 3, 44, 44};
return TEST(size, __PRETTY_FUNCTION__, ^bool {
auto X =
at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat)) * 12 - 6;
auto X2 = X.metal();
auto Y1 = at::leaky_relu(X, 0.025);
auto Y2 = at::leaky_relu(X2, 0.025).cpu();
return almostEqual(Y1, Y2);
});
}
bool test_addmm() {
bool result = true;
for (int i = 0; i < ITER_COUNT; ++i) {

View File

@ -70,6 +70,8 @@
REG_TEST("test_hardsigmoid", test_hardsigmoid);
REG_TEST("test_hardswish_", test_hardswish_);
REG_TEST("test_hardswish", test_hardswish);
REG_TEST("test_leaky_relu_", test_leaky_relu_);
REG_TEST("test_leaky_relu", test_leaky_relu);
REG_TEST("test_upsampling_nearest2d_vec", test_upsampling_nearest2d_vec);
REG_TEST("test_upsampling_nearest2d_vec2", test_upsampling_nearest2d_vec2);
REG_TEST("test_adaptive_avg_pool2d", test_adaptive_avg_pool2d);

View File

@ -0,0 +1,93 @@
#include <ATen/Tensor.h>
#import <ATen/native/metal/MetalCommandBuffer.h>
#import <ATen/native/metal/MetalContext.h>
#import <ATen/native/metal/MetalTensorImpl.h>
#import <ATen/native/metal/MetalTensorImplStorage.h>
#import <ATen/native/metal/MetalTensorUtils.h>
#import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
#import <ATen/native/metal/mpscnn/MPSImageUtils.h>
#include <torch/library.h>
namespace at {
namespace native {
namespace metal {
using MetalTensorImpl = at::MetalTensorImpl<MetalTensorImplStorage>;
Tensor& leaky_relu_(Tensor& input, const Scalar& negative_slope_val) {
MPSImage* X = imageFromTensor(input);
MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
IntArrayRef outputSize = input.sizes();
std::vector<int64_t> imageSize = computeImageSize(outputSize);
float negative_slope = negative_slope_val.toFloat();
MPSImage* Y = createTemporaryImage(commandBuffer, imageSize);
id<MTLComputeCommandEncoder> encoder =
[commandBuffer.buffer computeCommandEncoder];
id<MTLComputePipelineState> state =
[[MetalContext sharedInstance] specializedPipelineState:"leaky_relu"
Constants:@[
@(X.numberOfImages),
@(X.featureChannels),
@(X.height),
@(X.width),
@(negative_slope)
]];
[encoder setComputePipelineState:state];
[encoder setTexture:[X texture] atIndex:0];
[encoder setTexture:[Y texture] atIndex:1];
const auto& launchParams =
metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X);
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
[encoder endEncoding];
MetalTensorImpl* impl = (MetalTensorImpl*)input.unsafeGetTensorImpl();
MetalTensorImplStorage& implStorage = impl->unsafe_opaque_handle();
implStorage.texture()->setImage(Y);
return input;
}
Tensor leaky_relu(const at::Tensor& input, const Scalar& negative_slope_val) {
MPSImage* X = imageFromTensor(input);
IntArrayRef outputSize = input.sizes();
MetalTensorImplStorage mt{outputSize.vec()};
MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
mt.texture()->allocateTemporaryStorage(outputSize, commandBuffer);
float negative_slope = negative_slope_val.toFloat();
MPSImage* Y = mt.texture()->image();
id<MTLComputeCommandEncoder> encoder =
[commandBuffer.buffer computeCommandEncoder];
id<MTLComputePipelineState> state =
[[MetalContext sharedInstance] specializedPipelineState:"leaky_relu"
Constants:@[
@(X.numberOfImages),
@(X.featureChannels),
@(X.height),
@(X.width),
@(negative_slope)
]];
[encoder setComputePipelineState:state];
[encoder setTexture:[X texture] atIndex:0];
[encoder setTexture:[Y texture] atIndex:1];
const auto& launchParams =
metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X);
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
[encoder endEncoding];
auto output = makeTensor(std::move(mt), input.options());
return output;
}
TORCH_LIBRARY_IMPL(aten, Metal, m) {
m.impl(TORCH_SELECTIVE_NAME("aten::leaky_relu_"), TORCH_FN(leaky_relu_));
m.impl(TORCH_SELECTIVE_NAME("aten::leaky_relu"), TORCH_FN(leaky_relu));
};
}
}
}