mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add option to preserve certain methods during optimize_for_mobile. (#40629)
Summary: By default freeze_module pass, invoked from optimize_for_mobile, preserves only forward method. There is an option to specify a list of methods that can be preserved during freeze_module. This PR exposes that to optimize_for_module pass. Pull Request resolved: https://github.com/pytorch/pytorch/pull/40629 Test Plan: python test/test_mobile_optimizer.py Reviewed By: dreiss Differential Revision: D22260972 Pulled By: kimishpatel fbshipit-source-id: 452c653269da8bb865acfb58da2d28c23c66e326
This commit is contained in:
committed by
Facebook GitHub Bot
parent
4121d34036
commit
4a174c83ca
@ -127,6 +127,31 @@ class TestOptimizer(unittest.TestCase):
|
||||
bn_input = torch.rand(1, 1, 6, 6)
|
||||
torch.testing.assert_allclose(bn_scripted_module(bn_input), no_bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3)
|
||||
|
||||
class MyPreserveMethodsTest(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyPreserveMethodsTest, self).__init__()
|
||||
self.linear_weight = torch.nn.Parameter(torch.Tensor(torch.rand(linear_weight_shape)))
|
||||
self.linear_bias = torch.nn.Parameter(torch.Tensor(torch.rand((weight_output_dim))))
|
||||
|
||||
def forward(self, x):
|
||||
o = F.linear(x, self.linear_weight, self.linear_bias)
|
||||
return F.relu(o)
|
||||
|
||||
@torch.jit.export
|
||||
def preserveThis(self):
|
||||
pass
|
||||
|
||||
preserve_method_module = MyPreserveMethodsTest()
|
||||
m = torch.jit.script(preserve_method_module)
|
||||
m.eval()
|
||||
opt_m = optimize_for_mobile(m)
|
||||
no_preserveThis = getattr(opt_m, "preserveThis", None)
|
||||
self.assertEqual(no_preserveThis, None)
|
||||
opt_m = optimize_for_mobile(m, preserved_methods=["preserveThis"])
|
||||
preserveThis = getattr(opt_m, "preserveThis", None)
|
||||
self.assertNotEqual(preserveThis, None)
|
||||
|
||||
|
||||
def test_generate_mobile_module_lints(self):
|
||||
class MyTestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
||||
@ -276,7 +276,8 @@ void FoldPrePackingOps(script::Module& m) {
|
||||
|
||||
script::Module optimizeForMobile(
|
||||
const script::Module& m,
|
||||
const std::set<MobileOptimizerType>& optimization_blacklist) {
|
||||
const std::set<MobileOptimizerType>& optimization_blacklist,
|
||||
const std::vector<std::string>& preserved_methods) {
|
||||
auto cloned_module = m.clone();
|
||||
cloned_module.eval();
|
||||
|
||||
@ -287,7 +288,7 @@ script::Module optimizeForMobile(
|
||||
if (!optimization_blacklist.count(
|
||||
MobileOptimizerType::INSERT_FOLD_PREPACK_OPS)) {
|
||||
insertPrePackedOps(cloned_module);
|
||||
cloned_module = freeze_module(cloned_module);
|
||||
cloned_module = freeze_module(cloned_module, preserved_methods);
|
||||
fusePrePackedLinearConvWithClamp(cloned_module);
|
||||
FoldPrePackingOps(cloned_module);
|
||||
}
|
||||
@ -328,7 +329,8 @@ void FoldPrePackingOps(script::Module& m) {
|
||||
|
||||
script::Module optimizeForMobile(
|
||||
const script::Module& module,
|
||||
const std::set<MobileOptimizerType>& blacklist) {
|
||||
const std::set<MobileOptimizerType>& blacklist,
|
||||
const std::vector<std::string>& preserved_methods) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
"Mobile optimizaiton only available with XNNPACK at the moment. "
|
||||
"XNNPACK is not enabled. Please build with USE_XNNPACK=1");
|
||||
|
||||
@ -18,6 +18,7 @@ TORCH_API void fusePrePackedLinearConvWithClamp(script::Module& module);
|
||||
TORCH_API void FoldPrePackingOps(script::Module& module);
|
||||
TORCH_API script::Module optimizeForMobile(
|
||||
const script::Module& module,
|
||||
const std::set<MobileOptimizerType>& optimization_blacklist = {});
|
||||
const std::set<MobileOptimizerType>& optimization_blacklist = {},
|
||||
const std::vector<std::string>& preserved_methods = {});
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
||||
@ -550,8 +550,10 @@ void initJITBindings(PyObject* module) {
|
||||
.def(
|
||||
"_jit_pass_optimize_for_mobile",
|
||||
[](script::Module& module,
|
||||
std::set<MobileOptimizerType>& optimization_blacklist) {
|
||||
return optimizeForMobile(module, optimization_blacklist);
|
||||
std::set<MobileOptimizerType>& optimization_blacklist,
|
||||
std::vector<std::string>& preserved_methods) {
|
||||
return optimizeForMobile(
|
||||
module, optimization_blacklist, preserved_methods);
|
||||
})
|
||||
.def(
|
||||
"_jit_pass_vulkan_insert_prepacked_ops",
|
||||
|
||||
@ -5,7 +5,7 @@ This module contains utility method for mobile model optimization and lint.
|
||||
import torch
|
||||
from enum import Enum
|
||||
from torch._C import MobileOptimizerType
|
||||
from typing import Set
|
||||
from typing import Set, List, AnyStr
|
||||
|
||||
class LintCode(Enum):
|
||||
BUNDLED_INPUT = 1
|
||||
@ -13,7 +13,10 @@ class LintCode(Enum):
|
||||
DROPOUT = 3
|
||||
BATCHNORM = 4
|
||||
|
||||
def optimize_for_mobile(script_module, optimization_blacklist: Set[MobileOptimizerType] = None):
|
||||
def optimize_for_mobile(
|
||||
script_module,
|
||||
optimization_blacklist: Set[MobileOptimizerType] = None,
|
||||
preserved_methods: List[AnyStr] = None):
|
||||
"""
|
||||
Args:
|
||||
script_module: An instance of torch script module with type of ScriptModule
|
||||
@ -30,7 +33,10 @@ def optimize_for_mobile(script_module, optimization_blacklist: Set[MobileOptimiz
|
||||
if optimization_blacklist is None:
|
||||
optimization_blacklist = set()
|
||||
|
||||
optimized_cpp_module = torch._C._jit_pass_optimize_for_mobile(script_module._c, optimization_blacklist)
|
||||
if preserved_methods is None:
|
||||
preserved_methods = []
|
||||
|
||||
optimized_cpp_module = torch._C._jit_pass_optimize_for_mobile(script_module._c, optimization_blacklist, preserved_methods)
|
||||
return torch.jit._recursive.wrap_cpp_module(optimized_cpp_module)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user