mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[PyTorch Mobile] Preserve bundled input related methods when calling optimize_for_mobile (#49170)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49170 Added an extra step to **always** preserve the bundled inputs methods if they are present in the input module. Also added a check to see if all the methods in the `preseved_methods` exist. If not, we will now throw an exception. This can hopefully stop hard-to-debug inputs from getting into downstream functions. ~~Add an optional argument `preserve_bundled_inputs_methods=False` to the `optimize_for_mobile` function. If set to be True, the function will now add three additional functions related with bundled inputs to be preserved: `get_all_bundled_inputs`, `get_num_bundled_inputs` and `run_on_bundled_input`.~~ Test Plan: `buck test mode/dev //caffe2/test:mobile -- 'test_preserve_bundled_inputs_methods \(test_mobile_optimizer\.TestOptimizer\)'` or `buck test caffe2/test:mobile` to run some other related tests as well. Reviewed By: dhruvbird Differential Revision: D25463719 fbshipit-source-id: 6670dfd59bcaf54b56019c1a43db04b288481b6a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
ad9923e5d5
commit
e2d2d9bb0c
@ -39,13 +39,34 @@ def optimize_for_mobile(
|
||||
if preserved_methods is None:
|
||||
preserved_methods = []
|
||||
|
||||
# Convert potential byte arrays into strings (if there is any) to pass type checking
|
||||
# Here we use a new name as assigning it back to preserved_methods will invoke
|
||||
# mypy errors (i.e. List[AnyStr] = List[str])
|
||||
preserved_methods_str: List[str] = [str(method) for method in preserved_methods]
|
||||
|
||||
bundled_inputs_methods = ['get_all_bundled_inputs', 'get_num_bundled_inputs', 'run_on_bundled_input']
|
||||
if all([hasattr(script_module, method) for method in bundled_inputs_methods]):
|
||||
preserved_methods_str = list(set(preserved_methods_str + bundled_inputs_methods))
|
||||
|
||||
non_exist_methods = []
|
||||
for method in preserved_methods_str:
|
||||
if not hasattr(script_module, method):
|
||||
non_exist_methods.append(method)
|
||||
if non_exist_methods:
|
||||
raise AttributeError(
|
||||
'The following methods to preserve do not exist in script_module: {}'
|
||||
.format(', '.join(non_exist_methods)))
|
||||
|
||||
backend = backend.lower()
|
||||
if backend == 'cpu':
|
||||
optimized_cpp_module = torch._C._jit_pass_optimize_for_mobile(script_module._c, optimization_blocklist, preserved_methods)
|
||||
optimized_cpp_module = torch._C._jit_pass_optimize_for_mobile(
|
||||
script_module._c,
|
||||
optimization_blocklist,
|
||||
preserved_methods_str)
|
||||
elif backend == 'vulkan':
|
||||
optimized_cpp_module = torch._C._jit_pass_vulkan_optimize_for_mobile(script_module._c, preserved_methods)
|
||||
optimized_cpp_module = torch._C._jit_pass_vulkan_optimize_for_mobile(script_module._c, preserved_methods_str)
|
||||
elif backend == 'metal':
|
||||
optimized_cpp_module = torch._C._jit_pass_metal_optimize_for_mobile(script_module._c, preserved_methods)
|
||||
optimized_cpp_module = torch._C._jit_pass_metal_optimize_for_mobile(script_module._c, preserved_methods_str)
|
||||
else:
|
||||
raise TypeError("Unknown backend, must be one of 'CPU', 'Vulkan' or 'Metal'")
|
||||
|
||||
|
Reference in New Issue
Block a user