mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Fuse module enhancements (#26457)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/26457 Enhancement to fuse module to support sequentials, fuse list can now be just like the state dict. Also add support for Conv-Relu and linear-relu fusion Also support inplace and out of place fusion of models. ghstack-source-id: 91076386 Test Plan: buck test caffe2/test:quantization -- 'test_fusion_sequential_model_train \(test_quantization\.FusionTest\)' --print-passing-details buck test caffe2/test:quantization -- 'test_fusion_sequential_model_eval \(test_quantization\.FusionTest\)' --print-passing-details Differential Revision: D17466382 fbshipit-source-id: 0a548f8f4c366f3ecc59db693bac725ccd62328e
This commit is contained in:
committed by
Facebook Github Bot
parent
6a4ca9abec
commit
dddae3f854
@ -1,6 +1,7 @@
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import torch
|
||||
import copy
|
||||
|
||||
import torch.nn._intrinsic.modules.fused as torch_fused
|
||||
|
||||
@ -52,44 +53,78 @@ def fuse_conv_bn_relu(conv, bn, relu):
|
||||
return torch_fused.ConvReLU2d(
|
||||
torch.nn.utils.fusion.fuse_conv_bn_eval(conv, bn), relu)
|
||||
|
||||
# Generalization of getattr
|
||||
def _get_module(model, submodule_key):
|
||||
tokens = submodule_key.split('.')
|
||||
cur_mod = model
|
||||
for s in tokens:
|
||||
cur_mod = getattr(cur_mod, s)
|
||||
return cur_mod
|
||||
|
||||
def _fuse_modules(model, named_module_dict, modules_to_fuse, fuser_func=None):
|
||||
assert(len(modules_to_fuse) == 2 or len(modules_to_fuse) == 3),\
|
||||
"Can fuse only 2 or 3 modules."
|
||||
# Generalization of setattr
|
||||
def _set_module(model, submodule_key, module):
|
||||
tokens = submodule_key.split('.')
|
||||
sub_tokens = tokens[:-1]
|
||||
cur_mod = model
|
||||
for s in sub_tokens:
|
||||
cur_mod = getattr(cur_mod, s)
|
||||
|
||||
OP_LIST_TO_FUSER_FUNC = {
|
||||
setattr(cur_mod, tokens[-1], module)
|
||||
|
||||
def fuse_known_modules(mod_list):
|
||||
r"""Returns a list of modules that fuses the operations specified
|
||||
in the input module list.
|
||||
|
||||
Fuses only the following sequence of modules:
|
||||
conv, bn
|
||||
conv, bn, relu
|
||||
conv, relu
|
||||
linear, relu
|
||||
For these sequences, the first element in the output module list performs
|
||||
the fused operation. The rest of the elements are set to nn.Identity()
|
||||
"""
|
||||
|
||||
OP_LIST_TO_FUSER_METHOD = {
|
||||
(torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn,
|
||||
(torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU): fuse_conv_bn_relu
|
||||
(torch.nn.Conv2d, torch.nn.BatchNorm2d, torch.nn.ReLU): fuse_conv_bn_relu,
|
||||
(torch.nn.Conv2d, torch.nn.ReLU): torch.nn._intrinsic.ConvReLU2d,
|
||||
(torch.nn.Linear, torch.nn.ReLU): torch.nn._intrinsic.LinearReLU
|
||||
}
|
||||
|
||||
mod = []
|
||||
parent_mod = []
|
||||
for i in range(len(modules_to_fuse)):
|
||||
parent_module_name = '.'.join(modules_to_fuse[i].split('.')[:-1])
|
||||
mod.append(named_module_dict[modules_to_fuse[i]])
|
||||
parent_mod.append(named_module_dict.get(parent_module_name, model))
|
||||
types = tuple(type(m) for m in mod_list)
|
||||
fuser_method = OP_LIST_TO_FUSER_METHOD.get(types, None)
|
||||
if fuser_method is None:
|
||||
raise NotImplementedError("Cannot fuse modules: {}".format(types))
|
||||
new_mod = [None] * len(mod_list)
|
||||
new_mod[0] = fuser_method(*mod_list)
|
||||
|
||||
new_mod = mod[0]
|
||||
if fuser_func is None:
|
||||
types = tuple(type(m) for m in mod)
|
||||
fuser_func = OP_LIST_TO_FUSER_FUNC.get(types, None)
|
||||
if fuser_func is None:
|
||||
raise NotImplementedError("Cannot fuse modules: {}".format(types))
|
||||
new_mod = fuser_func(*mod)
|
||||
for i in range(1, len(mod_list)):
|
||||
new_mod[i] = torch.nn.Identity()
|
||||
new_mod[i].training = mod_list[0].training
|
||||
|
||||
# Assign new_mod to module and set remaining modules to identity
|
||||
if new_mod is not mod[0]:
|
||||
setattr(parent_mod[0], modules_to_fuse[0].split('.')[-1], new_mod)
|
||||
for i in range(1, len(modules_to_fuse)):
|
||||
setattr(parent_mod[i], modules_to_fuse[i].split('.')[-1], torch.nn.Identity())
|
||||
return new_mod
|
||||
|
||||
def _fuse_modules(model, modules_to_fuse, fuser_func=fuse_known_modules):
|
||||
|
||||
def fuse_modules(model, modules_to_fuse):
|
||||
mod_list = []
|
||||
for item in modules_to_fuse:
|
||||
mod_list.append(_get_module(model, item))
|
||||
|
||||
# Fuse list of modules
|
||||
new_mod_list = fuser_func(mod_list)
|
||||
|
||||
# Replace original module list with fused module list
|
||||
for i, item in enumerate(modules_to_fuse):
|
||||
_set_module(model, item, new_mod_list[i])
|
||||
|
||||
def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules):
|
||||
r"""Fuses a list of modules into a single module
|
||||
|
||||
Fuses only the following sequence of modules:
|
||||
conv, bn
|
||||
conv, bn, relu
|
||||
conv, relu
|
||||
linear, relu
|
||||
All other sequences are left unchanged.
|
||||
For these sequences, replaces the first item in the list
|
||||
with the fused module, replacing the rest of the modules
|
||||
@ -97,20 +132,40 @@ def fuse_modules(model, modules_to_fuse):
|
||||
|
||||
Arguments:
|
||||
model: Model containing the modules to be fused
|
||||
modules_to_fuse: list of list of module names to fuse.
|
||||
|
||||
modules_to_fuse: list of list of module names to fuse. Can also be a list
|
||||
of strings if there is only a single list of modules to fuse.
|
||||
inplace: bool specifying if fusion happens in place on the model, by default
|
||||
a new model is returned
|
||||
fuser_func: Function that takes in a list of modules and outputs a list of fused modules
|
||||
of the same length. For example,
|
||||
fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()]
|
||||
Defaults to torch.quantization.fuse_known_modules
|
||||
Returns:
|
||||
Modifies the model in place.
|
||||
model with fused modules. A new copy is created if inplace=True.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> m = myModel()
|
||||
>>> # m is a module containing the sub-modules below
|
||||
>>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']]
|
||||
>>> nn.quantization.fuse_module.fuse_module(m, modules_to_fuse)
|
||||
>>> output = m(input)
|
||||
>>> fused_m = torch.quantization.fuse_modules(m, modules_to_fuse)
|
||||
>>> output = fused_m(input)
|
||||
|
||||
>>> m = myModel()
|
||||
>>> # Alternately provide a single list of modules to fuse
|
||||
>>> modules_to_fuse = ['conv1', 'bn1', 'relu1']
|
||||
>>> fused_m = torch.quantization.fuse_modules(m, modules_to_fuse)
|
||||
>>> output = fused_m(input)
|
||||
|
||||
"""
|
||||
named_module_dict = {name: mod for name, mod in model.named_modules()}
|
||||
for module_list in modules_to_fuse:
|
||||
_fuse_modules(model, named_module_dict, module_list)
|
||||
if not inplace:
|
||||
model = copy.deepcopy(model)
|
||||
|
||||
if all(isinstance(module_element, str) for module_element in modules_to_fuse):
|
||||
# Handle case of modules_to_fuse being a list
|
||||
_fuse_modules(model, modules_to_fuse, fuser_func)
|
||||
else:
|
||||
# Handle case of modules_to_fuse being a list of lists
|
||||
for module_list in modules_to_fuse:
|
||||
_fuse_modules(model, module_list, fuser_func)
|
||||
return model
|
||||
|
Reference in New Issue
Block a user