[Vulkan] Add support for Optimization Blocklist to Vulkan Rewrite (#87431)

Optimization Blocklist will be used in a future diff (D40315730) to make the rewrite to transfer input/output backends optional

Differential Revision: [D40315729](https://our.internmc.facebook.com/intern/diff/D40315729/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87431
Approved by: https://github.com/mcr229, https://github.com/digantdesai
This commit is contained in:
Salil Desai
2022-10-30 20:30:55 -07:00
committed by PyTorch MergeBot
parent f717986f93
commit bc68625151
9 changed files with 35 additions and 19 deletions

View File

@ -16,13 +16,13 @@
#include <string>
#include <sstream>
#include "torch/script.h"
#include "torch/csrc/jit/api/module.h"
#include <torch/script.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/passes/metal_rewrite.h>
#include "torch/csrc/jit/passes/vulkan_rewrite.h"
#include "torch/csrc/jit/passes/xnnpack_rewrite.h"
#include "torch/csrc/jit/serialization/import.h"
#include "torch/csrc/jit/serialization/export.h"
#include <torch/csrc/jit/passes/vulkan_rewrite.h>
#include <torch/csrc/jit/passes/xnnpack_rewrite.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/csrc/jit/serialization/export.h>
C10_DEFINE_string(model, "", "The torch script model to optimize.");
C10_DEFINE_string(
@ -86,7 +86,8 @@ int main(int argc, char** argv) {
if (FLAGS_backend == "" || FLAGS_backend == "cpu") {
optimized_module = torch::jit::optimizeForMobile(module);
} else if (FLAGS_backend == "vulkan") {
optimized_module = torch::jit::vulkanOptimizeForMobile(module, preserved_methods);
optimized_module = torch::jit::vulkanOptimizeForMobile(
module, std::set<MobileOptimizerType>(), preserved_methods);
} else if (FLAGS_backend == "metal"){
optimized_module = torch::jit::metalOptimizeForMobile(module, preserved_methods);
}else{

View File

@ -169,7 +169,7 @@ class Future(object):
def _jit_set_num_profiled_runs(num: _size) -> _size: ...
# Defined in torch/csrc/jit/passes/xnnpack_rewrite.h
# Defined in torch/csrc/jit/passes/mobile_optimizer_type.h
class MobileOptimizerType:
...
@ -215,6 +215,7 @@ def _clone_module_with_class(module: 'torch.jit.ScriptModule',
ignored_methods: List[AnyStr],
ignored_attributes: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
def _jit_pass_vulkan_optimize_for_mobile(module: 'torch.jit.ScriptModule',
optimization_blocklist: Set[MobileOptimizerType],
preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
def _jit_pass_metal_optimize_for_mobile(module: 'torch.jit.ScriptModule',
preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ...

View File

@ -0,0 +1,12 @@
#pragma once
#include <cstdint>
enum class MobileOptimizerType : int8_t {
CONV_BN_FUSION,
INSERT_FOLD_PREPACK_OPS,
REMOVE_DROPOUT,
FUSE_ADD_RELU,
HOIST_CONV_PACKED_PARAMS,
CONV_1D_TO_2D,
};

View File

@ -269,6 +269,7 @@ void vulkanRunCanonicalOptimizations(script::Module& module) {
script::Module vulkanOptimizeForMobile(
const script::Module& m,
const std::set<MobileOptimizerType>& optimization_blocklist,
const std::vector<std::string>& preserved_methods) {
auto cloned_module = m.clone();
cloned_module.eval();

View File

@ -2,6 +2,7 @@
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/passes/mobile_optimizer_type.h>
namespace torch {
namespace jit {
@ -11,6 +12,7 @@ TORCH_API void vulkanFusePrePackedConvWithClamp(script::Module& module);
TORCH_API void vulkanFoldPrePackingOps(script::Module& module);
TORCH_API script::Module vulkanOptimizeForMobile(
const script::Module& module,
const std::set<MobileOptimizerType>& optimization_blocklist,
const std::vector<std::string>& preserved_methods);
} // namespace jit
} // namespace torch

View File

@ -12,6 +12,7 @@
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
#include <torch/csrc/jit/passes/hoist_conv_packed_params.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/passes/mobile_optimizer_type.h>
#include <torch/csrc/jit/passes/prepack_folding.h>
#include <torch/csrc/jit/passes/remove_dropout.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>

View File

@ -2,19 +2,11 @@
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/passes/mobile_optimizer_type.h>
namespace torch {
namespace jit {
enum class MobileOptimizerType : int8_t {
CONV_BN_FUSION,
INSERT_FOLD_PREPACK_OPS,
REMOVE_DROPOUT,
FUSE_ADD_RELU,
HOIST_CONV_PACKED_PARAMS,
CONV_1D_TO_2D,
};
TORCH_API void transformConv1dToConv2d(std::shared_ptr<Graph>& graph);
TORCH_API void transformConv1dToConv2d(script::Module& module);
TORCH_API void insertPrePackedOps(std::shared_ptr<Graph>& graph);

View File

@ -52,6 +52,7 @@
#include <torch/csrc/jit/passes/lower_graph.h>
#include <torch/csrc/jit/passes/lower_tuples.h>
#include <torch/csrc/jit/passes/metal_rewrite.h>
#include <torch/csrc/jit/passes/mobile_optimizer_type.h>
#include <torch/csrc/jit/passes/normalize_ops.h>
#include <torch/csrc/jit/passes/peephole.h>
#include <torch/csrc/jit/passes/peephole_list_idioms.h>
@ -1081,8 +1082,10 @@ void initJITBindings(PyObject* module) {
.def(
"_jit_pass_vulkan_optimize_for_mobile",
[](script::Module& module,
std::set<MobileOptimizerType>& optimization_blocklist,
std::vector<std::string>& preserved_methods) {
return vulkanOptimizeForMobile(module, preserved_methods);
return vulkanOptimizeForMobile(
module, optimization_blocklist, preserved_methods);
})
.def(
"_jit_pass_metal_insert_prepacked_ops",

View File

@ -64,7 +64,10 @@ def optimize_for_mobile(
optimization_blocklist,
preserved_methods_str)
elif backend == 'vulkan':
optimized_cpp_module = torch._C._jit_pass_vulkan_optimize_for_mobile(script_module._c, preserved_methods_str)
optimized_cpp_module = torch._C._jit_pass_vulkan_optimize_for_mobile(
script_module._c,
optimization_blocklist,
preserved_methods_str)
elif backend == 'metal':
optimized_cpp_module = torch._C._jit_pass_metal_optimize_for_mobile(script_module._c, preserved_methods_str)
else: