mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Pytorch Edge] Prepack folding for functions besides forward (#56081)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56081 ghstack-source-id: 127205799 Test Plan: unit test. Since I'm prepacking the weights of the same operators multiple times I wonder if its a just works thing? Reviewed By: kimishpatel Differential Revision: D27777337 fbshipit-source-id: 909d2a667d9eb51e205536b478a6668c33b3fb15
This commit is contained in:
committed by
Facebook GitHub Bot
parent
7ff1990caf
commit
7e9f7fb980
@ -74,6 +74,17 @@ class TestOptimizer(TestCase):
|
||||
o = o + x
|
||||
return F.relu(o)
|
||||
|
||||
@torch.jit.export
|
||||
def foo(self, x):
|
||||
o = F.conv2d(x, self.conv_weight, self.conv_bias,
|
||||
self.strides, self.paddings, self.dilations, self.groups)
|
||||
o = F.relu(o)
|
||||
x = o.permute([0, 2, 3, 1])
|
||||
o = F.linear(x, self.linear_weight, self.linear_bias)
|
||||
o = o + x
|
||||
return F.relu(o)
|
||||
|
||||
|
||||
class BNTestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(BNTestModule, self).__init__()
|
||||
@ -92,9 +103,11 @@ class TestOptimizer(TestCase):
|
||||
scripted_model = torch.jit.script(MyTestModule())
|
||||
scripted_model.eval()
|
||||
initial_result = scripted_model(input_data)
|
||||
initial_foo_result = scripted_model.foo(input_data)
|
||||
|
||||
optimized_scripted_model = optimize_for_mobile(scripted_model)
|
||||
optimized_scripted_model = optimize_for_mobile(scripted_model, methods_to_optimize=['foo'])
|
||||
optimized_result = optimized_scripted_model(input_data)
|
||||
optimized_foo_result = optimized_scripted_model.foo(input_data)
|
||||
|
||||
FileCheck().check_not("Tensor = aten::conv2d") \
|
||||
.check_not("Tensor = prim::CallFunction") \
|
||||
@ -108,6 +121,18 @@ class TestOptimizer(TestCase):
|
||||
.run(optimized_scripted_model.graph)
|
||||
torch.testing.assert_allclose(initial_result, optimized_result, rtol=1e-2, atol=1e-3)
|
||||
|
||||
FileCheck().check_not("Tensor = aten::conv2d") \
|
||||
.check_not("Tensor = prim::CallFunction") \
|
||||
.check_not("prepacked::conv2d_clamp_prepack") \
|
||||
.check_count("prepacked::conv2d_clamp_run", 1, exactly=True) \
|
||||
.check_not("prepacked::linear_clamp_prepack") \
|
||||
.check_count("prepacked::linear_clamp_run", 1, exactly=True) \
|
||||
.check_not("aten::add(") \
|
||||
.check_not("aten::relu(") \
|
||||
.check_count("aten::_add_relu(", 1, exactly=True) \
|
||||
.run(optimized_scripted_model.foo.graph)
|
||||
torch.testing.assert_allclose(initial_foo_result, optimized_foo_result, rtol=1e-2, atol=1e-3)
|
||||
|
||||
|
||||
optimization_blocklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
|
||||
optimized_scripted_model_no_prepack = optimize_for_mobile(scripted_model, optimization_blocklist_no_prepack)
|
||||
|
Reference in New Issue
Block a user