Add option to preserve certain methods during optimize_for_mobile. (#40629)

Summary:
By default freeze_module pass, invoked from optimize_for_mobile,
preserves only forward method. There is an option to specify a list of
methods that can be preserved during freeze_module. This PR exposes that
to optimize_for_module pass.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40629

Test Plan: python test/test_mobile_optimizer.py

Reviewed By: dreiss

Differential Revision: D22260972

Pulled By: kimishpatel

fbshipit-source-id: 452c653269da8bb865acfb58da2d28c23c66e326
This commit is contained in:
Kimish Patel
2020-06-29 09:31:00 -07:00
committed by Facebook GitHub Bot
parent 4121d34036
commit 4a174c83ca
5 changed files with 45 additions and 9 deletions

View File

@ -127,6 +127,31 @@ class TestOptimizer(unittest.TestCase):
bn_input = torch.rand(1, 1, 6, 6)
torch.testing.assert_allclose(bn_scripted_module(bn_input), no_bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3)
class MyPreserveMethodsTest(torch.nn.Module):
def __init__(self):
super(MyPreserveMethodsTest, self).__init__()
self.linear_weight = torch.nn.Parameter(torch.Tensor(torch.rand(linear_weight_shape)))
self.linear_bias = torch.nn.Parameter(torch.Tensor(torch.rand((weight_output_dim))))
def forward(self, x):
o = F.linear(x, self.linear_weight, self.linear_bias)
return F.relu(o)
@torch.jit.export
def preserveThis(self):
pass
preserve_method_module = MyPreserveMethodsTest()
m = torch.jit.script(preserve_method_module)
m.eval()
opt_m = optimize_for_mobile(m)
no_preserveThis = getattr(opt_m, "preserveThis", None)
self.assertEqual(no_preserveThis, None)
opt_m = optimize_for_mobile(m, preserved_methods=["preserveThis"])
preserveThis = getattr(opt_m, "preserveThis", None)
self.assertNotEqual(preserveThis, None)
def test_generate_mobile_module_lints(self):
class MyTestModule(torch.nn.Module):
def __init__(self):

View File

@ -276,7 +276,8 @@ void FoldPrePackingOps(script::Module& m) {
script::Module optimizeForMobile(
const script::Module& m,
const std::set<MobileOptimizerType>& optimization_blacklist) {
const std::set<MobileOptimizerType>& optimization_blacklist,
const std::vector<std::string>& preserved_methods) {
auto cloned_module = m.clone();
cloned_module.eval();
@ -287,7 +288,7 @@ script::Module optimizeForMobile(
if (!optimization_blacklist.count(
MobileOptimizerType::INSERT_FOLD_PREPACK_OPS)) {
insertPrePackedOps(cloned_module);
cloned_module = freeze_module(cloned_module);
cloned_module = freeze_module(cloned_module, preserved_methods);
fusePrePackedLinearConvWithClamp(cloned_module);
FoldPrePackingOps(cloned_module);
}
@ -328,7 +329,8 @@ void FoldPrePackingOps(script::Module& m) {
script::Module optimizeForMobile(
const script::Module& module,
const std::set<MobileOptimizerType>& blacklist) {
const std::set<MobileOptimizerType>& blacklist,
const std::vector<std::string>& preserved_methods) {
TORCH_INTERNAL_ASSERT(
"Mobile optimizaiton only available with XNNPACK at the moment. "
"XNNPACK is not enabled. Please build with USE_XNNPACK=1");

View File

@ -18,6 +18,7 @@ TORCH_API void fusePrePackedLinearConvWithClamp(script::Module& module);
TORCH_API void FoldPrePackingOps(script::Module& module);
TORCH_API script::Module optimizeForMobile(
const script::Module& module,
const std::set<MobileOptimizerType>& optimization_blacklist = {});
const std::set<MobileOptimizerType>& optimization_blacklist = {},
const std::vector<std::string>& preserved_methods = {});
} // namespace jit
} // namespace torch

View File

@ -550,8 +550,10 @@ void initJITBindings(PyObject* module) {
.def(
"_jit_pass_optimize_for_mobile",
[](script::Module& module,
std::set<MobileOptimizerType>& optimization_blacklist) {
return optimizeForMobile(module, optimization_blacklist);
std::set<MobileOptimizerType>& optimization_blacklist,
std::vector<std::string>& preserved_methods) {
return optimizeForMobile(
module, optimization_blacklist, preserved_methods);
})
.def(
"_jit_pass_vulkan_insert_prepacked_ops",

View File

@ -5,7 +5,7 @@ This module contains utility method for mobile model optimization and lint.
import torch
from enum import Enum
from torch._C import MobileOptimizerType
from typing import Set
from typing import Set, List, AnyStr
class LintCode(Enum):
BUNDLED_INPUT = 1
@ -13,7 +13,10 @@ class LintCode(Enum):
DROPOUT = 3
BATCHNORM = 4
def optimize_for_mobile(script_module, optimization_blacklist: Set[MobileOptimizerType] = None):
def optimize_for_mobile(
script_module,
optimization_blacklist: Set[MobileOptimizerType] = None,
preserved_methods: List[AnyStr] = None):
"""
Args:
script_module: An instance of torch script module with type of ScriptModule
@ -30,7 +33,10 @@ def optimize_for_mobile(script_module, optimization_blacklist: Set[MobileOptimiz
if optimization_blacklist is None:
optimization_blacklist = set()
optimized_cpp_module = torch._C._jit_pass_optimize_for_mobile(script_module._c, optimization_blacklist)
if preserved_methods is None:
preserved_methods = []
optimized_cpp_module = torch._C._jit_pass_optimize_for_mobile(script_module._c, optimization_blacklist, preserved_methods)
return torch.jit._recursive.wrap_cpp_module(optimized_cpp_module)