mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[PyTorch Mobile] Preserve bundled input related methods when calling optimize_for_mobile
Summary: Added an extra step to **always** preserve the bundled inputs methods if they are present in the input module. Also added a check to see if all the methods in the `preseved_methods` exist. If not, we will now throw an exception. This can hopefully stop hard-to-debug inputs from getting into downstream functions. ~~Add an optional argument `preserve_bundled_inputs_methods=False` to the `optimize_for_mobile` function. If set to be True, the function will now add three additional functions related with bundled inputs to be preserved: `get_all_bundled_inputs`, `get_num_bundled_inputs` and `run_on_bundled_input`.~~ Test Plan: `buck test mode/dev //caffe2/test:mobile -- 'test_preserve_bundled_inputs_methods \(test_mobile_optimizer\.TestOptimizer\)'` or `buck test caffe2/test:mobile` to run some other related tests as well. Reviewed By: dhruvbird Differential Revision: D25433268 fbshipit-source-id: 0bf9b4afe64b79ed1684a3db4c0baea40ed3cdd5
This commit is contained in:
committed by
Facebook GitHub Bot
parent
9417e92722
commit
95233870f2
@ -8,6 +8,7 @@ from torch.utils.mobile_optimizer import *
|
||||
from torch.nn import functional as F
|
||||
from torch._C import MobileOptimizerType
|
||||
from torch.testing._internal.common_quantized import override_quantized_engine
|
||||
from torch.nn.modules.module import ModuleAttributeError
|
||||
|
||||
FileCheck = torch._C.FileCheck
|
||||
|
||||
@ -268,6 +269,69 @@ class TestOptimizer(unittest.TestCase):
|
||||
bi_module_lint_list = generate_mobile_module_lints(bi_module)
|
||||
self.assertEqual(len(bi_module_lint_list), 0)
|
||||
|
||||
def test_preserve_bundled_inputs_methods(self):
|
||||
class MyBundledInputModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyBundledInputModule, self).__init__()
|
||||
|
||||
def forward(self, inputs):
|
||||
return inputs
|
||||
|
||||
class MyIncompleteBundledInputModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyIncompleteBundledInputModule, self).__init__()
|
||||
|
||||
def forward(self, inputs):
|
||||
return inputs
|
||||
|
||||
@torch.jit.export
|
||||
def get_all_bundled_inputs(self):
|
||||
pass
|
||||
|
||||
bi_module = torch.jit.script(MyBundledInputModule())
|
||||
module_optim_bi_not_preserved = optimize_for_mobile(bi_module)
|
||||
|
||||
# Expected to be False since no bundled inputs methods were added
|
||||
self.assertFalse(
|
||||
hasattr(module_optim_bi_not_preserved, 'get_all_bundled_inputs') or
|
||||
hasattr(module_optim_bi_not_preserved, 'get_num_bundled_inputs') or
|
||||
hasattr(module_optim_bi_not_preserved, 'run_on_bundled_input')
|
||||
)
|
||||
|
||||
# We expect an exception here
|
||||
with self.assertRaises(ModuleAttributeError):
|
||||
module_optim_bi_not_preserved.run_on_bundled_input(0)
|
||||
|
||||
# Add bundled inputs methods to the module
|
||||
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
|
||||
bi_module, [(torch.tensor([1]),)], [])
|
||||
# Now they should be preserved
|
||||
module_optim_bi_preserved = optimize_for_mobile(bi_module)
|
||||
|
||||
# All of the bundled inputs methods were preserved
|
||||
self.assertTrue(
|
||||
hasattr(module_optim_bi_preserved, 'get_all_bundled_inputs') and
|
||||
hasattr(module_optim_bi_preserved, 'get_num_bundled_inputs') and
|
||||
hasattr(module_optim_bi_preserved, 'run_on_bundled_input')
|
||||
)
|
||||
|
||||
# We do not expect an exception here
|
||||
module_optim_bi_preserved.run_on_bundled_input(0)
|
||||
|
||||
bundled_input = module_optim_bi_preserved.get_all_bundled_inputs()[0]
|
||||
module_optim_bi_preserved(*bundled_input)
|
||||
|
||||
# If not all 3 bundled inputs methods are present in the module,
|
||||
# we will not try to preserve them unless specified by the user.
|
||||
incomplete_bi_module = torch.jit.script(MyIncompleteBundledInputModule())
|
||||
incomplete_bi_module_optim = optimize_for_mobile(incomplete_bi_module)
|
||||
self.assertFalse(hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs'))
|
||||
|
||||
# Specifically preserve get_all_bundled_inputs even if it's the only one
|
||||
# bundled inputs method available.
|
||||
incomplete_bi_module_optim = optimize_for_mobile(incomplete_bi_module, preserved_methods=['get_all_bundled_inputs'])
|
||||
self.assertTrue(hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs'))
|
||||
|
||||
@unittest.skipUnless(torch.backends.xnnpack.enabled,
|
||||
" XNNPACK must be enabled for these tests."
|
||||
" Please build with USE_XNNPACK=1.")
|
||||
|
Reference in New Issue
Block a user