[Mobile GPU] Ban mutations in JIT passes (#56070)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56070

**Summary**

Currently, we're returning copies instead of alias on mobile GPU (Metal/Vulkan). As suggested by ailzhang , we could use the JIT pass - `RemoveTensorMutation` to ban mutations ahead of time. I've tested two scenarios as shown below. They both work fine on mobile.

- view

```
class Model (torch.nn.Module):
    def forward(self, x):
        y = x.view(-1)
        z = torch.tensor(2.0).float()
        y.add_(z)
        return x

m = Model()
x = torch.rand(2, 3)
y = m(x)
```
- transpose

```
class Model (torch.nn.Module):
    def forward(self, x):
        y = x.transpose(1, 2)
        z = torch.tensor(2.0).float()
        x.add_(z)
        return y

m = Model()
x = torch.rand(1, 2, 3)
y = m(x)
```

As we're adding more ops, we should add more tests to cover all the alias ops - https://github.com/pytorch/pytorch/blob/master/tools/autograd/gen_inplace_or_view_type.py#L31-L80

**Next step**

Synced offline with eellison. Since mutation removal is also being used in ONNX, Static runtime, some jit optimizations, Torch -> TVM, etc, instead of inventing something new, we would continue to make it better in cases where it fails.

Although this JIT pass could work for most of the mobile models, there are cases that it can't cover. What we're going to do next is to implement stub ops for GPU models to let them run on server side, such that users can compare results to see if there is any discrepancy.

ghstack-source-id: 126802123

Test Plan:
- Sandcastle
- CircleCI

Reviewed By: raziel

Differential Revision: D27692683

fbshipit-source-id: 9d1be8a6c0a276032b1907807a54fbe2afd882f9
This commit is contained in:
Tao Xu
2021-04-19 10:42:14 -07:00
committed by Facebook GitHub Bot
parent 98162cb0bb
commit 5748cc0d11
2 changed files with 16 additions and 2 deletions

View File

@ -10,6 +10,7 @@
#include <torch/csrc/jit/passes/metal_rewrite.h>
#include <torch/csrc/jit/passes/prepack_folding.h>
#include <torch/csrc/jit/passes/remove_dropout.h>
#include <torch/csrc/jit/passes/remove_mutation.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <torch/csrc/jit/runtime/graph_executor_impl.h>
@ -175,7 +176,12 @@ void metalInsertCopyOps(script::Module& module) {
rewriter.runOnGraph(graph);
}
void runCanonicalOptimizations(script::Module& module) {
void metalRemoveMutation(script::Module& module) {
auto graph = module.get_method("forward").graph();
RemoveTensorMutation(graph);
}
void metalRunCanonicalOptimizations(script::Module& module) {
auto graph = module.get_method("forward").graph();
runOptimization(graph, false /* no loop unrolling */);
}
@ -192,8 +198,9 @@ script::Module metalOptimizeForMobile(
metalFoldPrePackingOps(cloned_module);
metalInsertCopyOps(cloned_module);
removeDropout(cloned_module);
metalRemoveMutation(cloned_module);
// remove duplicated constants
runCanonicalOptimizations(cloned_module);
metalRunCanonicalOptimizations(cloned_module);
cloned_module.register_attribute(
"optimized_for_metal", BoolType::get(), true);
return cloned_module;

View File

@ -12,6 +12,7 @@
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
#include <torch/csrc/jit/passes/prepack_folding.h>
#include <torch/csrc/jit/passes/remove_dropout.h>
#include <torch/csrc/jit/passes/remove_mutation.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <torch/csrc/jit/passes/vulkan_rewrite.h>
@ -207,6 +208,11 @@ void vulkanFoldPrePackingOps(script::Module& m) {
PrePackingOpsFolder(m, filter_fn, "prepack_folding");
}
void vulkanRemoveMutation(script::Module& module) {
auto graph = module.get_method("forward").graph();
RemoveTensorMutation(graph);
}
script::Module vulkanOptimizeForMobile(
const script::Module& m,
const std::vector<std::string>& preserved_methods) {
@ -218,6 +224,7 @@ script::Module vulkanOptimizeForMobile(
vulkanFusePrePackedConvWithClamp(cloned_module);
vulkanFoldPrePackingOps(cloned_module);
removeDropout(cloned_module);
vulkanRemoveMutation(cloned_module);
return cloned_module;
}