mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50748 Adds support for Linear + BatchNorm1d fusion to quantization. This is a redo of dreiss's https://github.com/pytorch/pytorch/pull/37467, faster to copy-paste it than rebase and deal with conflicts. Test Plan: ``` python test/test_quantization.py TestFusion.test_fusion_linear_bn_eval ``` Imported from OSS Reviewed By: supriyar Differential Revision: D25957432 fbshipit-source-id: 24e5b760f70186aa953ef65ab0182770e89495e4
148 lines
5.6 KiB
Python
148 lines
5.6 KiB
Python
|
|
import copy
|
|
|
|
import torch.nn as nn
|
|
|
|
from .fuser_method_mappings import get_fuser_method
|
|
# for backward compatiblity
|
|
from .fuser_method_mappings import fuse_conv_bn # noqa: F401
|
|
from .fuser_method_mappings import fuse_conv_bn_relu # noqa: F40
|
|
|
|
from typing import List, Optional
|
|
|
|
# Generalization of getattr
|
|
def _get_module(model, submodule_key):
|
|
tokens = submodule_key.split('.')
|
|
cur_mod = model
|
|
for s in tokens:
|
|
cur_mod = getattr(cur_mod, s)
|
|
return cur_mod
|
|
|
|
# Generalization of setattr
|
|
def _set_module(model, submodule_key, module):
|
|
tokens = submodule_key.split('.')
|
|
sub_tokens = tokens[:-1]
|
|
cur_mod = model
|
|
for s in sub_tokens:
|
|
cur_mod = getattr(cur_mod, s)
|
|
|
|
setattr(cur_mod, tokens[-1], module)
|
|
|
|
def fuse_known_modules(mod_list, additional_fuser_method_mapping=None):
|
|
r"""Returns a list of modules that fuses the operations specified
|
|
in the input module list.
|
|
|
|
Fuses only the following sequence of modules:
|
|
conv, bn
|
|
conv, bn, relu
|
|
conv, relu
|
|
linear, bn
|
|
linear, relu
|
|
For these sequences, the first element in the output module list performs
|
|
the fused operation. The rest of the elements are set to nn.Identity()
|
|
"""
|
|
types = tuple(type(m) for m in mod_list)
|
|
fuser_method = get_fuser_method(types, additional_fuser_method_mapping)
|
|
if fuser_method is None:
|
|
raise NotImplementedError("Cannot fuse modules: {}".format(types))
|
|
new_mod : List[Optional[nn.Module]] = [None] * len(mod_list)
|
|
fused = fuser_method(*mod_list)
|
|
# NOTE: forward hooks not processed in the two following for loops will be lost after the fusion
|
|
# Move pre forward hooks of the base module to resulting fused module
|
|
for handle_id, pre_hook_fn in mod_list[0]._forward_pre_hooks.items():
|
|
fused.register_forward_pre_hook(pre_hook_fn)
|
|
del mod_list[0]._forward_pre_hooks[handle_id]
|
|
# Move post forward hooks of the last module to resulting fused module
|
|
for handle_id, hook_fn in mod_list[-1]._forward_hooks.items():
|
|
fused.register_forward_hook(hook_fn)
|
|
del mod_list[-1]._forward_hooks[handle_id]
|
|
new_mod[0] = fused
|
|
|
|
for i in range(1, len(mod_list)):
|
|
identity = nn.Identity()
|
|
identity.training = mod_list[0].training
|
|
new_mod[i] = identity
|
|
|
|
return new_mod
|
|
|
|
def _fuse_modules(model, modules_to_fuse, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
|
|
if fuse_custom_config_dict is None:
|
|
fuse_custom_config_dict = {}
|
|
additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
|
|
mod_list = []
|
|
for item in modules_to_fuse:
|
|
mod_list.append(_get_module(model, item))
|
|
|
|
# Fuse list of modules
|
|
new_mod_list = fuser_func(mod_list, additional_fuser_method_mapping)
|
|
|
|
# Replace original module list with fused module list
|
|
for i, item in enumerate(modules_to_fuse):
|
|
_set_module(model, item, new_mod_list[i])
|
|
|
|
def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
|
|
r"""Fuses a list of modules into a single module
|
|
|
|
Fuses only the following sequence of modules:
|
|
conv, bn
|
|
conv, bn, relu
|
|
conv, relu
|
|
linear, relu
|
|
bn, relu
|
|
All other sequences are left unchanged.
|
|
For these sequences, replaces the first item in the list
|
|
with the fused module, replacing the rest of the modules
|
|
with identity.
|
|
|
|
Args:
|
|
model: Model containing the modules to be fused
|
|
modules_to_fuse: list of list of module names to fuse. Can also be a list
|
|
of strings if there is only a single list of modules to fuse.
|
|
inplace: bool specifying if fusion happens in place on the model, by default
|
|
a new model is returned
|
|
fuser_func: Function that takes in a list of modules and outputs a list of fused modules
|
|
of the same length. For example,
|
|
fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()]
|
|
Defaults to torch.quantization.fuse_known_modules
|
|
`fuse_custom_config_dict`: custom configuration for fusion
|
|
|
|
.. code-block:: python
|
|
|
|
# Example of fuse_custom_config_dict
|
|
fuse_custom_config_dict = {
|
|
# Additional fuser_method mapping
|
|
"additional_fuser_method_mapping": {
|
|
(torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn
|
|
},
|
|
}
|
|
|
|
Returns:
|
|
model with fused modules. A new copy is created if inplace=True.
|
|
|
|
Examples::
|
|
|
|
>>> m = myModel()
|
|
>>> # m is a module containing the sub-modules below
|
|
>>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']]
|
|
>>> fused_m = torch.quantization.fuse_modules(m, modules_to_fuse)
|
|
>>> output = fused_m(input)
|
|
|
|
>>> m = myModel()
|
|
>>> # Alternately provide a single list of modules to fuse
|
|
>>> modules_to_fuse = ['conv1', 'bn1', 'relu1']
|
|
>>> fused_m = torch.quantization.fuse_modules(m, modules_to_fuse)
|
|
>>> output = fused_m(input)
|
|
|
|
"""
|
|
if not inplace:
|
|
model = copy.deepcopy(model)
|
|
|
|
if all(isinstance(module_element, str) for module_element in modules_to_fuse):
|
|
# Handle case of modules_to_fuse being a list
|
|
_fuse_modules(model, modules_to_fuse, fuser_func, fuse_custom_config_dict)
|
|
else:
|
|
# Handle case of modules_to_fuse being a list of lists
|
|
for module_list in modules_to_fuse:
|
|
_fuse_modules(model, module_list, fuser_func, fuse_custom_config_dict)
|
|
return model
|