mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[iOS][GPU] Add Metal/MPSCNN support on iOS (#46112)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46112 ### Summary This PR adds the support of running torchscript models on iOS GPU via Metal (Inference only). The feature is currently in prototype state, API changes are expected. The tutorial and the documents will be added once it goes to beta. allow-large-files - Users API ``` auto module = torch::jit::load(model); module.eval(); at::Tensor input = at::ones({1,3,224,224}, at::ScalarType::Float).metal(); auto output = module.forward({input}).toTensor().cpu(); ``` - Supported Models - Person Segmentation v106 (FB Internal) - Mobilenetv2 - Supported Operators - aten::conv2d - aten::addmm - aten::add.Tensor - aten::sub.Tensor - aten::mul.Tensor - aten::relu - aten::hardtanh - aten::hardtanh_ - aten::sigmoid - aten::max_pool2d - aten::adaptive_avg_pool2d - aten::reshape - aten::t - aten::view - aten::log_softmax.int - aten::upsample_nearest2d.vec - Supported Devices - Apple A9 and above - iOS 10.2 and above - CMake scripts - `IOS_ARCH=arm64 ./scripts/build_ios.sh -DUSE_METAL=ON` ### Test Plan - Circle CI ghstack-source-id: 114155638 Test Plan: 1. Sandcastle CI 2. Circle CI Reviewed By: dreiss Differential Revision: D23236555 fbshipit-source-id: 98ffc48b837e308bc678c37a9a5fd8ae72d11625
This commit is contained in:
committed by
Facebook GitHub Bot
parent
7f6a1b2bd5
commit
a277c097ac
@ -298,6 +298,11 @@ filegroup(
|
||||
srcs = glob(["aten/src/ATen/vulkan/*.cpp"]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "aten_base_metal",
|
||||
srcs = glob(["aten/src/ATen/metal/*.cpp"]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "ATen_QUANTIZED_SRCS",
|
||||
srcs = glob(
|
||||
@ -650,6 +655,7 @@ cc_library(
|
||||
":ATen_CORE_SRCS",
|
||||
":ATen_QUANTIZED_SRCS",
|
||||
":aten_base_cpp",
|
||||
":aten_base_metal",
|
||||
":aten_base_vulkan",
|
||||
":aten_native_cpp",
|
||||
":aten_native_mkl_cpp",
|
||||
|
@ -165,7 +165,7 @@ option(USE_GLOG "Use GLOG" OFF)
|
||||
option(USE_LEVELDB "Use LEVELDB" OFF)
|
||||
option(USE_LITE_PROTO "Use lite protobuf instead of full." OFF)
|
||||
option(USE_LMDB "Use LMDB" OFF)
|
||||
option(USE_METAL "Use Metal for iOS build" ON)
|
||||
option(USE_METAL "Use Metal for iOS build" OFF)
|
||||
option(USE_NATIVE_ARCH "Use -march=native" OFF)
|
||||
cmake_dependent_option(
|
||||
USE_NCCL "Use NCCL" ON
|
||||
|
@ -66,6 +66,14 @@ file(GLOB native_mkl_cpp "native/mkl/*.cpp")
|
||||
file(GLOB native_mkldnn_cpp "native/mkldnn/*.cpp")
|
||||
file(GLOB vulkan_cpp "vulkan/*.cpp")
|
||||
file(GLOB native_vulkan_cpp "native/vulkan/api/*.cpp" "native/vulkan/*.cpp")
|
||||
|
||||
file(GLOB metal_h "metal/*.h")
|
||||
file(GLOB metal_cpp "metal/*.cpp")
|
||||
file(GLOB_RECURSE native_metal_h "native/metal/*.h")
|
||||
file(GLOB metal_test_srcs "native/metal/mpscnn/tests/*.mm")
|
||||
file(GLOB_RECURSE native_metal_srcs "native/metal/*.mm", "native/metal/*.cpp")
|
||||
EXCLUDE(native_metal_srcs "${native_metal_srcs}" ${metal_test_srcs})
|
||||
|
||||
file(GLOB native_sparse_cpp "native/sparse/*.cpp")
|
||||
file(GLOB native_quantized_cpp
|
||||
"native/quantized/*.cpp"
|
||||
@ -117,6 +125,12 @@ else()
|
||||
set(all_cpu_cpp ${all_cpu_cpp} ${vulkan_cpp})
|
||||
endif()
|
||||
|
||||
if(USE_METAL)
|
||||
set(all_cpu_cpp ${all_cpu_cpp} ${metal_cpp} ${native_metal_srcs})
|
||||
else()
|
||||
set(all_cpu_cpp ${all_cpu_cpp} ${metal_cpp})
|
||||
endif()
|
||||
|
||||
if(USE_CUDA AND USE_ROCM)
|
||||
message(FATAL_ERROR "ATen doesn't not currently support simultaneously building with CUDA and ROCM")
|
||||
endif()
|
||||
@ -375,6 +389,10 @@ install(FILES "${CMAKE_CURRENT_BINARY_DIR}/cmake-exports/ATenConfig.cmake"
|
||||
set(INSTALL_HEADERS ${base_h} ${ATen_CORE_HEADERS})
|
||||
if(NOT INTERN_BUILD_MOBILE)
|
||||
list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${miopen_h})
|
||||
else()
|
||||
if(USE_METAL)
|
||||
list(APPEND INSTALL_HEADERS ${metal_h} ${native_metal_h})
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# https://stackoverflow.com/questions/11096471/how-can-i-install-a-hierarchy-of-files-using-cmake
|
||||
|
31
aten/src/ATen/metal/Context.cpp
Normal file
31
aten/src/ATen/metal/Context.cpp
Normal file
@ -0,0 +1,31 @@
|
||||
#include <atomic>
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/metal/Context.h>
|
||||
|
||||
namespace at {
|
||||
namespace metal {
|
||||
|
||||
std::atomic<const MetalInterface*> g_metal_impl_registry;
|
||||
|
||||
MetalImplRegistrar::MetalImplRegistrar(MetalInterface* impl) {
|
||||
g_metal_impl_registry.store(impl);
|
||||
}
|
||||
|
||||
at::Tensor& metal_copy_(at::Tensor& self, const at::Tensor& src) {
|
||||
auto p = at::metal::g_metal_impl_registry.load();
|
||||
if (p) {
|
||||
return p->metal_copy_(self, src);
|
||||
}
|
||||
AT_ERROR("Metal backend was not linked to the build");
|
||||
}
|
||||
} // namespace metal
|
||||
|
||||
namespace native {
|
||||
bool is_metal_available() {
|
||||
auto p = at::metal::g_metal_impl_registry.load();
|
||||
return p ? p->is_metal_available() : false;
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
30
aten/src/ATen/metal/Context.h
Normal file
30
aten/src/ATen/metal/Context.h
Normal file
@ -0,0 +1,30 @@
|
||||
#ifndef MetalContext_h
|
||||
#define MetalContext_h
|
||||
|
||||
#include <atomic>
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
|
||||
namespace at {
|
||||
namespace metal {
|
||||
|
||||
struct MetalInterface {
|
||||
virtual ~MetalInterface() = default;
|
||||
virtual bool is_metal_available() const = 0;
|
||||
virtual at::Tensor& metal_copy_(at::Tensor& self, const at::Tensor& src)
|
||||
const = 0;
|
||||
};
|
||||
|
||||
extern std::atomic<const MetalInterface*> g_metal_impl_registry;
|
||||
|
||||
class MetalImplRegistrar {
|
||||
public:
|
||||
explicit MetalImplRegistrar(MetalInterface*);
|
||||
};
|
||||
|
||||
at::Tensor& metal_copy_(at::Tensor& self, const at::Tensor& src);
|
||||
|
||||
} // namespace metal
|
||||
} // namespace at
|
||||
|
||||
#endif /* MetalContext_h */
|
@ -7,6 +7,7 @@
|
||||
#include <ATen/native/quantized/Copy.h>
|
||||
#include <ATen/quantized/Quantizer.h>
|
||||
#include <ATen/vulkan/Context.h>
|
||||
#include <ATen/metal/Context.h>
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
#include <ATen/NamedTensorUtils.h>
|
||||
#include <torch/library.h>
|
||||
@ -79,7 +80,7 @@ void copy_same_type_transpose_(Tensor& self, const Tensor& src) {
|
||||
// (e.g. XLA) may be supported by overriding copy_ and _copy_from.
|
||||
bool is_supported_device(Device device) {
|
||||
DeviceType device_type = device.type();
|
||||
return device_type == kCPU || device_type == kCUDA || device_type == kHIP || device_type == kVulkan;
|
||||
return device_type == kCPU || device_type == kCUDA || device_type == kHIP || device_type == kVulkan || device_type == kMetal;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -133,6 +134,10 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking)
|
||||
return at::vulkan::vulkan_copy_(self, src);
|
||||
}
|
||||
|
||||
if (self.device().type() == at::kMetal || src.device().type() == at::kMetal) {
|
||||
return at::metal::metal_copy_(self, src);
|
||||
}
|
||||
|
||||
auto iter = TensorIteratorConfig()
|
||||
.add_output(self)
|
||||
.add_input(src)
|
||||
|
267
aten/src/ATen/native/metal/MetalAten.mm
Normal file
267
aten/src/ATen/native/metal/MetalAten.mm
Normal file
@ -0,0 +1,267 @@
|
||||
#import <ATen/native/metal/MetalTensor.h>
|
||||
#import <ATen/native/metal/MetalTensorImpl.h>
|
||||
#import <ATen/native/metal/MetalUtils.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSCNNContext.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSCNNOps.h>
|
||||
|
||||
#include <ATen/metal/Context.h>
|
||||
#include <torch/script.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace metal {
|
||||
|
||||
at::Tensor& copy_from_metal_(at::Tensor& dst, const at::Tensor& src) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
src.device().type() == DeviceType::Metal,
|
||||
"copy_from_metal input tensor's device is not metal");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
dst.device().type() == DeviceType::CPU,
|
||||
"copy_from_metal is implemented only for CPU device output");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
dst.layout() == Layout::Strided,
|
||||
"copy_from_metal is implemented only for Strided layout output");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
dst.scalar_type() == ScalarType::Float,
|
||||
"copy_from_metal is implemented only for float dtype output, got:",
|
||||
dst.scalar_type());
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
dst.is_contiguous(),
|
||||
"copy_from_metal is implemented only for contiguous output tensor");
|
||||
|
||||
MetalTensor& mtensor = MetalTensor::fromTensor(src);
|
||||
mtensor.copy_data_to_host(dst.data_ptr<float>());
|
||||
return dst;
|
||||
}
|
||||
|
||||
at::Tensor& copy_to_metal_(at::Tensor& dst, const at::Tensor& src) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
dst.device().type() == DeviceType::Metal,
|
||||
"copy_to_metal_ output tensor's device is not metal");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
src.device().type() == DeviceType::CPU,
|
||||
"copy_to_metal_ is implemented only for CPU device input");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
src.layout() == Layout::Strided,
|
||||
"copy_to_metal_ is implemented only for Strided layout input");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
src.scalar_type() == ScalarType::Float,
|
||||
"copy_to_metal_ is implemented only for float dtype");
|
||||
auto cpu_tensor_contiguous = src.contiguous();
|
||||
MetalTensor& mtensor = MetalTensor::fromTensor(dst);
|
||||
mtensor.set_data_from_host(cpu_tensor_contiguous.data_ptr<float>());
|
||||
return dst;
|
||||
}
|
||||
|
||||
at::Tensor& metal_copy_impl_(at::Tensor& dst, const at::Tensor& src) {
|
||||
if (src.device().type() == at::kMetal && dst.device().type() == at::kCPU) {
|
||||
return copy_from_metal_(dst, src);
|
||||
}
|
||||
if (src.device().type() == at::kCPU && dst.device().type() == at::kMetal) {
|
||||
return copy_to_metal_(dst, src);
|
||||
}
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
src.device().type() == DeviceType::Metal,
|
||||
"metal_copy_ is implemented only for CPU,Strided,float->Metal; Metal->CPU,Strided,float");
|
||||
return dst;
|
||||
}
|
||||
|
||||
#pragma mark - ATen Ops
|
||||
|
||||
Tensor empty(
|
||||
IntArrayRef size,
|
||||
const TensorOptions& options,
|
||||
c10::optional<MemoryFormat> memory_format) {
|
||||
TORCH_CHECK(
|
||||
!options.has_pinned_memory(),
|
||||
"'pin_memory' argument is incompatible with Metal tensor");
|
||||
TORCH_CHECK(
|
||||
!options.has_memory_format() && !memory_format,
|
||||
"'memory_format' argument is incompatible with Metal tensor");
|
||||
MetalTensor mt{size.vec()};
|
||||
return MetalTensor::toTensor(
|
||||
std::move(mt), at::device(at::kMetal).dtype(options.dtype()));
|
||||
};
|
||||
|
||||
at::Tensor empty_strided(
|
||||
IntArrayRef size,
|
||||
IntArrayRef stride,
|
||||
optional<ScalarType> dtype,
|
||||
optional<Layout> layout,
|
||||
optional<Device> device,
|
||||
optional<bool> pin_memory) {
|
||||
TORCH_CHECK(
|
||||
!pin_memory.has_value(),
|
||||
"'pin_memory' argument is incompatible with Metal tensor");
|
||||
MetalTensor mt{size.vec(), stride.vec()};
|
||||
return MetalTensor::toTensor(
|
||||
std::move(mt), at::device(at::kMetal).dtype(dtype));
|
||||
}
|
||||
|
||||
Tensor addmm(
|
||||
const Tensor& bias,
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
Scalar beta,
|
||||
Scalar alpha) {
|
||||
TORCH_CHECK(input.is_metal());
|
||||
TORCH_CHECK(input.dim() == 2 && weight.dim() == 2);
|
||||
TORCH_CHECK(beta.toFloat() == 1.0f);
|
||||
TORCH_CHECK(alpha.toFloat() == 1.0f);
|
||||
auto&& sizes = weight.sizes();
|
||||
at::Tensor transposedWeight = weight.t().contiguous();
|
||||
at::Tensor mWeight =
|
||||
transposedWeight.view({sizes[1], sizes[0], 1, 1}).contiguous();
|
||||
return mpscnn::addmm(bias, input, mWeight);
|
||||
}
|
||||
|
||||
Tensor conv2d(
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const c10::optional<at::Tensor>& bias,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups) {
|
||||
TORCH_CHECK(input.is_metal());
|
||||
Conv2DParams params{
|
||||
input.sizes(), weight.sizes(), padding, stride, dilation, groups};
|
||||
TORCH_INTERNAL_ASSERT(input.dim() == 4, "Expected 4-dimensional input");
|
||||
TORCH_INTERNAL_ASSERT(weight.dim() == 4, "Expected 4-dimensional weight");
|
||||
TORCH_CHECK(weight.device().type() == kCPU);
|
||||
return mpscnn::conv2d(input, weight, bias, params);
|
||||
}
|
||||
|
||||
Tensor log_softmax_int(
|
||||
const Tensor& input,
|
||||
int64_t dim,
|
||||
c10::optional<ScalarType> dtype) {
|
||||
TORCH_CHECK(dim == 1);
|
||||
return mpscnn::log_softmax_int(input);
|
||||
}
|
||||
|
||||
Tensor max_pool2d(
|
||||
const Tensor& input,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
bool ceil_mode) {
|
||||
TORCH_CHECK(input.is_metal());
|
||||
TORCH_CHECK(
|
||||
dilation[0] == dilation[1] == 1, "dilation is not supported on MPSCNN");
|
||||
TORCH_CHECK(ceil_mode == false, "ceil_mode is not supported on MPSCNN");
|
||||
return mpscnn::max_pool2d(
|
||||
input, kernel_size, stride, padding, dilation, ceil_mode);
|
||||
}
|
||||
|
||||
Tensor relu(const Tensor& input) {
|
||||
TORCH_CHECK(input.is_metal());
|
||||
return mpscnn::relu(input);
|
||||
}
|
||||
|
||||
Tensor sigmoid(const Tensor& input) {
|
||||
TORCH_CHECK(input.is_metal());
|
||||
return mpscnn::sigmoid(input);
|
||||
}
|
||||
|
||||
Tensor t(const Tensor& input) {
|
||||
TORCH_CHECK(input.is_metal());
|
||||
TORCH_CHECK(input.dim() == 2);
|
||||
return mpscnn::t(input);
|
||||
}
|
||||
|
||||
Tensor view(const Tensor& input, IntArrayRef size) {
|
||||
TORCH_CHECK(input.is_metal());
|
||||
return mpscnn::view(input, size);
|
||||
}
|
||||
|
||||
Tensor upsample_nearest2d_vec(
|
||||
const Tensor& input,
|
||||
c10::optional<IntArrayRef> output_size,
|
||||
c10::optional<ArrayRef<double>> scale_factors) {
|
||||
TORCH_CHECK(input.is_metal());
|
||||
return mpscnn::upsample_nearest2d_vec(input, output_size, scale_factors);
|
||||
}
|
||||
|
||||
Tensor add_Tensor(const Tensor& input1, const Tensor& input2, Scalar alpha) {
|
||||
TORCH_CHECK(input1.is_metal());
|
||||
TORCH_CHECK(input1.dim() == input2.dim());
|
||||
TORCH_CHECK(input1.sizes()[2] == input2.sizes()[2]);
|
||||
TORCH_CHECK(input1.sizes()[3] == input2.sizes()[3]);
|
||||
return mpscnn::add(input1, input2.is_metal() ? input2 : input2.metal());
|
||||
}
|
||||
|
||||
Tensor sub_Tensor(const Tensor& input1, const Tensor& input2, Scalar alpha) {
|
||||
TORCH_CHECK(input1.is_metal());
|
||||
TORCH_CHECK(input1.dim() == input2.dim());
|
||||
TORCH_CHECK(input2.sizes()[2] == input2.sizes()[3] == 1);
|
||||
return mpscnn::sub(input1, input2.is_metal() ? input2 : input2.metal());
|
||||
}
|
||||
|
||||
Tensor mul_Tensor(const Tensor& input1, const Tensor& input2) {
|
||||
TORCH_CHECK(input1.is_metal());
|
||||
TORCH_CHECK(input1.dim() == input2.dim());
|
||||
TORCH_CHECK(input2.sizes()[2] == input2.sizes()[3] == 1);
|
||||
return mpscnn::mul(input1, input2.is_metal() ? input2 : input2.metal());
|
||||
}
|
||||
|
||||
Tensor adaptive_avg_pool2d(const Tensor& input, IntArrayRef output_size) {
|
||||
// averages across the width and height, and outputs a 1x1xC image.
|
||||
TORCH_CHECK(output_size[0] == 1 && output_size[1] == 1);
|
||||
TORCH_CHECK(input.is_metal());
|
||||
return mpscnn::global_avg_pool2d(input, output_size);
|
||||
}
|
||||
|
||||
Tensor& hardtanh_(Tensor& input, Scalar min_val, Scalar max_val) {
|
||||
TORCH_CHECK(input.is_metal());
|
||||
return mpscnn::hardtanh_(input, min_val, max_val);
|
||||
}
|
||||
|
||||
Tensor reshape(const Tensor& input, IntArrayRef shape) {
|
||||
TORCH_CHECK(input.is_metal());
|
||||
return mpscnn::reshape(input, shape);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, Metal, m) {
|
||||
m.impl("conv2d", TORCH_FN(conv2d));
|
||||
m.impl("add.Tensor", TORCH_FN(add_Tensor));
|
||||
m.impl("addmm", TORCH_FN(addmm));
|
||||
m.impl_UNBOXED("empty.memory_format", empty);
|
||||
m.impl("empty_strided", TORCH_FN(empty_strided));
|
||||
m.impl("log_softmax.int", TORCH_FN(log_softmax_int));
|
||||
m.impl("max_pool2d", TORCH_FN(max_pool2d));
|
||||
m.impl("mul.Tensor", TORCH_FN(mul_Tensor));
|
||||
m.impl("relu", TORCH_FN(relu));
|
||||
m.impl("sigmoid", TORCH_FN(sigmoid));
|
||||
m.impl("sub.Tensor", TORCH_FN(sub_Tensor));
|
||||
m.impl("upsample_nearest2d.vec", TORCH_FN(upsample_nearest2d_vec));
|
||||
m.impl("view", TORCH_FN(view));
|
||||
m.impl("adaptive_avg_pool2d", TORCH_FN(adaptive_avg_pool2d));
|
||||
m.impl("hardtanh_", TORCH_FN(hardtanh_));
|
||||
m.impl("reshape", TORCH_FN(reshape));
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
} // namespace native
|
||||
|
||||
struct MetalImpl : public at::metal::MetalInterface {
|
||||
bool is_metal_available() const override {
|
||||
#if defined(USE_METAL)
|
||||
return [[MPSCNNContext sharedInstance] available];
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
at::Tensor& metal_copy_(at::Tensor& input, const at::Tensor& src)
|
||||
const override {
|
||||
TORCH_CHECK(
|
||||
is_metal_available(), "Metal is not available on the current device");
|
||||
return native::metal::metal_copy_impl_(input, src);
|
||||
}
|
||||
};
|
||||
#if defined(USE_METAL)
|
||||
static at::metal::MetalImplRegistrar g_metal_impl(new MetalImpl());
|
||||
#endif
|
||||
|
||||
} // namespace at
|
16
aten/src/ATen/native/metal/MetalCommandBuffer.h
Normal file
16
aten/src/ATen/native/metal/MetalCommandBuffer.h
Normal file
@ -0,0 +1,16 @@
|
||||
#import <Foundation/Foundation.h>
|
||||
#import <Metal/Metal.h>
|
||||
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
||||
|
||||
@interface MetalCommandBuffer : NSObject
|
||||
@property(nonatomic, strong, readonly) NSThread* thread;
|
||||
@property(nonatomic, strong, readonly) id<MTLCommandBuffer> buffer;
|
||||
|
||||
+ (MetalCommandBuffer*)newBuffer;
|
||||
+ (MetalCommandBuffer*)currentBuffer;
|
||||
- (void)synchronize;
|
||||
|
||||
- (void)add:(MPSTemporaryImage*)image;
|
||||
- (void)remove:(MPSTemporaryImage*)image;
|
||||
|
||||
@end
|
79
aten/src/ATen/native/metal/MetalCommandBuffer.mm
Normal file
79
aten/src/ATen/native/metal/MetalCommandBuffer.mm
Normal file
@ -0,0 +1,79 @@
|
||||
#import <ATen/native/metal/MetalCommandBuffer.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSCNNContext.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
|
||||
|
||||
#include <mutex>
|
||||
|
||||
NSString* cb_key = @"PTCommandBuffer";
|
||||
@implementation MetalCommandBuffer {
|
||||
NSMutableArray* _images;
|
||||
std::mutex _mutex;
|
||||
}
|
||||
|
||||
+ (MetalCommandBuffer*)newBuffer {
|
||||
MetalCommandBuffer* cb = [MetalCommandBuffer new];
|
||||
cb->_buffer = [[MPSCNNContext sharedInstance].commandQueue commandBuffer];
|
||||
cb->_thread = [NSThread currentThread];
|
||||
cb->_images = [NSMutableArray new];
|
||||
return cb;
|
||||
}
|
||||
|
||||
+ (MetalCommandBuffer*)currentBuffer {
|
||||
NSThread* thd = [NSThread currentThread];
|
||||
NSMutableDictionary* dict = [thd threadDictionary];
|
||||
MetalCommandBuffer* cb = dict[cb_key];
|
||||
if (!cb) {
|
||||
cb = [MetalCommandBuffer new];
|
||||
cb->_buffer = [[MPSCNNContext sharedInstance].commandQueue commandBuffer];
|
||||
cb->_thread = thd;
|
||||
cb->_images = [NSMutableArray new];
|
||||
dict[cb_key] = cb;
|
||||
}
|
||||
return cb;
|
||||
}
|
||||
|
||||
- (void)flush {
|
||||
[[_thread threadDictionary] removeObjectForKey:cb_key];
|
||||
}
|
||||
|
||||
- (void)add:(MPSTemporaryImage*)image {
|
||||
if (![image isTemporaryImage]) {
|
||||
return;
|
||||
}
|
||||
std::lock_guard<std::mutex> g(_mutex);
|
||||
[_images addObject:image];
|
||||
}
|
||||
|
||||
- (void)remove:(MPSTemporaryImage*)image {
|
||||
if (![image isTemporaryImage]) {
|
||||
return;
|
||||
}
|
||||
std::lock_guard<std::mutex> g(_mutex);
|
||||
[_images removeObject:image];
|
||||
}
|
||||
|
||||
- (void)synchronize {
|
||||
if (_buffer.status == 0) {
|
||||
// recycle all temporary images manually before flushing the command buffer
|
||||
[self recycle];
|
||||
[_buffer commit];
|
||||
[_buffer waitUntilCompleted];
|
||||
[[_thread threadDictionary] removeObjectForKey:cb_key];
|
||||
}
|
||||
}
|
||||
|
||||
- (void)recycle {
|
||||
for (MPSTemporaryImage* image in _images) {
|
||||
[image recycle];
|
||||
}
|
||||
}
|
||||
|
||||
- (BOOL)isEqual:(id)object {
|
||||
if (![object isKindOfClass:[MetalCommandBuffer class]]) {
|
||||
return NO;
|
||||
}
|
||||
MetalCommandBuffer* mc = (MetalCommandBuffer*)object;
|
||||
return (_thread == mc.thread && _buffer == mc.buffer);
|
||||
}
|
||||
|
||||
@end
|
56
aten/src/ATen/native/metal/MetalConvolution.h
Normal file
56
aten/src/ATen/native/metal/MetalConvolution.h
Normal file
@ -0,0 +1,56 @@
|
||||
#import <ATen/native/metal/MetalPrepackOpContext.h>
|
||||
#import <ATen/native/metal/MetalUtils.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSCNNOp.h>
|
||||
#include <torch/script.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace metal {
|
||||
|
||||
enum class NeuronType {
|
||||
None,
|
||||
Clamp,
|
||||
Relu,
|
||||
Sigmoid,
|
||||
Tanh,
|
||||
};
|
||||
|
||||
struct Conv2DParams final {
|
||||
Conv2DParams() = delete;
|
||||
Conv2DParams(
|
||||
c10::IntArrayRef inputSizes,
|
||||
c10::IntArrayRef weightSizes,
|
||||
c10::IntArrayRef padding,
|
||||
c10::IntArrayRef stride,
|
||||
c10::IntArrayRef dilation,
|
||||
int64_t groups);
|
||||
|
||||
std::vector<int64_t> output_sizes() const;
|
||||
bool isDepthwise() const;
|
||||
|
||||
int64_t N; // batch size
|
||||
int64_t C; // channels
|
||||
int64_t H; // input height
|
||||
int64_t W; // input width
|
||||
int64_t OC; // output channels
|
||||
int64_t IC; // input channels
|
||||
int64_t KH; // kernel height
|
||||
int64_t KW; // kernel width
|
||||
int64_t SY; // stride y (height)
|
||||
int64_t SX; // stride x (width)
|
||||
int64_t PY; // padding y (height)
|
||||
int64_t PX; // padding x (width)
|
||||
int64_t DY; // dilation y (height)
|
||||
int64_t DX; // dilation x (width)
|
||||
int64_t G; // groups
|
||||
int64_t OW; // output width
|
||||
int64_t OH; // output height
|
||||
};
|
||||
|
||||
NeuronType neuronType(const Conv2dOpContext& context);
|
||||
|
||||
Tensor conv2d_prepack_run_impl(Conv2dOpContext& context, const Tensor& input);
|
||||
|
||||
} // namespace metal
|
||||
} // namespace native
|
||||
} // namespace at
|
69
aten/src/ATen/native/metal/MetalConvolution.mm
Normal file
69
aten/src/ATen/native/metal/MetalConvolution.mm
Normal file
@ -0,0 +1,69 @@
|
||||
#import <ATen/native/metal/MetalConvolution.h>
|
||||
#import <ATen/native/metal/MetalUtils.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSCNNOps.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace metal {
|
||||
|
||||
Conv2DParams::Conv2DParams(
|
||||
c10::IntArrayRef inputSizes,
|
||||
c10::IntArrayRef weightSizes,
|
||||
c10::IntArrayRef padding,
|
||||
c10::IntArrayRef stride,
|
||||
c10::IntArrayRef dilation,
|
||||
int64_t groups)
|
||||
: N(inputSizes[0]),
|
||||
C(inputSizes[1]),
|
||||
H(inputSizes[2]),
|
||||
W(inputSizes[3]),
|
||||
OC(weightSizes[0]),
|
||||
IC(weightSizes[1]),
|
||||
KH(weightSizes[2]),
|
||||
KW(weightSizes[3]),
|
||||
SY(stride[0]),
|
||||
SX(stride[1]),
|
||||
PY(padding[0]),
|
||||
PX(padding[1]),
|
||||
DY(dilation[0]),
|
||||
DX(dilation[1]),
|
||||
G(groups) {
|
||||
OH = std::floor((H + 2 * PY - DY * (KH - 1) - 1) / SY + 1);
|
||||
OW = std::floor((W + 2 * PX - DX * (KW - 1) - 1) / SX + 1);
|
||||
};
|
||||
|
||||
std::vector<int64_t> Conv2DParams::output_sizes() const {
|
||||
return {N, OC, OH, OW};
|
||||
}
|
||||
|
||||
bool Conv2DParams::isDepthwise() const {
|
||||
// Currently, only channel multipler of 1 is supported
|
||||
// i.e. inputFeatureChannels == outputFeatureChannels
|
||||
return G > 1 && IC == 1 && OC == G && OC == C;
|
||||
}
|
||||
|
||||
NeuronType neuronType(const Conv2dOpContext& context) {
|
||||
float inf_max = std::numeric_limits<float>::infinity();
|
||||
float inf_min = -std::numeric_limits<float>::infinity();
|
||||
float output_max = context.output_max.has_value()
|
||||
? context.output_max.value().toFloat()
|
||||
: inf_max;
|
||||
float output_min = context.output_min.has_value()
|
||||
? context.output_min.value().toFloat()
|
||||
: inf_min;
|
||||
if (output_max == inf_max && output_min == 0) {
|
||||
return NeuronType::Relu;
|
||||
} else if (output_max < inf_max && output_min > inf_min) {
|
||||
return NeuronType::Clamp;
|
||||
} else {
|
||||
return NeuronType::None;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor conv2d_prepack_run_impl(Conv2dOpContext& context, const Tensor& input) {
|
||||
return mpscnn::conv2d(input, context);
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
} // namespace native
|
||||
} // namespace at
|
64
aten/src/ATen/native/metal/MetalGuardImpl.cpp
Normal file
64
aten/src/ATen/native/metal/MetalGuardImpl.cpp
Normal file
@ -0,0 +1,64 @@
|
||||
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
namespace at {
|
||||
namespace detail {
|
||||
|
||||
struct MetalGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
MetalGuardImpl() {}
|
||||
|
||||
explicit MetalGuardImpl(DeviceType t) {
|
||||
TORCH_INTERNAL_ASSERT(t == DeviceType::Metal);
|
||||
}
|
||||
|
||||
DeviceType type() const override {
|
||||
return DeviceType::Metal;
|
||||
}
|
||||
Device exchangeDevice(Device) const override {
|
||||
// no-op
|
||||
return Device(DeviceType::Metal, -1);
|
||||
}
|
||||
Device getDevice() const override {
|
||||
return Device(DeviceType::Metal, -1);
|
||||
}
|
||||
void setDevice(Device) const override {
|
||||
// no-op
|
||||
}
|
||||
void uncheckedSetDevice(Device d) const noexcept override {
|
||||
// no-op
|
||||
}
|
||||
Stream getStream(Device d) const noexcept override {
|
||||
// no-op
|
||||
return Stream(Stream::DEFAULT, Device(DeviceType::Metal, -1));
|
||||
}
|
||||
// NB: These do NOT set the current device
|
||||
Stream exchangeStream(Stream s) const noexcept override {
|
||||
// no-op
|
||||
return Stream(Stream::DEFAULT, Device(DeviceType::Metal, -1));
|
||||
}
|
||||
DeviceIndex deviceCount() const noexcept override {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Event-related functions
|
||||
void record(
|
||||
void** event,
|
||||
const Stream& stream,
|
||||
const DeviceIndex device_index,
|
||||
const EventFlag flag) const override {
|
||||
TORCH_CHECK(false, "Metal backend doesn't support events.");
|
||||
}
|
||||
void block(void* event, const Stream& stream) const override {
|
||||
TORCH_CHECK(false, "Metal backend doesn't support events.")
|
||||
}
|
||||
bool queryEvent(void* event) const override {
|
||||
TORCH_CHECK(false, "Metal backend doesn't support events.")
|
||||
}
|
||||
void destroyEvent(void* event, const DeviceIndex device_index) const
|
||||
noexcept override {}
|
||||
};
|
||||
|
||||
C10_REGISTER_GUARD_IMPL(Metal, MetalGuardImpl);
|
||||
|
||||
} // namespace detail
|
||||
} // namespace at
|
91
aten/src/ATen/native/metal/MetalPrepackOpContext.h
Normal file
91
aten/src/ATen/native/metal/MetalPrepackOpContext.h
Normal file
@ -0,0 +1,91 @@
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <torch/custom_class.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace metal {
|
||||
|
||||
using SerializationTypeConv2dPrePack = std::tuple<
|
||||
Tensor,
|
||||
c10::optional<Tensor>,
|
||||
std::vector<int64_t>,
|
||||
std::vector<int64_t>,
|
||||
std::vector<int64_t>,
|
||||
int64_t,
|
||||
c10::optional<Scalar>,
|
||||
c10::optional<Scalar>>;
|
||||
|
||||
class Conv2dOpContext : public torch::jit::CustomClassHolder {
|
||||
public:
|
||||
SerializationTypeConv2dPrePack pack() {
|
||||
return std::make_tuple(
|
||||
weight,
|
||||
bias,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
output_min,
|
||||
output_max);
|
||||
}
|
||||
Conv2dOpContext() = delete;
|
||||
Conv2dOpContext(
|
||||
at::Tensor&& weight,
|
||||
c10::optional<at::Tensor>&& bias,
|
||||
const std::vector<int64_t>& stride,
|
||||
const std::vector<int64_t>& padding,
|
||||
const std::vector<int64_t>& dilation,
|
||||
int64_t groups,
|
||||
c10::optional<Scalar> output_min,
|
||||
c10::optional<Scalar> output_max)
|
||||
: weight(std::move(weight)),
|
||||
bias(std::move(bias)),
|
||||
stride(stride),
|
||||
padding(padding),
|
||||
dilation(dilation),
|
||||
groups(groups),
|
||||
output_min(output_min),
|
||||
output_max(output_max) {}
|
||||
|
||||
Tensor weight;
|
||||
c10::optional<Tensor> bias;
|
||||
std::vector<int64_t> stride;
|
||||
std::vector<int64_t> padding;
|
||||
std::vector<int64_t> dilation;
|
||||
int64_t groups;
|
||||
c10::optional<Scalar> output_min;
|
||||
c10::optional<Scalar> output_max;
|
||||
id extra = nil;
|
||||
};
|
||||
|
||||
c10::intrusive_ptr<Conv2dOpContext> unpack(
|
||||
Tensor&& weight,
|
||||
c10::optional<Tensor>&& bias,
|
||||
std::vector<int64_t>&& stride,
|
||||
std::vector<int64_t>&& padding,
|
||||
std::vector<int64_t>&& dilation,
|
||||
int64_t groups,
|
||||
c10::optional<Scalar> output_min,
|
||||
c10::optional<Scalar> output_max);
|
||||
|
||||
c10::intrusive_ptr<Conv2dOpContext> conv2d_prepack(
|
||||
Tensor&& weight,
|
||||
c10::optional<Tensor>&& bias,
|
||||
std::vector<int64_t>&& stride,
|
||||
std::vector<int64_t>&& padding,
|
||||
std::vector<int64_t>&& dilation,
|
||||
int64_t groups,
|
||||
c10::optional<Scalar> output_min,
|
||||
c10::optional<Scalar> output_max);
|
||||
|
||||
Tensor conv2d_prepack_run(
|
||||
const Tensor& input,
|
||||
const c10::intrusive_ptr<Conv2dOpContext>& op_context);
|
||||
|
||||
Tensor copy_to_host(const Tensor& input);
|
||||
|
||||
} // namespace metal
|
||||
} // namespace native
|
||||
} // namespace at
|
71
aten/src/ATen/native/metal/MetalPrepackOpContext.mm
Normal file
71
aten/src/ATen/native/metal/MetalPrepackOpContext.mm
Normal file
@ -0,0 +1,71 @@
|
||||
#import <ATen/native/metal/MetalConvolution.h>
|
||||
#import <ATen/native/metal/MetalPrepackOpContext.h>
|
||||
#import <ATen/native/metal/MetalUtils.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSCNNOps.h>
|
||||
|
||||
#include <torch/script.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace metal {
|
||||
|
||||
c10::intrusive_ptr<Conv2dOpContext> conv2d_prepack(
|
||||
at::Tensor&& weight,
|
||||
c10::optional<at::Tensor>&& bias,
|
||||
std::vector<int64_t>&& stride,
|
||||
std::vector<int64_t>&& padding,
|
||||
std::vector<int64_t>&& dilation,
|
||||
const int64_t groups,
|
||||
c10::optional<Scalar> output_min,
|
||||
c10::optional<Scalar> output_max) {
|
||||
TORCH_CHECK(weight.dim() == 4);
|
||||
return c10::make_intrusive<Conv2dOpContext>(
|
||||
std::move(weight),
|
||||
std::move(bias),
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
output_min,
|
||||
output_max);
|
||||
}
|
||||
|
||||
c10::intrusive_ptr<Conv2dOpContext> unpack(
|
||||
Tensor&& weight,
|
||||
c10::optional<Tensor>&& bias,
|
||||
std::vector<int64_t>&& stride,
|
||||
std::vector<int64_t>&& padding,
|
||||
std::vector<int64_t>&& dilation,
|
||||
int64_t groups,
|
||||
c10::optional<Scalar> output_min,
|
||||
c10::optional<Scalar> output_max) {
|
||||
const Tensor weightContig = weight.contiguous();
|
||||
const auto ws = weightContig.sizes();
|
||||
auto packed_buffer = permuteWeights(weightContig.data_ptr<float>(), ws.vec());
|
||||
auto packedWeight = at::empty(ws);
|
||||
int64_t size_bytes = at::prod_intlist(ws) * sizeof(float);
|
||||
memcpy(packedWeight.data_ptr(), packed_buffer.data(), size_bytes);
|
||||
return c10::make_intrusive<Conv2dOpContext>(
|
||||
std::move(packedWeight),
|
||||
std::move(bias),
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
output_min,
|
||||
output_max);
|
||||
}
|
||||
|
||||
Tensor conv2d_prepack_run(
|
||||
const Tensor& input,
|
||||
const c10::intrusive_ptr<Conv2dOpContext>& op_context) {
|
||||
return conv2d_prepack_run_impl(*op_context, input);
|
||||
}
|
||||
|
||||
Tensor copy_to_host(const Tensor& input) {
|
||||
return mpscnn::copy_to_host(input);
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
} // namespace native
|
||||
} // namespace at
|
55
aten/src/ATen/native/metal/MetalPrepackOpRegister.mm
Normal file
55
aten/src/ATen/native/metal/MetalPrepackOpRegister.mm
Normal file
@ -0,0 +1,55 @@
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#import <ATen/native/metal/MetalPrepackOpContext.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace metal {
|
||||
|
||||
TORCH_LIBRARY(metal, m) {
|
||||
m.class_<Conv2dOpContext>("Conv2dOpContext")
|
||||
.def_pickle(
|
||||
[](const c10::intrusive_ptr<Conv2dOpContext>& op_context)
|
||||
-> SerializationTypeConv2dPrePack { // __getstate__
|
||||
return op_context->pack();
|
||||
},
|
||||
[](SerializationTypeConv2dPrePack state)
|
||||
-> c10::intrusive_ptr<Conv2dOpContext> { // __setstate__
|
||||
return unpack(
|
||||
std::move(std::get<0>(state)),
|
||||
std::move(std::get<1>(state)),
|
||||
std::move(std::get<2>(state)),
|
||||
std::move(std::get<3>(state)),
|
||||
std::move(std::get<4>(state)),
|
||||
std::move(std::get<5>(state)),
|
||||
std::move(std::get<6>(state)),
|
||||
std::move(std::get<7>(state)));
|
||||
});
|
||||
m.def("copy_to_host(Tensor X) -> Tensor Y");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY(metal_prepack, m) {
|
||||
m.def(
|
||||
"conv2d_prepack(Tensor W, Tensor? B, int[2] stride, "
|
||||
"int[2] padding, int[2] dilation, int groups, "
|
||||
"Scalar? output_min=None, Scalar? output_max=None) "
|
||||
"-> __torch__.torch.classes.metal.Conv2dOpContext");
|
||||
m.def(
|
||||
"conv2d_run(Tensor X, "
|
||||
"__torch__.torch.classes.metal.Conv2dOpContext W_prepack) -> Tensor Y");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(metal_prepack, CPU, m) {
|
||||
m.impl("conv2d_prepack", TORCH_FN(conv2d_prepack));
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(metal_prepack, Metal, m) {
|
||||
m.impl("conv2d_run", conv2d_prepack_run);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(metal, Metal, m) {
|
||||
m.impl("copy_to_host", copy_to_host);
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
} // namespace native
|
||||
} // namespace at
|
270
aten/src/ATen/native/metal/MetalShaders.h
Normal file
270
aten/src/ATen/native/metal/MetalShaders.h
Normal file
@ -0,0 +1,270 @@
|
||||
#ifndef MPSCNNShaders_h
|
||||
#define MPSCNNShaders_h
|
||||
|
||||
static const char* METAL_SHADERS = R"METAL_SHADERS(
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
constant ushort ushort_arg_0[[function_constant(0)]];
|
||||
constant ushort ushort_arg_1[[function_constant(1)]];
|
||||
constant ushort ushort_arg_2[[function_constant(2)]];
|
||||
constant ushort ushort_arg_3[[function_constant(3)]];
|
||||
constant ushort ushort_arg_4[[function_constant(4)]];
|
||||
constant ushort ushort_arg_5[[function_constant(5)]];
|
||||
constant ushort ushort_arg_6[[function_constant(6)]];
|
||||
constant ushort ushort_arg_7[[function_constant(7)]];
|
||||
constant ushort ushort_arg_8[[function_constant(8)]];
|
||||
constant ushort ushort_arg_9[[function_constant(9)]];
|
||||
constant float float_arg_0 [[function_constant(10)]];
|
||||
constant float float_arg_1 [[function_constant(11)]];
|
||||
|
||||
|
||||
inline constexpr ushort divRoundUp(ushort x, ushort y) { return (x + (y - 1)) / y; }
|
||||
|
||||
kernel void elementwise_add_nonarray(texture2d<half, access::read> in0[[texture(0)]],
|
||||
texture2d<half, access::read> in1[[texture(1)]],
|
||||
texture2d<half, access::write> out[[texture(2)]],
|
||||
ushort2 gid[[thread_position_in_grid]]) {
|
||||
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
|
||||
return;
|
||||
}
|
||||
out.write(in0.read(gid) + in1.read(gid), gid);
|
||||
}
|
||||
|
||||
kernel void elementwise_add(texture2d_array<half, access::read> in0[[texture(0)]],
|
||||
texture2d_array<half, access::read> in1[[texture(1)]],
|
||||
texture2d_array<half, access::write> out[[texture(2)]],
|
||||
ushort3 gid[[thread_position_in_grid]]) {
|
||||
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
|
||||
return;
|
||||
}
|
||||
ushort2 gid_ = gid.xy;
|
||||
out.write(in0.read(gid_, gid.z) + in1.read(gid_, gid.z), gid_, gid.z);
|
||||
}
|
||||
|
||||
kernel void elementwise_sub_nonarray(texture2d<half, access::read> in0[[texture(0)]],
|
||||
texture2d<half, access::read> in1[[texture(1)]],
|
||||
texture2d<half, access::write> out[[texture(2)]],
|
||||
ushort2 gid[[thread_position_in_grid]]) {
|
||||
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
|
||||
return;
|
||||
}
|
||||
ushort2 gid2{0,0};
|
||||
out.write(in0.read(gid) - in1.read(gid2), gid);
|
||||
}
|
||||
|
||||
kernel void elementwise_sub(texture2d_array<half, access::read> in0[[texture(0)]],
|
||||
texture2d_array<half, access::read> in1[[texture(1)]],
|
||||
texture2d_array<half, access::write> out[[texture(2)]],
|
||||
ushort3 gid[[thread_position_in_grid]]) {
|
||||
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
|
||||
return;
|
||||
}
|
||||
ushort2 gid1 = gid.xy;
|
||||
ushort2 gid2{0,0};
|
||||
out.write(in0.read(gid1, gid.z) - in1.read(gid2, gid.z), gid1, gid.z);
|
||||
}
|
||||
kernel void elementwise_mul_nonarray(texture2d<half, access::read> in0[[texture(0)]],
|
||||
texture2d<half, access::read> in1[[texture(1)]],
|
||||
texture2d<half, access::write> out[[texture(2)]],
|
||||
ushort2 gid[[thread_position_in_grid]]) {
|
||||
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
|
||||
return;
|
||||
}
|
||||
ushort2 gid2{0,0};
|
||||
out.write(in0.read(gid) * in1.read(gid2), gid);
|
||||
}
|
||||
|
||||
kernel void elementwise_mul(texture2d_array<half, access::read> in0[[texture(0)]],
|
||||
texture2d_array<half, access::read> in1[[texture(1)]],
|
||||
texture2d_array<half, access::write> out[[texture(2)]],
|
||||
ushort3 gid[[thread_position_in_grid]]) {
|
||||
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
|
||||
return;
|
||||
}
|
||||
ushort2 gid1 = gid.xy;
|
||||
ushort2 gid2{0,0};
|
||||
out.write(in0.read(gid1, gid.z) * in1.read(gid2, gid.z), gid1, gid.z);
|
||||
}
|
||||
|
||||
kernel void copy_nchw_to_metal(constant float* in[[buffer(0)]],
|
||||
texture2d_array<half, access::write> out[[texture(0)]],
|
||||
ushort3 gid[[thread_position_in_grid]]) {
|
||||
const ushort C = ushort_arg_0;
|
||||
const ushort H = ushort_arg_1;
|
||||
const ushort W = ushort_arg_2;
|
||||
if (gid.x >= W || gid.y >= H) {
|
||||
return;
|
||||
}
|
||||
const ushort n = gid.z / divRoundUp(C, 4);
|
||||
const ushort c = gid.z - n * divRoundUp(C, 4);
|
||||
// TODO: are the `else` branches needed?
|
||||
// TODO: trick the optimizer for case where C == 4?
|
||||
#define CHW_TO_CHWP4(idx, n, c_, h, w) \
|
||||
if ((c_) < C) { \
|
||||
trns[idx] = in[n * H * W * C + int(c_) * H * W + int(h) * W + int(w)]; \
|
||||
} else { \
|
||||
trns[idx] = 0.0h; \
|
||||
}
|
||||
half4 trns;
|
||||
CHW_TO_CHWP4(0, n, c * 4 + 0, gid.y, gid.x);
|
||||
CHW_TO_CHWP4(1, n, c * 4 + 1, gid.y, gid.x);
|
||||
CHW_TO_CHWP4(2, n, c * 4 + 2, gid.y, gid.x);
|
||||
CHW_TO_CHWP4(3, n, c * 4 + 3, gid.y, gid.x);
|
||||
#undef CHW_TO_CHWP4
|
||||
out.write(trns, gid.xy, gid.z);
|
||||
}
|
||||
|
||||
kernel void copy_nchw_to_metal_nonarray(constant float* in[[buffer(0)]],
|
||||
texture2d<half, access::write> out[[texture(0)]],
|
||||
ushort2 gid[[thread_position_in_grid]]) {
|
||||
const ushort C = ushort_arg_0;
|
||||
const ushort H = ushort_arg_1;
|
||||
const ushort W = ushort_arg_2;
|
||||
if (gid.x >= W || gid.y >= H) {
|
||||
return;
|
||||
}
|
||||
half4 trns;
|
||||
// TODO: are the `else` branches needed?
|
||||
// TODO: trick the optimizer for case where C % 4 == 0?
|
||||
#define CHW_TO_CHWP4(idx, c, h, w) \
|
||||
if ((c) < C) { \
|
||||
trns[idx] = in[int(c) * H * W + int(h) * W + int(w)]; \
|
||||
} else { \
|
||||
trns[idx] = 0.0h; \
|
||||
}
|
||||
CHW_TO_CHWP4(0, 0, gid.y, gid.x);
|
||||
CHW_TO_CHWP4(1, 1, gid.y, gid.x);
|
||||
CHW_TO_CHWP4(2, 2, gid.y, gid.x);
|
||||
CHW_TO_CHWP4(3, 3, gid.y, gid.x);
|
||||
#undef CHW_TO_CHWP4
|
||||
out.write(trns, gid.xy);
|
||||
}
|
||||
|
||||
kernel void copy_metal_to_nchw(texture2d_array<half, access::read> in[[texture(0)]],
|
||||
device float* out[[buffer(0)]],
|
||||
ushort3 gid[[thread_position_in_grid]]) {
|
||||
const ushort C = ushort_arg_0;
|
||||
const ushort H = ushort_arg_1;
|
||||
const ushort W = ushort_arg_2;
|
||||
if (gid.x >= W || gid.y >= H) {
|
||||
return;
|
||||
}
|
||||
const ushort n = gid.z / divRoundUp(C, 4);
|
||||
const ushort c = gid.z - n * divRoundUp(C, 4);
|
||||
half4 cs = in.read(gid.xy, gid.z);
|
||||
#define CHWP4_TO_CHW(idx, n, c_, h, w) \
|
||||
if ((c_) < C) { \
|
||||
out[n * H * W * C + int(c_) * H * W + int(h) * W + int(w)] = cs[idx]; \
|
||||
}
|
||||
CHWP4_TO_CHW(0, n, c * 4 + 0, gid.y, gid.x);
|
||||
CHWP4_TO_CHW(1, n, c * 4 + 1, gid.y, gid.x);
|
||||
CHWP4_TO_CHW(2, n, c * 4 + 2, gid.y, gid.x);
|
||||
CHWP4_TO_CHW(3, n, c * 4 + 3, gid.y, gid.x);
|
||||
#undef CHWP4_TO_CHW
|
||||
}
|
||||
|
||||
kernel void copy_metal_to_nchw_nonarray(texture2d<half, access::read> in[[texture(0)]],
|
||||
device float* out[[buffer(0)]],
|
||||
ushort2 gid[[thread_position_in_grid]]) {
|
||||
const ushort C = ushort_arg_0;
|
||||
const ushort H = ushort_arg_1;
|
||||
const ushort W = ushort_arg_2;
|
||||
if (gid.x >= W || gid.y >= H) {
|
||||
return;
|
||||
}
|
||||
half4 cs = in.read(gid.xy);
|
||||
#define CHWP4_TO_CHW(idx, c, h, w) \
|
||||
if ((c) < C) { \
|
||||
out[int(c) * H * W + int(h) * W + int(w)] = cs[idx]; \
|
||||
}
|
||||
CHWP4_TO_CHW(0, 0, gid.y, gid.x);
|
||||
CHWP4_TO_CHW(1, 1, gid.y, gid.x);
|
||||
CHWP4_TO_CHW(2, 2, gid.y, gid.x);
|
||||
CHWP4_TO_CHW(3, 3, gid.y, gid.x);
|
||||
#undef CHWP4_TO_CHW
|
||||
}
|
||||
|
||||
kernel void copy(texture2d_array<half, access::read> in[[texture(0)]],
|
||||
texture2d_array<half, access::write> out[[texture(1)]],
|
||||
ushort3 gid[[thread_position_in_grid]]) {
|
||||
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
|
||||
return;
|
||||
}
|
||||
ushort2 gid_ = gid.xy;
|
||||
out.write(in.read(gid_, gid.z), gid_, gid.z);
|
||||
}
|
||||
|
||||
kernel void copy_nonarray(texture2d<half, access::read> in[[texture(0)]],
|
||||
texture2d<half, access::write> out[[texture(1)]],
|
||||
ushort2 gid[[thread_position_in_grid]]) {
|
||||
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
|
||||
return;
|
||||
}
|
||||
out.write(in.read(gid), gid);
|
||||
}
|
||||
|
||||
kernel void clamp_half4(texture2d_array<half, access::read> in[[texture(0)]],
|
||||
texture2d_array<half, access::write> out[[texture(1)]],
|
||||
constant half* clamp_buf[[buffer(0)]],
|
||||
ushort3 gid[[thread_position_in_grid]]) {
|
||||
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
|
||||
return;
|
||||
}
|
||||
const half4 min_(clamp_buf[0], clamp_buf[0], clamp_buf[0], clamp_buf[0]);
|
||||
const half4 max_(clamp_buf[1], clamp_buf[1], clamp_buf[1], clamp_buf[1]);
|
||||
ushort2 gid_ = gid.xy;
|
||||
half4 value = in.read(gid_, gid.z);
|
||||
half4 clamped = clamp(value, min_, max_);
|
||||
out.write(clamped, gid_, gid.z);
|
||||
}
|
||||
|
||||
kernel void clamp_half4_nonarray(texture2d<half, access::read> in[[texture(0)]],
|
||||
texture2d<half, access::write> out[[texture(1)]],
|
||||
constant half* clamp_buf[[buffer(0)]],
|
||||
ushort2 gid[[thread_position_in_grid]]) {
|
||||
if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
|
||||
return;
|
||||
}
|
||||
const half4 min_(clamp_buf[0], clamp_buf[0], clamp_buf[0], clamp_buf[0]);
|
||||
const half4 max_(clamp_buf[1], clamp_buf[1], clamp_buf[1], clamp_buf[1]);
|
||||
half4 value = in.read(gid);
|
||||
half4 clamped = clamp(value, min_, max_);
|
||||
out.write(clamped, gid);
|
||||
}
|
||||
|
||||
kernel void resize_nearest(texture2d_array<half, access::sample> in[[texture(0)]],
|
||||
texture2d_array<half, access::write> out[[texture(1)]],
|
||||
ushort3 gid[[thread_position_in_grid]]) {
|
||||
const ushort oH = ushort_arg_0;
|
||||
const ushort oW = ushort_arg_1;
|
||||
if (gid.x >= oW || gid.y >= oH) {
|
||||
return;
|
||||
}
|
||||
const float height_scale = float(ushort_arg_2) / 10000;
|
||||
const float width_scale = float(ushort_arg_3) / 10000;
|
||||
constexpr sampler s(coord::pixel, address::clamp_to_edge, filter::nearest);
|
||||
const int in_y = (int)(gid.y / height_scale);
|
||||
const int in_x = (int)(gid.x / width_scale);
|
||||
out.write(in.sample(s, float2(in_x, in_y), gid.z), gid.xy, gid.z);
|
||||
}
|
||||
|
||||
kernel void resize_nearest_nonarray(texture2d<half, access::sample> in[[texture(0)]],
|
||||
texture2d<half, access::write> out[[texture(1)]],
|
||||
ushort2 gid[[thread_position_in_grid]]) {
|
||||
const ushort oH = ushort_arg_0;
|
||||
const ushort oW = ushort_arg_1;
|
||||
if (gid.x >= oW || gid.y >= oH) {
|
||||
return;
|
||||
}
|
||||
const float height_scale = float(ushort_arg_2) / 10000;
|
||||
const float width_scale = float(ushort_arg_3) / 10000;
|
||||
constexpr sampler s(coord::pixel, address::clamp_to_edge, filter::nearest);
|
||||
const int in_y = (int)(gid.y / height_scale);
|
||||
const int in_x = (int)(gid.x / width_scale);
|
||||
out.write(in.sample(s, float2(in_x, in_y)), gid.xy);
|
||||
}
|
||||
|
||||
)METAL_SHADERS";
|
||||
|
||||
#endif /* MPSCNNShaders_h */
|
47
aten/src/ATen/native/metal/MetalTensor.h
Normal file
47
aten/src/ATen/native/metal/MetalTensor.h
Normal file
@ -0,0 +1,47 @@
|
||||
#include <torch/script.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace metal {
|
||||
|
||||
class MPSImageWrapper;
|
||||
class MetalTensor final {
|
||||
class Impl;
|
||||
|
||||
public:
|
||||
MetalTensor(){};
|
||||
explicit MetalTensor(const std::vector<int64_t>& sizes);
|
||||
explicit MetalTensor(
|
||||
const std::vector<int64_t>& sizes,
|
||||
const std::vector<int64_t>& strides);
|
||||
~MetalTensor() = default;
|
||||
|
||||
MetalTensor(MetalTensor&&) = default;
|
||||
MetalTensor& operator=(MetalTensor&&) = default;
|
||||
|
||||
MetalTensor(const MetalTensor&) = default;
|
||||
MetalTensor& operator=(const MetalTensor&) = default;
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& output, const MetalTensor& mt);
|
||||
|
||||
static at::Tensor toTensor(MetalTensor&& mt, const TensorOptions& options);
|
||||
static MetalTensor& fromTensor(const at::Tensor& tensor);
|
||||
|
||||
bool defined() const;
|
||||
IntArrayRef sizes() const;
|
||||
IntArrayRef strides() const;
|
||||
int64_t dim() const;
|
||||
int64_t numel() const;
|
||||
void set_data_from_host(const float* inputData);
|
||||
void copy_data_to_host(float* host);
|
||||
MPSImageWrapper* texture() const;
|
||||
|
||||
private:
|
||||
std::shared_ptr<Impl> impl();
|
||||
std::shared_ptr<const Impl> impl() const;
|
||||
std::shared_ptr<Impl> _impl;
|
||||
};
|
||||
|
||||
} // namespace metal
|
||||
} // namespace native
|
||||
} // namespace at
|
147
aten/src/ATen/native/metal/MetalTensor.mm
Normal file
147
aten/src/ATen/native/metal/MetalTensor.mm
Normal file
@ -0,0 +1,147 @@
|
||||
#import <ATen/native/metal/MetalTensor.h>
|
||||
#import <ATen/native/metal/MetalTensorImpl.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSImageWrapper.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace metal {
|
||||
|
||||
class API_AVAILABLE(ios(10.0), macos(10.13)) MetalTensor::Impl {
|
||||
public:
|
||||
Impl(const std::vector<int64_t>& sizes)
|
||||
: Impl(sizes, std::vector<int64_t>(sizes.size())) {}
|
||||
|
||||
Impl(const std::vector<int64_t>& sizes, const std::vector<int64_t>& strides)
|
||||
: _sizes(sizes),
|
||||
_strides(strides),
|
||||
_numel(std::accumulate(
|
||||
std::begin(_sizes),
|
||||
std::end(_sizes),
|
||||
1,
|
||||
std::multiplies<int64_t>())),
|
||||
_textureImpl(std::make_unique<MPSImageWrapper>(sizes)) {}
|
||||
|
||||
IntArrayRef sizes() const {
|
||||
return _sizes;
|
||||
}
|
||||
IntArrayRef strides() const {
|
||||
return _strides;
|
||||
}
|
||||
int64_t dim() const {
|
||||
return _sizes.size();
|
||||
}
|
||||
int64_t numel() const {
|
||||
return _numel;
|
||||
}
|
||||
void set_data_from_host(const float* inputData) {
|
||||
_textureImpl->copyDataFromHost(inputData);
|
||||
}
|
||||
void copy_data_to_host(float* host) {
|
||||
_textureImpl->copyDataToHost(host);
|
||||
}
|
||||
MPSImageWrapper* texture() const {
|
||||
return _textureImpl.get();
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int64_t> _sizes;
|
||||
std::vector<int64_t> _strides;
|
||||
int64_t _numel;
|
||||
std::unique_ptr<MPSImageWrapper> _textureImpl;
|
||||
};
|
||||
|
||||
MetalTensor::MetalTensor(const std::vector<int64_t>& sizes)
|
||||
: MetalTensor(sizes, std::vector<int64_t>(sizes.size())) {} // fake strides
|
||||
|
||||
MetalTensor::MetalTensor(
|
||||
const std::vector<int64_t>& sizes,
|
||||
const std::vector<int64_t>& strides)
|
||||
: _impl(std::make_shared<Impl>(std::move(sizes), std::move(strides))) {}
|
||||
|
||||
bool MetalTensor::defined() const {
|
||||
return static_cast<bool>(_impl);
|
||||
}
|
||||
|
||||
at::Tensor MetalTensor::toTensor(
|
||||
MetalTensor&& mt,
|
||||
const TensorOptions& options) {
|
||||
using MetalTensorImpl = at::MetalTensorImpl<MetalTensor>;
|
||||
auto sizes = mt.sizes(); // sizes is stored in TensorImpl
|
||||
auto strides = mt.strides(); // strides is stored in MetalTensorImpl
|
||||
return detail::make_tensor<MetalTensorImpl>(
|
||||
DispatchKeySet(DispatchKey::Metal),
|
||||
options.dtype(),
|
||||
at::Device(at::kMetal),
|
||||
std::move(mt),
|
||||
std::vector<int64_t>(sizes.begin(), sizes.end()),
|
||||
std::vector<int64_t>(strides.begin(), strides.end()));
|
||||
}
|
||||
|
||||
MetalTensor& MetalTensor::fromTensor(const at::Tensor& tensor) {
|
||||
using MetalTensorImpl = at::MetalTensorImpl<MetalTensor>;
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
tensor.is_metal(), "unbox expects Metal tensor as inputs");
|
||||
MetalTensorImpl* impl =
|
||||
static_cast<MetalTensorImpl*>(tensor.unsafeGetTensorImpl());
|
||||
return impl->unsafe_opaque_handle();
|
||||
}
|
||||
|
||||
std::shared_ptr<MetalTensor::Impl> MetalTensor::impl() {
|
||||
return _impl;
|
||||
}
|
||||
|
||||
std::shared_ptr<const MetalTensor::Impl> MetalTensor::impl() const {
|
||||
return _impl;
|
||||
}
|
||||
|
||||
IntArrayRef MetalTensor::sizes() const {
|
||||
return impl()->sizes();
|
||||
}
|
||||
|
||||
IntArrayRef MetalTensor::strides() const {
|
||||
return impl()->strides();
|
||||
}
|
||||
|
||||
int64_t MetalTensor::dim() const {
|
||||
return impl()->dim();
|
||||
}
|
||||
|
||||
int64_t MetalTensor::numel() const {
|
||||
return impl()->numel();
|
||||
}
|
||||
|
||||
void MetalTensor::set_data_from_host(const float* inputData) {
|
||||
impl()->set_data_from_host(inputData);
|
||||
}
|
||||
|
||||
void MetalTensor::copy_data_to_host(float* hostData) {
|
||||
impl()->copy_data_to_host(hostData);
|
||||
}
|
||||
|
||||
API_AVAILABLE(ios(10.0))
|
||||
MPSImageWrapper* MetalTensor::texture() const {
|
||||
return impl()->texture();
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& output, const MetalTensor& mt) {
|
||||
auto&& sizes = mt.sizes();
|
||||
auto&& strides = mt.strides();
|
||||
output << "[MetalTensor] | Size:{";
|
||||
std::ostringstream oss;
|
||||
std::copy(
|
||||
sizes.begin(), sizes.end() - 1, std::ostream_iterator<int>(oss, ","));
|
||||
oss << sizes.back();
|
||||
output << oss.str() << "}, Stride:{";
|
||||
std::string sizesStr = oss.str();
|
||||
oss.str("");
|
||||
oss.clear();
|
||||
std::copy(
|
||||
strides.begin(), strides.end() - 1, std::ostream_iterator<int>(oss, ","));
|
||||
oss << sizes.back();
|
||||
output << oss.str() << "}";
|
||||
return output;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
55
aten/src/ATen/native/metal/MetalTensorImpl.h
Normal file
55
aten/src/ATen/native/metal/MetalTensorImpl.h
Normal file
@ -0,0 +1,55 @@
|
||||
#ifndef MetalTensorImpl_h
|
||||
#define MetalTensorImpl_h
|
||||
|
||||
#include <ATen/OpaqueTensorImpl.h>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
#import <ATen/native/metal/MetalTensor.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSImageWrapper.h>
|
||||
|
||||
namespace at {
|
||||
template <typename OpaqueHandle>
|
||||
struct TORCH_API MetalTensorImpl : public OpaqueTensorImpl<OpaqueHandle> {
|
||||
MetalTensorImpl(
|
||||
at::DispatchKeySet key_set,
|
||||
const caffe2::TypeMeta& data_type,
|
||||
c10::Device device,
|
||||
OpaqueHandle opaque_handle,
|
||||
c10::IntArrayRef sizes,
|
||||
c10::IntArrayRef strides)
|
||||
: OpaqueTensorImpl<OpaqueHandle>(
|
||||
key_set,
|
||||
data_type,
|
||||
device,
|
||||
opaque_handle,
|
||||
sizes),
|
||||
strides_(strides.vec()) {}
|
||||
|
||||
IntArrayRef strides() const override {
|
||||
return strides_;
|
||||
}
|
||||
|
||||
bool is_contiguous(
|
||||
c10::MemoryFormat memory_format =
|
||||
c10::MemoryFormat::Contiguous) const override {
|
||||
return true;
|
||||
}
|
||||
|
||||
int64_t stride(int64_t d) const override {
|
||||
d = at::maybe_wrap_dim(d, this->dim(), false);
|
||||
return strides_[d];
|
||||
}
|
||||
|
||||
void release_resources() override {
|
||||
using MetalTensor = at::native::metal::MetalTensor;
|
||||
auto&& handle = (MetalTensor)this->opaque_handle();
|
||||
handle.texture()->recycleImage();
|
||||
OpaqueTensorImpl<OpaqueHandle>::release_resources();
|
||||
}
|
||||
|
||||
private:
|
||||
SmallVector<int64_t, 5> strides_;
|
||||
};
|
||||
} // namespace at
|
||||
|
||||
#endif /* MetalTensorImpl_h*/
|
23
aten/src/ATen/native/metal/MetalUtils.h
Normal file
23
aten/src/ATen/native/metal/MetalUtils.h
Normal file
@ -0,0 +1,23 @@
|
||||
#include <vector>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace metal {
|
||||
|
||||
std::vector<uint16_t> fp32_to_fp16(const std::vector<float>& src);
|
||||
std::vector<float> fp16_to_fp32(const std::vector<uint16_t>& src);
|
||||
std::vector<float> NCHW_to_NC4(
|
||||
const float* src,
|
||||
const std::vector<int64_t>& sizes);
|
||||
std::vector<float> NC4_to_NCHW(
|
||||
const float* src,
|
||||
const std::vector<int64_t>& sizes);
|
||||
// The MPSCNNConvolution class takes weights in the order
|
||||
// [outputChannels][kernelHeight][kernelWidth][inputChannels/groups].
|
||||
std::vector<float> permuteWeights(
|
||||
const float* src,
|
||||
const std::vector<int64_t>& sizes);
|
||||
|
||||
} // namespace metal
|
||||
} // namespace native
|
||||
} // namespace at
|
132
aten/src/ATen/native/metal/MetalUtils.mm
Normal file
132
aten/src/ATen/native/metal/MetalUtils.mm
Normal file
@ -0,0 +1,132 @@
|
||||
#import <Accelerate/Accelerate.h>
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
#import <ATen/native/metal/MetalUtils.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Logging.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace metal {
|
||||
|
||||
std::vector<uint16_t> fp32_to_fp16(const std::vector<float>& src) {
|
||||
unsigned long count = src.size();
|
||||
std::vector<uint16_t> output(count, 0);
|
||||
vImage_Buffer float32{(void*)src.data(), 1, count, count * sizeof(float)};
|
||||
vImage_Buffer float16{
|
||||
(void*)output.data(), 1, count, count * sizeof(uint16_t)};
|
||||
if (vImageConvert_PlanarFtoPlanar16F(&float32, &float16, 0) !=
|
||||
kvImageNoError) {
|
||||
TORCH_CHECK(false, "fp32_to_fp16 failed");
|
||||
return {};
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
std::vector<float> fp16_to_fp32(const std::vector<uint16_t>& src) {
|
||||
unsigned long count = src.size();
|
||||
std::vector<float> output(count, 0);
|
||||
vImage_Buffer float16{(void*)src.data(), 1, count, count * sizeof(uint16_t)};
|
||||
vImage_Buffer float32{(void*)output.data(), 1, count, count * sizeof(float)};
|
||||
if (vImageConvert_Planar16FtoPlanarF(&float16, &float32, 0) !=
|
||||
kvImageNoError) {
|
||||
TORCH_CHECK(false, "fp16_to_fp32 failed");
|
||||
return {};
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
std::vector<float> NCHW_to_NC4(
|
||||
const float* src,
|
||||
const std::vector<int64_t>& sizes) {
|
||||
int64_t N = sizes[0];
|
||||
int64_t C = sizes[1];
|
||||
int64_t H = sizes[2];
|
||||
int64_t W = sizes[3];
|
||||
int64_t src_image_count = C * H * W;
|
||||
int64_t src_count = N * src_image_count;
|
||||
int64_t slices = (C + 3) / 4;
|
||||
int64_t numComponents = C < 3 ? C : 4;
|
||||
int64_t dst_image_count = slices * numComponents * W * H;
|
||||
int64_t dst_count = N * dst_image_count;
|
||||
std::vector<float> output(dst_count, 0.0f);
|
||||
for (int n = 0; n < N; ++n) {
|
||||
int64_t src_image = n * src_image_count;
|
||||
int64_t dst_image = n * dst_image_count;
|
||||
for (int i = 0; i < slices; ++i) {
|
||||
int64_t slice = i * W * H * numComponents;
|
||||
for (int j = 0; j < W * H; ++j) {
|
||||
for (int k = 0; k < numComponents; ++k) {
|
||||
int ii = src_image + slice + k * W * H + j;
|
||||
int oi = dst_image + slice + j * numComponents + k;
|
||||
if (k < C && ii < src_count) {
|
||||
output[oi] = src[ii];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
std::vector<float> NC4_to_NCHW(
|
||||
const float* src,
|
||||
const std::vector<int64_t>& sizes) {
|
||||
int64_t N = sizes[0];
|
||||
int64_t C = sizes[1];
|
||||
int64_t H = sizes[2];
|
||||
int64_t W = sizes[3];
|
||||
int64_t slices = (C + 3) / 4;
|
||||
int64_t numComponents = C < 3 ? C : 4;
|
||||
int64_t src_image_count = slices * numComponents * W * H;
|
||||
int64_t dst_image_count = C * H * W;
|
||||
int64_t dst_count = N * dst_image_count;
|
||||
std::vector<float> output(dst_count, 0.0f);
|
||||
for (int n = 0; n < N; ++n) {
|
||||
int64_t src_image = n * src_image_count;
|
||||
int64_t dst_image = n * dst_image_count;
|
||||
for (int i = 0; i < slices; ++i) {
|
||||
int64_t slice = i * W * H * numComponents;
|
||||
for (int j = 0; j < numComponents; ++j) {
|
||||
for (int k = 0; k < W * H; ++k) {
|
||||
int ii = src_image + slice + k * numComponents + j;
|
||||
int oi = dst_image + slice + j * W * H + k;
|
||||
if (j < C && oi < dst_count) {
|
||||
output[oi] = src[ii];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
std::vector<float> permuteWeights(
|
||||
const float* src,
|
||||
const std::vector<int64_t>& sizes) {
|
||||
const int64_t M = sizes[0];
|
||||
const int64_t Cf = sizes[1];
|
||||
const int64_t kH = sizes[2];
|
||||
const int64_t kW = sizes[3];
|
||||
std::vector<float> packedWeights(M * kH * kW * Cf);
|
||||
for (auto m = 0; m < M; ++m) {
|
||||
for (auto c = 0; c < Cf; ++c) {
|
||||
for (auto kh = 0; kh < kH; ++kh) {
|
||||
for (auto kw = 0; kw < kW; ++kw) {
|
||||
int64_t oc = m * kH * kW * Cf + kh * kW * Cf + kw * Cf + c;
|
||||
int64_t ic = m * Cf * kH * kW + c * kH * kW + kh * kW + kw;
|
||||
packedWeights[oc] = src[ic];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return packedWeights;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
31
aten/src/ATen/native/metal/mpscnn/MPSCNN.h
Normal file
31
aten/src/ATen/native/metal/mpscnn/MPSCNN.h
Normal file
@ -0,0 +1,31 @@
|
||||
#import <Metal/Metal.h>
|
||||
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace metal {
|
||||
namespace mpscnn {
|
||||
|
||||
struct LaunchParams {
|
||||
MTLSize threadsPerThreadgroup;
|
||||
MTLSize threadgroupsPerGrid;
|
||||
MTLSize threadsPerGrid; // iOS 11.0
|
||||
};
|
||||
|
||||
API_AVAILABLE(ios(10.0), macos(10.13))
|
||||
LaunchParams spatialPointwiseKernelLaunchParams(
|
||||
id<MTLComputePipelineState> pipeline,
|
||||
MPSImage* im);
|
||||
|
||||
API_AVAILABLE(ios(10.0), macos(10.13))
|
||||
NSString* kernelFor(
|
||||
MPSImage* image,
|
||||
NSString* arrayKernel,
|
||||
NSString* nonArrayKernel);
|
||||
|
||||
int computeMPSAlignOffset(int kernel, int pad);
|
||||
|
||||
}
|
||||
} // namespace metal
|
||||
} // namespace native
|
||||
} // namespace at
|
59
aten/src/ATen/native/metal/mpscnn/MPSCNN.mm
Normal file
59
aten/src/ATen/native/metal/mpscnn/MPSCNN.mm
Normal file
@ -0,0 +1,59 @@
|
||||
#import <ATen/native/metal/mpscnn/MPSCNN.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace metal {
|
||||
namespace mpscnn {
|
||||
|
||||
auto divRoundUp(uint x, uint y) -> uint {
|
||||
return (x + y - 1) / y;
|
||||
}
|
||||
|
||||
int computeMPSAlignOffset(int kernel, int pad) {
|
||||
// To set the offset, we can just match the top-left pixel (in the input
|
||||
// image, with negative values for padding) that we look at. For 3x3s1p1, we
|
||||
// look at the (-1, -1) pixel in the original impl. For 3x3s1p0, we look at
|
||||
// (0, 0) pixel. For 3x3s1p2, look at (-2, -2) MPSCNN always looks at
|
||||
// (-floor(kernel_size - 1 / 2), -floor(kernel_size - 1 / 2)) Thus, we just
|
||||
// need to match this up.
|
||||
|
||||
// For 3x3s1p1, offset should be (0, 0)
|
||||
// For 3x3s1p0, offset should be (1, 1)
|
||||
// For 3x3s1p2, offset should be (-1, -1)
|
||||
const int mps_offset = kernel / 2;
|
||||
const int c2_offset = pad;
|
||||
return mps_offset - c2_offset;
|
||||
}
|
||||
|
||||
NSString* kernelFor(
|
||||
MPSImage* X,
|
||||
NSString* arrayKernel,
|
||||
NSString* nonArrayKernel) {
|
||||
if (X.featureChannels > 4 || X.numberOfImages > 1) {
|
||||
return arrayKernel;
|
||||
}
|
||||
return nonArrayKernel;
|
||||
}
|
||||
|
||||
LaunchParams spatialPointwiseKernelLaunchParams(
|
||||
id<MTLComputePipelineState> pipeline,
|
||||
MPSImage* im) {
|
||||
const auto threadsPerThreadgroup = MTLSizeMake(
|
||||
8 /* threadExecutionWidth */,
|
||||
4 /* maxThreadsPerThreadgroup / threadExecutionWidth */,
|
||||
1);
|
||||
const auto threadgroupsPerGrid = MTLSizeMake(
|
||||
divRoundUp(im.width, threadsPerThreadgroup.width),
|
||||
divRoundUp(im.height, threadsPerThreadgroup.height),
|
||||
im.numberOfImages * divRoundUp(im.featureChannels, 4));
|
||||
const auto threadsPerGrid = MTLSizeMake(
|
||||
im.width,
|
||||
im.height,
|
||||
im.numberOfImages * divRoundUp(im.featureChannels, 4));
|
||||
return {threadsPerThreadgroup, threadgroupsPerGrid, threadsPerGrid};
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
5
aten/src/ATen/native/metal/mpscnn/MPSCNNClampOp.h
Normal file
5
aten/src/ATen/native/metal/mpscnn/MPSCNNClampOp.h
Normal file
@ -0,0 +1,5 @@
|
||||
#import <ATen/native/metal/mpscnn/MPSCNNOp.h>
|
||||
|
||||
@interface MPSCNNClampOp : NSObject<MPSCNNShaderOp>
|
||||
|
||||
@end
|
53
aten/src/ATen/native/metal/mpscnn/MPSCNNClampOp.mm
Normal file
53
aten/src/ATen/native/metal/mpscnn/MPSCNNClampOp.mm
Normal file
@ -0,0 +1,53 @@
|
||||
#import <ATen/native/metal/mpscnn/MPSCNN.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSCNNClampOp.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSCNNContext.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
|
||||
|
||||
@implementation MPSCNNClampOp {
|
||||
MPSImage* _X;
|
||||
MPSImage* _Y;
|
||||
NSNumber* _min;
|
||||
NSNumber* _max;
|
||||
}
|
||||
|
||||
+ (id<MPSCNNShaderOp>)newWithTextures:(NSArray<MPSImage*>*)textures
|
||||
Args:(NSArray<NSNumber*>*)args {
|
||||
MPSCNNClampOp* op = [MPSCNNClampOp new];
|
||||
op->_X = textures[0];
|
||||
op->_Y = textures[1];
|
||||
op->_min = args[0];
|
||||
op->_max = args[1];
|
||||
|
||||
return op;
|
||||
}
|
||||
|
||||
- (void)encode:(id<MTLCommandBuffer>)cb {
|
||||
/*
|
||||
`clamp(vector<half4>, float, float)` is not available on iOS 10.0,
|
||||
have to use `clamp(vector<half4>, half4, half4)` instead.
|
||||
*/
|
||||
id<MTLComputeCommandEncoder> encoder = [cb computeCommandEncoder];
|
||||
id<MTLComputePipelineState> state = [[MPSCNNContext sharedInstance]
|
||||
pipelineState:at::native::metal::mpscnn::kernelFor(
|
||||
_X, @"clamp_half4", @"clamp_half4_nonarray")];
|
||||
|
||||
[encoder setComputePipelineState:state];
|
||||
[encoder setTexture:[_X texture] atIndex:0];
|
||||
[encoder setTexture:[_Y texture] atIndex:1];
|
||||
id<MTLBuffer> clampBuffer = [[MPSCNNContext sharedInstance].device
|
||||
newBufferWithLength:2 * sizeof(fp16)
|
||||
options:MTLResourceOptionCPUCacheModeWriteCombined];
|
||||
fp16* clampBufferPtr = (fp16*)[clampBuffer contents];
|
||||
clampBufferPtr[0] = _min.floatValue;
|
||||
clampBufferPtr[1] = _max.floatValue;
|
||||
[encoder setBuffer:clampBuffer offset:0 atIndex:0];
|
||||
const auto& launchParams =
|
||||
at::native::metal::mpscnn::spatialPointwiseKernelLaunchParams(state, _Y);
|
||||
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
|
||||
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
|
||||
[encoder endEncoding];
|
||||
[_X markRead];
|
||||
[_Y markRead];
|
||||
}
|
||||
|
||||
@end
|
18
aten/src/ATen/native/metal/mpscnn/MPSCNNContext.h
Normal file
18
aten/src/ATen/native/metal/mpscnn/MPSCNNContext.h
Normal file
@ -0,0 +1,18 @@
|
||||
#import <Foundation/Foundation.h>
|
||||
#import <Metal/Metal.h>
|
||||
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
||||
|
||||
API_AVAILABLE(ios(10.0), macos(10.13))
|
||||
@interface MPSCNNContext : NSObject
|
||||
@property(nonatomic, strong, readonly) id<MTLDevice> device;
|
||||
@property(nonatomic, strong, readonly) id<MTLCommandQueue> commandQueue;
|
||||
@property(nonatomic, strong, readonly) id<MTLLibrary> library;
|
||||
|
||||
+ (instancetype)sharedInstance;
|
||||
- (BOOL)available;
|
||||
- (id<MTLComputePipelineState>)pipelineState:(NSString*)kernel;
|
||||
- (id<MTLComputePipelineState>)specializedPipelineState:(NSString*)kernel
|
||||
Constants:(NSArray<NSNumber*>*)
|
||||
constants;
|
||||
|
||||
@end
|
121
aten/src/ATen/native/metal/mpscnn/MPSCNNContext.mm
Normal file
121
aten/src/ATen/native/metal/mpscnn/MPSCNNContext.mm
Normal file
@ -0,0 +1,121 @@
|
||||
#import <ATen/native/metal/MetalShaders.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSCNNContext.h>
|
||||
|
||||
#include <torch/script.h>
|
||||
#include <mutex>
|
||||
|
||||
#if defined(C10_IOS)
|
||||
#import <UIKit/UIKit.h>
|
||||
#endif
|
||||
|
||||
@implementation MPSCNNContext {
|
||||
std::mutex _pipelineCacheMutex;
|
||||
NSMutableDictionary<NSString*, id<MTLComputePipelineState>>* _pipelineCache;
|
||||
}
|
||||
|
||||
+ (instancetype)sharedInstance {
|
||||
static dispatch_once_t onceToken;
|
||||
static MPSCNNContext* instance = nil;
|
||||
dispatch_once(&onceToken, ^{
|
||||
instance = [[MPSCNNContext alloc] init];
|
||||
instance->_device = MTLCreateSystemDefaultDevice();
|
||||
instance->_library = [instance.device
|
||||
newLibraryWithSource:[NSString stringWithUTF8String:METAL_SHADERS]
|
||||
options:nil
|
||||
error:nil];
|
||||
instance->_commandQueue = [instance.device newCommandQueue];
|
||||
instance->_pipelineCache =
|
||||
[NSMutableDictionary<NSString*, id<MTLComputePipelineState>> new];
|
||||
});
|
||||
return instance;
|
||||
}
|
||||
|
||||
- (BOOL)available {
|
||||
#if defined(C10_IOS)
|
||||
#if TARGET_IPHONE_SIMULATOR
|
||||
return false;
|
||||
#else
|
||||
if (!MPSSupportsMTLDevice(_device)) {
|
||||
return false;
|
||||
}
|
||||
if ([UIDevice currentDevice].systemVersion.floatValue < 10.2) {
|
||||
return false;
|
||||
}
|
||||
if (![MTLCreateSystemDefaultDevice()
|
||||
supportsFeatureSet:MTLFeatureSet_iOS_GPUFamily3_v2]) {
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
return _device && _library && _commandQueue;
|
||||
}
|
||||
|
||||
- (id<MTLComputePipelineState>)pipelineState:(NSString*)kernel {
|
||||
TORCH_CHECK(_library, "Failed to load kernels");
|
||||
std::lock_guard<std::mutex> g(_pipelineCacheMutex);
|
||||
id<MTLComputePipelineState> state = _pipelineCache[kernel];
|
||||
if (state) {
|
||||
return state;
|
||||
}
|
||||
id<MTLFunction> func = [_library newFunctionWithName:kernel];
|
||||
TORCH_CHECK(func != nil, "Failed to load the kernel function", kernel);
|
||||
NSError* errors;
|
||||
state = [_device newComputePipelineStateWithFunction:func error:&errors];
|
||||
TORCH_CHECK(state != nil, errors.localizedDescription.UTF8String);
|
||||
_pipelineCache[kernel] = state;
|
||||
return state;
|
||||
}
|
||||
|
||||
- (id<MTLComputePipelineState>)specializedPipelineState:(NSString*)kernel
|
||||
Constants:(NSArray<NSNumber*>*)
|
||||
constants {
|
||||
TORCH_CHECK(_library, "Failed to load kernels");
|
||||
std::string kernelStr = std::string([kernel UTF8String]);
|
||||
for (auto i = 0; i < constants.count; ++i) {
|
||||
kernelStr += "_" + std::string([constants[i] stringValue].UTF8String);
|
||||
}
|
||||
std::lock_guard<std::mutex> g(_pipelineCacheMutex);
|
||||
id<MTLComputePipelineState> state = _pipelineCache[kernel];
|
||||
if (state) {
|
||||
return state;
|
||||
}
|
||||
MTLFunctionConstantValues* constantValues = [MTLFunctionConstantValues new];
|
||||
NSUInteger ushortArgIndex = 0;
|
||||
NSUInteger floatArgIndex = 10;
|
||||
for (auto i = 0; i < constants.count; ++i) {
|
||||
NSNumber* constant = constants[i];
|
||||
const char* type = constant.objCType;
|
||||
if (strcmp(type, @encode(NSUInteger)) == 0 ||
|
||||
strcmp(type, @encode(NSInteger)) == 0) {
|
||||
TORCH_CHECK(ushortArgIndex <= 10);
|
||||
ushort value = ushort([constant unsignedIntegerValue]);
|
||||
[constantValues setConstantValue:&value
|
||||
type:MTLDataTypeUShort
|
||||
atIndex:ushortArgIndex];
|
||||
ushortArgIndex++;
|
||||
}
|
||||
if (strcmp(type, @encode(float)) == 0 ||
|
||||
strcmp(type, @encode(double)) == 0) {
|
||||
TORCH_CHECK(floatArgIndex <= 2);
|
||||
float value = [constant floatValue];
|
||||
[constantValues setConstantValue:&value
|
||||
type:MTLDataTypeFloat
|
||||
atIndex:floatArgIndex];
|
||||
floatArgIndex++;
|
||||
}
|
||||
}
|
||||
NSError* errors;
|
||||
id<MTLFunction> func = [_library newFunctionWithName:kernel
|
||||
constantValues:constantValues
|
||||
error:&errors];
|
||||
TORCH_CHECK(
|
||||
func, "Couldn't get function: ", errors.localizedDescription.UTF8String);
|
||||
state = [_device newComputePipelineStateWithFunction:func error:&errors];
|
||||
TORCH_CHECK(state != nil, errors.localizedDescription.UTF8String);
|
||||
kernel = [NSString stringWithCString:kernelStr.c_str()
|
||||
encoding:NSUTF8StringEncoding];
|
||||
_pipelineCache[kernel] = state;
|
||||
return state;
|
||||
}
|
||||
|
||||
@end
|
23
aten/src/ATen/native/metal/mpscnn/MPSCNNConvOp.h
Normal file
23
aten/src/ATen/native/metal/mpscnn/MPSCNNConvOp.h
Normal file
@ -0,0 +1,23 @@
|
||||
#import <ATen/native/metal/MetalConvolution.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSCNNOp.h>
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
API_AVAILABLE(ios(10.0), macos(10.13))
|
||||
@interface MPSCNNConvDataSource : NSObject<MPSCNNConvolutionDataSource>
|
||||
@property(nonatomic, assign) void* weights;
|
||||
@property(nonatomic, assign) float* bias;
|
||||
|
||||
- (id)initWithWeights:(void*)weights
|
||||
Bias:(float*)bias
|
||||
Desc:(MPSCNNConvolutionDescriptor*)desc;
|
||||
|
||||
@end
|
||||
|
||||
using namespace at::native::metal;
|
||||
API_AVAILABLE(ios(10.0), macos(10.13))
|
||||
@interface MPSCNNConvOp : NSObject<MPSCNNOp>
|
||||
+ (MPSCNNConvOp*)conv2d:(const Conv2DParams&)params
|
||||
weights:(float*)w
|
||||
bias:(float*)b
|
||||
neuronFilter:(NeuronType)t;
|
||||
@end
|
173
aten/src/ATen/native/metal/mpscnn/MPSCNNConvOp.mm
Normal file
173
aten/src/ATen/native/metal/mpscnn/MPSCNNConvOp.mm
Normal file
@ -0,0 +1,173 @@
|
||||
#import <ATen/native/metal/mpscnn/MPSCNN.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSCNNContext.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSCNNConvOp.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSCNNNeuronOp.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
@implementation MPSCNNConvDataSource {
|
||||
void* _weights;
|
||||
float* _bias;
|
||||
MPSCNNConvolutionDescriptor* _descriptor;
|
||||
}
|
||||
|
||||
- (id)initWithWeights:(void*)weights
|
||||
Bias:(float*)bias
|
||||
Desc:(MPSCNNConvolutionDescriptor*)desc
|
||||
API_AVAILABLE(ios(10.0), macos(10.13)) {
|
||||
self = [super init];
|
||||
if (self) {
|
||||
_weights = (float*)weights;
|
||||
_bias = (float*)bias;
|
||||
_descriptor = desc;
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
- (nonnull id)copyWithZone:(nullable NSZone*)zone {
|
||||
MPSCNNConvDataSource* dataSource = [MPSCNNConvDataSource allocWithZone:zone];
|
||||
dataSource->_weights = _weights;
|
||||
dataSource->_bias = _bias;
|
||||
dataSource->_descriptor = _descriptor;
|
||||
return dataSource;
|
||||
}
|
||||
|
||||
- (float* _Nullable)biasTerms {
|
||||
return _bias;
|
||||
}
|
||||
|
||||
- (MPSDataType)dataType API_AVAILABLE(ios(10.0), macos(10.13)) {
|
||||
return MPSDataTypeFloat32;
|
||||
}
|
||||
|
||||
- (NSString* _Nullable)label {
|
||||
return @"";
|
||||
}
|
||||
|
||||
- (BOOL)load {
|
||||
return true;
|
||||
}
|
||||
|
||||
- (void)purge {
|
||||
_bias = nullptr;
|
||||
_weights = nullptr;
|
||||
}
|
||||
|
||||
- (void*)weights {
|
||||
return _weights;
|
||||
}
|
||||
|
||||
- (MPSCNNConvolutionDescriptor* _Nonnull)descriptor {
|
||||
return _descriptor;
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
@implementation MPSCNNConvOp {
|
||||
}
|
||||
|
||||
@synthesize kernel = _kernel;
|
||||
|
||||
+ (MPSCNNConvOp*)conv2d:(const Conv2DParams&)params
|
||||
weights:(float*)w
|
||||
bias:(float*)b
|
||||
neuronFilter:(NeuronType)t API_AVAILABLE(ios(10.0), macos(10.13)) {
|
||||
using namespace at::native::metal::mpscnn;
|
||||
TORCH_CHECK(
|
||||
params.DX == params.DY == 1, "Dilated convolution is not supported yet.");
|
||||
const int64_t oC = params.OC;
|
||||
const int64_t iC = params.C;
|
||||
const int64_t kH = params.KH;
|
||||
const int64_t kW = params.KW;
|
||||
MPSCNNNeuron* neuron = [MPSCNNConvOp neuron:t];
|
||||
MPSCNNConvolutionDescriptor* desc = nil;
|
||||
if (params.isDepthwise()) {
|
||||
if (@available(iOS 11.0, *)) {
|
||||
desc = [MPSCNNDepthWiseConvolutionDescriptor
|
||||
cnnConvolutionDescriptorWithKernelWidth:kW
|
||||
kernelHeight:kH
|
||||
inputFeatureChannels:iC
|
||||
outputFeatureChannels:oC
|
||||
neuronFilter:neuron];
|
||||
desc.groups = 1;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"MPSCNNDepthWiseConvolutionDescriptor is only available on iOS 11.0 and above");
|
||||
}
|
||||
} else {
|
||||
if (params.G > 1) {
|
||||
TORCH_CHECK(
|
||||
params.IC % 4 == 0,
|
||||
"MPSCNNConvolution requires number of input \
|
||||
channels in each group to be multiple of 4 for \
|
||||
group > 1.");
|
||||
}
|
||||
desc = [MPSCNNConvolutionDescriptor
|
||||
cnnConvolutionDescriptorWithKernelWidth:kW
|
||||
kernelHeight:kH
|
||||
inputFeatureChannels:iC
|
||||
outputFeatureChannels:oC
|
||||
neuronFilter:neuron];
|
||||
desc.groups = params.G;
|
||||
}
|
||||
desc.strideInPixelsX = params.SX;
|
||||
desc.strideInPixelsY = params.SY;
|
||||
id<MPSCNNConvolutionDataSource> dataSource =
|
||||
[[MPSCNNConvDataSource alloc] initWithWeights:(float*)w
|
||||
Bias:(float*)b
|
||||
Desc:desc];
|
||||
MPSCNNConvolution* conv = nil;
|
||||
if (@available(iOS 11.0, *)) {
|
||||
conv = [[MPSCNNConvolution alloc]
|
||||
initWithDevice:[MPSCNNContext sharedInstance].device
|
||||
weights:dataSource];
|
||||
|
||||
} else {
|
||||
#if TARGET_OS_IPHONE
|
||||
// Fallback on earlier versions
|
||||
conv = [[MPSCNNConvolution alloc]
|
||||
initWithDevice:[MPSCNNContext sharedInstance].device
|
||||
convolutionDescriptor:desc
|
||||
kernelWeights:w
|
||||
biasTerms:b
|
||||
flags:MPSCNNConvolutionFlagsNone];
|
||||
#endif
|
||||
}
|
||||
[conv setEdgeMode:MPSImageEdgeModeZero];
|
||||
MPSOffset offset;
|
||||
offset.x = computeMPSAlignOffset(kW, params.PX);
|
||||
offset.y = computeMPSAlignOffset(kH, params.PY);
|
||||
offset.z = 0;
|
||||
[conv setOffset:offset];
|
||||
|
||||
TORCH_CHECK(conv.inputFeatureChannels == params.IC * params.G);
|
||||
TORCH_CHECK(oC % conv.groups == 0);
|
||||
TORCH_CHECK(conv.outputFeatureChannels == oC);
|
||||
TORCH_CHECK(conv.kernelWidth == kW);
|
||||
TORCH_CHECK(conv.kernelHeight == kH);
|
||||
|
||||
MPSCNNConvOp* op = [MPSCNNConvOp new];
|
||||
op->_kernel = conv;
|
||||
return op;
|
||||
}
|
||||
|
||||
- (void)encode:(id<MTLCommandBuffer>)cb
|
||||
sourceImage:(MPSImage*)src
|
||||
destinationImage:(MPSImage*)dst {
|
||||
[_kernel encodeToCommandBuffer:cb sourceImage:src destinationImage:dst];
|
||||
}
|
||||
|
||||
+ (MPSCNNNeuron*)neuron:(NeuronType)type {
|
||||
if (type == NeuronType::Relu) {
|
||||
return [MPSCNNNeuronOp relu];
|
||||
} else if (type == NeuronType::Sigmoid) {
|
||||
return [MPSCNNNeuronOp sigmoid];
|
||||
} else if (type == NeuronType::Tanh) {
|
||||
return [MPSCNNNeuronOp tanh];
|
||||
} else {
|
||||
return nil;
|
||||
}
|
||||
}
|
||||
|
||||
@end
|
12
aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.h
Normal file
12
aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.h
Normal file
@ -0,0 +1,12 @@
|
||||
#import <ATen/native/metal/MetalConvolution.h>
|
||||
#import <Foundation/Foundation.h>
|
||||
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
||||
|
||||
using namespace at::native::metal;
|
||||
@interface MPSCNNNeuronOp : NSObject
|
||||
|
||||
+ (MPSCNNNeuronReLU*)relu;
|
||||
+ (MPSCNNNeuronSigmoid*)sigmoid;
|
||||
+ (MPSCNNNeuronTanH*)tanh;
|
||||
|
||||
@end
|
39
aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.mm
Normal file
39
aten/src/ATen/native/metal/mpscnn/MPSCNNNeuronOp.mm
Normal file
@ -0,0 +1,39 @@
|
||||
#import <ATen/native/metal/mpscnn/MPSCNNContext.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSCNNNeuronOp.h>
|
||||
|
||||
@implementation MPSCNNNeuronOp
|
||||
|
||||
+ (MPSCNNNeuronReLU*)relu {
|
||||
static MPSCNNNeuronReLU* relu = nil;
|
||||
static dispatch_once_t onceToken;
|
||||
dispatch_once(&onceToken, ^{
|
||||
relu = [[MPSCNNNeuronReLU alloc]
|
||||
initWithDevice:[MPSCNNContext sharedInstance].device
|
||||
a:0];
|
||||
});
|
||||
return relu;
|
||||
}
|
||||
|
||||
+ (MPSCNNNeuronSigmoid*)sigmoid {
|
||||
static dispatch_once_t onceToken;
|
||||
static MPSCNNNeuronSigmoid* sigmoid = nil;
|
||||
dispatch_once(&onceToken, ^{
|
||||
sigmoid = [[MPSCNNNeuronSigmoid alloc]
|
||||
initWithDevice:[MPSCNNContext sharedInstance].device];
|
||||
});
|
||||
return sigmoid;
|
||||
}
|
||||
|
||||
+ (MPSCNNNeuronTanH*)tanh {
|
||||
static dispatch_once_t onceToken;
|
||||
static MPSCNNNeuronTanH* tanh = nil;
|
||||
dispatch_once(&onceToken, ^{
|
||||
tanh = [[MPSCNNNeuronTanH alloc]
|
||||
initWithDevice:[MPSCNNContext sharedInstance].device
|
||||
a:1
|
||||
b:1];
|
||||
});
|
||||
return tanh;
|
||||
}
|
||||
|
||||
@end
|
27
aten/src/ATen/native/metal/mpscnn/MPSCNNOp.h
Normal file
27
aten/src/ATen/native/metal/mpscnn/MPSCNNOp.h
Normal file
@ -0,0 +1,27 @@
|
||||
#import <Foundation/Foundation.h>
|
||||
#import <Metal/Metal.h>
|
||||
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
||||
|
||||
#if (defined(__ARM_NEON__) || defined(__ARM_NEON))
|
||||
typedef float16_t fp16;
|
||||
#else
|
||||
typedef uint16_t fp16;
|
||||
#endif
|
||||
|
||||
@protocol MPSCNNOp<NSObject>
|
||||
|
||||
@property(nonatomic, strong) MPSCNNKernel* kernel;
|
||||
|
||||
- (void)encode:(id<MTLCommandBuffer>)cb
|
||||
sourceImage:(MPSImage*)src
|
||||
destinationImage:(MPSImage*)dst;
|
||||
|
||||
@end
|
||||
|
||||
@protocol MPSCNNShaderOp<NSObject>
|
||||
|
||||
+ (id<MPSCNNShaderOp>)newWithTextures:(NSArray<MPSImage*>*)textures
|
||||
Args:(NSArray<NSNumber*>*)args;
|
||||
- (void)encode:(id<MTLCommandBuffer>)cb;
|
||||
|
||||
@end
|
62
aten/src/ATen/native/metal/mpscnn/MPSCNNOps.h
Normal file
62
aten/src/ATen/native/metal/mpscnn/MPSCNNOps.h
Normal file
@ -0,0 +1,62 @@
|
||||
#import <ATen/native/metal/MetalConvolution.h>
|
||||
|
||||
#include <torch/script.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace metal {
|
||||
namespace mpscnn {
|
||||
|
||||
Tensor conv2d(
|
||||
const Tensor& input, // metal
|
||||
const Tensor& weight, // cpu
|
||||
const c10::optional<at::Tensor>& bias, // cpu
|
||||
const Conv2DParams& params,
|
||||
NeuronType t = NeuronType::None);
|
||||
|
||||
// conv2d with prepacked weights
|
||||
Tensor conv2d(const Tensor& input, Conv2dOpContext& context);
|
||||
|
||||
Tensor max_pool2d(
|
||||
const Tensor& input,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
bool ceil_mode);
|
||||
|
||||
Tensor global_avg_pool2d(const Tensor& input, IntArrayRef output_size);
|
||||
|
||||
Tensor relu(const Tensor& input);
|
||||
|
||||
Tensor sigmoid(const Tensor& input);
|
||||
|
||||
Tensor& hardtanh_(Tensor& input, Scalar min_val, Scalar max_val);
|
||||
|
||||
Tensor t(const Tensor& input);
|
||||
|
||||
Tensor view(const Tensor& input, IntArrayRef size);
|
||||
|
||||
Tensor reshape(const Tensor& input, IntArrayRef shape);
|
||||
|
||||
Tensor addmm(const Tensor& bias, const Tensor& input, const Tensor& weight);
|
||||
|
||||
Tensor add(const Tensor& input1, const Tensor& input2);
|
||||
|
||||
Tensor sub(const Tensor& input1, const Tensor& input2);
|
||||
|
||||
Tensor mul(const Tensor& input1, const Tensor& input2);
|
||||
|
||||
Tensor log_softmax_int(const Tensor& input);
|
||||
|
||||
Tensor upsample_nearest2d_vec(
|
||||
const Tensor& input,
|
||||
c10::optional<IntArrayRef> output_size,
|
||||
c10::optional<ArrayRef<double>> scale_factors);
|
||||
|
||||
Tensor copy_to_host(const Tensor& input);
|
||||
|
||||
} // namespace mpscnn
|
||||
} // namespace metal
|
||||
} // namespace native
|
||||
} // namespace at
|
544
aten/src/ATen/native/metal/mpscnn/MPSCNNOps.mm
Normal file
544
aten/src/ATen/native/metal/mpscnn/MPSCNNOps.mm
Normal file
@ -0,0 +1,544 @@
|
||||
#import <ATen/native/metal/MetalCommandBuffer.h>
|
||||
#import <ATen/native/metal/MetalTensor.h>
|
||||
#import <ATen/native/metal/MetalTensorImpl.h>
|
||||
#import <ATen/native/metal/MetalUtils.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSCNN.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSCNNClampOp.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSCNNContext.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSCNNConvOp.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSCNNNeuronOp.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSCNNOps.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSImageWrapper.h>
|
||||
|
||||
#include <ATen/InferSize.h>
|
||||
#include <ATen/native/Pool.h>
|
||||
#include <ATen/native/UpSample.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace metal {
|
||||
namespace mpscnn {
|
||||
|
||||
using MetalTensor = at::native::metal::MetalTensor;
|
||||
using MetalTensorImpl = at::MetalTensorImpl<MetalTensor>;
|
||||
|
||||
API_AVAILABLE(ios(10.0), macos(10.13))
|
||||
static inline MPSImage* imageFromMetalTensor(const MetalTensor& tensor) {
|
||||
return tensor.texture()->image();
|
||||
}
|
||||
|
||||
API_AVAILABLE(ios(10.0), macos(10.13))
|
||||
static inline MPSImage* imageFromTensor(const Tensor& tensor) {
|
||||
TORCH_CHECK(tensor.is_metal());
|
||||
MetalTensorImpl* impl = (MetalTensorImpl*)tensor.unsafeGetTensorImpl();
|
||||
MetalTensor& metalTensor = impl->unsafe_opaque_handle();
|
||||
return imageFromMetalTensor(metalTensor);
|
||||
}
|
||||
|
||||
API_AVAILABLE(ios(10.0), macos(10.13))
|
||||
static inline MetalCommandBuffer* commandBufferFromInputTensor(
|
||||
const Tensor& tensor) {
|
||||
TORCH_CHECK(tensor.is_metal());
|
||||
MetalTensorImpl* impl = (MetalTensorImpl*)tensor.unsafeGetTensorImpl();
|
||||
MetalTensor& metalTensor = impl->unsafe_opaque_handle();
|
||||
MetalCommandBuffer* cmdBuffer = metalTensor.texture()->commandBuffer();
|
||||
TORCH_CHECK(cmdBuffer, @"Command Buffer can't be nil!");
|
||||
return cmdBuffer;
|
||||
}
|
||||
|
||||
API_AVAILABLE(ios(10.0), macos(10.13))
|
||||
Tensor conv2d(
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const c10::optional<at::Tensor>& bias,
|
||||
const Conv2DParams& params,
|
||||
NeuronType t) {
|
||||
TORCH_CHECK(weight.device().type() == kCPU);
|
||||
MPSImage* X = imageFromTensor(input);
|
||||
const int64_t oC = weight.sizes()[0];
|
||||
const int64_t iC = weight.sizes()[1];
|
||||
const int64_t kH = weight.sizes()[2];
|
||||
const int64_t kW = weight.sizes()[3];
|
||||
auto packedWeights = at::native::metal::permuteWeights(
|
||||
weight.data_ptr<float>(), {oC, iC, kH, kW});
|
||||
// MPSCNN Convolution
|
||||
float* w = packedWeights.data();
|
||||
float* b = bias.has_value() ? bias->data_ptr<float>() : nullptr;
|
||||
MPSCNNConvOp* op = [MPSCNNConvOp conv2d:params
|
||||
weights:w
|
||||
bias:b
|
||||
neuronFilter:t];
|
||||
auto outputSize = params.output_sizes();
|
||||
MetalTensor mt{outputSize};
|
||||
MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input);
|
||||
mt.texture()->allocateTemporaryTextureStorage(outputSize, commandBuffer);
|
||||
MPSImage* Y = imageFromMetalTensor(mt);
|
||||
[op encode:commandBuffer.buffer sourceImage:X destinationImage:Y];
|
||||
auto output = MetalTensor::toTensor(std::move(mt), input.options());
|
||||
return output;
|
||||
}
|
||||
|
||||
API_AVAILABLE(ios(10.0), macos(10.13))
|
||||
Tensor conv2d(const Tensor& input, Conv2dOpContext& context) {
|
||||
MPSImage* X = imageFromTensor(input);
|
||||
Conv2DParams params{input.sizes(),
|
||||
context.weight.sizes(),
|
||||
context.padding,
|
||||
context.stride,
|
||||
context.dilation,
|
||||
context.groups};
|
||||
MPSCNNConvOp* op = (MPSCNNConvOp*)context.extra;
|
||||
NeuronType nt = neuronType(context);
|
||||
if (!op) {
|
||||
float* w = context.weight.data_ptr<float>();
|
||||
float* b = context.bias.has_value() ? ((*context.bias).data_ptr<float>())
|
||||
: nullptr;
|
||||
op = [MPSCNNConvOp conv2d:params weights:w bias:b neuronFilter:nt];
|
||||
context.extra = op;
|
||||
}
|
||||
|
||||
auto outputSize = params.output_sizes();
|
||||
MetalTensor mt{outputSize};
|
||||
MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input);
|
||||
mt.texture()->allocateTemporaryTextureStorage(outputSize, commandBuffer);
|
||||
MPSImage* Y1 = imageFromMetalTensor(mt);
|
||||
[op encode:commandBuffer.buffer sourceImage:X destinationImage:Y1];
|
||||
// fuse hardtanh with convolution
|
||||
if (nt == NeuronType::Clamp) {
|
||||
MPSImage* Y2 = [MPSImage temporaryImageFromSize:[Y1 sizes]
|
||||
commandBuffer:commandBuffer];
|
||||
float min = context.output_min.value().toFloat();
|
||||
float max = context.output_max.value().toFloat();
|
||||
MPSCNNClampOp* clampOp =
|
||||
[MPSCNNClampOp newWithTextures:@[ Y1, Y2 ] Args:@[ @(min), @(max) ]];
|
||||
[clampOp encode:commandBuffer.buffer];
|
||||
mt.texture()->copyFromTexture(Y2);
|
||||
}
|
||||
auto output = MetalTensor::toTensor(std::move(mt), input.options());
|
||||
return output;
|
||||
}
|
||||
|
||||
API_AVAILABLE(ios(10.0), macos(10.13))
|
||||
Tensor max_pool2d(
|
||||
const Tensor& input,
|
||||
IntArrayRef kernel_size,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation,
|
||||
bool ceil_mode) {
|
||||
const int64_t iN = input.sizes()[0];
|
||||
const int64_t iC = input.sizes()[1];
|
||||
const int64_t iH = input.sizes()[2];
|
||||
const int64_t iW = input.sizes()[3];
|
||||
const int64_t kH = kernel_size[0];
|
||||
const int64_t kW = kernel_size[1];
|
||||
const int64_t sH = stride[0];
|
||||
const int64_t sW = stride[1];
|
||||
const int64_t pH = padding[0];
|
||||
const int64_t pW = padding[1];
|
||||
const int64_t dH = dilation[0];
|
||||
const int64_t dW = dilation[1];
|
||||
MPSImage* X = imageFromTensor(input);
|
||||
MPSCNNPoolingMax* pool = [[MPSCNNPoolingMax alloc]
|
||||
initWithDevice:[MPSCNNContext sharedInstance].device
|
||||
kernelWidth:kernel_size[0]
|
||||
kernelHeight:kernel_size[1]
|
||||
strideInPixelsX:stride[0]
|
||||
strideInPixelsY:stride[1]];
|
||||
[pool setEdgeMode:MPSImageEdgeModeClamp];
|
||||
[pool setOffset:{.x = kernel_size[0] / 2, .y = kernel_size[1] / 2, .z = 0}];
|
||||
|
||||
int64_t oN = iN;
|
||||
int64_t oC = iC;
|
||||
int64_t oH = pooling_output_shape(iH, kH, pH, sH, dH, ceil_mode);
|
||||
int64_t oW = pooling_output_shape(iW, kW, pW, sW, dW, ceil_mode);
|
||||
|
||||
std::vector<int64_t> outputSize{oN, oC, oH, oW};
|
||||
MetalTensor mt{outputSize};
|
||||
MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input);
|
||||
mt.texture()->allocateTemporaryTextureStorage(outputSize, commandBuffer);
|
||||
MPSImage* Y = imageFromMetalTensor(mt);
|
||||
[pool encodeToCommandBuffer:commandBuffer.buffer
|
||||
sourceImage:X
|
||||
destinationImage:Y];
|
||||
auto output = MetalTensor::toTensor(std::move(mt), input.options());
|
||||
return output;
|
||||
}
|
||||
|
||||
API_AVAILABLE(ios(10.0), macos(10.13))
|
||||
Tensor global_avg_pool2d(const Tensor& input, IntArrayRef output_size) {
|
||||
MPSImage* X = imageFromTensor(input);
|
||||
MPSCNNPoolingAverage* pool = [[MPSCNNPoolingAverage alloc]
|
||||
initWithDevice:[MPSCNNContext sharedInstance].device
|
||||
kernelWidth:X.width
|
||||
kernelHeight:X.height
|
||||
strideInPixelsX:X.width
|
||||
strideInPixelsY:X.height];
|
||||
[pool setEdgeMode:MPSImageEdgeModeClamp];
|
||||
[pool setOffset:{.x = static_cast<NSInteger>(X.width / 2),
|
||||
.y = static_cast<NSInteger>(X.height / 2),
|
||||
.z = 0}];
|
||||
std::vector<int64_t> outputSize{
|
||||
input.sizes()[0], input.sizes()[1], output_size[0], output_size[1]};
|
||||
MetalTensor mt{outputSize};
|
||||
MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input);
|
||||
mt.texture()->allocateTemporaryTextureStorage(outputSize, commandBuffer);
|
||||
MPSImage* Y = imageFromMetalTensor(mt);
|
||||
[pool encodeToCommandBuffer:commandBuffer.buffer
|
||||
sourceImage:X
|
||||
destinationImage:Y];
|
||||
auto output = MetalTensor::toTensor(std::move(mt), input.options());
|
||||
return output;
|
||||
}
|
||||
|
||||
API_AVAILABLE(ios(10.0), macos(10.13))
|
||||
Tensor neuronKernel(const Tensor& input, MPSCNNNeuron* neuron) {
|
||||
MPSImage* X = imageFromTensor(input);
|
||||
std::vector<int64_t> outputSize = input.sizes().vec();
|
||||
std::vector<int64_t> textureSize = outputSize;
|
||||
if (input.dim() == 2) {
|
||||
textureSize = {outputSize[0], outputSize[1], 1, 1};
|
||||
}
|
||||
MetalTensor mt{outputSize};
|
||||
MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input);
|
||||
mt.texture()->allocateTemporaryTextureStorage(textureSize, commandBuffer);
|
||||
MPSImage* Y = imageFromMetalTensor(mt);
|
||||
[neuron encodeToCommandBuffer:commandBuffer.buffer
|
||||
sourceImage:X
|
||||
destinationImage:Y];
|
||||
auto output = MetalTensor::toTensor(std::move(mt), input.options());
|
||||
return output;
|
||||
}
|
||||
|
||||
API_AVAILABLE(ios(10.0), macos(10.13))
|
||||
Tensor relu(const Tensor& input) {
|
||||
return neuronKernel(input, [MPSCNNNeuronOp relu]);
|
||||
}
|
||||
|
||||
API_AVAILABLE(ios(10.0), macos(10.13))
|
||||
Tensor sigmoid(const Tensor& input) {
|
||||
return neuronKernel(input, [MPSCNNNeuronOp sigmoid]);
|
||||
}
|
||||
|
||||
API_AVAILABLE(ios(10.0), macos(10.13))
|
||||
Tensor tanh(const Tensor& input) {
|
||||
return neuronKernel(input, [MPSCNNNeuronOp tanh]);
|
||||
}
|
||||
|
||||
API_AVAILABLE(ios(10.0), macos(10.13))
|
||||
Tensor& hardtanh_(Tensor& input, Scalar min_val, Scalar max_val) {
|
||||
MPSImage* X = imageFromTensor(input);
|
||||
MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input);
|
||||
MPSImage* Y = [MPSImage temporaryImageFromSize:input.sizes().vec()
|
||||
commandBuffer:commandBuffer];
|
||||
float min = min_val.toFloat();
|
||||
float max = max_val.toFloat();
|
||||
MPSCNNClampOp* clampOp = [MPSCNNClampOp newWithTextures:@[ X, Y ]
|
||||
Args:@[ @(min), @(max) ]];
|
||||
[clampOp encode:commandBuffer.buffer];
|
||||
MetalTensorImpl* impl = (MetalTensorImpl*)input.unsafeGetTensorImpl();
|
||||
MetalTensor& metalTensor = impl->unsafe_opaque_handle();
|
||||
metalTensor.texture()->copyFromTexture(Y);
|
||||
return input;
|
||||
}
|
||||
|
||||
/*
|
||||
A fully connected layer takes an MPSImage object with dimensions source.width x
|
||||
source.height x Ni, convolves it with
|
||||
Weights[No][source.width][source.height][Ni],and produces a 1 x 1 x No output.
|
||||
|
||||
Thus, the following conditions must be true:
|
||||
kernelWidth == source.width
|
||||
kernelHeight == source.height
|
||||
clipRect.size.width == 1
|
||||
clipRect.size.height == 1
|
||||
|
||||
You can think of a fully connected layer as a matrix multiplication
|
||||
where the image is flattened into a vector of length
|
||||
source.width*source.height*Ni, and the weights are arranged in a matrix of
|
||||
dimension No x (source.width*source.height*Ni) to produce an output vector of
|
||||
length No
|
||||
|
||||
The value of the strideInPixelsX, strideInPixelsY, and groups properties must
|
||||
be 1. The offset property is not applicable and it is ignored. Because the clip
|
||||
rectangle is clamped to the destination image bounds, if the destination is 1 x
|
||||
1, you do not need to set the clipRect property.
|
||||
*/
|
||||
API_AVAILABLE(ios(10.0), macos(10.13))
|
||||
Tensor addmm(const Tensor& bias, const Tensor& input, const Tensor& weight) {
|
||||
MPSImage* X = imageFromTensor(input);
|
||||
const int64_t N = X.numberOfImages;
|
||||
const int64_t oC = weight.sizes()[0];
|
||||
const int64_t kH = X.height;
|
||||
const int64_t kW = X.width;
|
||||
const int64_t iC = weight.sizes()[1] / kH / kW;
|
||||
auto packedWeights = at::native::metal::permuteWeights(
|
||||
weight.data_ptr<float>(), {oC, iC, kH, kW});
|
||||
MPSCNNConvolutionDescriptor* desc =
|
||||
[MPSCNNConvolutionDescriptor cnnConvolutionDescriptorWithKernelWidth:kW
|
||||
kernelHeight:kH
|
||||
inputFeatureChannels:iC
|
||||
outputFeatureChannels:oC
|
||||
neuronFilter:nil];
|
||||
desc.strideInPixelsX = 1;
|
||||
desc.strideInPixelsY = 1;
|
||||
MPSCNNConvDataSource* ds = [[MPSCNNConvDataSource alloc]
|
||||
initWithWeights:packedWeights.data()
|
||||
Bias:bias.defined() ? bias.data_ptr<float>() : nil
|
||||
Desc:desc];
|
||||
MPSCNNFullyConnected* fc = nil;
|
||||
if (@available(iOS 11.0, *)) {
|
||||
fc = [[MPSCNNFullyConnected alloc]
|
||||
initWithDevice:[MPSCNNContext sharedInstance].device
|
||||
weights:ds];
|
||||
} else {
|
||||
#if TARGET_OS_IPHONE
|
||||
fc = [[MPSCNNFullyConnected alloc]
|
||||
initWithDevice:[MPSCNNContext sharedInstance].device
|
||||
convolutionDescriptor:desc
|
||||
kernelWeights:(float*)packedWeights.data()
|
||||
biasTerms:bias.defined() ? bias.data_ptr<float>() : nil
|
||||
flags:MPSCNNConvolutionFlagsNone];
|
||||
#endif
|
||||
}
|
||||
[fc setClipRect:MTLRegionMake3D(0, 0, 0, 1, 1, N)];
|
||||
[fc setOffset:{.x = static_cast<NSInteger>(X.width / 2),
|
||||
.y = static_cast<NSInteger>(X.height / 2),
|
||||
.z = 0}];
|
||||
std::vector<int64_t> outputSize = {N, oC, 1, 1};
|
||||
MetalTensor mt{{N, oC}};
|
||||
|
||||
MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input);
|
||||
mt.texture()->allocateTemporaryTextureStorage(outputSize, commandBuffer);
|
||||
MPSImage* Y = imageFromMetalTensor(mt);
|
||||
[fc encodeToCommandBuffer:commandBuffer.buffer
|
||||
sourceImage:X
|
||||
destinationImage:Y];
|
||||
auto output = MetalTensor::toTensor(std::move(mt), input.options());
|
||||
return output;
|
||||
}
|
||||
|
||||
API_AVAILABLE(ios(10.0), macos(10.13))
|
||||
Tensor binaryElementwiseKernel(
|
||||
const Tensor& input1,
|
||||
const Tensor& input2,
|
||||
NSString* arrayKernel,
|
||||
NSString* nonarrayKernal) {
|
||||
MPSImage* X1 = imageFromTensor(input1);
|
||||
MPSImage* X2 = imageFromTensor(input2);
|
||||
std::vector<int64_t> outputSize = input1.sizes().vec();
|
||||
MetalTensor mt{outputSize};
|
||||
MetalCommandBuffer* cb1 = commandBufferFromInputTensor(input1);
|
||||
MetalCommandBuffer* cb2 = commandBufferFromInputTensor(input2);
|
||||
TORCH_CHECK([cb1 isEqual:cb2], @"inputs have different command buffer");
|
||||
mt.texture()->allocateTemporaryTextureStorage(outputSize, cb1);
|
||||
MPSImage* Y = imageFromMetalTensor(mt);
|
||||
id<MTLComputePipelineState> state = [[MPSCNNContext sharedInstance]
|
||||
pipelineState:kernelFor(X1, arrayKernel, nonarrayKernal)];
|
||||
id<MTLComputeCommandEncoder> encoder = [cb1.buffer computeCommandEncoder];
|
||||
[encoder setComputePipelineState:state];
|
||||
[encoder setTexture:[X1 texture] atIndex:0];
|
||||
[encoder setTexture:[X2 texture] atIndex:1];
|
||||
[encoder setTexture:[Y texture] atIndex:2];
|
||||
const auto& launchParams = spatialPointwiseKernelLaunchParams(state, Y);
|
||||
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
|
||||
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
|
||||
[encoder endEncoding];
|
||||
[X1 markRead];
|
||||
[X2 markRead];
|
||||
auto output = MetalTensor::toTensor(std::move(mt), input1.options());
|
||||
return output;
|
||||
}
|
||||
|
||||
API_AVAILABLE(ios(10.0), macos(10.13))
|
||||
Tensor add(const Tensor& input1, const Tensor& input2) {
|
||||
return binaryElementwiseKernel(
|
||||
input1, input2, @"elementwise_add", @"elementwise_add_nonarray");
|
||||
}
|
||||
|
||||
API_AVAILABLE(ios(10.0), macos(10.13))
|
||||
Tensor sub(const Tensor& input1, const Tensor& input2) {
|
||||
return binaryElementwiseKernel(
|
||||
input1, input2, @"elementwise_sub", @"elementwise_sub_nonarray");
|
||||
}
|
||||
|
||||
API_AVAILABLE(ios(10.0), macos(10.13))
|
||||
Tensor mul(const Tensor& input1, const Tensor& input2) {
|
||||
return binaryElementwiseKernel(
|
||||
input1, input2, @"elementwise_mul", @"elementwise_mul_nonarray");
|
||||
}
|
||||
|
||||
API_AVAILABLE(ios(10.0), macos(10.13))
|
||||
Tensor t(const Tensor& input) {
|
||||
auto strides = input.strides().vec();
|
||||
auto sizes = input.sizes().vec();
|
||||
MPSImage* X = imageFromTensor(input);
|
||||
TORCH_CHECK(X.numberOfImages == 1);
|
||||
TORCH_CHECK(X.featureChannels == 1);
|
||||
MetalTensor mt({sizes[1], sizes[0]}, {strides[1], strides[0]});
|
||||
MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input);
|
||||
mt.texture()->allocateTemporaryTextureStorage(
|
||||
{1, 1, sizes[1], sizes[0]}, commandBuffer);
|
||||
MPSImage* Y = imageFromMetalTensor(mt);
|
||||
MPSImageTranspose* transpose = [[MPSImageTranspose alloc]
|
||||
initWithDevice:[MPSCNNContext sharedInstance].device];
|
||||
[transpose encodeToCommandBuffer:commandBuffer.buffer
|
||||
sourceImage:X
|
||||
destinationImage:Y];
|
||||
|
||||
auto output = MetalTensor::toTensor(std::move(mt), input.options());
|
||||
return output;
|
||||
}
|
||||
|
||||
API_AVAILABLE(ios(10.0), macos(10.13))
|
||||
Tensor view(const Tensor& input, IntArrayRef size) {
|
||||
auto inferred_size = at::infer_size(size, input.numel());
|
||||
auto stride =
|
||||
at::detail::computeStride(input.sizes(), input.strides(), inferred_size);
|
||||
TORCH_CHECK(
|
||||
stride.has_value(),
|
||||
"view size is "
|
||||
"not compatible with input tensor's size and stride (at least one dimension"
|
||||
" spans across two contiguous subspaces). Use .reshape(...) instead.");
|
||||
auto stride_value = *stride;
|
||||
|
||||
MPSImage* X = imageFromTensor(input);
|
||||
MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input);
|
||||
MetalTensor mt{inferred_size, stride_value};
|
||||
mt.texture()->setCommandBuffer(commandBuffer);
|
||||
mt.texture()->copyFromTexture(X);
|
||||
auto output = MetalTensor::toTensor(std::move(mt), input.options());
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor reshape(const Tensor& input, IntArrayRef shape) {
|
||||
return view(input, shape);
|
||||
}
|
||||
|
||||
API_AVAILABLE(ios(10.0), macos(10.13))
|
||||
Tensor log_softmax_int(const Tensor& input) {
|
||||
MPSImage* X = imageFromTensor(input);
|
||||
TORCH_CHECK(X.height == 1 && X.width == 1);
|
||||
std::vector<int64_t> outputSize = input.sizes().vec();
|
||||
MPSCNNLogSoftMax* logSoftmax = [[MPSCNNLogSoftMax alloc]
|
||||
initWithDevice:[MPSCNNContext sharedInstance].device];
|
||||
|
||||
MetalTensor mt{outputSize};
|
||||
MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input);
|
||||
mt.texture()->allocateTemporaryTextureStorage(
|
||||
{outputSize[0], outputSize[1], 1, 1}, commandBuffer);
|
||||
MPSImage* Y = imageFromMetalTensor(mt);
|
||||
[logSoftmax encodeToCommandBuffer:commandBuffer.buffer
|
||||
sourceImage:X
|
||||
destinationImage:Y];
|
||||
auto output = MetalTensor::toTensor(std::move(mt), input.options());
|
||||
return output;
|
||||
}
|
||||
|
||||
API_AVAILABLE(ios(10.0), macos(10.13))
|
||||
Tensor upsample_nearest2d_vec(
|
||||
const Tensor& input,
|
||||
c10::optional<IntArrayRef> output_size,
|
||||
c10::optional<ArrayRef<double>> scale_factors) {
|
||||
auto osize =
|
||||
upsample::compute_output_size(input.sizes(), output_size, scale_factors);
|
||||
auto scale_h = upsample::get_scale_value(scale_factors, 0);
|
||||
auto scale_w = upsample::get_scale_value(scale_factors, 1);
|
||||
int64_t output_height = osize[0];
|
||||
int64_t output_width = osize[1];
|
||||
int64_t nbatch = input.size(0);
|
||||
int64_t channels = input.size(1);
|
||||
int64_t input_height = input.size(2);
|
||||
int64_t input_width = input.size(3);
|
||||
upsample_2d_shape_check(
|
||||
input,
|
||||
Tensor(),
|
||||
nbatch,
|
||||
channels,
|
||||
input_height,
|
||||
input_width,
|
||||
output_height,
|
||||
output_width);
|
||||
std::vector<int64_t> outputSizes{
|
||||
nbatch, channels, output_height, output_width};
|
||||
MPSImage* X = imageFromTensor(input);
|
||||
MetalTensor mt{outputSizes};
|
||||
MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input);
|
||||
mt.texture()->allocateTemporaryTextureStorage(outputSizes, commandBuffer);
|
||||
MPSImage* Y = imageFromMetalTensor(mt);
|
||||
if (@available(iOS 11.0, *)) {
|
||||
MPSCNNUpsamplingNearest* kernel = [[MPSCNNUpsamplingNearest alloc]
|
||||
initWithDevice:[MPSCNNContext sharedInstance].device
|
||||
integerScaleFactorX:(NSUInteger)scale_w.value()
|
||||
integerScaleFactorY:(NSUInteger)scale_h.value()];
|
||||
[kernel encodeToCommandBuffer:commandBuffer.buffer
|
||||
sourceImage:X
|
||||
destinationImage:Y];
|
||||
} else {
|
||||
NSUInteger sh = scale_h.value() * 10000;
|
||||
NSUInteger sw = scale_w.value() * 10000;
|
||||
id<MTLComputePipelineState> state = [[MPSCNNContext sharedInstance]
|
||||
specializedPipelineState:kernelFor(
|
||||
Y,
|
||||
@"resize_nearest",
|
||||
@"resize_nearest_nonarray")
|
||||
Constants:@[
|
||||
@(output_height),
|
||||
@(output_width),
|
||||
@(sh),
|
||||
@(sw)
|
||||
]];
|
||||
id<MTLComputeCommandEncoder> encoder =
|
||||
[commandBuffer.buffer computeCommandEncoder];
|
||||
[encoder setComputePipelineState:state];
|
||||
[encoder setTexture:[X texture] atIndex:0];
|
||||
[encoder setTexture:[Y texture] atIndex:1];
|
||||
const auto& launchParams = spatialPointwiseKernelLaunchParams(state, Y);
|
||||
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
|
||||
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
|
||||
[encoder endEncoding];
|
||||
[X markRead];
|
||||
[Y markRead];
|
||||
}
|
||||
auto output = MetalTensor::toTensor(std::move(mt), input.options());
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor copy_to_host(const Tensor& input) {
|
||||
MPSImage* X = imageFromTensor(input);
|
||||
MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input);
|
||||
auto&& sizes = [X sizes];
|
||||
MetalTensor mt{sizes};
|
||||
mt.texture()->setCommandBuffer(commandBuffer);
|
||||
mt.texture()->allocateTextureStorage(sizes);
|
||||
MPSImage* Y = imageFromMetalTensor(mt);
|
||||
id<MTLComputeCommandEncoder> encoder =
|
||||
[commandBuffer.buffer computeCommandEncoder];
|
||||
id<MTLComputePipelineState> state = [[MPSCNNContext sharedInstance]
|
||||
specializedPipelineState:metal::mpscnn::kernelFor(
|
||||
X, @"copy", @"copy_nonarray")
|
||||
Constants:@[
|
||||
@(X.featureChannels),
|
||||
@(X.height),
|
||||
@(X.width)
|
||||
]];
|
||||
|
||||
[encoder setComputePipelineState:state];
|
||||
[encoder setTexture:[X texture] atIndex:0];
|
||||
[encoder setTexture:[Y texture] atIndex:1];
|
||||
|
||||
const auto& launchParams =
|
||||
metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X);
|
||||
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
|
||||
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
|
||||
[encoder endEncoding];
|
||||
[X markRead];
|
||||
auto output = MetalTensor::toTensor(std::move(mt), input.options());
|
||||
return output;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
47
aten/src/ATen/native/metal/mpscnn/MPSImage+Tensor.h
Normal file
47
aten/src/ATen/native/metal/mpscnn/MPSImage+Tensor.h
Normal file
@ -0,0 +1,47 @@
|
||||
#include <ATen/Tensor.h>
|
||||
#import <ATen/native/metal/MetalCommandBuffer.h>
|
||||
#import <Metal/Metal.h>
|
||||
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
||||
|
||||
@interface MPSImage (Tensor)
|
||||
|
||||
+ (MPSImage*)imageFromCPUTensor:(const at::Tensor&)tensor;
|
||||
- (at::Tensor)toCPUTensor;
|
||||
|
||||
+ (MPSImage*)imageFromFp16Array:(const uint16_t*)src
|
||||
Sizes:(const std::vector<int64_t>&)sizes;
|
||||
- (std::vector<uint16_t>)toFp16Array;
|
||||
|
||||
+ (MPSImage*)imageFromSize:(const std::vector<int64_t>&)size;
|
||||
+ (MPSTemporaryImage*)temporaryImageFromSize:(const std::vector<int64_t>&)size
|
||||
commandBuffer:(MetalCommandBuffer*)cmdBuffer;
|
||||
|
||||
- (std::vector<int64_t>)sizes;
|
||||
- (int64_t)readCount;
|
||||
- (BOOL)isTemporaryImage;
|
||||
- (void)markRead;
|
||||
- (void)recycle;
|
||||
|
||||
@end
|
||||
|
||||
@interface MPSImage (Shaders)
|
||||
|
||||
+ (MPSImage*)imageFromImage:(MPSImage*)image;
|
||||
|
||||
+ (MPSTemporaryImage*)temporaryImageFromImage:(MPSImage*)image
|
||||
CommandBuffer:(MetalCommandBuffer*)cb;
|
||||
|
||||
+ (MPSImage*)imageFromTemporaryImage:(MPSTemporaryImage*)image
|
||||
CommandBuffer:(MetalCommandBuffer*)cb
|
||||
waitUntilCompleted:(BOOL)b;
|
||||
|
||||
+ (MPSImage*)imageFromHost:(const float*)src
|
||||
Sizes:(const std::vector<int64_t>&)sizes;
|
||||
|
||||
+ (MPSTemporaryImage*)temporaryImageFromHost:(const float*)src
|
||||
Sizes:(const std::vector<int64_t>&)sizes
|
||||
CommandBuffer:(MetalCommandBuffer*)cb;
|
||||
|
||||
+ (void)copyToHost:(float*)dst FromImage:(MPSImage*)image;
|
||||
|
||||
@end
|
337
aten/src/ATen/native/metal/mpscnn/MPSImage+Tensor.mm
Normal file
337
aten/src/ATen/native/metal/mpscnn/MPSImage+Tensor.mm
Normal file
@ -0,0 +1,337 @@
|
||||
#include <ATen/native/metal/MetalUtils.h>
|
||||
#include <ATen/native/metal/mpscnn/MPSCNN.h>
|
||||
#include <ATen/native/metal/mpscnn/MPSCNNContext.h>
|
||||
#include <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
|
||||
|
||||
#include <torch/script.h>
|
||||
|
||||
using namespace at::native;
|
||||
@implementation MPSImage (Tensor)
|
||||
|
||||
+ (MPSImage*)imageFromCPUTensor:(const at::Tensor&)tensor {
|
||||
TORCH_CHECK(tensor.device().is_cpu());
|
||||
TORCH_CHECK(tensor.dim() == 4);
|
||||
auto contiguousTensor = tensor.contiguous();
|
||||
float* src = tensor.data_ptr<float>();
|
||||
std::vector<int64_t> sizes = tensor.sizes().vec();
|
||||
auto c4 = metal::NCHW_to_NC4(src, sizes);
|
||||
auto c4fp16 = metal::fp32_to_fp16(c4);
|
||||
return [self imageFromFp16Array:c4fp16.data() Sizes:sizes];
|
||||
}
|
||||
|
||||
+ (MPSImage*)imageFromFp16Array:(const uint16_t*)src
|
||||
Sizes:(const std::vector<int64_t>&)sizes {
|
||||
int64_t N = sizes[0];
|
||||
int64_t C = sizes[1];
|
||||
int64_t H = sizes[2];
|
||||
int64_t W = sizes[3];
|
||||
MPSImageDescriptor* desc = [MPSImageDescriptor
|
||||
imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat16
|
||||
width:W
|
||||
height:H
|
||||
featureChannels:C
|
||||
numberOfImages:N
|
||||
usage:MTLTextureUsageShaderRead |
|
||||
MTLTextureUsageShaderWrite];
|
||||
MPSImage* image =
|
||||
[[MPSImage alloc] initWithDevice:[MPSCNNContext sharedInstance].device
|
||||
imageDescriptor:desc];
|
||||
|
||||
int64_t slices = (C + 3) / 4 * N;
|
||||
int64_t numComponents = image.featureChannels < 3 ? image.featureChannels : 4;
|
||||
int64_t bytesPerRow = W * numComponents * sizeof(uint16_t);
|
||||
uint8_t* ptr = (uint8_t*)src;
|
||||
for (int i = 0; i < slices; ++i) {
|
||||
[image.texture replaceRegion:MTLRegionMake2D(0, 0, W, H)
|
||||
mipmapLevel:0
|
||||
slice:i
|
||||
withBytes:ptr
|
||||
bytesPerRow:bytesPerRow
|
||||
bytesPerImage:0];
|
||||
ptr += H * bytesPerRow;
|
||||
}
|
||||
return image;
|
||||
}
|
||||
|
||||
+ (MPSImage*)imageFromSize:(const std::vector<int64_t>&)size {
|
||||
MPSImageDescriptor* desc = [MPSImageDescriptor
|
||||
imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat16
|
||||
width:size[3]
|
||||
height:size[2]
|
||||
featureChannels:size[1]
|
||||
numberOfImages:size[0]
|
||||
usage:MTLTextureUsageShaderRead |
|
||||
MTLTextureUsageShaderWrite];
|
||||
return [[MPSImage alloc] initWithDevice:[MPSCNNContext sharedInstance].device
|
||||
imageDescriptor:desc];
|
||||
}
|
||||
|
||||
- (std::vector<uint16_t>)toFp16Array {
|
||||
if (self.pixelFormat == MTLPixelFormatR16Float ||
|
||||
self.pixelFormat == MTLPixelFormatRG16Float ||
|
||||
self.pixelFormat == MTLPixelFormatRGBA16Float) {
|
||||
int64_t slices = (self.featureChannels + 3) / 4;
|
||||
int64_t C = self.featureChannels < 3 ? self.featureChannels : slices * 4;
|
||||
int64_t numComponents = self.featureChannels < 3 ? self.featureChannels : 4;
|
||||
int64_t count = self.width * self.height * self.numberOfImages * C;
|
||||
std::vector<uint16_t> output(count, 0);
|
||||
int64_t bytesPerRow = self.width * numComponents * sizeof(uint16_t);
|
||||
uint8_t* buffer = (uint8_t*)output.data();
|
||||
for (int i = 0; i < slices * self.numberOfImages; ++i) {
|
||||
[self.texture getBytes:buffer
|
||||
bytesPerRow:bytesPerRow
|
||||
bytesPerImage:0
|
||||
fromRegion:MTLRegionMake2D(0, 0, self.width, self.height)
|
||||
mipmapLevel:0
|
||||
slice:i];
|
||||
buffer += self.height * bytesPerRow;
|
||||
}
|
||||
return output;
|
||||
}
|
||||
TORCH_CHECK(
|
||||
false, "Copy to float buffer failed: The pixel format didn't match");
|
||||
return {};
|
||||
}
|
||||
|
||||
- (at::Tensor)toCPUTensor {
|
||||
auto outputSize = [self sizes];
|
||||
std::vector<uint16_t> fp16 = [self toFp16Array];
|
||||
auto fp32 = metal::fp16_to_fp32(fp16);
|
||||
std::vector<float> fp32_nchw = metal::NC4_to_NCHW(fp32.data(), outputSize);
|
||||
auto tensor = at::empty(outputSize);
|
||||
int64_t size_bytes = at::prod_intlist(outputSize) * sizeof(float);
|
||||
memcpy(tensor.data_ptr(), fp32_nchw.data(), size_bytes);
|
||||
return tensor;
|
||||
}
|
||||
|
||||
- (std::vector<int64_t>)sizes {
|
||||
int64_t N = self.numberOfImages;
|
||||
int64_t C = self.featureChannels;
|
||||
int64_t H = self.height;
|
||||
int64_t W = self.width;
|
||||
return {N, C, H, W};
|
||||
}
|
||||
|
||||
+ (MPSTemporaryImage*)temporaryImageFromSize:(const std::vector<int64_t>&)size
|
||||
commandBuffer:(MetalCommandBuffer*)cmdBuffer {
|
||||
NSCAssert(cmdBuffer, @"CommandBuffer is nil!");
|
||||
MPSImageDescriptor* desc = [MPSImageDescriptor
|
||||
imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat16
|
||||
width:size[3]
|
||||
height:size[2]
|
||||
featureChannels:size[1]
|
||||
numberOfImages:size[0]
|
||||
usage:MTLTextureUsageShaderRead |
|
||||
MTLTextureUsageShaderWrite];
|
||||
MPSTemporaryImage* image =
|
||||
[MPSTemporaryImage temporaryImageWithCommandBuffer:cmdBuffer.buffer
|
||||
imageDescriptor:desc];
|
||||
image.readCount = INT_MAX;
|
||||
[cmdBuffer add:image];
|
||||
return image;
|
||||
}
|
||||
|
||||
- (BOOL)isTemporaryImage {
|
||||
return [self isKindOfClass:[MPSTemporaryImage class]];
|
||||
}
|
||||
|
||||
- (void)markRead {
|
||||
if ([self isTemporaryImage]) {
|
||||
MPSTemporaryImage* tmpImage = (MPSTemporaryImage*)self;
|
||||
if (tmpImage.readCount > 0) {
|
||||
tmpImage.readCount -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
- (void)recycle {
|
||||
if ([self isTemporaryImage]) {
|
||||
MPSTemporaryImage* tmpImage = (MPSTemporaryImage*)self;
|
||||
if (tmpImage.readCount > 0) {
|
||||
tmpImage.readCount = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
- (int64_t)readCount {
|
||||
if ([self isTemporaryImage]) {
|
||||
MPSTemporaryImage* tmpImage = (MPSTemporaryImage*)self;
|
||||
return (int64_t)tmpImage.readCount;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
@end
|
||||
|
||||
@implementation MPSImage (Shaders)
|
||||
|
||||
+ (MPSImage*)imageFromImage:(MPSImage*)X {
|
||||
auto&& sizes = [X sizes];
|
||||
MPSImage* Y = [MPSImage imageFromSize:sizes];
|
||||
MetalCommandBuffer* cb = [MetalCommandBuffer newBuffer];
|
||||
id<MTLComputeCommandEncoder> encoder = [cb.buffer computeCommandEncoder];
|
||||
id<MTLComputePipelineState> state = [[MPSCNNContext sharedInstance]
|
||||
pipelineState:metal::mpscnn::kernelFor(X, @"copy", @"copy_nonarray")];
|
||||
[encoder setComputePipelineState:state];
|
||||
[encoder setTexture:[X texture] atIndex:0];
|
||||
[encoder setTexture:[Y texture] atIndex:1];
|
||||
|
||||
const auto& launchParams =
|
||||
metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X);
|
||||
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
|
||||
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
|
||||
[encoder endEncoding];
|
||||
[cb synchronize];
|
||||
return Y;
|
||||
}
|
||||
|
||||
+ (MPSTemporaryImage*)temporaryImageFromImage:(MPSImage*)X
|
||||
CommandBuffer:(MetalCommandBuffer*)cb {
|
||||
NSCAssert(cb, @"CommandBuffer is nil!");
|
||||
MPSTemporaryImage* Y = [MPSImage temporaryImageFromSize:[X sizes]
|
||||
commandBuffer:cb];
|
||||
id<MTLComputeCommandEncoder> encoder = [cb.buffer computeCommandEncoder];
|
||||
id<MTLComputePipelineState> state = [[MPSCNNContext sharedInstance]
|
||||
pipelineState:metal::mpscnn::kernelFor(X, @"copy", @"copy_nonarray")];
|
||||
[encoder setComputePipelineState:state];
|
||||
[encoder setTexture:[X texture] atIndex:0];
|
||||
[encoder setTexture:[Y texture] atIndex:1];
|
||||
|
||||
const auto& launchParams =
|
||||
metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X);
|
||||
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
|
||||
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
|
||||
[encoder endEncoding];
|
||||
return Y;
|
||||
}
|
||||
|
||||
+ (MPSImage*)imageFromTemporaryImage:(MPSTemporaryImage*)X
|
||||
CommandBuffer:(MetalCommandBuffer*)cb
|
||||
waitUntilCompleted:(BOOL)b {
|
||||
NSCAssert(cb, @"CommandBuffer is nil!");
|
||||
auto&& sizes = [X sizes];
|
||||
MPSImage* Y = [MPSImage imageFromSize:sizes];
|
||||
id<MTLComputeCommandEncoder> encoder = [cb.buffer computeCommandEncoder];
|
||||
id<MTLComputePipelineState> state = [[MPSCNNContext sharedInstance]
|
||||
pipelineState:metal::mpscnn::kernelFor(X, @"copy", @"copy_nonarray")];
|
||||
|
||||
[encoder setComputePipelineState:state];
|
||||
[encoder setTexture:[X texture] atIndex:0];
|
||||
[encoder setTexture:[Y texture] atIndex:1];
|
||||
|
||||
const auto& launchParams =
|
||||
metal::mpscnn::spatialPointwiseKernelLaunchParams(state, X);
|
||||
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
|
||||
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
|
||||
[encoder endEncoding];
|
||||
[X markRead];
|
||||
if (b) {
|
||||
[cb synchronize];
|
||||
}
|
||||
return Y;
|
||||
}
|
||||
|
||||
+ (MPSImage*)imageFromHost:(const float*)src
|
||||
Sizes:(const std::vector<int64_t>&)sizes {
|
||||
int64_t size_bytes = at::prod_intlist(sizes) * sizeof(float);
|
||||
// allocte buffer on CPU
|
||||
id<MTLBuffer> buff = [[MPSCNNContext sharedInstance].device
|
||||
newBufferWithLength:size_bytes
|
||||
options:MTLResourceOptionCPUCacheModeWriteCombined];
|
||||
memcpy(buff.contents, src, size_bytes);
|
||||
MPSImage* output = [MPSImage imageFromSize:sizes];
|
||||
id<MTLComputePipelineState> state = [[MPSCNNContext sharedInstance]
|
||||
specializedPipelineState:metal::mpscnn::kernelFor(
|
||||
output,
|
||||
@"copy_nchw_to_metal",
|
||||
@"copy_nchw_to_metal_nonarray")
|
||||
Constants:@[
|
||||
@(output.featureChannels),
|
||||
@(output.height),
|
||||
@(output.width)
|
||||
]];
|
||||
MetalCommandBuffer* cb = [MetalCommandBuffer newBuffer];
|
||||
id<MTLComputeCommandEncoder> encoder = [cb.buffer computeCommandEncoder];
|
||||
[encoder setComputePipelineState:state];
|
||||
[encoder setBuffer:buff offset:0 atIndex:0];
|
||||
[encoder setTexture:[output texture] atIndex:0];
|
||||
const auto& launchParams =
|
||||
metal::mpscnn::spatialPointwiseKernelLaunchParams(state, output);
|
||||
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
|
||||
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
|
||||
[encoder endEncoding];
|
||||
[cb synchronize];
|
||||
return output;
|
||||
}
|
||||
|
||||
+ (MPSTemporaryImage*)temporaryImageFromHost:(const float*)src
|
||||
Sizes:(const std::vector<int64_t>&)sizes
|
||||
CommandBuffer:(MetalCommandBuffer*)cb {
|
||||
NSCAssert(cb, @"CommandBuffer is nil!");
|
||||
int64_t size_bytes = at::prod_intlist(sizes) * sizeof(float);
|
||||
// allocte buffer on CPU
|
||||
id<MTLBuffer> buff = [[MPSCNNContext sharedInstance].device
|
||||
newBufferWithLength:size_bytes
|
||||
options:MTLResourceOptionCPUCacheModeWriteCombined];
|
||||
memcpy(buff.contents, src, size_bytes);
|
||||
MPSTemporaryImage* output = [MPSImage temporaryImageFromSize:sizes
|
||||
commandBuffer:cb];
|
||||
id<MTLComputePipelineState> state = [[MPSCNNContext sharedInstance]
|
||||
specializedPipelineState:metal::mpscnn::kernelFor(
|
||||
output,
|
||||
@"copy_nchw_to_metal",
|
||||
@"copy_nchw_to_metal_nonarray")
|
||||
Constants:@[
|
||||
@(output.featureChannels),
|
||||
@(output.height),
|
||||
@(output.width)
|
||||
]];
|
||||
id<MTLComputeCommandEncoder> encoder = [cb.buffer computeCommandEncoder];
|
||||
[encoder setComputePipelineState:state];
|
||||
[encoder setBuffer:buff offset:0 atIndex:0];
|
||||
[encoder setTexture:[output texture] atIndex:0];
|
||||
const auto& launchParams =
|
||||
metal::mpscnn::spatialPointwiseKernelLaunchParams(state, output);
|
||||
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
|
||||
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
|
||||
[encoder endEncoding];
|
||||
[output markRead];
|
||||
return output;
|
||||
}
|
||||
|
||||
+ (void)copyToHost:(float*)dst FromImage:(MPSImage*)image {
|
||||
auto&& sizes = [image sizes];
|
||||
int64_t size_bytes = at::prod_intlist(sizes) * sizeof(float);
|
||||
id<MTLBuffer> buffer = [[MPSCNNContext sharedInstance].device
|
||||
newBufferWithLength:size_bytes
|
||||
options:MTLResourceOptionCPUCacheModeDefault];
|
||||
|
||||
id<MTLCommandBuffer> cb =
|
||||
[MPSCNNContext sharedInstance].commandQueue.commandBuffer;
|
||||
id<MTLComputeCommandEncoder> encoder = [cb computeCommandEncoder];
|
||||
id<MTLComputePipelineState> state = [[MPSCNNContext sharedInstance]
|
||||
specializedPipelineState:metal::mpscnn::kernelFor(
|
||||
image,
|
||||
@"copy_metal_to_nchw",
|
||||
@"copy_metal_to_nchw_nonarray")
|
||||
Constants:@[
|
||||
@(image.featureChannels),
|
||||
@(image.height),
|
||||
@(image.width)
|
||||
]];
|
||||
|
||||
[encoder setComputePipelineState:state];
|
||||
[encoder setBuffer:buffer offset:0 atIndex:0];
|
||||
[encoder setTexture:[image texture] atIndex:0];
|
||||
|
||||
const auto& launchParams =
|
||||
metal::mpscnn::spatialPointwiseKernelLaunchParams(state, image);
|
||||
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
|
||||
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
|
||||
[encoder endEncoding];
|
||||
[cb commit];
|
||||
[cb waitUntilCompleted];
|
||||
memcpy(dst, buffer.contents, buffer.length);
|
||||
}
|
||||
|
||||
@end
|
49
aten/src/ATen/native/metal/mpscnn/MPSImageWrapper.h
Normal file
49
aten/src/ATen/native/metal/mpscnn/MPSImageWrapper.h
Normal file
@ -0,0 +1,49 @@
|
||||
#ifndef MPSImageWrapper_h
|
||||
#define MPSImageWrapper_h
|
||||
|
||||
#import <ATen/native/metal/MetalCommandBuffer.h>
|
||||
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
||||
#include <torch/script.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace metal {
|
||||
|
||||
enum class TextureType {
|
||||
TextureNone,
|
||||
TextureType2D,
|
||||
TextureType2DArray,
|
||||
};
|
||||
|
||||
class API_AVAILABLE(ios(10.0), macos(10.13)) MPSImageWrapper {
|
||||
public:
|
||||
MPSImageWrapper(IntArrayRef sizes);
|
||||
operator bool() const {
|
||||
return _image;
|
||||
}
|
||||
void copyDataFromHost(const float* inputData);
|
||||
void copyDataToHost(float* hostData);
|
||||
void allocateTextureStorage(IntArrayRef sizes);
|
||||
void allocateTemporaryTextureStorage(
|
||||
IntArrayRef sizes,
|
||||
MetalCommandBuffer* commandBuffer);
|
||||
void copyFromTexture(MPSImage* image);
|
||||
void setCommandBuffer(MetalCommandBuffer* buffer);
|
||||
MetalCommandBuffer* commandBuffer() const;
|
||||
TextureType textureType() const;
|
||||
IntArrayRef textureSizes() const;
|
||||
MPSImage* image() const;
|
||||
void recycleImage();
|
||||
void synchronize();
|
||||
|
||||
private:
|
||||
std::vector<int64_t> _textureSizes;
|
||||
MPSImage* _image = nullptr;
|
||||
MetalCommandBuffer* _commandBuffer;
|
||||
};
|
||||
|
||||
} // namespace metal
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
||||
#endif /* MPSImageWrapper_h */
|
116
aten/src/ATen/native/metal/mpscnn/MPSImageWrapper.mm
Normal file
116
aten/src/ATen/native/metal/mpscnn/MPSImageWrapper.mm
Normal file
@ -0,0 +1,116 @@
|
||||
#import <ATen/native/metal/MetalCommandBuffer.h>
|
||||
#import <ATen/native/metal/MetalUtils.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSCNN.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSCNNContext.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSImageWrapper.h>
|
||||
|
||||
#include <numeric>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace metal {
|
||||
|
||||
std::vector<int64_t> textureSizeFromSizes(IntArrayRef sizes, TextureType type) {
|
||||
if (sizes.size() == 2) {
|
||||
if (type == TextureType::TextureType2DArray) {
|
||||
return {sizes[0], sizes[1], 1, 1};
|
||||
} else if (type == TextureType::TextureType2D) {
|
||||
return {1, 1, sizes[0], sizes[1]};
|
||||
} else {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
return sizes.vec();
|
||||
}
|
||||
MPSImageWrapper::MPSImageWrapper(IntArrayRef sizes) {
|
||||
_textureSizes = textureSizeFromSizes(sizes, TextureType::TextureType2D);
|
||||
}
|
||||
|
||||
void MPSImageWrapper::copyDataFromHost(const float* inputData) {
|
||||
TORCH_CHECK(inputData);
|
||||
TORCH_CHECK(_textureSizes.size() == 4);
|
||||
_commandBuffer = [MetalCommandBuffer currentBuffer];
|
||||
_image = [MPSImage temporaryImageFromHost:inputData
|
||||
Sizes:_textureSizes
|
||||
CommandBuffer:_commandBuffer];
|
||||
}
|
||||
|
||||
void MPSImageWrapper::copyDataToHost(float* hostData) {
|
||||
TORCH_CHECK(_image);
|
||||
synchronize();
|
||||
[MPSImage copyToHost:hostData FromImage:_image];
|
||||
}
|
||||
|
||||
MPSImage* MPSImageWrapper::image() const {
|
||||
return _image;
|
||||
}
|
||||
|
||||
void MPSImageWrapper::recycleImage() {
|
||||
if ([_image isTemporaryImage]) {
|
||||
[_image recycle];
|
||||
[_commandBuffer remove:(MPSTemporaryImage*)_image];
|
||||
}
|
||||
}
|
||||
|
||||
void MPSImageWrapper::setCommandBuffer(MetalCommandBuffer* cb) {
|
||||
_commandBuffer = cb;
|
||||
}
|
||||
MetalCommandBuffer* MPSImageWrapper::commandBuffer() const {
|
||||
return _commandBuffer;
|
||||
}
|
||||
|
||||
IntArrayRef MPSImageWrapper::textureSizes() const {
|
||||
return _textureSizes;
|
||||
}
|
||||
|
||||
TextureType MPSImageWrapper::textureType() const {
|
||||
if (!_image) {
|
||||
return TextureType::TextureNone;
|
||||
}
|
||||
MTLTextureType textureType = _image.textureType;
|
||||
if (textureType == MTLTextureType2D) {
|
||||
return TextureType::TextureType2D;
|
||||
} else if (textureType == MTLTextureType2DArray) {
|
||||
return TextureType::TextureType2DArray;
|
||||
}
|
||||
return TextureType::TextureNone;
|
||||
}
|
||||
|
||||
void MPSImageWrapper::allocateTextureStorage(IntArrayRef sizes) {
|
||||
_textureSizes = sizes.vec();
|
||||
_image = [MPSImage imageFromSize:_textureSizes];
|
||||
}
|
||||
|
||||
void MPSImageWrapper::allocateTemporaryTextureStorage(
|
||||
IntArrayRef sizes,
|
||||
MetalCommandBuffer* commandBuffer) {
|
||||
TORCH_CHECK(commandBuffer)
|
||||
_textureSizes = sizes.vec();
|
||||
_commandBuffer = commandBuffer;
|
||||
_image = [MPSImage temporaryImageFromSize:_textureSizes
|
||||
commandBuffer:commandBuffer];
|
||||
}
|
||||
|
||||
void MPSImageWrapper::copyFromTexture(MPSImage* image) {
|
||||
if ([image isTemporaryImage]) {
|
||||
_image = [MPSImage temporaryImageFromImage:image
|
||||
CommandBuffer:_commandBuffer];
|
||||
} else {
|
||||
_image = [MPSImage imageFromImage:image];
|
||||
}
|
||||
}
|
||||
|
||||
void MPSImageWrapper::synchronize() {
|
||||
if ([_image isTemporaryImage]) {
|
||||
_image = [MPSImage imageFromTemporaryImage:(MPSTemporaryImage*)_image
|
||||
CommandBuffer:_commandBuffer
|
||||
waitUntilCompleted:NO];
|
||||
}
|
||||
[_commandBuffer synchronize];
|
||||
_commandBuffer = nil;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
35
aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h
Normal file
35
aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.h
Normal file
@ -0,0 +1,35 @@
|
||||
#ifndef MPSCNNTests_h
|
||||
#define MPSCNNTests_h
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace metal {
|
||||
|
||||
bool test_aten();
|
||||
bool test_NC4();
|
||||
bool test_MPSImage();
|
||||
bool test_MPSImageCopy();
|
||||
bool test_MPSTemporaryImageCopy();
|
||||
bool test_conv2d();
|
||||
bool test_depthwiseConv();
|
||||
bool test_max_pool2d();
|
||||
bool test_relu();
|
||||
bool test_addmm();
|
||||
bool test_add();
|
||||
bool test_sub();
|
||||
bool test_mul();
|
||||
bool test_t();
|
||||
bool test_view();
|
||||
bool test_softmax();
|
||||
bool test_sigmoid();
|
||||
bool test_upsampling_nearest2d_vec();
|
||||
bool test_adaptive_avg_pool2d();
|
||||
bool test_hardtanh_();
|
||||
bool test_reshape();
|
||||
bool test_mobilenetv2();
|
||||
|
||||
} // namespace metal
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
||||
#endif
|
552
aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm
Normal file
552
aten/src/ATen/native/metal/mpscnn/tests/MPSCNNTests.mm
Normal file
@ -0,0 +1,552 @@
|
||||
#import <ATen/native/metal/MetalConvolution.h>
|
||||
#import <ATen/native/metal/MetalUtils.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSCNNOps.h>
|
||||
#import <ATen/native/metal/mpscnn/MPSImage+Tensor.h>
|
||||
#import <ATen/native/metal/mpscnn/tests/MPSCNNTests.h>
|
||||
#import <Foundation/Foundation.h>
|
||||
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/TypeDefault.h>
|
||||
#import <ATen/native/metal/mpscnn/tests/MPSCNNTests.h>
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <torch/script.h>
|
||||
#include <sstream>
|
||||
|
||||
#define ITER_COUNT 10
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace metal {
|
||||
|
||||
int64_t rand(int64_t min, int64_t max) {
|
||||
return min + (std::rand() % static_cast<int64_t>(max - min + 1));
|
||||
}
|
||||
|
||||
bool checkRtol(const at::Tensor& diff, const std::vector<at::Tensor> inputs) {
|
||||
double maxValue = 0.0;
|
||||
for (auto& tensor : inputs) {
|
||||
maxValue = fmax(tensor.abs().max().item<float>(), maxValue);
|
||||
}
|
||||
return diff.abs().max().item<float>() < (0.01 + 2e-2 * maxValue);
|
||||
}
|
||||
bool almostEqual(const at::Tensor& a, const at::Tensor& b) {
|
||||
return checkRtol(a - b, {a, b});
|
||||
}
|
||||
|
||||
bool almostEqualTensor(const at::Tensor& a, const at::Tensor& b, float t) {
|
||||
if (a.sizes() != b.sizes()) {
|
||||
return false;
|
||||
}
|
||||
if (a.numel() != b.numel()) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < a.numel(); ++i) {
|
||||
float x1 = a.data_ptr<float>()[i];
|
||||
float x2 = b.data_ptr<float>()[i];
|
||||
if (std::abs(x1 - x2) > t) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool almostEqualVec(
|
||||
const std::vector<float> vec1,
|
||||
const std::vector<float> vec2,
|
||||
float t) {
|
||||
if (vec1.size() != vec2.size()) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < vec1.size(); ++i) {
|
||||
if (std::abs(vec1[i] - vec2[i]) > t) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void print(bool cond, NSString* name, const std::vector<int64_t>& sizes) {
|
||||
NSMutableString* strSizes = [NSMutableString new];
|
||||
std::for_each(sizes.begin(), sizes.end(), ^(int64_t n) {
|
||||
[strSizes appendString:[NSString stringWithFormat:@"%lld,", n]];
|
||||
});
|
||||
void (^print)(NSString*) = ^(NSString* str) {
|
||||
NSLog(@"[TEST_%@], [%@], [%@]", name, strSizes, str);
|
||||
};
|
||||
cond ? print(@"SUCCEED") : print(@"FAILED");
|
||||
}
|
||||
|
||||
bool test_aten() {
|
||||
auto x1 =
|
||||
at::rand({1, 2, 2, 2}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto mx1 = x1.metal();
|
||||
TORCH_CHECK(mx1.device().type() == at::kMetal);
|
||||
auto x2 = mx1.cpu();
|
||||
TORCH_CHECK(x2.device().type() == at::kCPU);
|
||||
bool b = almostEqual(x1, x2);
|
||||
print(b, @"ATEN", {1, 2, 2, 2});
|
||||
return b;
|
||||
}
|
||||
|
||||
bool test_NC4() {
|
||||
#define TEST_NC4(n, c, h, w) \
|
||||
{ \
|
||||
auto t = \
|
||||
at::rand({n, c, h, w}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); \
|
||||
auto b = std::vector<float>{t.data_ptr<float>(), \
|
||||
t.data_ptr<float>() + n * c * h * w}; \
|
||||
auto c4 = NCHW_to_NC4((float*)t.data_ptr<float>(), t.sizes().vec()); \
|
||||
auto n4 = NC4_to_NCHW((float*)c4.data(), t.sizes().vec()); \
|
||||
if (n4 == b) { \
|
||||
print(true, @"NC4", {n, c, h, w}); \
|
||||
} else { \
|
||||
return false; \
|
||||
} \
|
||||
}
|
||||
for (int i = 0; i < ITER_COUNT; ++i) {
|
||||
int64_t N = rand(1, 24);
|
||||
int64_t C = rand(1, 48);
|
||||
int64_t H = rand(1, 320);
|
||||
int64_t W = rand(1, 320);
|
||||
std::vector<int64_t> x{N, C, H, W};
|
||||
auto t = at::rand(x, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto b = std::vector<float>{t.data_ptr<float>(),
|
||||
t.data_ptr<float>() + N * C * H * W};
|
||||
auto c4 = NCHW_to_NC4((float*)t.data_ptr<float>(), t.sizes().vec());
|
||||
auto n4 = NC4_to_NCHW((float*)c4.data(), t.sizes().vec());
|
||||
if (n4 == b) {
|
||||
print(true, @"NC4", x);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool test_MPSImage() API_AVAILABLE(ios(10.0), macos(10.13)) {
|
||||
#define TEST_MPS_IMAGE(n, c, h, w) \
|
||||
{ \
|
||||
auto t1 = \
|
||||
at::rand({n, c, h, w}, at::TensorOptions(at::kCPU).dtype(at::kFloat)); \
|
||||
auto b = std::vector<float>{t1.data_ptr<float>(), \
|
||||
t1.data_ptr<float>() + n * c * h * w}; \
|
||||
MPSImage* img = [MPSImage imageFromCPUTensor:t1]; \
|
||||
auto t2 = [img toCPUTensor]; \
|
||||
bool result = almostEqual(t1, t2); \
|
||||
if (result) { \
|
||||
print(result, @"MPS_IMAGE", {n, c, h, w}); \
|
||||
} else { \
|
||||
return false; \
|
||||
} \
|
||||
}
|
||||
for (int i = 0; i < ITER_COUNT; ++i) {
|
||||
int64_t N = rand(1, 24);
|
||||
int64_t C = rand(1, 48);
|
||||
int64_t H = rand(1, 320);
|
||||
int64_t W = rand(1, 320);
|
||||
std::vector<int64_t> x{N, C, H, W};
|
||||
auto t1 = at::rand(x, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto b = std::vector<float>{t1.data_ptr<float>(),
|
||||
t1.data_ptr<float>() + N * C * H * W};
|
||||
MPSImage* img = [MPSImage imageFromCPUTensor:t1];
|
||||
auto t2 = [img toCPUTensor];
|
||||
bool result = almostEqual(t1, t2);
|
||||
if (result) {
|
||||
print(result, @"MPS_IMAGE", {N, C, H, W});
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool test_MPSImageCopy() {
|
||||
std::vector<int64_t> sz{2, 3, 1, 1};
|
||||
auto t1 = at::rand(sz, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
float* src = t1.data_ptr<float>();
|
||||
MPSImage* im = [MPSImage imageFromHost:src Sizes:t1.sizes().vec()];
|
||||
MPSImage* cim = [MPSImage imageFromImage:im];
|
||||
auto t2 = [cim toCPUTensor];
|
||||
bool b = almostEqual(t1, t2);
|
||||
print(b, @"MPSImageCopy", sz);
|
||||
return b;
|
||||
}
|
||||
|
||||
bool test_MPSTemporaryImageCopy() {
|
||||
std::vector<int64_t> sz{2, 3, 1, 1};
|
||||
auto t1 = at::rand(sz, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
MetalCommandBuffer* cb = [MetalCommandBuffer newBuffer];
|
||||
float* src = t1.data_ptr<float>();
|
||||
MPSTemporaryImage* tim = [MPSImage temporaryImageFromHost:src
|
||||
Sizes:t1.sizes().vec()
|
||||
CommandBuffer:cb];
|
||||
MPSImage* im = [MPSImage imageFromTemporaryImage:tim
|
||||
CommandBuffer:cb
|
||||
waitUntilCompleted:YES];
|
||||
auto t2 = [im toCPUTensor];
|
||||
bool b = almostEqual(t1, t2);
|
||||
print(b, @"MPSTemporaryImageCopy", sz);
|
||||
return b;
|
||||
}
|
||||
|
||||
bool test_conv2d() {
|
||||
#define ARRAY(...) __VA_ARGS__
|
||||
#define TEST_CONV2D(x, w, b, pad) \
|
||||
{ \
|
||||
auto X = torch::rand(x, at::TensorOptions(at::kCPU).dtype(at::kFloat)); \
|
||||
auto W = torch::rand(w, at::TensorOptions(at::kCPU).dtype(at::kFloat)); \
|
||||
auto B = torch::rand(b, at::TensorOptions(at::kCPU).dtype(at::kFloat)); \
|
||||
auto S = c10::IntArrayRef{1, 1}; \
|
||||
auto P = c10::IntArrayRef(pad); \
|
||||
auto D = c10::IntArrayRef{1, 1}; \
|
||||
int64_t groups = 1; \
|
||||
auto Y1 = at::native::conv2d(X, W, B, S, P, D, groups); \
|
||||
auto X2 = X.metal(); \
|
||||
at::native::metal::Conv2DParams params{ \
|
||||
X.sizes(), W.sizes(), P, S, D, groups}; \
|
||||
auto Y2 = at::native::metal::mpscnn::conv2d(X2, W, B, params).cpu(); \
|
||||
bool check = almostEqual(Y1, Y2); \
|
||||
if (check) { \
|
||||
print(check, @"CONV2D", x); \
|
||||
} else { \
|
||||
return false; \
|
||||
} \
|
||||
}
|
||||
for (int i = 0; i < ITER_COUNT; ++i) {
|
||||
int64_t N = rand(1, 10);
|
||||
int64_t C = rand(1, 48);
|
||||
int64_t IH = rand(1, 300);
|
||||
int64_t IW = rand(1, 300);
|
||||
int64_t OC = rand(1, 48);
|
||||
int64_t IC = C;
|
||||
int64_t KH = rand(1, MIN(10, IH));
|
||||
int64_t KW = rand(1, MIN(10, IW));
|
||||
int64_t PH = rand(1, 10);
|
||||
int64_t PW = rand(1, 10);
|
||||
int64_t SH = rand(1, 10);
|
||||
int64_t SW = rand(1, 10);
|
||||
auto X = torch::rand(
|
||||
{N, C, IH, IW}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto W = torch::rand(
|
||||
{OC, IC, KH, KW}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto B = torch::rand({OC}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto S = c10::IntArrayRef({SH, SW});
|
||||
auto P = c10::IntArrayRef({PH, PW});
|
||||
auto D =
|
||||
c10::IntArrayRef({1, 1}); // Dilated convolution is not supported yet
|
||||
int64_t groups = 1;
|
||||
auto Y1 = at::native::conv2d(X, W, B, S, P, D, groups);
|
||||
auto X2 = X.metal();
|
||||
at::native::metal::Conv2DParams params{
|
||||
X.sizes(), W.sizes(), P, S, D, groups};
|
||||
auto Y2 = at::native::metal::mpscnn::conv2d(X2, W, B, params).cpu();
|
||||
bool check = almostEqual(Y1, Y2);
|
||||
if (check) {
|
||||
print(check, @"CONV2D", {N, C, IH, IW});
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool test_depthwiseConv() {
|
||||
#define ARRAY(...) __VA_ARGS__
|
||||
#define TEST_DEPTHWISECONV(x, w, b, p, g) \
|
||||
{ \
|
||||
auto S = c10::IntArrayRef{1, 1}; \
|
||||
auto D = c10::IntArrayRef{1, 1}; \
|
||||
auto OP = c10::IntArrayRef({0, 0}); \
|
||||
auto X = torch::rand(x, at::TensorOptions(at::kCPU).dtype(at::kFloat)); \
|
||||
auto W = torch::rand(w, at::TensorOptions(at::kCPU).dtype(at::kFloat)); \
|
||||
auto B = torch::rand(b, at::TensorOptions(at::kCPU).dtype(at::kFloat)); \
|
||||
auto Y1 = at::native::_convolution( \
|
||||
X, W, B, S, p, D, false, OP, g, false, false, true, true); \
|
||||
auto X2 = X.metal(); \
|
||||
at::native::metal::Conv2DParams params{X.sizes(), W.sizes(), p, S, D, g}; \
|
||||
auto Y2 = at::native::metal::mpscnn::conv2d(X2, W, B, params).cpu(); \
|
||||
bool check = almostEqual(Y1, Y2); \
|
||||
if (check) { \
|
||||
print(check, @"DEPTHWISECONV", x); \
|
||||
} else { \
|
||||
return false; \
|
||||
} \
|
||||
}
|
||||
|
||||
TEST_DEPTHWISECONV(
|
||||
ARRAY({1, 32, 112, 112}),
|
||||
ARRAY({32, 1, 3, 3}),
|
||||
ARRAY({32}),
|
||||
ARRAY({1, 1}),
|
||||
32);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool test_max_pool2d() {
|
||||
auto X =
|
||||
torch::rand({1, 3, 4, 4}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto Y1 = at::native::max_pool2d(X, {2, 2}, {2, 2}, {0, 0}, {1, 1}, false);
|
||||
auto X2 = X.metal();
|
||||
auto Y2 = at::native::metal::mpscnn::max_pool2d(
|
||||
X2, {2, 2}, {2, 2}, {0, 0}, {1, 1}, false)
|
||||
.cpu();
|
||||
bool check = almostEqual(Y1, Y2);
|
||||
if (check) {
|
||||
print(check, @"MAX_POOL2D", {1, 3, 4, 4});
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool test_relu() {
|
||||
auto X =
|
||||
torch::rand({1, 3, 4, 4}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto Y1 = torch::native::relu(X);
|
||||
auto X2 = X.metal();
|
||||
auto Y2 = torch::native::metal::mpscnn::relu(X2).cpu();
|
||||
bool check = almostEqual(Y1, Y2);
|
||||
if (check) {
|
||||
print(check, @"RELU", {1, 3, 4, 4});
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool test_sigmoid() {
|
||||
auto X =
|
||||
torch::rand({1, 3, 4, 4}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto Y1 = torch::native::sigmoid(X);
|
||||
auto X2 = X.metal();
|
||||
auto Y2 = torch::native::metal::mpscnn::sigmoid(X2).cpu();
|
||||
bool check = almostEqual(Y1, Y2);
|
||||
print(check, @"SIGMOID", {1, 3, 4, 4});
|
||||
return true;
|
||||
}
|
||||
|
||||
bool test_addmm() {
|
||||
#define ARRAY(...) __VA_ARGS__
|
||||
#define TEST_ADDMM(x, w, b) \
|
||||
{ \
|
||||
auto X1 = torch::rand(x, at::TensorOptions(at::kCPU).dtype(at::kFloat)); \
|
||||
auto W1 = torch::rand(w, at::TensorOptions(at::kCPU).dtype(at::kFloat)); \
|
||||
auto B = torch::rand(b, at::TensorOptions(at::kCPU).dtype(at::kFloat)); \
|
||||
auto Y1 = at::native::addmm_cpu(B, X1, W1); \
|
||||
auto X2 = X1.metal(); \
|
||||
auto W2 = W1.t().view({W1.sizes()[1], W1.sizes()[0], 1, 1}).contiguous(); \
|
||||
auto Y2 = at::native::metal::mpscnn::addmm(B, X2, W2).cpu(); \
|
||||
bool check = almostEqual(Y1, Y2); \
|
||||
if (check) { \
|
||||
print(check, @"ADDMM", x); \
|
||||
} else { \
|
||||
return false; \
|
||||
} \
|
||||
}
|
||||
for (int i = 0; i < ITER_COUNT; ++i) {
|
||||
int64_t N = rand(1, 10);
|
||||
int64_t IC = rand(1, 128);
|
||||
int64_t OC = rand(1, 128);
|
||||
auto X1 =
|
||||
torch::rand({N, IC}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto W1 =
|
||||
torch::rand({IC, OC}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto B1 =
|
||||
torch::rand({1, OC}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto Y1 = at::native::addmm_cpu(B1, X1, W1);
|
||||
// MPSCNNFullyConnected
|
||||
auto X2 = X1.view({N, IC, 1, 1}).contiguous().metal();
|
||||
auto W2 = W1.t()
|
||||
.view({W1.sizes()[1], W1.sizes()[0], 1, 1})
|
||||
.contiguous(); // W2 lives in CPU
|
||||
auto Y2 = mpscnn::addmm(B1, X2, W2).cpu();
|
||||
bool check = almostEqual(Y1, Y2);
|
||||
if (check) {
|
||||
print(check, @"ADDMM", {N, IC});
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool test_add() {
|
||||
#define ARRAY(...) __VA_ARGS__
|
||||
#define TEST_ADD(a1, a2) \
|
||||
{ \
|
||||
auto X1 = torch::rand(a1, at::TensorOptions(at::kCPU).dtype(at::kFloat)); \
|
||||
auto X2 = torch::rand(a2, at::TensorOptions(at::kCPU).dtype(at::kFloat)); \
|
||||
auto Y1 = at::native::add(X1, X2); \
|
||||
auto MX1 = X1.metal(); \
|
||||
auto MX2 = X2.metal(); \
|
||||
auto Y2 = at::native::metal::mpscnn::add(MX1, MX2).cpu(); \
|
||||
bool check = almostEqual(Y1, Y2); \
|
||||
if (check) { \
|
||||
print(check, @"ADD", a1); \
|
||||
} else { \
|
||||
return false; \
|
||||
} \
|
||||
}
|
||||
TEST_ADD(ARRAY({1, 180, 12, 12}), ARRAY({1, 180, 12, 12}));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool test_sub() {
|
||||
#define ARRAY(...) __VA_ARGS__
|
||||
#define TEST_SUB(a1, a2) \
|
||||
{ \
|
||||
auto X1 = torch::rand(a1, at::TensorOptions(at::kCPU).dtype(at::kFloat)); \
|
||||
auto X2 = torch::rand(a2, at::TensorOptions(at::kCPU).dtype(at::kFloat)); \
|
||||
auto Y1 = at::native::sub(X1, X2); \
|
||||
auto MX1 = X1.metal(); \
|
||||
auto MX2 = X2.metal(); \
|
||||
auto Y2 = at::native::metal::mpscnn::sub(MX1, MX2).cpu(); \
|
||||
bool check = almostEqual(Y1, Y2); \
|
||||
if (check) { \
|
||||
print(check, @"SUB", a1); \
|
||||
} else { \
|
||||
return false; \
|
||||
} \
|
||||
}
|
||||
TEST_SUB(ARRAY({1, 3, 192, 192}), ARRAY({1, 3, 1, 1}));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool test_mul() {
|
||||
#define ARRAY(...) __VA_ARGS__
|
||||
#define TEST_MUL(a1, a2) \
|
||||
{ \
|
||||
auto X1 = torch::rand(a1, at::TensorOptions(at::kCPU).dtype(at::kFloat)); \
|
||||
auto X2 = torch::rand(a2, at::TensorOptions(at::kCPU).dtype(at::kFloat)); \
|
||||
auto Y1 = at::native::mul(X1, X2); \
|
||||
auto MX1 = X1.metal(); \
|
||||
auto MX2 = X2.metal(); \
|
||||
auto Y2 = at::native::metal::mpscnn::mul(MX1, MX2).cpu(); \
|
||||
bool check = almostEqual(Y1, Y2); \
|
||||
if (check) { \
|
||||
print(check, @"MUL", a1); \
|
||||
} else { \
|
||||
return false; \
|
||||
} \
|
||||
}
|
||||
TEST_MUL(ARRAY({1, 3, 192, 192}), ARRAY({1, 3, 1, 1}));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool test_t() {
|
||||
#define ARRAY(...) __VA_ARGS__
|
||||
#define TEST_TRANSPOSE(a1) \
|
||||
{ \
|
||||
auto X1 = torch::rand(a1, at::TensorOptions(at::kCPU).dtype(at::kFloat)); \
|
||||
auto Y1 = at::native::t(X1).contiguous(); \
|
||||
auto X2 = X1.metal(); \
|
||||
auto Y2 = at::native::metal::mpscnn::t(X2).cpu(); \
|
||||
bool check = almostEqual(Y1, Y2); \
|
||||
if (check) { \
|
||||
print(check, @"TRANSPOSE_2D", a1); \
|
||||
} else { \
|
||||
return false; \
|
||||
} \
|
||||
}
|
||||
for (int i = 0; i < ITER_COUNT; ++i) {
|
||||
int64_t H = rand(1, 256);
|
||||
int64_t W = rand(1, 256);
|
||||
auto X1 =
|
||||
torch::rand({H, W}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto Y1 = at::native::t(X1).contiguous();
|
||||
auto X2 = X1.metal();
|
||||
auto Y2 = at::native::metal::mpscnn::t(X2).cpu();
|
||||
bool check = almostEqual(Y1, Y2);
|
||||
if (check) {
|
||||
print(check, @"TRANSPOSE_2D", {H, W});
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool test_view() {
|
||||
auto X1 =
|
||||
torch::rand({1, 3, 2, 2}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto Y1 = X1.view({3, 4}).contiguous();
|
||||
auto X2 = X1.metal();
|
||||
auto Y2 = at::native::metal::mpscnn::view(X2, {3, 4}).cpu();
|
||||
bool b1 = (Y1.sizes() == Y2.sizes());
|
||||
bool b2 = (Y1.strides() == Y2.strides());
|
||||
if (b1 && b2) {
|
||||
print(true, @"VIEW", {1, 3, 2, 2});
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool test_softmax() {
|
||||
auto X1 =
|
||||
torch::rand({2, 3, 1, 1}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto Y1 = torch::native::log_softmax(X1, 1);
|
||||
auto X2 = X1.metal();
|
||||
auto Y2 = torch::native::metal::mpscnn::log_softmax_int(X2).cpu();
|
||||
bool check = almostEqual(Y1, Y2);
|
||||
print(check, @"SOFTMAX", {2, 3, 1, 1});
|
||||
return check;
|
||||
}
|
||||
|
||||
bool test_upsampling_nearest2d_vec() {
|
||||
auto X1 = torch::rand(
|
||||
{1, 48, 24, 24}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto Y1 = torch::native::upsample_nearest2d_cpu(
|
||||
X1,
|
||||
c10::optional<IntArrayRef>({}),
|
||||
c10::optional<ArrayRef<double>>({2, 2}));
|
||||
auto X2 = X1.metal();
|
||||
auto Y2 = torch::native::metal::mpscnn::upsample_nearest2d_vec(
|
||||
X2,
|
||||
c10::optional<IntArrayRef>({}),
|
||||
c10::optional<ArrayRef<double>>({2, 2}))
|
||||
.cpu();
|
||||
bool check = almostEqual(Y1, Y2);
|
||||
print(check, @"UPSAMPLING_NEAREST2D", {1, 48, 24, 24});
|
||||
return check;
|
||||
}
|
||||
|
||||
bool test_adaptive_avg_pool2d() {
|
||||
auto X1 = torch::rand(
|
||||
{1, 48, 24, 24}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto Y1 = at::native::adaptive_avg_pool2d(X1, {1, 1});
|
||||
auto X2 = X1.metal();
|
||||
auto Y2 = torch::native::metal::mpscnn::global_avg_pool2d(X2, {1, 1}).cpu();
|
||||
bool check = almostEqual(Y1, Y2);
|
||||
print(check, @"ADAPTIVE_AVG_POOL2D", {1, 48, 24, 24});
|
||||
return check;
|
||||
}
|
||||
|
||||
bool test_reshape() {
|
||||
auto X1 = torch::rand(
|
||||
{1, 1280, 1, 1}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto Y1 = at::native::reshape(X1, {1, -1});
|
||||
auto X2 = X1.metal();
|
||||
auto Y2 = torch::native::metal::mpscnn::reshape(X2, {1, -1}).cpu();
|
||||
bool check = almostEqual(Y1, Y2);
|
||||
print(check, @"RESHAPE", {1, 1280, 1, 1});
|
||||
return check;
|
||||
}
|
||||
|
||||
bool test_hardtanh_() {
|
||||
auto X1 = torch::rand(
|
||||
{1, 32, 112, 112}, at::TensorOptions(at::kCPU).dtype(at::kFloat));
|
||||
auto Y1 = at::native::hardtanh_(X1, 0, 6.0);
|
||||
auto X2 = X1.metal();
|
||||
auto Y2 = at::native::metal::mpscnn::hardtanh_(X2, 0, 6.0).cpu();
|
||||
bool check = almostEqual(Y1, Y2);
|
||||
print(check, @"HARDTANH_", {1, 32, 112, 112});
|
||||
return check;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
@ -329,6 +329,9 @@ class CAFFE2_API Tensor {
|
||||
/// Returns if a `Tensor` is vulkan tensor.
|
||||
bool is_vulkan() const;
|
||||
|
||||
/// Returns if a `Tensor` is metal tensor.
|
||||
bool is_metal() const;
|
||||
|
||||
/// Returns if a `Tensor` has quantized backend.
|
||||
bool is_quantized() const;
|
||||
|
||||
@ -447,6 +450,7 @@ class CAFFE2_API Tensor {
|
||||
Tensor cuda() const;
|
||||
Tensor hip() const;
|
||||
Tensor vulkan() const;
|
||||
Tensor metal() const;
|
||||
|
||||
// ~~~~~ Autograd API ~~~~~
|
||||
|
||||
|
@ -31,6 +31,10 @@ Tensor Tensor::vulkan() const {
|
||||
return to(options().device(DeviceType::Vulkan), /*non_blocking*/ false, /*copy*/ false);
|
||||
}
|
||||
|
||||
Tensor Tensor::metal() const {
|
||||
return to(options().device(DeviceType::Metal), /*non_blocking*/ false, /*copy*/ false);
|
||||
}
|
||||
|
||||
Tensor Tensor::toType(ScalarType t) const {
|
||||
return to(options().dtype(t), /*non_blocking*/ false, /*copy*/ false);
|
||||
}
|
||||
@ -127,10 +131,20 @@ bool Tensor::is_vulkan() const {
|
||||
return impl_->is_vulkan();
|
||||
}
|
||||
|
||||
bool Tensor::is_metal() const {
|
||||
// NB: this is not a native function to avoid dispatching overhead.
|
||||
return impl_->is_metal();
|
||||
}
|
||||
|
||||
|
||||
bool is_vulkan(Tensor self) {
|
||||
return self.is_vulkan();
|
||||
}
|
||||
|
||||
bool is_metal(Tensor self) {
|
||||
return self.is_metal();
|
||||
}
|
||||
|
||||
bool Tensor::is_quantized() const {
|
||||
// NB: this is not a native function to avoid dispatching overhead.
|
||||
return impl_->is_quantized();
|
||||
|
@ -15,9 +15,10 @@
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
|
||||
#include <sstream>
|
||||
#include "torch/script.h"
|
||||
#include "torch/csrc/jit/api/module.h"
|
||||
#include <torch/csrc/jit/passes/metal_rewrite.h>
|
||||
#include "torch/csrc/jit/passes/vulkan_rewrite.h"
|
||||
#include "torch/csrc/jit/passes/xnnpack_rewrite.h"
|
||||
#include "torch/csrc/jit/serialization/import.h"
|
||||
@ -29,6 +30,7 @@ C10_DEFINE_string(
|
||||
"",
|
||||
"Name of the output model to be saved.");
|
||||
C10_DEFINE_string(backend, "", "The backend to be optimized");
|
||||
C10_DEFINE_string(preserved_methods, "", "Methods to be preserved")
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
c10::SetUsageMessage(
|
||||
@ -36,7 +38,8 @@ int main(int argc, char** argv) {
|
||||
"./optimize_for_mobile"
|
||||
" --model=<model_file>"
|
||||
" [--output=<output_file_name>]"
|
||||
" [--backend=<cpu|vulkan>]"
|
||||
" [--backend=<cpu|vulkan|metal>]"
|
||||
" [--preserved_methods=<method_names>]"
|
||||
);
|
||||
|
||||
if (!c10::ParseCommandLineFlags(&argc, &argv)) {
|
||||
@ -54,6 +57,21 @@ int main(int argc, char** argv) {
|
||||
output_model_name = FLAGS_output;
|
||||
}
|
||||
|
||||
std::vector<std::string> preserved_methods;
|
||||
if(FLAGS_preserved_methods != ""){
|
||||
std::stringstream ss(FLAGS_preserved_methods);
|
||||
std::string m;
|
||||
while(std::getline(ss, m, ';')){
|
||||
if(m != ""){
|
||||
preserved_methods.emplace_back(std::move(m));
|
||||
}
|
||||
}
|
||||
std::cout<<"The following methods will be preserved:"<<std::endl;
|
||||
for(auto& str : preserved_methods){
|
||||
std::cout<<str<<std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
auto module = torch::jit::load(FLAGS_model);
|
||||
auto ops = torch::jit::export_opnames(module);
|
||||
std::cout << "\npt_operator_library(" << std::endl;
|
||||
@ -68,9 +86,10 @@ int main(int argc, char** argv) {
|
||||
if (FLAGS_backend == "" || FLAGS_backend == "cpu") {
|
||||
optimized_module = torch::jit::optimizeForMobile(module);
|
||||
} else if (FLAGS_backend == "vulkan") {
|
||||
std::vector<std::string> empty_preserved_methods;
|
||||
optimized_module = torch::jit::vulkanOptimizeForMobile(module, empty_preserved_methods);
|
||||
} else {
|
||||
optimized_module = torch::jit::vulkanOptimizeForMobile(module, preserved_methods);
|
||||
} else if (FLAGS_backend == "metal"){
|
||||
optimized_module = torch::jit::metalOptimizeForMobile(module, preserved_methods);
|
||||
}else{
|
||||
CAFFE_ENFORCE(false, "Unknown backend: " + FLAGS_backend);
|
||||
}
|
||||
auto new_ops = torch::jit::export_opnames(optimized_module);
|
||||
|
@ -37,6 +37,7 @@ enum class Backend {
|
||||
MSNPU,
|
||||
XLA,
|
||||
Vulkan,
|
||||
Metal,
|
||||
QuantizedCPU,
|
||||
QuantizedCUDA,
|
||||
Undefined,
|
||||
@ -107,6 +108,8 @@ static inline Backend dispatchKeyToBackend(DispatchKey t) {
|
||||
return Backend::XLA;
|
||||
} else if (t == DispatchKey::Vulkan) {
|
||||
return Backend::Vulkan;
|
||||
} else if (t == DispatchKey::Metal) {
|
||||
return Backend::Metal;
|
||||
} else if (t == DispatchKey::SparseCPU) {
|
||||
return Backend::SparseCPU;
|
||||
} else if (t == DispatchKey::SparseCUDA) {
|
||||
@ -150,6 +153,8 @@ static inline DispatchKey backendToDispatchKey(Backend b) {
|
||||
return DispatchKey::MkldnnCPU;
|
||||
case Backend::Vulkan:
|
||||
return DispatchKey::Vulkan;
|
||||
case Backend::Metal:
|
||||
return DispatchKey::Metal;
|
||||
case Backend::QuantizedCPU:
|
||||
return DispatchKey::QuantizedCPU;
|
||||
case Backend::QuantizedCUDA:
|
||||
@ -188,6 +193,8 @@ static inline DeviceType backendToDeviceType(Backend b) {
|
||||
return DeviceType::CUDA;
|
||||
case Backend::Vulkan:
|
||||
return DeviceType::Vulkan;
|
||||
case Backend::Metal:
|
||||
return DeviceType::Metal;
|
||||
case Backend::Undefined:
|
||||
AT_ERROR("Undefined backend is not a valid device type");
|
||||
default:
|
||||
@ -292,6 +299,8 @@ static inline const char* toString(Backend b) {
|
||||
return "MkldnnCPU";
|
||||
case Backend::Vulkan:
|
||||
return "Vulkan";
|
||||
case Backend::Metal:
|
||||
return "Metal";
|
||||
case Backend::QuantizedCPU:
|
||||
return "QuantizedCPU";
|
||||
case Backend::QuantizedCUDA:
|
||||
|
@ -29,6 +29,8 @@ std::string DeviceTypeName(DeviceType d, bool lower_case) {
|
||||
return lower_case ? "xla" : "XLA";
|
||||
case DeviceType::Vulkan:
|
||||
return lower_case ? "vulkan" : "VULKAN";
|
||||
case DeviceType::Metal:
|
||||
return lower_case ? "metal" : "METAL";
|
||||
default:
|
||||
AT_ERROR(
|
||||
"Unknown device: ",
|
||||
@ -62,6 +64,7 @@ bool isValidDeviceType(DeviceType d) {
|
||||
case DeviceType::MSNPU:
|
||||
case DeviceType::XLA:
|
||||
case DeviceType::Vulkan:
|
||||
case DeviceType::Metal:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
@ -24,6 +24,7 @@ enum class DeviceType : int16_t {
|
||||
MSNPU = 8, // MSNPU
|
||||
XLA = 9, // XLA / TPU
|
||||
Vulkan = 10, // Vulkan
|
||||
Metal = 11, //Metal
|
||||
// NB: If you add more devices:
|
||||
// - Change the implementations of DeviceTypeName and isValidDeviceType
|
||||
// in DeviceType.cpp
|
||||
@ -39,6 +40,7 @@ constexpr DeviceType kFPGA = DeviceType::FPGA;
|
||||
constexpr DeviceType kMSNPU = DeviceType::MSNPU;
|
||||
constexpr DeviceType kXLA = DeviceType::XLA;
|
||||
constexpr DeviceType kVulkan = DeviceType::Vulkan;
|
||||
constexpr DeviceType kMetal = DeviceType::Metal;
|
||||
|
||||
// define explicit int constant
|
||||
constexpr int COMPILE_TIME_MAX_DEVICE_TYPES =
|
||||
|
@ -21,7 +21,8 @@ const char* toString(DispatchKey t) {
|
||||
return "XLA";
|
||||
case DispatchKey::Vulkan:
|
||||
return "Vulkan";
|
||||
|
||||
case DispatchKey::Metal:
|
||||
return "Metal";
|
||||
case DispatchKey::MKLDNN:
|
||||
return "MKLDNN";
|
||||
case DispatchKey::OpenGL:
|
||||
@ -30,7 +31,6 @@ const char* toString(DispatchKey t) {
|
||||
return "OpenCL";
|
||||
case DispatchKey::IDEEP:
|
||||
return "IDEEP";
|
||||
|
||||
case DispatchKey::QuantizedCPU:
|
||||
return "QuantizedCPU";
|
||||
case DispatchKey::QuantizedCUDA:
|
||||
|
@ -60,6 +60,7 @@ enum class DispatchKey : uint8_t {
|
||||
// test/cpp_extensions/msnpu_extension.cpp
|
||||
XLA, // lives out of tree at https://github.com/pytorch/xla
|
||||
Vulkan,
|
||||
Metal,
|
||||
|
||||
// These are Caffe2 device types which we grandfathered into
|
||||
// DispatchKey.
|
||||
|
@ -469,6 +469,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||
return key_set_.has(DispatchKey::Vulkan);
|
||||
}
|
||||
|
||||
bool is_metal() const {
|
||||
return key_set_.has(DispatchKey::Metal);
|
||||
}
|
||||
|
||||
// TODO: remove this once we don't automatically enabled Autograd dispatch keys
|
||||
// in TensorImpl constructor.
|
||||
// DON'T USE THIS API!! It's only created for testing purpose in
|
||||
|
@ -613,6 +613,8 @@ inline DispatchKey computeDispatchKey(c10::optional<ScalarType> dtype, c10::opti
|
||||
return DispatchKey::XLA;
|
||||
case DeviceType::Vulkan:
|
||||
return DispatchKey::Vulkan;
|
||||
case DeviceType::Metal:
|
||||
return DispatchKey::Metal;
|
||||
default:
|
||||
AT_ERROR("Unsupported device type for dense layout: ", device_.type());
|
||||
}
|
||||
@ -675,6 +677,8 @@ inline DeviceType computeDeviceType(DispatchKey tid) {
|
||||
return DeviceType::CPU;
|
||||
} else if (tid == DispatchKey::Vulkan) {
|
||||
return DeviceType::Vulkan;
|
||||
} else if (tid == DispatchKey::Metal) {
|
||||
return DeviceType::Metal;
|
||||
} else {
|
||||
AT_ASSERTM(false, "Unknown DispatchKey: ", tid);
|
||||
}
|
||||
|
@ -92,6 +92,8 @@ CMAKE_ARGS+=("-DUSE_LEVELDB=OFF")
|
||||
CMAKE_ARGS+=("-DUSE_MPI=OFF")
|
||||
CMAKE_ARGS+=("-DUSE_NUMPY=OFF")
|
||||
CMAKE_ARGS+=("-DUSE_NNPACK=OFF")
|
||||
CMAKE_ARGS+=("-DUSE_METAL=OFF")
|
||||
CMAKE_ARGS+=("-DUSE_MKLDNN=OFF")
|
||||
|
||||
# pthreads
|
||||
CMAKE_ARGS+=("-DCMAKE_THREAD_LIBS_INIT=-lpthread")
|
||||
|
@ -199,6 +199,7 @@ core_sources_full = [
|
||||
"torch/csrc/jit/passes/utils/subgraph_utils.cpp",
|
||||
"torch/csrc/jit/passes/xnnpack_rewrite.cpp",
|
||||
"torch/csrc/jit/passes/vulkan_rewrite.cpp",
|
||||
"torch/csrc/jit/passes/metal_rewrite.cpp",
|
||||
"torch/csrc/jit/passes/quantization/helper.cpp",
|
||||
"torch/csrc/jit/passes/quantization/quantization_type.cpp",
|
||||
"torch/csrc/jit/passes/quantization/insert_observers.cpp",
|
||||
|
225
torch/csrc/jit/passes/metal_rewrite.cpp
Normal file
225
torch/csrc/jit/passes/metal_rewrite.cpp
Normal file
@ -0,0 +1,225 @@
|
||||
#include <ATen/core/jit_type.h>
|
||||
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/ir/subgraph_matcher.h>
|
||||
#include <torch/csrc/jit/passes/constant_pooling.h>
|
||||
#include <torch/csrc/jit/passes/fold_conv_bn.h>
|
||||
#include <torch/csrc/jit/passes/freeze_module.h>
|
||||
#include <torch/csrc/jit/passes/fuse_linear.h>
|
||||
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
|
||||
#include <torch/csrc/jit/passes/metal_rewrite.h>
|
||||
#include <torch/csrc/jit/passes/prepack_folding.h>
|
||||
#include <torch/csrc/jit/passes/remove_dropout.h>
|
||||
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
#ifdef USE_METAL
|
||||
|
||||
namespace {
|
||||
|
||||
void insertPrePackedConv2dOp(std::shared_ptr<Graph>& graph) {
|
||||
graph_rewrite_helper::replaceConvolutionWithAtenConv(graph);
|
||||
|
||||
std::string conv_2d_pattern = R"(
|
||||
graph(%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int):
|
||||
%r = aten::conv2d(%input, %weight, %bias, %stride, %padding, %dilation, %groups)
|
||||
return (%r) )";
|
||||
|
||||
std::string prepacked_ops_conv2d_pattern = R"(
|
||||
graph(%input, %weight, %bias, %stride:int[], %padding:int[],
|
||||
%dilation:int[], %groups:int):
|
||||
%output_min_max : None = prim::Constant()
|
||||
%packed_weight_bias = metal_prepack::conv2d_prepack(
|
||||
%weight, %bias, %stride, %padding, %dilation, %groups,
|
||||
%output_min_max, %output_min_max)
|
||||
%r = metal_prepack::conv2d_run(%input, %packed_weight_bias)
|
||||
return (%r) )";
|
||||
|
||||
SubgraphRewriter rewriter;
|
||||
rewriter.RegisterRewritePattern(
|
||||
conv_2d_pattern, prepacked_ops_conv2d_pattern);
|
||||
rewriter.runOnGraph(graph);
|
||||
}
|
||||
|
||||
void fuseReluWithPackedOps(std::shared_ptr<Graph>& graph) {
|
||||
SubgraphRewriter rewriter;
|
||||
|
||||
std::string conv2d_prepack_run_relu = R"(
|
||||
graph(%input, %weight, %bias, %stride:int[], %padding:int[],
|
||||
%dilation:int[], %groups:int, %dummy_min_max):
|
||||
%packed_weight_bias = metal_prepack::conv2d_prepack(
|
||||
%weight, %bias, %stride, %padding, %dilation, %groups,
|
||||
%dummy_min_max, %dummy_min_max)
|
||||
%r = metal_prepack::conv2d_run(%input, %packed_weight_bias)
|
||||
%r = aten::relu(%r)
|
||||
return (%r) )";
|
||||
|
||||
std::string conv2d_prepack_run_relu_fused = R"(
|
||||
graph(%input, %weight, %bias, %stride:int[], %padding:int[],
|
||||
%dilation:int[], %groups:int, %dummy_min_max):
|
||||
%output_min: float = prim::Constant[value=0.0]()
|
||||
%output_max: None = prim::Constant()
|
||||
%packed_weight_bias: __torch__.torch.classes.metal.Conv2dOpContext = metal_prepack::conv2d_prepack(
|
||||
%weight, %bias, %stride, %padding, %dilation, %groups,
|
||||
%output_min, %output_max)
|
||||
%r = metal_prepack::conv2d_run(%input, %packed_weight_bias)
|
||||
return (%r) )";
|
||||
|
||||
rewriter.RegisterRewritePattern(
|
||||
conv2d_prepack_run_relu, conv2d_prepack_run_relu_fused);
|
||||
|
||||
std::string conv2d_prepack_run_relu_inplace = R"(
|
||||
graph(%input, %weight, %bias, %stride:int[], %padding:int[],
|
||||
%dilation:int[], %groups:int, %dummy_min_max):
|
||||
%packed_weight_bias = metal_prepack::conv2d_prepack(
|
||||
%weight, %bias, %stride, %padding, %dilation, %groups,
|
||||
%dummy_min_max, %dummy_min_max)
|
||||
%r = metal_prepack::conv2d_run(%input, %packed_weight_bias)
|
||||
%r = aten::relu_(%r)
|
||||
return (%r) )";
|
||||
|
||||
rewriter.RegisterRewritePattern(
|
||||
conv2d_prepack_run_relu_inplace, conv2d_prepack_run_relu_fused);
|
||||
|
||||
rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
|
||||
}
|
||||
|
||||
void fuseHardtanhWithPackedOps(std::shared_ptr<Graph>& graph) {
|
||||
SubgraphRewriter rewriter;
|
||||
|
||||
std::string conv2d_prepack_run_hardtanh_fused = R"(
|
||||
graph(%input, %weight, %bias, %stride:int[], %padding:int[],
|
||||
%dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
|
||||
%packed_weight_bias: __torch__.torch.classes.metal.Conv2dOpContext = metal_prepack::conv2d_prepack(
|
||||
%weight, %bias, %stride, %padding, %dilation, %groups,
|
||||
%output_min, %output_max)
|
||||
%r = metal_prepack::conv2d_run(%input, %packed_weight_bias)
|
||||
return (%r) )";
|
||||
|
||||
std::string conv2d_prepack_run_hardtanh = R"(
|
||||
graph(%input, %weight, %bias, %stride:int[], %padding:int[],
|
||||
%dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
|
||||
%packed_weight_bias = metal_prepack::conv2d_prepack(
|
||||
%weight, %bias, %stride, %padding, %dilation, %groups,
|
||||
%dummy_min_max, %dummy_min_max)
|
||||
%r = metal_prepack::conv2d_run(%input, %packed_weight_bias)
|
||||
%r = aten::hardtanh(%r, %output_min, %output_max)
|
||||
return (%r) )";
|
||||
|
||||
rewriter.RegisterRewritePattern(
|
||||
conv2d_prepack_run_hardtanh, conv2d_prepack_run_hardtanh_fused);
|
||||
|
||||
std::string conv2d_prepack_run_hardtanh_inplace = R"(
|
||||
graph(%input, %weight, %bias, %stride:int[], %padding:int[],
|
||||
%dilation:int[], %groups:int, %output_min, %output_max, %dummy_min_max):
|
||||
%packed_weight_bias = metal_prepack::conv2d_prepack(
|
||||
%weight, %bias, %stride, %padding, %dilation, %groups,
|
||||
%dummy_min_max, %dummy_min_max)
|
||||
%r = metal_prepack::conv2d_run(%input, %packed_weight_bias)
|
||||
%r = aten::hardtanh_(%r, %output_min, %output_max)
|
||||
return (%r) )";
|
||||
|
||||
rewriter.RegisterRewritePattern(
|
||||
conv2d_prepack_run_hardtanh_inplace, conv2d_prepack_run_hardtanh_fused);
|
||||
|
||||
rewriter.runOnGraph(graph, torch::jit::graph_rewrite_helper::isClampFusable);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void metalInsertPrePackedOps(std::shared_ptr<Graph>& graph) {
|
||||
insertPrePackedConv2dOp(graph);
|
||||
}
|
||||
|
||||
void metalInsertPrePackedOps(script::Module& module) {
|
||||
for (auto& method : module.get_methods()) {
|
||||
auto graph = method.graph();
|
||||
metalInsertPrePackedOps(graph);
|
||||
}
|
||||
for (script::Module m : module.children()) {
|
||||
metalInsertPrePackedOps(m);
|
||||
}
|
||||
}
|
||||
|
||||
void metalFoldPrePackingOps(script::Module& m) {
|
||||
PrePackingOpsFilterFn filter_fn = [](const Node* n) -> bool {
|
||||
return (
|
||||
n->kind() == Symbol::fromQualString("metal_prepack::conv2d_prepack"));
|
||||
};
|
||||
PrePackingOpsFolder(m, filter_fn, "prepack_folding");
|
||||
}
|
||||
|
||||
void metalFusePrePackedConvWithClamp(script::Module& module) {
|
||||
auto graph = module.get_method("forward").graph();
|
||||
fuseReluWithPackedOps(graph);
|
||||
fuseHardtanhWithPackedOps(graph);
|
||||
}
|
||||
|
||||
void metalInsertCopyOps(script::Module& module) {
|
||||
auto graph = module.get_method("forward").graph();
|
||||
auto&& outputs = graph->outputs();
|
||||
for (int i = 0; i < outputs.size(); ++i) {
|
||||
Value* output = outputs[i];
|
||||
std::cout << "find output: " << *output->node() << std::endl;
|
||||
auto namedValue = NamedValue("", output);
|
||||
if (namedValue.type()->kind() == TypeKind::TensorType) {
|
||||
// find the insertion point
|
||||
WithInsertPoint ip(output->node()->next());
|
||||
Value* replaced_output = graph->insert(
|
||||
Symbol::fromQualString("metal::copy_to_host"), {namedValue});
|
||||
std::cout << "insert: " << *replaced_output->node() << std::endl;
|
||||
// replaced the output
|
||||
graph->block()->replaceOutput(i, replaced_output);
|
||||
}
|
||||
}
|
||||
SubgraphRewriter rewriter;
|
||||
rewriter.runOnGraph(graph);
|
||||
}
|
||||
|
||||
script::Module metalOptimizeForMobile(
|
||||
const script::Module& m,
|
||||
const std::vector<std::string>& preserved_methods) {
|
||||
auto cloned_module = m.clone();
|
||||
cloned_module.eval();
|
||||
cloned_module = FoldConvBatchNorm(cloned_module);
|
||||
metalInsertPrePackedOps(cloned_module);
|
||||
cloned_module = freeze_module(cloned_module, preserved_methods);
|
||||
metalFusePrePackedConvWithClamp(cloned_module);
|
||||
metalFoldPrePackingOps(cloned_module);
|
||||
metalInsertCopyOps(cloned_module);
|
||||
removeDropout(cloned_module);
|
||||
return cloned_module;
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
void metalInsertPrePackedOps(std::shared_ptr<Graph>& graph) {
|
||||
TORCH_INTERNAL_ASSERT("metal is not enabled. Please build with USE_METAL=1");
|
||||
}
|
||||
|
||||
void metalInsertPrePackedOps(script::Module& module) {
|
||||
TORCH_INTERNAL_ASSERT("metal is not enabled. Please build with USE_METAL=1");
|
||||
}
|
||||
|
||||
TORCH_API void metalFusePrePackedConvWithClamp(script::Module& module) {
|
||||
TORCH_INTERNAL_ASSERT("metal is not enabled. Please build with USE_METAL=1");
|
||||
}
|
||||
|
||||
TORCH_API void metalFoldPrePackingOps(script::Module& module) {
|
||||
TORCH_INTERNAL_ASSERT("metal is not enabled. Please build with USE_METAL=1");
|
||||
}
|
||||
|
||||
script::Module metalOptimizeForMobile(
|
||||
const script::Module& m,
|
||||
const std::vector<std::string>& preserved_methods) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
"Mobile optimizaiton only available with metal at the moment. "
|
||||
"metal is not enabled. Please build with USE_METAL=1");
|
||||
return m;
|
||||
}
|
||||
|
||||
#endif
|
||||
} // namespace jit
|
||||
} // namespace torch
|
17
torch/csrc/jit/passes/metal_rewrite.h
Normal file
17
torch/csrc/jit/passes/metal_rewrite.h
Normal file
@ -0,0 +1,17 @@
|
||||
#pragma once
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
TORCH_API void metalInsertPrePackedOps(std::shared_ptr<Graph>& graph);
|
||||
TORCH_API void metalInsertPrePackedOps(script::Module& module);
|
||||
TORCH_API void metalFusePrePackedConvWithClamp(script::Module& module);
|
||||
TORCH_API void metalFoldPrePackingOps(script::Module& module);
|
||||
TORCH_API script::Module metalOptimizeForMobile(
|
||||
const script::Module& module,
|
||||
const std::vector<std::string>& preserved_methods);
|
||||
} // namespace jit
|
||||
} // namespace torch
|
Reference in New Issue
Block a user