From 495070b388478532f0149a37d75a38c7bd23ad96 Mon Sep 17 00:00:00 2001 From: Tao Xu Date: Sat, 17 Oct 2020 10:22:55 -0700 Subject: [PATCH] [Metal] Add the Python binding for optimize_for_mobile (#46456) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46456 Add the python binding in CMake. The general workflow is - Build pytorch - `USE_PYTORCH_METAL=ON python setup.py install --cmake` - Run optimize_for_mobile ``` import torch from torch.utils.mobile_optimizer import optimize_for_mobile scripted_model = torch.jit.load('./mobilenetv2.pt') optimized_model = optimize_for_mobile(scripted_model, backend='metal') torch.jit.export_opnames(optimized_model) torch.jit.save(optimized_model, './mobilenetv2_metal.bc') ``` The exported ops are ``` ['aten::adaptive_avg_pool2d', 'aten::add.Tensor', 'aten::addmm', 'aten::reshape', 'aten::size.int', 'metal::copy_to_host', 'metal_prepack::conv2d_run'] ``` ghstack-source-id: 114559878 Test Plan: - Sandcastle CI - Circle CI Reviewed By: kimishpatel Differential Revision: D24356768 fbshipit-source-id: fb5c4c4b6316347b67edb4132da044a81470ddfd --- CMakeLists.txt | 4 + aten/src/ATen/CMakeLists.txt | 17 +++- test/test_metal.py | 159 ++++++++++++++++++++++++++++++++ torch/_C/__init__.pyi.in | 2 + torch/csrc/jit/python/init.cpp | 25 +++++ torch/utils/mobile_optimizer.py | 11 ++- 6 files changed, 212 insertions(+), 6 deletions(-) create mode 100644 test/test_metal.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 8c8ff698d133..3dc01b5925ee 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -548,6 +548,10 @@ if(USE_VULKAN_RELAXED_PRECISION) string(APPEND CMAKE_CXX_FLAGS " -DUSE_VULKAN_RELAXED_PRECISION") endif() +if(USE_PYTORCH_METAL) + string(APPEND CMAKE_CXX_FLAGS " -DUSE_PYTORCH_METAL") +endif() + # ---[ Allowlist file if allowlist is specified include(cmake/Allowlist.cmake) diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index bf553642311d..31b479de1ae7 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -67,12 +67,15 @@ file(GLOB native_mkldnn_cpp "native/mkldnn/*.cpp") file(GLOB vulkan_cpp "vulkan/*.cpp") file(GLOB native_vulkan_cpp "native/vulkan/api/*.cpp" "native/vulkan/*.cpp") +# Metal file(GLOB metal_h "metal/*.h") file(GLOB metal_cpp "metal/*.cpp") file(GLOB_RECURSE native_metal_h "native/metal/*.h") file(GLOB metal_test_srcs "native/metal/mpscnn/tests/*.mm") file(GLOB_RECURSE native_metal_srcs "native/metal/*.mm", "native/metal/*.cpp") EXCLUDE(native_metal_srcs "${native_metal_srcs}" ${metal_test_srcs}) +file(GLOB metal_prepack_h "native/metal/MetalPrepackOpContext.h") +file(GLOB metal_prepack_cpp "native/metal/MetalPrepackOpRegister.cpp") file(GLOB native_sparse_cpp "native/sparse/*.cpp") file(GLOB native_quantized_cpp @@ -125,8 +128,14 @@ else() set(all_cpu_cpp ${all_cpu_cpp} ${vulkan_cpp}) endif() +# Metal if(USE_PYTORCH_METAL) - set(all_cpu_cpp ${all_cpu_cpp} ${metal_cpp} ${native_metal_srcs}) + if(IOS) + set(all_cpu_cpp ${all_cpu_cpp} ${metal_cpp} ${native_metal_srcs}) + else() + # Add files needed from optimized_for_mobile + set(all_cpu_cpp ${all_cpu_cpp} ${metal_cpp} ${metal_prepack_cpp}) + endif() else() set(all_cpu_cpp ${all_cpu_cpp} ${metal_cpp}) endif() @@ -391,7 +400,11 @@ if(NOT INTERN_BUILD_MOBILE) list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${miopen_h}) else() if(USE_PYTORCH_METAL) - list(APPEND INSTALL_HEADERS ${metal_h} ${native_metal_h}) + if(IOS) + list(APPEND INSTALL_HEADERS ${metal_h} ${native_metal_h}) + else() + list(APPEND INSTALL_HEADERS ${metal_h} ${metal_prepack_h}) + endif() endif() endif() diff --git a/test/test_metal.py b/test/test_metal.py new file mode 100644 index 000000000000..f5a77d0b06d6 --- /dev/null +++ b/test/test_metal.py @@ -0,0 +1,159 @@ +import torch +from torch.nn import functional as F + +from torch.testing._internal.common_utils import TestCase, run_tests +from torch.testing import FileCheck +import io + +class TestMetalRewritePass(TestCase): + @staticmethod + def validate_transformed_module( + # To please flake + self, + pattern_count_map, + data_shape, + prepack_removal=False, + fuse_clamping_ops=False): + module_instance = self + scripted_model = torch.jit.script(module_instance) + scripted_model.eval() + input_data = torch.normal(1, 20, size=data_shape) + ref_result = scripted_model(input_data) + torch._C._jit_pass_metal_insert_prepacked_ops(scripted_model._c) + if fuse_clamping_ops or prepack_removal: + scripted_model._c = torch._C._freeze_module(scripted_model._c) + if fuse_clamping_ops: + torch._C._jit_pass_metal_fuse_clamp_w_prepacked_conv(scripted_model._c) + if prepack_removal: + torch._C._jit_pass_metal_fold_prepacking_ops(scripted_model._c) + + buffer = io.BytesIO() + torch.jit.save(scripted_model, buffer) + buffer.seek(0) + deserialized_scripted_model = torch.jit.load(buffer) + for pattern, v in pattern_count_map.items(): + if (v == 0): + FileCheck().check(pattern).run(deserialized_scripted_model.graph) + elif (v == -1): + FileCheck().check_not(pattern).run(deserialized_scripted_model.graph) + else: + FileCheck().check_count(pattern, v, exactly=True).run(deserialized_scripted_model.graph) + + def test_conv(self): + # Conv params + batch_size = 2 + input_channels_per_group = 6 + height = 16 + width = 16 + output_channels_per_group = 6 + groups = 4 + kernel_h = kernel_w = 3 + stride_h = stride_w = 1 + pad_h = pad_w = 1 + dilation = 1 + input_channels = input_channels_per_group * groups + output_channels = output_channels_per_group * groups + kernels = (kernel_h, kernel_w) + strides = (stride_h, stride_w) + paddings = (pad_h, pad_w) + dilations = (dilation, dilation) + conv_weight_shape = (output_channels, input_channels_per_group, kernel_h, kernel_w) + conv_bias_shape = (output_channels) + + class Conv2D(torch.nn.Module): + def __init__(self): + super(Conv2D, self).__init__() + self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False) + self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False) + self.strides = strides + self.paddings = paddings + self.dilations = dilations + self.groups = groups + + def forward(self, x): + return F.conv2d(x, self.weight, self.bias, + self.strides, self.paddings, self.dilations, self.groups) + + data_shape = (batch_size, input_channels, height, width) + pattern_count_map = {"Tensor = aten::conv2d": -1, + "metal_prepack::conv2d_prepack": 1, + "metal_prepack::conv2d_run": 1} + TestMetalRewritePass.validate_transformed_module(Conv2D(), pattern_count_map, data_shape) + + class Conv2DRelu(torch.nn.Module): + def __init__(self): + super(Conv2DRelu, self).__init__() + self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False) + self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False) + self.strides = strides + self.paddings = paddings + self.dilations = dilations + self.groups = groups + + def forward(self, x): + o = F.conv2d(x, self.weight, self.bias, + self.strides, self.paddings, self.dilations, self.groups) + o = F.relu(o) + return o + + data_shape = (batch_size, input_channels, height, width) + pattern_count_map = {"Tensor = aten::conv2d": -1, + "metal_prepack::conv2d_prepack": 1, + "metal_prepack::conv2d_run": 1} + TestMetalRewritePass.validate_transformed_module( + Conv2DRelu(), pattern_count_map, data_shape) + + pattern_count_map["aten::relu"] = 1 + pattern_count_map["metal_prepack::conv2d_prepack"] = -1 + TestMetalRewritePass.validate_transformed_module( + Conv2DRelu(), + pattern_count_map, + data_shape, + prepack_removal=True) + pattern_count_map["aten::relu"] = -1 + TestMetalRewritePass.validate_transformed_module( + Conv2DRelu(), + pattern_count_map, + data_shape, + prepack_removal=True, + fuse_clamping_ops=True) + + + class Conv2DHardtanh(torch.nn.Module): + def __init__(self): + super(Conv2DHardtanh, self).__init__() + self.weight = torch.nn.Parameter(torch.Tensor(torch.rand(conv_weight_shape)), requires_grad=False) + self.bias = torch.nn.Parameter(torch.Tensor(torch.rand(conv_bias_shape)), requires_grad=False) + self.strides = strides + self.paddings = paddings + self.dilations = dilations + self.groups = groups + + def forward(self, x): + o = F.conv2d(x, self.weight, self.bias, + self.strides, self.paddings, self.dilations, self.groups) + o = F.hardtanh(o) + return o + + data_shape = (batch_size, input_channels, height, width) + pattern_count_map = {"Tensor = aten::conv2d": -1, + "metal_prepack::conv2d_prepack": 1, + "metal_prepack::conv2d_run": 1} + TestMetalRewritePass.validate_transformed_module(Conv2DHardtanh(), pattern_count_map, data_shape) + pattern_count_map["aten::hardtanh"] = 1 + pattern_count_map["metal_prepack::conv2d_prepack"] = -1 + TestMetalRewritePass.validate_transformed_module( + Conv2DHardtanh(), + pattern_count_map, + data_shape, + prepack_removal=True) + pattern_count_map["aten::hardtanh"] = -1 + TestMetalRewritePass.validate_transformed_module( + Conv2DRelu(), + pattern_count_map, + data_shape, + prepack_removal=True, + fuse_clamping_ops=True) + +if __name__ == "__main__": + run_tests() diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index e5679c091d55..c81666156189 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -176,6 +176,8 @@ def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule', preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ... def _jit_pass_vulkan_optimize_for_mobile(module: 'torch.jit.ScriptModule', preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ... +def _jit_pass_metal_optimize_for_mobile(module: 'torch.jit.ScriptModule', + preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ... def _jit_pass_inline(Graph) -> None: ... def _jit_get_schemas_for_operator(name :str) -> List[FunctionSchema]: ... def _jit_can_fuse_on_cpu() -> _bool: ... diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 1649d6b175cf..d505000e6031 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -679,6 +680,30 @@ void initJITBindings(PyObject* module) { std::vector& preserved_methods) { return vulkanOptimizeForMobile(module, preserved_methods); }) + .def( + "_jit_pass_metal_insert_prepacked_ops", + [](std::shared_ptr& graph) { + return metalInsertPrePackedOps(graph); + }) + .def( + "_jit_pass_metal_insert_prepacked_ops", + [](script::Module& module) { + return metalInsertPrePackedOps(module); + }) + .def( + "_jit_pass_metal_fuse_clamp_w_prepacked_conv", + [](script::Module& module) { + return metalFusePrePackedConvWithClamp(module); + }) + .def( + "_jit_pass_metal_fold_prepacking_ops", + [](script::Module& module) { return metalFoldPrePackingOps(module); }) + .def( + "_jit_pass_metal_optimize_for_mobile", + [](script::Module& module, + std::vector& preserved_methods) { + return metalOptimizeForMobile(module, preserved_methods); + }) .def( "_jit_pass_onnx_unpack_quantized_weights", [](std::shared_ptr& graph, diff --git a/torch/utils/mobile_optimizer.py b/torch/utils/mobile_optimizer.py index 16d49195ce16..a9bbbfb9e6ac 100644 --- a/torch/utils/mobile_optimizer.py +++ b/torch/utils/mobile_optimizer.py @@ -25,7 +25,7 @@ def optimize_for_mobile( optimization method will run all the optimizer pass; otherwise, optimizer method will run the optimization pass that is not included inside optimization_blocklist. perserved_methods: A list of methods that needed to be preserved when freeze_module pass is invoked - backend: Device type to use for running the result model ('CPU'(default) or 'Vulkan'). + backend: Device type to use for running the result model ('CPU'(default), 'Vulkan' or 'Metal'). Returns: A new optimized torch script module """ @@ -39,12 +39,15 @@ def optimize_for_mobile( if preserved_methods is None: preserved_methods = [] - if backend == 'CPU': + backend = backend.lower() + if backend == 'cpu': optimized_cpp_module = torch._C._jit_pass_optimize_for_mobile(script_module._c, optimization_blocklist, preserved_methods) - elif backend == 'Vulkan': + elif backend == 'vulkan': optimized_cpp_module = torch._C._jit_pass_vulkan_optimize_for_mobile(script_module._c, preserved_methods) + elif backend == 'metal': + optimized_cpp_module = torch._C._jit_pass_metal_optimize_for_mobile(script_module._c, preserved_methods) else: - raise TypeError("Unknown backend, must be one of 'CPU', 'Vulkan'") + raise TypeError("Unknown backend, must be one of 'CPU', 'Vulkan' or 'Metal'") return torch.jit._recursive.wrap_cpp_module(optimized_cpp_module)