[Pytorch Mobile] optimize_for_mobile: Fuse Add Relu on any function (#54441)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54441

Similar to previous dropout one
ghstack-source-id: 124544176

Test Plan: Printed graphs before and after fusion. verified input outputs stayed the same {P299343882}

Reviewed By: kimishpatel

Differential Revision: D27014352

fbshipit-source-id: d0a9548f8743472bdd7e194efd8e8d5fe53b95b6
This commit is contained in:
Jacob Szwejbka
2021-03-23 12:05:48 -07:00
committed by Facebook GitHub Bot
parent acffa604cc
commit 583c4bf7d3
2 changed files with 10 additions and 5 deletions

View File

@ -189,7 +189,9 @@ class TestOptimizer(TestCase):
@torch.jit.export
def foo(self, x):
x = self.d(F.relu(self.l(x)))
return self.l2(x)
x = self.l2(x)
x = x + torch.ones(1, 100)
return F.relu(x)
input_data = torch.ones(1, 10)
m = torch.jit.script(OptimizeNoForwardTest())
m.eval()
@ -200,6 +202,7 @@ class TestOptimizer(TestCase):
FileCheck().check_not("dropout.__") \
.check_count("aten::_add_relu(", 1, exactly=True) \
.run(optimized_scripted_model.foo.graph)
torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3)

View File

@ -433,9 +433,11 @@ script::Module optimizeForMobile(
}
}
if (!optimization_blocklist.count(MobileOptimizerType::FUSE_ADD_RELU) &&
optimize_forward) {
FuseAddRelu(cloned_module);
if (!optimization_blocklist.count(MobileOptimizerType::FUSE_ADD_RELU)) {
for (const std::string& method : methods_to_optimize) {
auto graph = cloned_module.get_method(method).graph();
FuseAddRelu(graph);
}
}
cloned_module.register_attribute("mobile_optimized", BoolType::get(), true);
return cloned_module;