[PyTorch Mobile] Preserve bundled input related methods when calling optimize_for_mobile (#49170)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49170

Added an extra step to **always** preserve the bundled inputs methods if they are present in the input module.

Also added a check to see if all the methods in the `preseved_methods` exist. If not, we will now throw an exception. This can hopefully stop hard-to-debug inputs from getting into downstream functions.

~~Add an optional argument `preserve_bundled_inputs_methods=False` to the `optimize_for_mobile` function. If set to be True, the function will now add three additional functions related with bundled inputs to be preserved: `get_all_bundled_inputs`, `get_num_bundled_inputs` and `run_on_bundled_input`.~~

Test Plan:
`buck test mode/dev //caffe2/test:mobile -- 'test_preserve_bundled_inputs_methods \(test_mobile_optimizer\.TestOptimizer\)'`

or

`buck test caffe2/test:mobile` to run some other related tests as well.

Reviewed By: dhruvbird

Differential Revision: D25463719

fbshipit-source-id: 6670dfd59bcaf54b56019c1a43db04b288481b6a
This commit is contained in:
Xiong Zhang
2020-12-18 21:59:46 -08:00
committed by Facebook GitHub Bot
parent ad9923e5d5
commit e2d2d9bb0c
2 changed files with 88 additions and 3 deletions

View File

@ -8,6 +8,7 @@ from torch.utils.mobile_optimizer import *
from torch.nn import functional as F
from torch._C import MobileOptimizerType
from torch.testing._internal.common_quantized import override_quantized_engine
from torch.nn.modules.module import ModuleAttributeError
FileCheck = torch._C.FileCheck
@ -268,6 +269,69 @@ class TestOptimizer(unittest.TestCase):
bi_module_lint_list = generate_mobile_module_lints(bi_module)
self.assertEqual(len(bi_module_lint_list), 0)
def test_preserve_bundled_inputs_methods(self):
class MyBundledInputModule(torch.nn.Module):
def __init__(self):
super(MyBundledInputModule, self).__init__()
def forward(self, inputs):
return inputs
class MyIncompleteBundledInputModule(torch.nn.Module):
def __init__(self):
super(MyIncompleteBundledInputModule, self).__init__()
def forward(self, inputs):
return inputs
@torch.jit.export
def get_all_bundled_inputs(self):
pass
bi_module = torch.jit.script(MyBundledInputModule())
module_optim_bi_not_preserved = optimize_for_mobile(bi_module)
# Expected to be False since no bundled inputs methods were added
self.assertFalse(
hasattr(module_optim_bi_not_preserved, 'get_all_bundled_inputs') or
hasattr(module_optim_bi_not_preserved, 'get_num_bundled_inputs') or
hasattr(module_optim_bi_not_preserved, 'run_on_bundled_input')
)
# We expect an exception here
with self.assertRaises(ModuleAttributeError):
module_optim_bi_not_preserved.run_on_bundled_input(0)
# Add bundled inputs methods to the module
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(
bi_module, [(torch.tensor([1]),)], [])
# Now they should be preserved
module_optim_bi_preserved = optimize_for_mobile(bi_module)
# All of the bundled inputs methods were preserved
self.assertTrue(
hasattr(module_optim_bi_preserved, 'get_all_bundled_inputs') and
hasattr(module_optim_bi_preserved, 'get_num_bundled_inputs') and
hasattr(module_optim_bi_preserved, 'run_on_bundled_input')
)
# We do not expect an exception here
module_optim_bi_preserved.run_on_bundled_input(0)
bundled_input = module_optim_bi_preserved.get_all_bundled_inputs()[0]
module_optim_bi_preserved(*bundled_input)
# If not all 3 bundled inputs methods are present in the module,
# we will not try to preserve them unless specified by the user.
incomplete_bi_module = torch.jit.script(MyIncompleteBundledInputModule())
incomplete_bi_module_optim = optimize_for_mobile(incomplete_bi_module)
self.assertFalse(hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs'))
# Specifically preserve get_all_bundled_inputs even if it's the only one
# bundled inputs method available.
incomplete_bi_module_optim = optimize_for_mobile(incomplete_bi_module, preserved_methods=['get_all_bundled_inputs'])
self.assertTrue(hasattr(incomplete_bi_module_optim, 'get_all_bundled_inputs'))
@unittest.skipUnless(torch.backends.xnnpack.enabled,
" XNNPACK must be enabled for these tests."
" Please build with USE_XNNPACK=1.")

View File

@ -39,13 +39,34 @@ def optimize_for_mobile(
if preserved_methods is None:
preserved_methods = []
# Convert potential byte arrays into strings (if there is any) to pass type checking
# Here we use a new name as assigning it back to preserved_methods will invoke
# mypy errors (i.e. List[AnyStr] = List[str])
preserved_methods_str: List[str] = [str(method) for method in preserved_methods]
bundled_inputs_methods = ['get_all_bundled_inputs', 'get_num_bundled_inputs', 'run_on_bundled_input']
if all([hasattr(script_module, method) for method in bundled_inputs_methods]):
preserved_methods_str = list(set(preserved_methods_str + bundled_inputs_methods))
non_exist_methods = []
for method in preserved_methods_str:
if not hasattr(script_module, method):
non_exist_methods.append(method)
if non_exist_methods:
raise AttributeError(
'The following methods to preserve do not exist in script_module: {}'
.format(', '.join(non_exist_methods)))
backend = backend.lower()
if backend == 'cpu':
optimized_cpp_module = torch._C._jit_pass_optimize_for_mobile(script_module._c, optimization_blocklist, preserved_methods)
optimized_cpp_module = torch._C._jit_pass_optimize_for_mobile(
script_module._c,
optimization_blocklist,
preserved_methods_str)
elif backend == 'vulkan':
optimized_cpp_module = torch._C._jit_pass_vulkan_optimize_for_mobile(script_module._c, preserved_methods)
optimized_cpp_module = torch._C._jit_pass_vulkan_optimize_for_mobile(script_module._c, preserved_methods_str)
elif backend == 'metal':
optimized_cpp_module = torch._C._jit_pass_metal_optimize_for_mobile(script_module._c, preserved_methods)
optimized_cpp_module = torch._C._jit_pass_metal_optimize_for_mobile(script_module._c, preserved_methods_str)
else:
raise TypeError("Unknown backend, must be one of 'CPU', 'Vulkan' or 'Metal'")