mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[Pytorch Mobile] Optimize Non Forward for Mobile (#53314)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53314 Introduction of api for optimizing non forward functions for mobile. As of this diff, all functions that you say to optimize will be preserved, and those functions will be run through canonical optimization. The intention is to stack each further optimization onto separate diffs since they touch multiple files, and it seems like it'd be a nightmare to review. ghstack-source-id: 123909414 Test Plan: torch.utils.mobile_optimizer.optimize_for_mobile(net, methods_to_optimize=["forward", "foo"]) runs fine torch.utils.mobile_optimizer.optimize_for_mobile(net, methods_to_optimize={"foo"}) optimizes just foo if the model doesnt define forward otherwise optimizes foo and forward torch.utils.mobile_optimizer.optimize_for_mobile(net, methods_to_optimize=["forward"]) runs fine torch.utils.mobile_optimizer.optimize_for_mobile(net) runs fine if the model defines forward, Throws otherwise Reviewed By: kimishpatel Differential Revision: D26618689 fbshipit-source-id: 5bff1fb3f3f6085c4a649a8128af9c10f0fa9400
This commit is contained in:
committed by
Facebook GitHub Bot
parent
407d60ee91
commit
8f61b13e80
@ -17,7 +17,8 @@ def optimize_for_mobile(
|
||||
script_module,
|
||||
optimization_blocklist: Set[MobileOptimizerType] = None,
|
||||
preserved_methods: List[AnyStr] = None,
|
||||
backend: str = 'CPU'):
|
||||
backend: str = 'CPU',
|
||||
methods_to_optimize: List[AnyStr] = None):
|
||||
"""
|
||||
Args:
|
||||
script_module: An instance of torch script module with type of ScriptModule.
|
||||
@ -26,6 +27,7 @@ def optimize_for_mobile(
|
||||
method will run the optimization pass that is not included inside optimization_blocklist.
|
||||
perserved_methods: A list of methods that needed to be preserved when freeze_module pass is invoked
|
||||
backend: Device type to use for running the result model ('CPU'(default), 'Vulkan' or 'Metal').
|
||||
methods_to_optimize: List of functions to optimize, CPU only, forward is optimized if it exists
|
||||
Returns:
|
||||
A new optimized torch script module
|
||||
"""
|
||||
@ -39,10 +41,14 @@ def optimize_for_mobile(
|
||||
if preserved_methods is None:
|
||||
preserved_methods = []
|
||||
|
||||
if methods_to_optimize is None:
|
||||
methods_to_optimize = []
|
||||
|
||||
# Convert potential byte arrays into strings (if there is any) to pass type checking
|
||||
# Here we use a new name as assigning it back to preserved_methods will invoke
|
||||
# mypy errors (i.e. List[AnyStr] = List[str])
|
||||
preserved_methods_str: List[str] = [str(method) for method in preserved_methods]
|
||||
methods_to_optimize_str: List[str] = [str(method) for method in methods_to_optimize]
|
||||
|
||||
bundled_inputs_attributes = _get_bundled_inputs_preserved_attributes(script_module, preserved_methods_str)
|
||||
if all([hasattr(script_module, method) for method in bundled_inputs_attributes]):
|
||||
@ -62,7 +68,8 @@ def optimize_for_mobile(
|
||||
optimized_cpp_module = torch._C._jit_pass_optimize_for_mobile(
|
||||
script_module._c,
|
||||
optimization_blocklist,
|
||||
preserved_methods_str)
|
||||
preserved_methods_str,
|
||||
methods_to_optimize_str)
|
||||
elif backend == 'vulkan':
|
||||
optimized_cpp_module = torch._C._jit_pass_vulkan_optimize_for_mobile(script_module._c, preserved_methods_str)
|
||||
elif backend == 'metal':
|
||||
|
Reference in New Issue
Block a user