[quant][eagermode] Add additional_fuser_method_mapping to config (#46355)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46355

Test Plan: Imported from OSS

Reviewed By: vkuzo

Differential Revision: D24319562

fbshipit-source-id: be9800723c0b3e36f26e73c25c0c6ae1d4344f45
This commit is contained in:
Jerry Zhang
2020-10-24 02:15:45 -07:00
committed by Facebook GitHub Bot
parent 13b7855f33
commit 37dbc6117f
3 changed files with 26 additions and 10 deletions

View File

@ -28,7 +28,7 @@ def _set_module(model, submodule_key, module):
setattr(cur_mod, tokens[-1], module)
def fuse_known_modules(mod_list):
def fuse_known_modules(mod_list, additional_fuser_method_mapping=None):
r"""Returns a list of modules that fuses the operations specified
in the input module list.
@ -41,7 +41,7 @@ def fuse_known_modules(mod_list):
the fused operation. The rest of the elements are set to nn.Identity()
"""
types = tuple(type(m) for m in mod_list)
fuser_method = get_fuser_method(types)
fuser_method = get_fuser_method(types, additional_fuser_method_mapping)
if fuser_method is None:
raise NotImplementedError("Cannot fuse modules: {}".format(types))
new_mod : List[Optional[nn.Module]] = [None] * len(mod_list)
@ -64,20 +64,22 @@ def fuse_known_modules(mod_list):
return new_mod
def _fuse_modules(model, modules_to_fuse, fuser_func=fuse_known_modules):
def _fuse_modules(model, modules_to_fuse, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
if fuse_custom_config_dict is None:
fuse_custom_config_dict = {}
additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
mod_list = []
for item in modules_to_fuse:
mod_list.append(_get_module(model, item))
# Fuse list of modules
new_mod_list = fuser_func(mod_list)
new_mod_list = fuser_func(mod_list, additional_fuser_method_mapping)
# Replace original module list with fused module list
for i, item in enumerate(modules_to_fuse):
_set_module(model, item, new_mod_list[i])
def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules):
def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
r"""Fuses a list of modules into a single module
Fuses only the following sequence of modules:
@ -101,6 +103,18 @@ def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_mo
of the same length. For example,
fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()]
Defaults to torch.quantization.fuse_known_modules
`fuse_custom_config_dict`: custom configuration for fusion
.. code-block:: python
# Example of fuse_custom_config_dict
fuse_custom_config_dict = {
# Additional fuser_method mapping
"additional_fuser_method_mapping": {
(torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn
},
}
Returns:
model with fused modules. A new copy is created if inplace=True.
@ -124,9 +138,9 @@ def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_mo
if all(isinstance(module_element, str) for module_element in modules_to_fuse):
# Handle case of modules_to_fuse being a list
_fuse_modules(model, modules_to_fuse, fuser_func)
_fuse_modules(model, modules_to_fuse, fuser_func, fuse_custom_config_dict)
else:
# Handle case of modules_to_fuse being a list of lists
for module_list in modules_to_fuse:
_fuse_modules(model, module_list, fuser_func)
_fuse_modules(model, module_list, fuser_func, fuse_custom_config_dict)
return model