diff --git a/test/test_mobile_optimizer.py b/test/test_mobile_optimizer.py index e9ba061db44f..63fb0f7623a8 100644 --- a/test/test_mobile_optimizer.py +++ b/test/test_mobile_optimizer.py @@ -189,7 +189,9 @@ class TestOptimizer(TestCase): @torch.jit.export def foo(self, x): x = self.d(F.relu(self.l(x))) - return self.l2(x) + x = self.l2(x) + x = x + torch.ones(1, 100) + return F.relu(x) input_data = torch.ones(1, 10) m = torch.jit.script(OptimizeNoForwardTest()) m.eval() @@ -200,7 +202,8 @@ class TestOptimizer(TestCase): FileCheck().check_not("dropout.__") \ - .run(optimized_scripted_model.foo.graph) + .check_count("aten::_add_relu(", 1, exactly=True) \ + .run(optimized_scripted_model.foo.graph) torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3) diff --git a/torch/csrc/jit/passes/xnnpack_rewrite.cpp b/torch/csrc/jit/passes/xnnpack_rewrite.cpp index d8b1f71a95f2..d2f7348dee6f 100644 --- a/torch/csrc/jit/passes/xnnpack_rewrite.cpp +++ b/torch/csrc/jit/passes/xnnpack_rewrite.cpp @@ -433,9 +433,11 @@ script::Module optimizeForMobile( } } - if (!optimization_blocklist.count(MobileOptimizerType::FUSE_ADD_RELU) && - optimize_forward) { - FuseAddRelu(cloned_module); + if (!optimization_blocklist.count(MobileOptimizerType::FUSE_ADD_RELU)) { + for (const std::string& method : methods_to_optimize) { + auto graph = cloned_module.get_method(method).graph(); + FuseAddRelu(graph); + } } cloned_module.register_attribute("mobile_optimized", BoolType::get(), true); return cloned_module;