[Vulkan] Add support for Optimization Blocklist to Vulkan Rewrite (#87431)

Optimization Blocklist will be used in a future diff (D40315730) to make the rewrite to transfer input/output backends optional

Differential Revision: [D40315729](https://our.internmc.facebook.com/intern/diff/D40315729/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87431
Approved by: https://github.com/mcr229, https://github.com/digantdesai
This commit is contained in:
Salil Desai
2022-10-30 20:30:55 -07:00
committed by PyTorch MergeBot
parent f717986f93
commit bc68625151
9 changed files with 35 additions and 19 deletions

View File

@ -16,13 +16,13 @@
#include <string>
#include <sstream>
#include "torch/script.h"
#include "torch/csrc/jit/api/module.h"
#include <torch/script.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/passes/metal_rewrite.h>
#include "torch/csrc/jit/passes/vulkan_rewrite.h"
#include "torch/csrc/jit/passes/xnnpack_rewrite.h"
#include "torch/csrc/jit/serialization/import.h"
#include "torch/csrc/jit/serialization/export.h"
#include <torch/csrc/jit/passes/vulkan_rewrite.h>
#include <torch/csrc/jit/passes/xnnpack_rewrite.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/csrc/jit/serialization/export.h>
C10_DEFINE_string(model, "", "The torch script model to optimize.");
C10_DEFINE_string(
@ -86,7 +86,8 @@ int main(int argc, char** argv) {
if (FLAGS_backend == "" || FLAGS_backend == "cpu") {
optimized_module = torch::jit::optimizeForMobile(module);
} else if (FLAGS_backend == "vulkan") {
optimized_module = torch::jit::vulkanOptimizeForMobile(module, preserved_methods);
optimized_module = torch::jit::vulkanOptimizeForMobile(
module, std::set<MobileOptimizerType>(), preserved_methods);
} else if (FLAGS_backend == "metal"){
optimized_module = torch::jit::metalOptimizeForMobile(module, preserved_methods);
}else{