mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add hardshrink op to metal backend (#82224)
Summary: Add hardshrink op's metal implementation to pytorch codebase Heavily based on already existing Hardswish implementation Test Plan: Add two unit tests to compare the result of metal-based hardshrink function to that of the CPU implementation. Both tests pass: {F755783049} Differential Revision: D38152454 Pull Request resolved: https://github.com/pytorch/pytorch/pull/82224 Approved by: https://github.com/SS-JIA
This commit is contained in:
committed by
PyTorch MergeBot
parent
3a4343afc5
commit
1231194dd3
@ -421,6 +421,35 @@ kernel void hardswish(texture2d_array<half, access::read> in_arr[[texture(0), fu
|
||||
}
|
||||
}
|
||||
|
||||
constant bool hardshrink_is_arr = (ushort_arg_0 > 1 || ushort_arg_1 > 4);
|
||||
constant bool hardshrink_is_tex = !hardshrink_is_arr;
|
||||
kernel void hardshrink(texture2d_array<half, access::read> in_arr[[texture(0), function_constant(hardshrink_is_arr)]],
|
||||
texture2d<half, access::read> in_tex[[texture(0), function_constant(hardshrink_is_tex)]],
|
||||
texture2d_array<half, access::write> out_arr[[texture(1), function_constant(hardshrink_is_arr)]],
|
||||
texture2d<half, access::write> out_tex[[texture(1), function_constant(hardshrink_is_tex)]],
|
||||
ushort3 gid[[thread_position_in_grid]]) {
|
||||
const ushort oH = ushort_arg_2;
|
||||
const ushort oW = ushort_arg_3;
|
||||
const half lambda = (half)float_arg_0;
|
||||
if (gid.x >= oW || gid.y >= oH) {
|
||||
return;
|
||||
}
|
||||
ushort2 gid_ = gid.xy;
|
||||
if (hardshrink_is_arr) {
|
||||
half4 value = in_arr.read(gid_, gid.z);
|
||||
half4 mask1 = half4(value <= lambda);
|
||||
half4 mask2 = half4(value >= -lambda);
|
||||
half4 outval = (1 - mask1)*value + (1 - mask2)*value;
|
||||
out_arr.write(outval, gid_, gid.z);
|
||||
} else {
|
||||
half4 value = in_tex.read(gid_);
|
||||
half4 mask1 = half4(value <= lambda);
|
||||
half4 mask2 = half4(value >= -lambda);
|
||||
half4 outval = (1 - mask1)*value + (1 - mask2)*value;
|
||||
out_tex.write(outval, gid_);
|
||||
}
|
||||
}
|
||||
|
||||
constant bool leaky_relu_is_arr = (ushort_arg_0 > 1 || ushort_arg_1 > 4);
|
||||
constant bool leaky_relu_is_tex = !leaky_relu_is_arr;
|
||||
kernel void leaky_relu(texture2d_array<half, access::read> in_arr[[texture(0), function_constant(leaky_relu_is_arr)]],
|
||||
|
@ -42,6 +42,8 @@ bool test_sigmoid();
|
||||
bool test_hardsigmoid();
|
||||
bool test_hardswish_();
|
||||
bool test_hardswish();
|
||||
bool test_hardshrink_();
|
||||
bool test_hardshrink();
|
||||
bool test_leaky_relu_();
|
||||
bool test_leaky_relu();
|
||||
bool test_upsampling_nearest2d_vec();
|
||||
|
@ -25,6 +25,44 @@ bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs) {
|
||||
}
|
||||
return diff.abs().max().item<float>() < (0.01 + 2e-2 * maxValue);
|
||||
}
|
||||
|
||||
bool checkHardShrink(const at::Tensor& ref, const at::Tensor& out, const float clamp_thresh) {
|
||||
float* ref_ptr = ref.data_ptr<float>();
|
||||
float* out_ptr = out.data_ptr<float>();
|
||||
float ref_max = ref.abs().max().item<float>();
|
||||
float out_max = out.abs().max().item<float>();
|
||||
float max_val = std::fmax(ref_max, out_max);
|
||||
float kTolerance = 1e-2;
|
||||
|
||||
float abs_clamp_thresh = std::abs(clamp_thresh);
|
||||
|
||||
for (int i = 0; i < ref.numel(); ++i) {
|
||||
float ref_val = ref_ptr[i];
|
||||
float out_val = out_ptr[i];
|
||||
|
||||
float abs_diff = std::abs(ref_val - out_val);
|
||||
|
||||
// For values near the clamp threshold, results may be ambiguous.
|
||||
float distance_from_thresh = std::abs(std::abs(ref_val) - abs_clamp_thresh);
|
||||
if (distance_from_thresh < kTolerance * abs_clamp_thresh) {
|
||||
if (out_val != 0.0f) {
|
||||
if (abs_diff >= kTolerance * max_val) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (std::abs(ref_val) < std::abs(abs_clamp_thresh)) {
|
||||
if (out_val != 0.0f) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if (abs_diff >= kTolerance * max_val) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool almostEqual(const at::Tensor& a, const at::Tensor& b) {
|
||||
return checkRtol(a - b, {a, b}) && a.strides().vec() == b.strides().vec();
|
||||
}
|
||||
@ -274,6 +312,44 @@ bool test_hardswish() {
|
||||
});
|
||||
}
|
||||
|
||||
bool test_hardshrink_() {
|
||||
__block std::vector<int64_t> size{3, 3, 44, 44};
|
||||
bool result = true;
|
||||
for (const auto lambd_value : {0.42, 1.0, 4.2, 13.7}) {
|
||||
bool b = TEST(size, __PRETTY_FUNCTION__, ^bool {
|
||||
auto X =
|
||||
(at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat)) - 0.5) * 20;
|
||||
auto X2 = X.metal();
|
||||
auto Y1 = X.hardshrink(lambd_value);
|
||||
auto Y2 = X2.hardshrink(lambd_value).cpu();
|
||||
return checkHardShrink(Y1, Y2, lambd_value);
|
||||
});
|
||||
if (!b) {
|
||||
result = false;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool test_hardshrink() {
|
||||
__block std::vector<int64_t> size{3, 3, 44, 44};
|
||||
bool result = true;
|
||||
for (const auto lambd_value : {0.42, 1.0, 4.2, 13.7}) {
|
||||
bool b = TEST(size, __PRETTY_FUNCTION__, ^bool {
|
||||
auto X =
|
||||
(at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat)) - 0.5) * 20;
|
||||
auto X2 = X.metal();
|
||||
auto Y1 = at::hardshrink(X, lambd_value);
|
||||
auto Y2 = at::hardshrink(X2, lambd_value).cpu();
|
||||
return checkHardShrink(Y1, Y2, lambd_value);
|
||||
});
|
||||
if (!b) {
|
||||
result = false;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool test_leaky_relu_() {
|
||||
__block std::vector<int64_t> size{3, 3, 44, 44};
|
||||
return TEST(size, __PRETTY_FUNCTION__, ^bool {
|
||||
|
@ -70,6 +70,8 @@
|
||||
REG_TEST("test_hardsigmoid", test_hardsigmoid);
|
||||
REG_TEST("test_hardswish_", test_hardswish_);
|
||||
REG_TEST("test_hardswish", test_hardswish);
|
||||
REG_TEST("test_hardshrink_", test_hardshrink_);
|
||||
REG_TEST("test_hardshrink", test_hardshrink);
|
||||
REG_TEST("test_leaky_relu_", test_leaky_relu_);
|
||||
REG_TEST("test_leaky_relu", test_leaky_relu);
|
||||
REG_TEST("test_upsampling_nearest2d_vec", test_upsampling_nearest2d_vec);
|
||||
|
93
aten/src/ATen/native/metal/ops/MetalHardshrink.mm
Normal file
93
aten/src/ATen/native/metal/ops/MetalHardshrink.mm
Normal file
@ -0,0 +1,93 @@
|
||||
#include <ATen/Tensor.h>
|
||||
#import <ATen/native/metal/MetalCommandBuffer.h>
|
||||
#import <ATen/native/metal/MetalContext.h>
|
||||
#import <ATen/native/metal/MetalTensorImpl.h>
|
||||
#import <ATen/native/metal/MetalTensorImplStorage.h>
|
||||
#import <ATen/native/metal/MetalTensorUtils.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSImageUtils.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace metal {
|
||||
|
||||
using MetalTensorImpl = at::MetalTensorImpl<MetalTensorImplStorage>;
|
||||
|
||||
Tensor& hardshrink_(Tensor& input, const at::Scalar& lambda=0.5) {
|
||||
float l = lambda.toFloat();
|
||||
MPSImage* X = imageFromTensor(input);
|
||||
MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
|
||||
IntArrayRef outputSize = input.sizes();
|
||||
std::vector<int64_t> imageSize = computeImageSize(outputSize);
|
||||
MPSImage* Y = createTemporaryImage(commandBuffer, imageSize);
|
||||
id<MTLComputeCommandEncoder> encoder =
|
||||
[commandBuffer.buffer computeCommandEncoder];
|
||||
id<MTLComputePipelineState> state =
|
||||
[[MetalContext sharedInstance] specializedPipelineState:"hardshrink"
|
||||
Constants:@[
|
||||
@(X.numberOfImages),
|
||||
@(X.featureChannels),
|
||||
@(X.height),
|
||||
@(X.width),
|
||||
@(l)
|
||||
]];
|
||||
|
||||
[encoder setComputePipelineState:state];
|
||||
[encoder setTexture:[X texture] atIndex:0];
|
||||
[encoder setTexture:[Y texture] atIndex:1];
|
||||
|
||||
const auto& launchParams =
|
||||
metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X);
|
||||
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
|
||||
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
|
||||
[encoder endEncoding];
|
||||
MetalTensorImpl* impl = (MetalTensorImpl*)input.unsafeGetTensorImpl();
|
||||
MetalTensorImplStorage& implStorage = impl->unsafe_opaque_handle();
|
||||
implStorage.texture()->setImage(Y);
|
||||
return input;
|
||||
}
|
||||
|
||||
Tensor hardshrink(const at::Tensor& input, const at::Scalar& lambda=0.5) {
|
||||
float l = lambda.toFloat();
|
||||
MPSImage* X = imageFromTensor(input);
|
||||
IntArrayRef outputSize = input.sizes();
|
||||
MetalTensorImplStorage mt{outputSize.vec()};
|
||||
MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
|
||||
mt.texture()->allocateTemporaryStorage(outputSize, commandBuffer);
|
||||
MPSImage* Y = mt.texture()->image();
|
||||
id<MTLComputeCommandEncoder> encoder =
|
||||
[commandBuffer.buffer computeCommandEncoder];
|
||||
id<MTLComputePipelineState> state =
|
||||
[[MetalContext sharedInstance] specializedPipelineState:"hardshrink"
|
||||
Constants:@[
|
||||
@(X.numberOfImages),
|
||||
@(X.featureChannels),
|
||||
@(X.height),
|
||||
@(X.width),
|
||||
@(l)
|
||||
]];
|
||||
|
||||
[encoder setComputePipelineState:state];
|
||||
[encoder setTexture:[X texture] atIndex:0];
|
||||
[encoder setTexture:[Y texture] atIndex:1];
|
||||
|
||||
const auto& launchParams =
|
||||
metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X);
|
||||
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
|
||||
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
|
||||
[encoder endEncoding];
|
||||
|
||||
auto output = makeTensor(std::move(mt), input.options());
|
||||
return output;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, Metal, m) {
|
||||
m.impl(TORCH_SELECTIVE_NAME("aten::hardshrink_"), TORCH_FN(hardshrink_));
|
||||
m.impl(TORCH_SELECTIVE_NAME("aten::hardshrink"), TORCH_FN(hardshrink));
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@ -59,6 +59,7 @@ METAL_SOURCE_LIST = [
|
||||
"aten/src/ATen/native/metal/ops/MetalConvolution.mm",
|
||||
"aten/src/ATen/native/metal/ops/MetalCopy.mm",
|
||||
"aten/src/ATen/native/metal/ops/MetalHardswish.mm",
|
||||
"aten/src/ATen/native/metal/ops/MetalHardshrink.mm",
|
||||
"aten/src/ATen/native/metal/ops/MetalLeakyReLU.mm",
|
||||
"aten/src/ATen/native/metal/ops/MetalNeurons.mm",
|
||||
"aten/src/ATen/native/metal/ops/MetalPadding.mm",
|
||||
|
Reference in New Issue
Block a user