mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Pytorch Mobile] optimize_for_mobile: Fuse Add Relu on any function (#54441)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54441 Similar to previous dropout one ghstack-source-id: 124544176 Test Plan: Printed graphs before and after fusion. verified input outputs stayed the same {P299343882} Reviewed By: kimishpatel Differential Revision: D27014352 fbshipit-source-id: d0a9548f8743472bdd7e194efd8e8d5fe53b95b6
This commit is contained in:
committed by
Facebook GitHub Bot
parent
acffa604cc
commit
583c4bf7d3
@ -189,7 +189,9 @@ class TestOptimizer(TestCase):
|
||||
@torch.jit.export
|
||||
def foo(self, x):
|
||||
x = self.d(F.relu(self.l(x)))
|
||||
return self.l2(x)
|
||||
x = self.l2(x)
|
||||
x = x + torch.ones(1, 100)
|
||||
return F.relu(x)
|
||||
input_data = torch.ones(1, 10)
|
||||
m = torch.jit.script(OptimizeNoForwardTest())
|
||||
m.eval()
|
||||
@ -200,7 +202,8 @@ class TestOptimizer(TestCase):
|
||||
|
||||
|
||||
FileCheck().check_not("dropout.__") \
|
||||
.run(optimized_scripted_model.foo.graph)
|
||||
.check_count("aten::_add_relu(", 1, exactly=True) \
|
||||
.run(optimized_scripted_model.foo.graph)
|
||||
torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3)
|
||||
|
||||
|
||||
|
@ -433,9 +433,11 @@ script::Module optimizeForMobile(
|
||||
}
|
||||
}
|
||||
|
||||
if (!optimization_blocklist.count(MobileOptimizerType::FUSE_ADD_RELU) &&
|
||||
optimize_forward) {
|
||||
FuseAddRelu(cloned_module);
|
||||
if (!optimization_blocklist.count(MobileOptimizerType::FUSE_ADD_RELU)) {
|
||||
for (const std::string& method : methods_to_optimize) {
|
||||
auto graph = cloned_module.get_method(method).graph();
|
||||
FuseAddRelu(graph);
|
||||
}
|
||||
}
|
||||
cloned_module.register_attribute("mobile_optimized", BoolType::get(), true);
|
||||
return cloned_module;
|
||||
|
Reference in New Issue
Block a user