Add hardshrink op to metal backend (#82224)

Summary:
Add hardshrink op's metal implementation to pytorch codebase

Heavily based on already existing Hardswish implementation

Test Plan:
Add two unit tests to compare the result of metal-based hardshrink function to that of the CPU implementation. Both tests pass:

{F755783049}

Differential Revision: D38152454

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82224
Approved by: https://github.com/SS-JIA
This commit is contained in:
David Tsenter
2022-07-27 19:48:15 +00:00
committed by PyTorch MergeBot
parent 3a4343afc5
commit 1231194dd3
6 changed files with 203 additions and 0 deletions

View File

@ -421,6 +421,35 @@ kernel void hardswish(texture2d_array<half, access::read> in_arr[[texture(0), fu
}
}
constant bool hardshrink_is_arr = (ushort_arg_0 > 1 || ushort_arg_1 > 4);
constant bool hardshrink_is_tex = !hardshrink_is_arr;
kernel void hardshrink(texture2d_array<half, access::read> in_arr[[texture(0), function_constant(hardshrink_is_arr)]],
texture2d<half, access::read> in_tex[[texture(0), function_constant(hardshrink_is_tex)]],
texture2d_array<half, access::write> out_arr[[texture(1), function_constant(hardshrink_is_arr)]],
texture2d<half, access::write> out_tex[[texture(1), function_constant(hardshrink_is_tex)]],
ushort3 gid[[thread_position_in_grid]]) {
const ushort oH = ushort_arg_2;
const ushort oW = ushort_arg_3;
const half lambda = (half)float_arg_0;
if (gid.x >= oW || gid.y >= oH) {
return;
}
ushort2 gid_ = gid.xy;
if (hardshrink_is_arr) {
half4 value = in_arr.read(gid_, gid.z);
half4 mask1 = half4(value <= lambda);
half4 mask2 = half4(value >= -lambda);
half4 outval = (1 - mask1)*value + (1 - mask2)*value;
out_arr.write(outval, gid_, gid.z);
} else {
half4 value = in_tex.read(gid_);
half4 mask1 = half4(value <= lambda);
half4 mask2 = half4(value >= -lambda);
half4 outval = (1 - mask1)*value + (1 - mask2)*value;
out_tex.write(outval, gid_);
}
}
constant bool leaky_relu_is_arr = (ushort_arg_0 > 1 || ushort_arg_1 > 4);
constant bool leaky_relu_is_tex = !leaky_relu_is_arr;
kernel void leaky_relu(texture2d_array<half, access::read> in_arr[[texture(0), function_constant(leaky_relu_is_arr)]],

View File

@ -42,6 +42,8 @@ bool test_sigmoid();
bool test_hardsigmoid();
bool test_hardswish_();
bool test_hardswish();
bool test_hardshrink_();
bool test_hardshrink();
bool test_leaky_relu_();
bool test_leaky_relu();
bool test_upsampling_nearest2d_vec();

View File

@ -25,6 +25,44 @@ bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs) {
}
return diff.abs().max().item<float>() < (0.01 + 2e-2 * maxValue);
}
bool checkHardShrink(const at::Tensor& ref, const at::Tensor& out, const float clamp_thresh) {
float* ref_ptr = ref.data_ptr<float>();
float* out_ptr = out.data_ptr<float>();
float ref_max = ref.abs().max().item<float>();
float out_max = out.abs().max().item<float>();
float max_val = std::fmax(ref_max, out_max);
float kTolerance = 1e-2;
float abs_clamp_thresh = std::abs(clamp_thresh);
for (int i = 0; i < ref.numel(); ++i) {
float ref_val = ref_ptr[i];
float out_val = out_ptr[i];
float abs_diff = std::abs(ref_val - out_val);
// For values near the clamp threshold, results may be ambiguous.
float distance_from_thresh = std::abs(std::abs(ref_val) - abs_clamp_thresh);
if (distance_from_thresh < kTolerance * abs_clamp_thresh) {
if (out_val != 0.0f) {
if (abs_diff >= kTolerance * max_val) {
return false;
}
}
}
else if (std::abs(ref_val) < std::abs(abs_clamp_thresh)) {
if (out_val != 0.0f) {
return false;
}
}
else if (abs_diff >= kTolerance * max_val) {
return false;
}
}
return true;
}
bool almostEqual(const at::Tensor& a, const at::Tensor& b) {
return checkRtol(a - b, {a, b}) && a.strides().vec() == b.strides().vec();
}
@ -274,6 +312,44 @@ bool test_hardswish() {
});
}
bool test_hardshrink_() {
__block std::vector<int64_t> size{3, 3, 44, 44};
bool result = true;
for (const auto lambd_value : {0.42, 1.0, 4.2, 13.7}) {
bool b = TEST(size, __PRETTY_FUNCTION__, ^bool {
auto X =
(at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat)) - 0.5) * 20;
auto X2 = X.metal();
auto Y1 = X.hardshrink(lambd_value);
auto Y2 = X2.hardshrink(lambd_value).cpu();
return checkHardShrink(Y1, Y2, lambd_value);
});
if (!b) {
result = false;
}
}
return result;
}
bool test_hardshrink() {
__block std::vector<int64_t> size{3, 3, 44, 44};
bool result = true;
for (const auto lambd_value : {0.42, 1.0, 4.2, 13.7}) {
bool b = TEST(size, __PRETTY_FUNCTION__, ^bool {
auto X =
(at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat)) - 0.5) * 20;
auto X2 = X.metal();
auto Y1 = at::hardshrink(X, lambd_value);
auto Y2 = at::hardshrink(X2, lambd_value).cpu();
return checkHardShrink(Y1, Y2, lambd_value);
});
if (!b) {
result = false;
}
}
return result;
}
bool test_leaky_relu_() {
__block std::vector<int64_t> size{3, 3, 44, 44};
return TEST(size, __PRETTY_FUNCTION__, ^bool {

View File

@ -70,6 +70,8 @@
REG_TEST("test_hardsigmoid", test_hardsigmoid);
REG_TEST("test_hardswish_", test_hardswish_);
REG_TEST("test_hardswish", test_hardswish);
REG_TEST("test_hardshrink_", test_hardshrink_);
REG_TEST("test_hardshrink", test_hardshrink);
REG_TEST("test_leaky_relu_", test_leaky_relu_);
REG_TEST("test_leaky_relu", test_leaky_relu);
REG_TEST("test_upsampling_nearest2d_vec", test_upsampling_nearest2d_vec);

View File

@ -0,0 +1,93 @@
#include <ATen/Tensor.h>
#import <ATen/native/metal/MetalCommandBuffer.h>
#import <ATen/native/metal/MetalContext.h>
#import <ATen/native/metal/MetalTensorImpl.h>
#import <ATen/native/metal/MetalTensorImplStorage.h>
#import <ATen/native/metal/MetalTensorUtils.h>
#import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
#import <ATen/native/metal/mpscnn/MPSImageUtils.h>
#include <torch/library.h>
namespace at {
namespace native {
namespace metal {
using MetalTensorImpl = at::MetalTensorImpl<MetalTensorImplStorage>;
Tensor& hardshrink_(Tensor& input, const at::Scalar& lambda=0.5) {
float l = lambda.toFloat();
MPSImage* X = imageFromTensor(input);
MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
IntArrayRef outputSize = input.sizes();
std::vector<int64_t> imageSize = computeImageSize(outputSize);
MPSImage* Y = createTemporaryImage(commandBuffer, imageSize);
id<MTLComputeCommandEncoder> encoder =
[commandBuffer.buffer computeCommandEncoder];
id<MTLComputePipelineState> state =
[[MetalContext sharedInstance] specializedPipelineState:"hardshrink"
Constants:@[
@(X.numberOfImages),
@(X.featureChannels),
@(X.height),
@(X.width),
@(l)
]];
[encoder setComputePipelineState:state];
[encoder setTexture:[X texture] atIndex:0];
[encoder setTexture:[Y texture] atIndex:1];
const auto& launchParams =
metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X);
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
[encoder endEncoding];
MetalTensorImpl* impl = (MetalTensorImpl*)input.unsafeGetTensorImpl();
MetalTensorImplStorage& implStorage = impl->unsafe_opaque_handle();
implStorage.texture()->setImage(Y);
return input;
}
Tensor hardshrink(const at::Tensor& input, const at::Scalar& lambda=0.5) {
float l = lambda.toFloat();
MPSImage* X = imageFromTensor(input);
IntArrayRef outputSize = input.sizes();
MetalTensorImplStorage mt{outputSize.vec()};
MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
mt.texture()->allocateTemporaryStorage(outputSize, commandBuffer);
MPSImage* Y = mt.texture()->image();
id<MTLComputeCommandEncoder> encoder =
[commandBuffer.buffer computeCommandEncoder];
id<MTLComputePipelineState> state =
[[MetalContext sharedInstance] specializedPipelineState:"hardshrink"
Constants:@[
@(X.numberOfImages),
@(X.featureChannels),
@(X.height),
@(X.width),
@(l)
]];
[encoder setComputePipelineState:state];
[encoder setTexture:[X texture] atIndex:0];
[encoder setTexture:[Y texture] atIndex:1];
const auto& launchParams =
metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X);
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
[encoder endEncoding];
auto output = makeTensor(std::move(mt), input.options());
return output;
}
TORCH_LIBRARY_IMPL(aten, Metal, m) {
m.impl(TORCH_SELECTIVE_NAME("aten::hardshrink_"), TORCH_FN(hardshrink_));
m.impl(TORCH_SELECTIVE_NAME("aten::hardshrink"), TORCH_FN(hardshrink));
};
}
}
}

View File

@ -59,6 +59,7 @@ METAL_SOURCE_LIST = [
"aten/src/ATen/native/metal/ops/MetalConvolution.mm",
"aten/src/ATen/native/metal/ops/MetalCopy.mm",
"aten/src/ATen/native/metal/ops/MetalHardswish.mm",
"aten/src/ATen/native/metal/ops/MetalHardshrink.mm",
"aten/src/ATen/native/metal/ops/MetalLeakyReLU.mm",
"aten/src/ATen/native/metal/ops/MetalNeurons.mm",
"aten/src/ATen/native/metal/ops/MetalPadding.mm",