mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add mobile_optimized tag to optimized model. (#45479)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45479 Add a top level boolean attribute to the model called mobile_optimized that is set to true if it is optimized. Test Plan: buck test //caffe2/test:mobile passes Reviewed By: kimishpatel Differential Revision: D23956728 fbshipit-source-id: 79c5931702208b871454319ca2ab8633596b1eb8
This commit is contained in:
committed by
Facebook GitHub Bot
parent
17be7c6e5c
commit
5f49d14be2
@ -131,6 +131,23 @@ class TestOptimizer(unittest.TestCase):
|
||||
bn_input = torch.rand(1, 1, 6, 6)
|
||||
torch.testing.assert_allclose(bn_scripted_module(bn_input), no_bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3)
|
||||
|
||||
class MyMobileOptimizedTagTest(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyMobileOptimizedTagTest, self).__init__()
|
||||
self.linear_weight = torch.nn.Parameter(torch.Tensor(torch.rand(linear_weight_shape)))
|
||||
self.linear_bias = torch.nn.Parameter(torch.Tensor(torch.rand((weight_output_dim))))
|
||||
|
||||
def forward(self, x):
|
||||
o = F.linear(x, self.linear_weight, self.linear_bias)
|
||||
return F.relu(o)
|
||||
|
||||
mobile_optimized_tag_module = MyMobileOptimizedTagTest()
|
||||
m = torch.jit.script(mobile_optimized_tag_module)
|
||||
m.eval()
|
||||
opt_m = optimize_for_mobile(m)
|
||||
tag = getattr(opt_m, "mobile_optimized", None)
|
||||
self.assertTrue(tag)
|
||||
|
||||
class MyPreserveMethodsTest(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyPreserveMethodsTest, self).__init__()
|
||||
|
Reference in New Issue
Block a user