[Pytorch Mobile] optimize_for_mobile: Fuse Add Relu on any function (#54441)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54441

Similar to previous dropout one
ghstack-source-id: 124544176

Test Plan: Printed graphs before and after fusion. verified input outputs stayed the same {P299343882}

Reviewed By: kimishpatel

Differential Revision: D27014352

fbshipit-source-id: d0a9548f8743472bdd7e194efd8e8d5fe53b95b6
This commit is contained in:
Jacob Szwejbka
2021-03-23 12:05:48 -07:00
committed by Facebook GitHub Bot
parent acffa604cc
commit 583c4bf7d3
2 changed files with 10 additions and 5 deletions

View File

@ -189,7 +189,9 @@ class TestOptimizer(TestCase):
@torch.jit.export @torch.jit.export
def foo(self, x): def foo(self, x):
x = self.d(F.relu(self.l(x))) x = self.d(F.relu(self.l(x)))
return self.l2(x) x = self.l2(x)
x = x + torch.ones(1, 100)
return F.relu(x)
input_data = torch.ones(1, 10) input_data = torch.ones(1, 10)
m = torch.jit.script(OptimizeNoForwardTest()) m = torch.jit.script(OptimizeNoForwardTest())
m.eval() m.eval()
@ -200,7 +202,8 @@ class TestOptimizer(TestCase):
FileCheck().check_not("dropout.__") \ FileCheck().check_not("dropout.__") \
.run(optimized_scripted_model.foo.graph) .check_count("aten::_add_relu(", 1, exactly=True) \
.run(optimized_scripted_model.foo.graph)
torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3) torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3)

View File

@ -433,9 +433,11 @@ script::Module optimizeForMobile(
} }
} }
if (!optimization_blocklist.count(MobileOptimizerType::FUSE_ADD_RELU) && if (!optimization_blocklist.count(MobileOptimizerType::FUSE_ADD_RELU)) {
optimize_forward) { for (const std::string& method : methods_to_optimize) {
FuseAddRelu(cloned_module); auto graph = cloned_module.get_method(method).graph();
FuseAddRelu(graph);
}
} }
cloned_module.register_attribute("mobile_optimized", BoolType::get(), true); cloned_module.register_attribute("mobile_optimized", BoolType::get(), true);
return cloned_module; return cloned_module;