mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Pytorch Mobile] optimize_for_mobile: Fuse Add Relu on any function (#54441)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54441 Similar to previous dropout one ghstack-source-id: 124544176 Test Plan: Printed graphs before and after fusion. verified input outputs stayed the same {P299343882} Reviewed By: kimishpatel Differential Revision: D27014352 fbshipit-source-id: d0a9548f8743472bdd7e194efd8e8d5fe53b95b6
This commit is contained in:
committed by
Facebook GitHub Bot
parent
acffa604cc
commit
583c4bf7d3
@ -189,7 +189,9 @@ class TestOptimizer(TestCase):
|
|||||||
@torch.jit.export
|
@torch.jit.export
|
||||||
def foo(self, x):
|
def foo(self, x):
|
||||||
x = self.d(F.relu(self.l(x)))
|
x = self.d(F.relu(self.l(x)))
|
||||||
return self.l2(x)
|
x = self.l2(x)
|
||||||
|
x = x + torch.ones(1, 100)
|
||||||
|
return F.relu(x)
|
||||||
input_data = torch.ones(1, 10)
|
input_data = torch.ones(1, 10)
|
||||||
m = torch.jit.script(OptimizeNoForwardTest())
|
m = torch.jit.script(OptimizeNoForwardTest())
|
||||||
m.eval()
|
m.eval()
|
||||||
@ -200,7 +202,8 @@ class TestOptimizer(TestCase):
|
|||||||
|
|
||||||
|
|
||||||
FileCheck().check_not("dropout.__") \
|
FileCheck().check_not("dropout.__") \
|
||||||
.run(optimized_scripted_model.foo.graph)
|
.check_count("aten::_add_relu(", 1, exactly=True) \
|
||||||
|
.run(optimized_scripted_model.foo.graph)
|
||||||
torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3)
|
torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
@ -433,9 +433,11 @@ script::Module optimizeForMobile(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!optimization_blocklist.count(MobileOptimizerType::FUSE_ADD_RELU) &&
|
if (!optimization_blocklist.count(MobileOptimizerType::FUSE_ADD_RELU)) {
|
||||||
optimize_forward) {
|
for (const std::string& method : methods_to_optimize) {
|
||||||
FuseAddRelu(cloned_module);
|
auto graph = cloned_module.get_method(method).graph();
|
||||||
|
FuseAddRelu(graph);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
cloned_module.register_attribute("mobile_optimized", BoolType::get(), true);
|
cloned_module.register_attribute("mobile_optimized", BoolType::get(), true);
|
||||||
return cloned_module;
|
return cloned_module;
|
||||||
|
Reference in New Issue
Block a user