Files
pytorch/torch/utils/mobile_optimizer.py
Xingying Cheng 5c9d1e4824 Propagate module lints for mobile scripted module. (#37046)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/37046
ghstack-source-id: 102669259

Creating a python api entry to generate mobile model lints which takes a scripted module as argument and returns a map of module lints.

The initial version is to create placeholder which included module bundled input as the first lint instance. More lints will be added in the future.

Test Plan: python test/test_optimizer.py

Reviewed By: dreiss

Differential Revision: D21164648

fbshipit-source-id: 9e8f4e19d74b5464a55cc73b9dc18f358c5947d6
2020-04-27 10:20:12 -07:00

64 lines
2.4 KiB
Python

"""
This module contains utility method for mobile model optimization and lint.
"""
import torch
from enum import Enum
class LintCode(Enum):
BUNDLED_INPUT = 1
REQUIRES_GRAD = 2
DROPOUT = 3
BATCHNORM = 4
def optimize_for_mobile(script_module):
"""
Args:
script_module: An instance of torch script module with type of ScriptModule
Returns:
script_module: A new optimized torch script module
"""
if not isinstance(script_module, torch.jit.ScriptModule):
raise TypeError(
'Got {}, but ScriptModule is expected.'.format(type(script_module)))
optimized_cpp_module = torch._C._jit_pass_optimize_for_mobile(script_module._c)
return torch.jit._recursive.wrap_cpp_module(optimized_cpp_module)
def generate_mobile_module_lints(script_module: torch.jit.ScriptModule):
"""
Args:
script_module: An instance of torch script module with type of ScriptModule
Returns:
lint_map: A list of dictionary that contains modules lints
"""
if not isinstance(script_module, torch.jit.ScriptModule):
raise TypeError(
'Got {}, but ScriptModule is expected.'.format(type(script_module)))
lint_list = []
if not hasattr(script_module, "_generate_bundled_inputs"):
lint_list.append({"name": LintCode.BUNDLED_INPUT.name, "message": "No bundled input, please add bundled inputs before "
"saving the module using torch.utils.bundled_inputs.augment_model_with_bundled_inputs."})
for name, param in script_module.named_parameters():
if param.requires_grad:
lint_list.append({"name": LintCode.REQUIRES_GRAD.name, "message": "Param {} requires grad, "
"please set torch.no_grad() to reduce memory usage and improve computation speed during "
"inference phase.".format(name)})
op_names = torch.jit.export_opnames(script_module)
for op_name in op_names:
if "dropout" in op_name:
lint_list.append({"name": LintCode.DROPOUT.name, "message": "Operator {} exists, remember to call eval() before "
"saving the module.".format(op_name)})
if "batch_norm" in op_name:
lint_list.append({"name": LintCode.BATCHNORM.name, "message": "Operator {} exists, remember to call eval() before "
"saving the module.".format(op_name)})
return lint_list