[PyTorch Mobile] Preserve bundled input related methods when calling optimize_for_mobile

Summary:
Added an extra step to **always** preserve the bundled inputs methods if they are present in the input module.

Also added a check to see if all the methods in the `preseved_methods` exist. If not, we will now throw an exception. This can hopefully stop hard-to-debug inputs from getting into downstream functions.

~~Add an optional argument `preserve_bundled_inputs_methods=False` to the `optimize_for_mobile` function. If set to be True, the function will now add three additional functions related with bundled inputs to be preserved: `get_all_bundled_inputs`, `get_num_bundled_inputs` and `run_on_bundled_input`.~~

Test Plan:
`buck test mode/dev //caffe2/test:mobile -- 'test_preserve_bundled_inputs_methods \(test_mobile_optimizer\.TestOptimizer\)'`

or

`buck test caffe2/test:mobile` to run some other related tests as well.

Reviewed By: dhruvbird

Differential Revision: D25433268

fbshipit-source-id: 0bf9b4afe64b79ed1684a3db4c0baea40ed3cdd5
This commit is contained in:
Xiong Zhang
2020-12-09 22:50:03 -08:00
committed by Facebook GitHub Bot
parent 9417e92722
commit 95233870f2
2 changed files with 76 additions and 0 deletions

View File

@ -39,6 +39,18 @@ def optimize_for_mobile(
if preserved_methods is None:
preserved_methods = []
bundled_inputs_methods = ['get_all_bundled_inputs', 'get_num_bundled_inputs', 'run_on_bundled_input']
if all([hasattr(script_module, method) for method in bundled_inputs_methods]):
preserved_methods = list(set(preserved_methods + bundled_inputs_methods))
non_exist_methods = []
for method in preserved_methods:
if not hasattr(script_module, method):
non_exist_methods.append(method)
if non_exist_methods:
raise AttributeError(
'The following methods to preserve do not exist in script_module: {}'.format(', '.join(non_exist_methods)))
backend = backend.lower()
if backend == 'cpu':
optimized_cpp_module = torch._C._jit_pass_optimize_for_mobile(script_module._c, optimization_blocklist, preserved_methods)