Fuse module enhancements (#26457)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26457

Enhancement to fuse module to support sequentials, fuse list can now be just like the state dict.
Also add support for Conv-Relu and linear-relu fusion
Also support inplace and out of place fusion of models.
ghstack-source-id: 91076386

Test Plan:
buck test caffe2/test:quantization -- 'test_fusion_sequential_model_train \(test_quantization\.FusionTest\)' --print-passing-details
buck test caffe2/test:quantization -- 'test_fusion_sequential_model_eval \(test_quantization\.FusionTest\)' --print-passing-details

Differential Revision: D17466382

fbshipit-source-id: 0a548f8f4c366f3ecc59db693bac725ccd62328e
This commit is contained in:
Raghuraman Krishnamoorthi
2019-09-30 21:58:29 -07:00
committed by Facebook Github Bot
parent 6a4ca9abec
commit dddae3f854
9 changed files with 239 additions and 55 deletions

View File

@ -1,6 +1,7 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
import copy
import torch.nn._intrinsic.modules.fused as torch_fused
@ -52,44 +53,78 @@ def fuse_conv_bn_relu(conv, bn, relu):
return torch_fused.ConvReLU2d(
torch.nn.utils.fusion.fuse_conv_bn_eval(conv, bn), relu)
# Generalization of getattr
def _get_module(model, submodule_key):
tokens = submodule_key.split('.')
cur_mod = model
for s in tokens:
cur_mod = getattr(cur_mod, s)
return cur_mod
def _fuse_modules(model, named_module_dict, modules_to_fuse, fuser_func=None):
assert(len(modules_to_fuse) == 2 or len(modules_to_fuse) == 3),\
"Can fuse only 2 or 3 modules."
# Generalization of setattr
def _set_module(model, submodule_key, module):
tokens = submodule_key.split('.')
sub_tokens = tokens[:-1]
cur_mod = model
for s in sub_tokens:
cur_mod = getattr(cur_mod, s)
OP_LIST_TO_FUSER_FUNC = {
setattr(cur_mod, tokens[-1], module)
def fuse_known_modules(mod_list):
r"""Returns a list of modules that fuses the operations specified
in the input module list.
Fuses only the following sequence of modules:
conv, bn
conv, bn, relu
conv, relu
linear, relu
For these sequences, the first element in the output module list performs
the fused operation. The rest of the elements are set to nn.Identity()
"""
OP_LIST_TO_FUSER_METHOD = {
(torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn,
(torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU): fuse_conv_bn_relu
(torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU): fuse_conv_bn_relu,
(torch.nn.Conv2d, torch.nn.ReLU): torch.nn._intrinsic.ConvReLU2d,
(torch.nn.Linear, torch.nn.ReLU): torch.nn._intrinsic.LinearReLU
}
mod = []
parent_mod = []
for i in range(len(modules_to_fuse)):
parent_module_name = '.'.join(modules_to_fuse[i].split('.')[:-1])
mod.append(named_module_dict[modules_to_fuse[i]])
parent_mod.append(named_module_dict.get(parent_module_name, model))
types = tuple(type(m) for m in mod_list)
fuser_method = OP_LIST_TO_FUSER_METHOD.get(types, None)
if fuser_method is None:
raise NotImplementedError("Cannot fuse modules: {}".format(types))
new_mod = [None] * len(mod_list)
new_mod[0] = fuser_method(*mod_list)
new_mod = mod[0]
if fuser_func is None:
types = tuple(type(m) for m in mod)
fuser_func = OP_LIST_TO_FUSER_FUNC.get(types, None)
if fuser_func is None:
raise NotImplementedError("Cannot fuse modules: {}".format(types))
new_mod = fuser_func(*mod)
for i in range(1, len(mod_list)):
new_mod[i] = torch.nn.Identity()
new_mod[i].training = mod_list[0].training
# Assign new_mod to module and set remaining modules to identity
if new_mod is not mod[0]:
setattr(parent_mod[0], modules_to_fuse[0].split('.')[-1], new_mod)
for i in range(1, len(modules_to_fuse)):
setattr(parent_mod[i], modules_to_fuse[i].split('.')[-1], torch.nn.Identity())
return new_mod
def _fuse_modules(model, modules_to_fuse, fuser_func=fuse_known_modules):
def fuse_modules(model, modules_to_fuse):
mod_list = []
for item in modules_to_fuse:
mod_list.append(_get_module(model, item))
# Fuse list of modules
new_mod_list = fuser_func(mod_list)
# Replace original module list with fused module list
for i, item in enumerate(modules_to_fuse):
_set_module(model, item, new_mod_list[i])
def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules):
r"""Fuses a list of modules into a single module
Fuses only the following sequence of modules:
conv, bn
conv, bn, relu
conv, relu
linear, relu
All other sequences are left unchanged.
For these sequences, replaces the first item in the list
with the fused module, replacing the rest of the modules
@ -97,20 +132,40 @@ def fuse_modules(model, modules_to_fuse):
Arguments:
model: Model containing the modules to be fused
modules_to_fuse: list of list of module names to fuse.
modules_to_fuse: list of list of module names to fuse. Can also be a list
of strings if there is only a single list of modules to fuse.
inplace: bool specifying if fusion happens in place on the model, by default
a new model is returned
fuser_func: Function that takes in a list of modules and outputs a list of fused modules
of the same length. For example,
fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()]
Defaults to torch.quantization.fuse_known_modules
Returns:
Modifies the model in place.
model with fused modules. A new copy is created if inplace=True.
Examples::
>>> m = myModel()
>>> # m is a module containing the sub-modules below
>>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']]
>>> nn.quantization.fuse_module.fuse_module(m, modules_to_fuse)
>>> output = m(input)
>>> fused_m = torch.quantization.fuse_modules(m, modules_to_fuse)
>>> output = fused_m(input)
>>> m = myModel()
>>> # Alternately provide a single list of modules to fuse
>>> modules_to_fuse = ['conv1', 'bn1', 'relu1']
>>> fused_m = torch.quantization.fuse_modules(m, modules_to_fuse)
>>> output = fused_m(input)
"""
named_module_dict = {name: mod for name, mod in model.named_modules()}
for module_list in modules_to_fuse:
_fuse_modules(model, named_module_dict, module_list)
if not inplace:
model = copy.deepcopy(model)
if all(isinstance(module_element, str) for module_element in modules_to_fuse):
# Handle case of modules_to_fuse being a list
_fuse_modules(model, modules_to_fuse, fuser_func)
else:
# Handle case of modules_to_fuse being a list of lists
for module_list in modules_to_fuse:
_fuse_modules(model, module_list, fuser_func)
return model