From 583c4bf7d36eddacd0c7b20ce06de333cac1951d Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Tue, 23 Mar 2021 12:05:48 -0700 Subject: [PATCH] [Pytorch Mobile] optimize_for_mobile: Fuse Add Relu on any function (#54441) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54441 Similar to previous dropout one ghstack-source-id: 124544176 Test Plan: Printed graphs before and after fusion. verified input outputs stayed the same {P299343882} Reviewed By: kimishpatel Differential Revision: D27014352 fbshipit-source-id: d0a9548f8743472bdd7e194efd8e8d5fe53b95b6 --- test/test_mobile_optimizer.py | 7 +++++-- torch/csrc/jit/passes/xnnpack_rewrite.cpp | 8 +++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/test/test_mobile_optimizer.py b/test/test_mobile_optimizer.py index e9ba061db44f..63fb0f7623a8 100644 --- a/test/test_mobile_optimizer.py +++ b/test/test_mobile_optimizer.py @@ -189,7 +189,9 @@ class TestOptimizer(TestCase): @torch.jit.export def foo(self, x): x = self.d(F.relu(self.l(x))) - return self.l2(x) + x = self.l2(x) + x = x + torch.ones(1, 100) + return F.relu(x) input_data = torch.ones(1, 10) m = torch.jit.script(OptimizeNoForwardTest()) m.eval() @@ -200,7 +202,8 @@ class TestOptimizer(TestCase): FileCheck().check_not("dropout.__") \ - .run(optimized_scripted_model.foo.graph) + .check_count("aten::_add_relu(", 1, exactly=True) \ + .run(optimized_scripted_model.foo.graph) torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3) diff --git a/torch/csrc/jit/passes/xnnpack_rewrite.cpp b/torch/csrc/jit/passes/xnnpack_rewrite.cpp index d8b1f71a95f2..d2f7348dee6f 100644 --- a/torch/csrc/jit/passes/xnnpack_rewrite.cpp +++ b/torch/csrc/jit/passes/xnnpack_rewrite.cpp @@ -433,9 +433,11 @@ script::Module optimizeForMobile( } } - if (!optimization_blocklist.count(MobileOptimizerType::FUSE_ADD_RELU) && - optimize_forward) { - FuseAddRelu(cloned_module); + if (!optimization_blocklist.count(MobileOptimizerType::FUSE_ADD_RELU)) { + for (const std::string& method : methods_to_optimize) { + auto graph = cloned_module.get_method(method).graph(); + FuseAddRelu(graph); + } } cloned_module.register_attribute("mobile_optimized", BoolType::get(), true); return cloned_module;