mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Leaky relu in metal shader (#78544)
Summary: Heavily referenced how Hardswish was implemented. This is a great intro task to get a taste of how a torch method is implemented in shader and tested. Test Plan: Compared in metal shader metal version and cpu version result in tests. https://pxl.cl/251kT Reviewed By: SS-JIA Differential Revision: D36732187 Pull Request resolved: https://github.com/pytorch/pytorch/pull/78544 Approved by: https://github.com/SS-JIA
This commit is contained in:
committed by
PyTorch MergeBot
parent
849b08f14b
commit
9c8eb2cf1b
@ -421,6 +421,33 @@ kernel void hardswish(texture2d_array<half, access::read> in_arr[[texture(0), fu
|
||||
}
|
||||
}
|
||||
|
||||
constant bool leaky_relu_is_arr = (ushort_arg_0 > 1 || ushort_arg_1 > 4);
|
||||
constant bool leaky_relu_is_tex = !leaky_relu_is_arr;
|
||||
kernel void leaky_relu(texture2d_array<half, access::read> in_arr[[texture(0), function_constant(leaky_relu_is_arr)]],
|
||||
texture2d<half, access::read> in_tex[[texture(0), function_constant(leaky_relu_is_tex)]],
|
||||
texture2d_array<half, access::write> out_arr[[texture(1), function_constant(leaky_relu_is_arr)]],
|
||||
texture2d<half, access::write> out_tex[[texture(1), function_constant(leaky_relu_is_tex)]],
|
||||
ushort3 gid[[thread_position_in_grid]]) {
|
||||
const ushort oH = ushort_arg_2;
|
||||
const ushort oW = ushort_arg_3;
|
||||
const half negative_slope = (half)float_arg_0;
|
||||
if (gid.x >= oW || gid.y >= oH) {
|
||||
return;
|
||||
}
|
||||
ushort2 gid_ = gid.xy;
|
||||
if (leaky_relu_is_arr) {
|
||||
half4 value = in_arr.read(gid_, gid.z);
|
||||
half4 is_negative = half4(value < 0.0);
|
||||
half4 outval = is_negative*value*negative_slope + (1-is_negative)*value;
|
||||
out_arr.write(outval, gid_, gid.z);
|
||||
} else {
|
||||
half4 value = in_tex.read(gid_);
|
||||
half4 is_negative = half4(value < 0.0);
|
||||
half4 outval = is_negative*value*negative_slope + (1-is_negative)*value;
|
||||
out_tex.write(outval, gid_);
|
||||
}
|
||||
}
|
||||
|
||||
constant bool out_is_arr = (ushort_arg_3 > 1 || ushort_arg_2 > 4);
|
||||
constant bool out_is_tex = !out_is_arr;
|
||||
constant bool in_is_arr = (ushort_arg_7 > 1 || ushort_arg_6 > 4);
|
||||
|
@ -42,6 +42,8 @@ bool test_sigmoid();
|
||||
bool test_hardsigmoid();
|
||||
bool test_hardswish_();
|
||||
bool test_hardswish();
|
||||
bool test_leaky_relu_();
|
||||
bool test_leaky_relu();
|
||||
bool test_upsampling_nearest2d_vec();
|
||||
bool test_upsampling_nearest2d_vec2();
|
||||
bool test_adaptive_avg_pool2d();
|
||||
|
@ -274,6 +274,30 @@ bool test_hardswish() {
|
||||
});
|
||||
}
|
||||
|
||||
bool test_leaky_relu_() {
|
||||
__block std::vector<int64_t> size{3, 3, 44, 44};
|
||||
return TEST(size, __PRETTY_FUNCTION__, ^bool {
|
||||
auto X =
|
||||
at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat)) * 12 - 6;
|
||||
auto X2 = X.metal();
|
||||
auto Y1 = at::leaky_relu_(X, -0.0125);
|
||||
auto Y2 = at::leaky_relu_(X2, -0.0125).cpu();
|
||||
return almostEqual(Y1, Y2);
|
||||
});
|
||||
}
|
||||
|
||||
bool test_leaky_relu() {
|
||||
__block std::vector<int64_t> size{1, 3, 44, 44};
|
||||
return TEST(size, __PRETTY_FUNCTION__, ^bool {
|
||||
auto X =
|
||||
at::rand(size, at::TensorOptions(at::kCPU).dtype(at::kFloat)) * 12 - 6;
|
||||
auto X2 = X.metal();
|
||||
auto Y1 = at::leaky_relu(X, 0.025);
|
||||
auto Y2 = at::leaky_relu(X2, 0.025).cpu();
|
||||
return almostEqual(Y1, Y2);
|
||||
});
|
||||
}
|
||||
|
||||
bool test_addmm() {
|
||||
bool result = true;
|
||||
for (int i = 0; i < ITER_COUNT; ++i) {
|
||||
|
@ -70,6 +70,8 @@
|
||||
REG_TEST("test_hardsigmoid", test_hardsigmoid);
|
||||
REG_TEST("test_hardswish_", test_hardswish_);
|
||||
REG_TEST("test_hardswish", test_hardswish);
|
||||
REG_TEST("test_leaky_relu_", test_leaky_relu_);
|
||||
REG_TEST("test_leaky_relu", test_leaky_relu);
|
||||
REG_TEST("test_upsampling_nearest2d_vec", test_upsampling_nearest2d_vec);
|
||||
REG_TEST("test_upsampling_nearest2d_vec2", test_upsampling_nearest2d_vec2);
|
||||
REG_TEST("test_adaptive_avg_pool2d", test_adaptive_avg_pool2d);
|
||||
|
93
aten/src/ATen/native/metal/ops/MetalLeakyReLU.mm
Normal file
93
aten/src/ATen/native/metal/ops/MetalLeakyReLU.mm
Normal file
@ -0,0 +1,93 @@
|
||||
#include <ATen/Tensor.h>
|
||||
#import <ATen/native/metal/MetalCommandBuffer.h>
|
||||
#import <ATen/native/metal/MetalContext.h>
|
||||
#import <ATen/native/metal/MetalTensorImpl.h>
|
||||
#import <ATen/native/metal/MetalTensorImplStorage.h>
|
||||
#import <ATen/native/metal/MetalTensorUtils.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSCNNUtils.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSImageUtils.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace metal {
|
||||
|
||||
using MetalTensorImpl = at::MetalTensorImpl<MetalTensorImplStorage>;
|
||||
|
||||
Tensor& leaky_relu_(Tensor& input, const Scalar& negative_slope_val) {
|
||||
MPSImage* X = imageFromTensor(input);
|
||||
MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
|
||||
IntArrayRef outputSize = input.sizes();
|
||||
std::vector<int64_t> imageSize = computeImageSize(outputSize);
|
||||
float negative_slope = negative_slope_val.toFloat();
|
||||
MPSImage* Y = createTemporaryImage(commandBuffer, imageSize);
|
||||
id<MTLComputeCommandEncoder> encoder =
|
||||
[commandBuffer.buffer computeCommandEncoder];
|
||||
id<MTLComputePipelineState> state =
|
||||
[[MetalContext sharedInstance] specializedPipelineState:"leaky_relu"
|
||||
Constants:@[
|
||||
@(X.numberOfImages),
|
||||
@(X.featureChannels),
|
||||
@(X.height),
|
||||
@(X.width),
|
||||
@(negative_slope)
|
||||
]];
|
||||
|
||||
[encoder setComputePipelineState:state];
|
||||
[encoder setTexture:[X texture] atIndex:0];
|
||||
[encoder setTexture:[Y texture] atIndex:1];
|
||||
|
||||
const auto& launchParams =
|
||||
metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X);
|
||||
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
|
||||
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
|
||||
[encoder endEncoding];
|
||||
MetalTensorImpl* impl = (MetalTensorImpl*)input.unsafeGetTensorImpl();
|
||||
MetalTensorImplStorage& implStorage = impl->unsafe_opaque_handle();
|
||||
implStorage.texture()->setImage(Y);
|
||||
return input;
|
||||
}
|
||||
|
||||
Tensor leaky_relu(const at::Tensor& input, const Scalar& negative_slope_val) {
|
||||
MPSImage* X = imageFromTensor(input);
|
||||
IntArrayRef outputSize = input.sizes();
|
||||
MetalTensorImplStorage mt{outputSize.vec()};
|
||||
MetalCommandBuffer* commandBuffer = getCommandBuffer(input);
|
||||
mt.texture()->allocateTemporaryStorage(outputSize, commandBuffer);
|
||||
float negative_slope = negative_slope_val.toFloat();
|
||||
MPSImage* Y = mt.texture()->image();
|
||||
id<MTLComputeCommandEncoder> encoder =
|
||||
[commandBuffer.buffer computeCommandEncoder];
|
||||
id<MTLComputePipelineState> state =
|
||||
[[MetalContext sharedInstance] specializedPipelineState:"leaky_relu"
|
||||
Constants:@[
|
||||
@(X.numberOfImages),
|
||||
@(X.featureChannels),
|
||||
@(X.height),
|
||||
@(X.width),
|
||||
@(negative_slope)
|
||||
]];
|
||||
|
||||
[encoder setComputePipelineState:state];
|
||||
[encoder setTexture:[X texture] atIndex:0];
|
||||
[encoder setTexture:[Y texture] atIndex:1];
|
||||
|
||||
const auto& launchParams =
|
||||
metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X);
|
||||
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
|
||||
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
|
||||
[encoder endEncoding];
|
||||
|
||||
auto output = makeTensor(std::move(mt), input.options());
|
||||
return output;
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, Metal, m) {
|
||||
m.impl(TORCH_SELECTIVE_NAME("aten::leaky_relu_"), TORCH_FN(leaky_relu_));
|
||||
m.impl(TORCH_SELECTIVE_NAME("aten::leaky_relu"), TORCH_FN(leaky_relu));
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user