mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-22 14:15:01 +08:00
[quant][eagermode] Add additional_fuser_method_mapping to config (#46355)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46355 Test Plan: Imported from OSS Reviewed By: vkuzo Differential Revision: D24319562 fbshipit-source-id: be9800723c0b3e36f26e73c25c0c6ae1d4344f45
This commit is contained in:
committed by
Facebook GitHub Bot
parent
13b7855f33
commit
37dbc6117f
@ -28,7 +28,7 @@ def _set_module(model, submodule_key, module):
|
||||
|
||||
setattr(cur_mod, tokens[-1], module)
|
||||
|
||||
def fuse_known_modules(mod_list):
|
||||
def fuse_known_modules(mod_list, additional_fuser_method_mapping=None):
|
||||
r"""Returns a list of modules that fuses the operations specified
|
||||
in the input module list.
|
||||
|
||||
@ -41,7 +41,7 @@ def fuse_known_modules(mod_list):
|
||||
the fused operation. The rest of the elements are set to nn.Identity()
|
||||
"""
|
||||
types = tuple(type(m) for m in mod_list)
|
||||
fuser_method = get_fuser_method(types)
|
||||
fuser_method = get_fuser_method(types, additional_fuser_method_mapping)
|
||||
if fuser_method is None:
|
||||
raise NotImplementedError("Cannot fuse modules: {}".format(types))
|
||||
new_mod : List[Optional[nn.Module]] = [None] * len(mod_list)
|
||||
@ -64,20 +64,22 @@ def fuse_known_modules(mod_list):
|
||||
|
||||
return new_mod
|
||||
|
||||
def _fuse_modules(model, modules_to_fuse, fuser_func=fuse_known_modules):
|
||||
|
||||
def _fuse_modules(model, modules_to_fuse, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
|
||||
if fuse_custom_config_dict is None:
|
||||
fuse_custom_config_dict = {}
|
||||
additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
|
||||
mod_list = []
|
||||
for item in modules_to_fuse:
|
||||
mod_list.append(_get_module(model, item))
|
||||
|
||||
# Fuse list of modules
|
||||
new_mod_list = fuser_func(mod_list)
|
||||
new_mod_list = fuser_func(mod_list, additional_fuser_method_mapping)
|
||||
|
||||
# Replace original module list with fused module list
|
||||
for i, item in enumerate(modules_to_fuse):
|
||||
_set_module(model, item, new_mod_list[i])
|
||||
|
||||
def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules):
|
||||
def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
|
||||
r"""Fuses a list of modules into a single module
|
||||
|
||||
Fuses only the following sequence of modules:
|
||||
@ -101,6 +103,18 @@ def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_mo
|
||||
of the same length. For example,
|
||||
fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()]
|
||||
Defaults to torch.quantization.fuse_known_modules
|
||||
`fuse_custom_config_dict`: custom configuration for fusion
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# Example of fuse_custom_config_dict
|
||||
fuse_custom_config_dict = {
|
||||
# Additional fuser_method mapping
|
||||
"additional_fuser_method_mapping": {
|
||||
(torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn
|
||||
},
|
||||
}
|
||||
|
||||
Returns:
|
||||
model with fused modules. A new copy is created if inplace=True.
|
||||
|
||||
@ -124,9 +138,9 @@ def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_mo
|
||||
|
||||
if all(isinstance(module_element, str) for module_element in modules_to_fuse):
|
||||
# Handle case of modules_to_fuse being a list
|
||||
_fuse_modules(model, modules_to_fuse, fuser_func)
|
||||
_fuse_modules(model, modules_to_fuse, fuser_func, fuse_custom_config_dict)
|
||||
else:
|
||||
# Handle case of modules_to_fuse being a list of lists
|
||||
for module_list in modules_to_fuse:
|
||||
_fuse_modules(model, module_list, fuser_func)
|
||||
_fuse_modules(model, module_list, fuser_func, fuse_custom_config_dict)
|
||||
return model
|
||||
|
Reference in New Issue
Block a user