mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Mobile GPU] Ban mutations in JIT passes (#56070)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56070 **Summary** Currently, we're returning copies instead of alias on mobile GPU (Metal/Vulkan). As suggested by ailzhang , we could use the JIT pass - `RemoveTensorMutation` to ban mutations ahead of time. I've tested two scenarios as shown below. They both work fine on mobile. - view ``` class Model (torch.nn.Module): def forward(self, x): y = x.view(-1) z = torch.tensor(2.0).float() y.add_(z) return x m = Model() x = torch.rand(2, 3) y = m(x) ``` - transpose ``` class Model (torch.nn.Module): def forward(self, x): y = x.transpose(1, 2) z = torch.tensor(2.0).float() x.add_(z) return y m = Model() x = torch.rand(1, 2, 3) y = m(x) ``` As we're adding more ops, we should add more tests to cover all the alias ops - https://github.com/pytorch/pytorch/blob/master/tools/autograd/gen_inplace_or_view_type.py#L31-L80 **Next step** Synced offline with eellison. Since mutation removal is also being used in ONNX, Static runtime, some jit optimizations, Torch -> TVM, etc, instead of inventing something new, we would continue to make it better in cases where it fails. Although this JIT pass could work for most of the mobile models, there are cases that it can't cover. What we're going to do next is to implement stub ops for GPU models to let them run on server side, such that users can compare results to see if there is any discrepancy. ghstack-source-id: 126802123 Test Plan: - Sandcastle - CircleCI Reviewed By: raziel Differential Revision: D27692683 fbshipit-source-id: 9d1be8a6c0a276032b1907807a54fbe2afd882f9
This commit is contained in:
committed by
Facebook GitHub Bot
parent
98162cb0bb
commit
5748cc0d11
@ -10,6 +10,7 @@
|
||||
#include <torch/csrc/jit/passes/metal_rewrite.h>
|
||||
#include <torch/csrc/jit/passes/prepack_folding.h>
|
||||
#include <torch/csrc/jit/passes/remove_dropout.h>
|
||||
#include <torch/csrc/jit/passes/remove_mutation.h>
|
||||
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
||||
#include <torch/csrc/jit/runtime/graph_executor_impl.h>
|
||||
|
||||
@ -175,7 +176,12 @@ void metalInsertCopyOps(script::Module& module) {
|
||||
rewriter.runOnGraph(graph);
|
||||
}
|
||||
|
||||
void runCanonicalOptimizations(script::Module& module) {
|
||||
void metalRemoveMutation(script::Module& module) {
|
||||
auto graph = module.get_method("forward").graph();
|
||||
RemoveTensorMutation(graph);
|
||||
}
|
||||
|
||||
void metalRunCanonicalOptimizations(script::Module& module) {
|
||||
auto graph = module.get_method("forward").graph();
|
||||
runOptimization(graph, false /* no loop unrolling */);
|
||||
}
|
||||
@ -192,8 +198,9 @@ script::Module metalOptimizeForMobile(
|
||||
metalFoldPrePackingOps(cloned_module);
|
||||
metalInsertCopyOps(cloned_module);
|
||||
removeDropout(cloned_module);
|
||||
metalRemoveMutation(cloned_module);
|
||||
// remove duplicated constants
|
||||
runCanonicalOptimizations(cloned_module);
|
||||
metalRunCanonicalOptimizations(cloned_module);
|
||||
cloned_module.register_attribute(
|
||||
"optimized_for_metal", BoolType::get(), true);
|
||||
return cloned_module;
|
||||
|
@ -12,6 +12,7 @@
|
||||
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
|
||||
#include <torch/csrc/jit/passes/prepack_folding.h>
|
||||
#include <torch/csrc/jit/passes/remove_dropout.h>
|
||||
#include <torch/csrc/jit/passes/remove_mutation.h>
|
||||
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
|
||||
#include <torch/csrc/jit/passes/vulkan_rewrite.h>
|
||||
|
||||
@ -207,6 +208,11 @@ void vulkanFoldPrePackingOps(script::Module& m) {
|
||||
PrePackingOpsFolder(m, filter_fn, "prepack_folding");
|
||||
}
|
||||
|
||||
void vulkanRemoveMutation(script::Module& module) {
|
||||
auto graph = module.get_method("forward").graph();
|
||||
RemoveTensorMutation(graph);
|
||||
}
|
||||
|
||||
script::Module vulkanOptimizeForMobile(
|
||||
const script::Module& m,
|
||||
const std::vector<std::string>& preserved_methods) {
|
||||
@ -218,6 +224,7 @@ script::Module vulkanOptimizeForMobile(
|
||||
vulkanFusePrePackedConvWithClamp(cloned_module);
|
||||
vulkanFoldPrePackingOps(cloned_module);
|
||||
removeDropout(cloned_module);
|
||||
vulkanRemoveMutation(cloned_module);
|
||||
return cloned_module;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user