From bc6862515164a31d3a62e46a49977d54a618323c Mon Sep 17 00:00:00 2001 From: Salil Desai Date: Sun, 30 Oct 2022 20:30:55 -0700 Subject: [PATCH] [Vulkan] Add support for Optimization Blocklist to Vulkan Rewrite (#87431) Optimization Blocklist will be used in a future diff (D40315730) to make the rewrite to transfer input/output backends optional Differential Revision: [D40315729](https://our.internmc.facebook.com/intern/diff/D40315729/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/87431 Approved by: https://github.com/mcr229, https://github.com/digantdesai --- binaries/optimize_for_mobile.cc | 15 ++++++++------- torch/_C/__init__.pyi.in | 3 ++- torch/csrc/jit/passes/mobile_optimizer_type.h | 12 ++++++++++++ torch/csrc/jit/passes/vulkan_rewrite.cpp | 1 + torch/csrc/jit/passes/vulkan_rewrite.h | 2 ++ torch/csrc/jit/passes/xnnpack_rewrite.cpp | 1 + torch/csrc/jit/passes/xnnpack_rewrite.h | 10 +--------- torch/csrc/jit/python/init.cpp | 5 ++++- torch/utils/mobile_optimizer.py | 5 ++++- 9 files changed, 35 insertions(+), 19 deletions(-) create mode 100644 torch/csrc/jit/passes/mobile_optimizer_type.h diff --git a/binaries/optimize_for_mobile.cc b/binaries/optimize_for_mobile.cc index 991bca7e5587..005b19ce888a 100644 --- a/binaries/optimize_for_mobile.cc +++ b/binaries/optimize_for_mobile.cc @@ -16,13 +16,13 @@ #include #include -#include "torch/script.h" -#include "torch/csrc/jit/api/module.h" +#include +#include #include -#include "torch/csrc/jit/passes/vulkan_rewrite.h" -#include "torch/csrc/jit/passes/xnnpack_rewrite.h" -#include "torch/csrc/jit/serialization/import.h" -#include "torch/csrc/jit/serialization/export.h" +#include +#include +#include +#include C10_DEFINE_string(model, "", "The torch script model to optimize."); C10_DEFINE_string( @@ -86,7 +86,8 @@ int main(int argc, char** argv) { if (FLAGS_backend == "" || FLAGS_backend == "cpu") { optimized_module = torch::jit::optimizeForMobile(module); } else if (FLAGS_backend == "vulkan") { - optimized_module = torch::jit::vulkanOptimizeForMobile(module, preserved_methods); + optimized_module = torch::jit::vulkanOptimizeForMobile( + module, std::set(), preserved_methods); } else if (FLAGS_backend == "metal"){ optimized_module = torch::jit::metalOptimizeForMobile(module, preserved_methods); }else{ diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index af6734b059f4..8b936be23122 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -169,7 +169,7 @@ class Future(object): def _jit_set_num_profiled_runs(num: _size) -> _size: ... -# Defined in torch/csrc/jit/passes/xnnpack_rewrite.h +# Defined in torch/csrc/jit/passes/mobile_optimizer_type.h class MobileOptimizerType: ... @@ -215,6 +215,7 @@ def _clone_module_with_class(module: 'torch.jit.ScriptModule', ignored_methods: List[AnyStr], ignored_attributes: List[AnyStr]) -> 'torch.jit.ScriptModule': ... def _jit_pass_vulkan_optimize_for_mobile(module: 'torch.jit.ScriptModule', + optimization_blocklist: Set[MobileOptimizerType], preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ... def _jit_pass_metal_optimize_for_mobile(module: 'torch.jit.ScriptModule', preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ... diff --git a/torch/csrc/jit/passes/mobile_optimizer_type.h b/torch/csrc/jit/passes/mobile_optimizer_type.h new file mode 100644 index 000000000000..fe3fffe16c22 --- /dev/null +++ b/torch/csrc/jit/passes/mobile_optimizer_type.h @@ -0,0 +1,12 @@ +#pragma once + +#include + +enum class MobileOptimizerType : int8_t { + CONV_BN_FUSION, + INSERT_FOLD_PREPACK_OPS, + REMOVE_DROPOUT, + FUSE_ADD_RELU, + HOIST_CONV_PACKED_PARAMS, + CONV_1D_TO_2D, +}; diff --git a/torch/csrc/jit/passes/vulkan_rewrite.cpp b/torch/csrc/jit/passes/vulkan_rewrite.cpp index 9a0f45ff8402..7618aa50d686 100644 --- a/torch/csrc/jit/passes/vulkan_rewrite.cpp +++ b/torch/csrc/jit/passes/vulkan_rewrite.cpp @@ -269,6 +269,7 @@ void vulkanRunCanonicalOptimizations(script::Module& module) { script::Module vulkanOptimizeForMobile( const script::Module& m, + const std::set& optimization_blocklist, const std::vector& preserved_methods) { auto cloned_module = m.clone(); cloned_module.eval(); diff --git a/torch/csrc/jit/passes/vulkan_rewrite.h b/torch/csrc/jit/passes/vulkan_rewrite.h index 8e67dce70f54..395d885e8e2c 100644 --- a/torch/csrc/jit/passes/vulkan_rewrite.h +++ b/torch/csrc/jit/passes/vulkan_rewrite.h @@ -2,6 +2,7 @@ #include #include +#include namespace torch { namespace jit { @@ -11,6 +12,7 @@ TORCH_API void vulkanFusePrePackedConvWithClamp(script::Module& module); TORCH_API void vulkanFoldPrePackingOps(script::Module& module); TORCH_API script::Module vulkanOptimizeForMobile( const script::Module& module, + const std::set& optimization_blocklist, const std::vector& preserved_methods); } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/xnnpack_rewrite.cpp b/torch/csrc/jit/passes/xnnpack_rewrite.cpp index 2476d1be4df6..0e2163f7a19f 100644 --- a/torch/csrc/jit/passes/xnnpack_rewrite.cpp +++ b/torch/csrc/jit/passes/xnnpack_rewrite.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include diff --git a/torch/csrc/jit/passes/xnnpack_rewrite.h b/torch/csrc/jit/passes/xnnpack_rewrite.h index 498dcd006fe3..d1a64c52c923 100644 --- a/torch/csrc/jit/passes/xnnpack_rewrite.h +++ b/torch/csrc/jit/passes/xnnpack_rewrite.h @@ -2,19 +2,11 @@ #include #include +#include namespace torch { namespace jit { -enum class MobileOptimizerType : int8_t { - CONV_BN_FUSION, - INSERT_FOLD_PREPACK_OPS, - REMOVE_DROPOUT, - FUSE_ADD_RELU, - HOIST_CONV_PACKED_PARAMS, - CONV_1D_TO_2D, -}; - TORCH_API void transformConv1dToConv2d(std::shared_ptr& graph); TORCH_API void transformConv1dToConv2d(script::Module& module); TORCH_API void insertPrePackedOps(std::shared_ptr& graph); diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 91eecfa4596e..d576c1ff3d74 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -52,6 +52,7 @@ #include #include #include +#include #include #include #include @@ -1081,8 +1082,10 @@ void initJITBindings(PyObject* module) { .def( "_jit_pass_vulkan_optimize_for_mobile", [](script::Module& module, + std::set& optimization_blocklist, std::vector& preserved_methods) { - return vulkanOptimizeForMobile(module, preserved_methods); + return vulkanOptimizeForMobile( + module, optimization_blocklist, preserved_methods); }) .def( "_jit_pass_metal_insert_prepacked_ops", diff --git a/torch/utils/mobile_optimizer.py b/torch/utils/mobile_optimizer.py index bda5defab29e..2b59ac41809f 100644 --- a/torch/utils/mobile_optimizer.py +++ b/torch/utils/mobile_optimizer.py @@ -64,7 +64,10 @@ def optimize_for_mobile( optimization_blocklist, preserved_methods_str) elif backend == 'vulkan': - optimized_cpp_module = torch._C._jit_pass_vulkan_optimize_for_mobile(script_module._c, preserved_methods_str) + optimized_cpp_module = torch._C._jit_pass_vulkan_optimize_for_mobile( + script_module._c, + optimization_blocklist, + preserved_methods_str) elif backend == 'metal': optimized_cpp_module = torch._C._jit_pass_metal_optimize_for_mobile(script_module._c, preserved_methods_str) else: