mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Vulkan] Add support for Optimization Blocklist to Vulkan Rewrite (#87431)
Optimization Blocklist will be used in a future diff (D40315730) to make the rewrite to transfer input/output backends optional Differential Revision: [D40315729](https://our.internmc.facebook.com/intern/diff/D40315729/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/87431 Approved by: https://github.com/mcr229, https://github.com/digantdesai
This commit is contained in:
committed by
PyTorch MergeBot
parent
f717986f93
commit
bc68625151
@ -16,13 +16,13 @@
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include "torch/script.h"
|
||||
#include "torch/csrc/jit/api/module.h"
|
||||
#include <torch/script.h>
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/passes/metal_rewrite.h>
|
||||
#include "torch/csrc/jit/passes/vulkan_rewrite.h"
|
||||
#include "torch/csrc/jit/passes/xnnpack_rewrite.h"
|
||||
#include "torch/csrc/jit/serialization/import.h"
|
||||
#include "torch/csrc/jit/serialization/export.h"
|
||||
#include <torch/csrc/jit/passes/vulkan_rewrite.h>
|
||||
#include <torch/csrc/jit/passes/xnnpack_rewrite.h>
|
||||
#include <torch/csrc/jit/serialization/import.h>
|
||||
#include <torch/csrc/jit/serialization/export.h>
|
||||
|
||||
C10_DEFINE_string(model, "", "The torch script model to optimize.");
|
||||
C10_DEFINE_string(
|
||||
@ -86,7 +86,8 @@ int main(int argc, char** argv) {
|
||||
if (FLAGS_backend == "" || FLAGS_backend == "cpu") {
|
||||
optimized_module = torch::jit::optimizeForMobile(module);
|
||||
} else if (FLAGS_backend == "vulkan") {
|
||||
optimized_module = torch::jit::vulkanOptimizeForMobile(module, preserved_methods);
|
||||
optimized_module = torch::jit::vulkanOptimizeForMobile(
|
||||
module, std::set<MobileOptimizerType>(), preserved_methods);
|
||||
} else if (FLAGS_backend == "metal"){
|
||||
optimized_module = torch::jit::metalOptimizeForMobile(module, preserved_methods);
|
||||
}else{
|
||||
|
@ -169,7 +169,7 @@ class Future(object):
|
||||
|
||||
def _jit_set_num_profiled_runs(num: _size) -> _size: ...
|
||||
|
||||
# Defined in torch/csrc/jit/passes/xnnpack_rewrite.h
|
||||
# Defined in torch/csrc/jit/passes/mobile_optimizer_type.h
|
||||
class MobileOptimizerType:
|
||||
...
|
||||
|
||||
@ -215,6 +215,7 @@ def _clone_module_with_class(module: 'torch.jit.ScriptModule',
|
||||
ignored_methods: List[AnyStr],
|
||||
ignored_attributes: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
|
||||
def _jit_pass_vulkan_optimize_for_mobile(module: 'torch.jit.ScriptModule',
|
||||
optimization_blocklist: Set[MobileOptimizerType],
|
||||
preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
|
||||
def _jit_pass_metal_optimize_for_mobile(module: 'torch.jit.ScriptModule',
|
||||
preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
|
||||
|
12
torch/csrc/jit/passes/mobile_optimizer_type.h
Normal file
12
torch/csrc/jit/passes/mobile_optimizer_type.h
Normal file
@ -0,0 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
enum class MobileOptimizerType : int8_t {
|
||||
CONV_BN_FUSION,
|
||||
INSERT_FOLD_PREPACK_OPS,
|
||||
REMOVE_DROPOUT,
|
||||
FUSE_ADD_RELU,
|
||||
HOIST_CONV_PACKED_PARAMS,
|
||||
CONV_1D_TO_2D,
|
||||
};
|
@ -269,6 +269,7 @@ void vulkanRunCanonicalOptimizations(script::Module& module) {
|
||||
|
||||
script::Module vulkanOptimizeForMobile(
|
||||
const script::Module& m,
|
||||
const std::set<MobileOptimizerType>& optimization_blocklist,
|
||||
const std::vector<std::string>& preserved_methods) {
|
||||
auto cloned_module = m.clone();
|
||||
cloned_module.eval();
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/passes/mobile_optimizer_type.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
@ -11,6 +12,7 @@ TORCH_API void vulkanFusePrePackedConvWithClamp(script::Module& module);
|
||||
TORCH_API void vulkanFoldPrePackingOps(script::Module& module);
|
||||
TORCH_API script::Module vulkanOptimizeForMobile(
|
||||
const script::Module& module,
|
||||
const std::set<MobileOptimizerType>& optimization_blocklist,
|
||||
const std::vector<std::string>& preserved_methods);
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
@ -12,6 +12,7 @@
|
||||
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
|
||||
#include <torch/csrc/jit/passes/hoist_conv_packed_params.h>
|
||||
#include <torch/csrc/jit/passes/inliner.h>
|
||||
#include <torch/csrc/jit/passes/mobile_optimizer_type.h>
|
||||
#include <torch/csrc/jit/passes/prepack_folding.h>
|
||||
#include <torch/csrc/jit/passes/remove_dropout.h>
|
||||
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
||||
|
@ -2,19 +2,11 @@
|
||||
|
||||
#include <torch/csrc/jit/api/module.h>
|
||||
#include <torch/csrc/jit/ir/ir.h>
|
||||
#include <torch/csrc/jit/passes/mobile_optimizer_type.h>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
enum class MobileOptimizerType : int8_t {
|
||||
CONV_BN_FUSION,
|
||||
INSERT_FOLD_PREPACK_OPS,
|
||||
REMOVE_DROPOUT,
|
||||
FUSE_ADD_RELU,
|
||||
HOIST_CONV_PACKED_PARAMS,
|
||||
CONV_1D_TO_2D,
|
||||
};
|
||||
|
||||
TORCH_API void transformConv1dToConv2d(std::shared_ptr<Graph>& graph);
|
||||
TORCH_API void transformConv1dToConv2d(script::Module& module);
|
||||
TORCH_API void insertPrePackedOps(std::shared_ptr<Graph>& graph);
|
||||
|
@ -52,6 +52,7 @@
|
||||
#include <torch/csrc/jit/passes/lower_graph.h>
|
||||
#include <torch/csrc/jit/passes/lower_tuples.h>
|
||||
#include <torch/csrc/jit/passes/metal_rewrite.h>
|
||||
#include <torch/csrc/jit/passes/mobile_optimizer_type.h>
|
||||
#include <torch/csrc/jit/passes/normalize_ops.h>
|
||||
#include <torch/csrc/jit/passes/peephole.h>
|
||||
#include <torch/csrc/jit/passes/peephole_list_idioms.h>
|
||||
@ -1081,8 +1082,10 @@ void initJITBindings(PyObject* module) {
|
||||
.def(
|
||||
"_jit_pass_vulkan_optimize_for_mobile",
|
||||
[](script::Module& module,
|
||||
std::set<MobileOptimizerType>& optimization_blocklist,
|
||||
std::vector<std::string>& preserved_methods) {
|
||||
return vulkanOptimizeForMobile(module, preserved_methods);
|
||||
return vulkanOptimizeForMobile(
|
||||
module, optimization_blocklist, preserved_methods);
|
||||
})
|
||||
.def(
|
||||
"_jit_pass_metal_insert_prepacked_ops",
|
||||
|
@ -64,7 +64,10 @@ def optimize_for_mobile(
|
||||
optimization_blocklist,
|
||||
preserved_methods_str)
|
||||
elif backend == 'vulkan':
|
||||
optimized_cpp_module = torch._C._jit_pass_vulkan_optimize_for_mobile(script_module._c, preserved_methods_str)
|
||||
optimized_cpp_module = torch._C._jit_pass_vulkan_optimize_for_mobile(
|
||||
script_module._c,
|
||||
optimization_blocklist,
|
||||
preserved_methods_str)
|
||||
elif backend == 'metal':
|
||||
optimized_cpp_module = torch._C._jit_pass_metal_optimize_for_mobile(script_module._c, preserved_methods_str)
|
||||
else:
|
||||
|
Reference in New Issue
Block a user