[reland][quant][graphmode][fx] Enable fuse handler for sequence of 3 ops (#70006)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70006

reland: fixing some mypy errors that was missed before

This PR enables fuse handler for sequence of three ops, and merges all fuse handlers into one

TODO: we can also move this to backend_config_dict folder

Test Plan:
regression fusion test
```
python test/test_quantization.py TestFuseFx
```

Imported from OSS

Imported from OSS

Reviewed By: supriyar

Differential Revision: D33144606

fbshipit-source-id: ca34f282018a0fb4d04c7e35119eaf2d64258e78
This commit is contained in:
Jerry Zhang
2021-12-16 15:00:48 -08:00
committed by Facebook GitHub Bot
parent fa582045fc
commit a73c6a45b6
7 changed files with 58 additions and 100 deletions

View File

@ -166,8 +166,7 @@ class TestAOMigrationQuantizationFx(AOMigrationTestCase):
def test_function_import_fx_fusion_patterns(self):
function_list = [
'FuseHandler',
'ConvOrLinearBNReLUFusion',
'ModuleReLUFusion'
'DefaultFuseHandler'
]
self._test_function_import('fx.fusion_patterns', function_list)

View File

@ -148,7 +148,7 @@ DEFAULT_PATTERN_TO_FUSER_METHOD: Dict[Pattern, Union[nn.Sequential, Callable]] =
(nn.ReLU, (nn.BatchNorm1d, nn.Conv1d)): reverse3(fuse_conv_bn_relu),
(nn.BatchNorm2d, nn.Conv2d): reverse2(fuse_conv_bn),
(nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)): reverse3(fuse_conv_bn_relu),
(nn.BatchNorm3d, nn.Conv2d): reverse2(fuse_conv_bn),
(nn.BatchNorm3d, nn.Conv3d): reverse2(fuse_conv_bn),
(nn.ReLU, (nn.BatchNorm3d, nn.Conv3d)): reverse3(fuse_conv_bn_relu),
(nn.ReLU, nn.Conv1d): reverse2(nni.ConvReLU1d),
(nn.ReLU, nn.Conv2d): reverse2(nni.ConvReLU2d),

View File

@ -1,5 +1,5 @@
from ..fusion_patterns import ModuleReLUFusion
from ..fusion_patterns import DefaultFuseHandler
# TODO: move ModuleReLUFusion here
def get_fuse_handler_cls():
return ModuleReLUFusion
return DefaultFuseHandler

View File

@ -56,15 +56,20 @@ class Fuser:
def load_arg(a):
return map_arg(a, lambda node: env[node.name])
def get_root_node(node_pattern):
while not isinstance(node_pattern[-1], Node):
node_pattern = node_pattern[-1]
return node_pattern[-1]
for node in input_graph.nodes:
maybe_last_node, pattern, matched_node_pattern, obj = \
fusion_pairs.get(node.name, (None, None, None, None))
if maybe_last_node is node:
assert obj is not None
# TODO: currently we hard code the root node, which only works for
# a tuple of two nodes, we want to make this more general to
# support more complex patterns
root_node = matched_node_pattern[-1] # type: ignore[index]
# a sequence of ops and assume the root node is the last node,
# we want to make this more general to support more complex patterns
root_node = get_root_node(matched_node_pattern) # type: ignore[index]
env[node.name] = obj.fuse(
self, load_arg, root_node, matched_node_pattern, # type: ignore[arg-type]
fuse_custom_config_dict, fuser_method_mapping)

View File

@ -5,10 +5,9 @@ from .pattern_utils import (
)
from .utils import _parent_name
from .quantization_types import QuantizerCls, NodePattern, Pattern
from ..fuser_method_mappings import get_fuser_method
from ..fuser_method_mappings import get_fuser_method_new
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, Union, List
from .match_utils import MatchAllNode
# ----------------------------
@ -32,80 +31,6 @@ class FuseHandler(ABC):
fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]]) -> Node:
pass
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm1d, torch.nn.Conv1d)))
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm3d, torch.nn.Conv3d)))
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm1d, torch.nn.Conv1d)))
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm3d, torch.nn.Conv3d)))
@register_fusion_pattern((torch.nn.BatchNorm1d, torch.nn.Conv1d))
@register_fusion_pattern((torch.nn.BatchNorm2d, torch.nn.Conv2d))
@register_fusion_pattern((torch.nn.BatchNorm3d, torch.nn.Conv3d))
@register_fusion_pattern((torch.nn.BatchNorm1d, torch.nn.Linear))
class ConvOrLinearBNReLUFusion(FuseHandler):
def __init__(self, quantizer: QuantizerCls, node: Node):
super().__init__(quantizer, node)
self.relu_node = None
self.bn_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and type(quantizer.modules[node.target]) == torch.nn.ReLU):
self.relu_node = node
assert isinstance(node.args[0], Node)
node = node.args[0]
assert node.op == 'call_module'
if type(quantizer.modules[node.target]) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]:
self.bn_node = node
self.bn = quantizer.modules[self.bn_node.target]
assert isinstance(node.args[0], Node)
node = node.args[0]
assert node.op == 'call_module'
self.conv_or_linear_node = node
self.conv_or_linear = quantizer.modules[self.conv_or_linear_node.target]
def fuse(self,
quantizer: QuantizerCls,
load_arg: Callable,
root_node: Node,
matched_node_pattern: NodePattern,
fuse_custom_config_dict: Dict[str, Any],
fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]]) -> Node:
additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
op_list = []
if self.relu_node is not None:
# since relu can be used multiple times, we'll need to create a relu module for each match
if self.relu_node.op == 'call_module':
relu = torch.nn.ReLU(quantizer.modules[self.relu_node.target].inplace)
else:
# TODO: get inplace argument from functional
relu = torch.nn.ReLU()
op_list.append(relu)
relu.training = self.conv_or_linear.training
if self.bn_node is not None:
op_list.append(self.bn)
op_list.append(self.conv_or_linear)
else:
assert self.bn_node is not None
op_list.append(self.bn)
op_list.append(self.conv_or_linear)
# the modules are added in order of relu - bn - conv_or_linear
# so we need to correct it
op_list.reverse()
op_type_list = tuple(type(m) for m in op_list)
conv_or_linear_parent_name, conv_or_linear_name = _parent_name(self.conv_or_linear_node.target)
fuser_method = get_fuser_method(op_type_list, additional_fuser_method_mapping)
if fuser_method is None:
raise NotImplementedError("Cannot fuse modules: {}".format(op_type_list))
fused = fuser_method(*op_list)
setattr(quantizer.modules[conv_or_linear_parent_name], conv_or_linear_name, fused)
# TODO: do we need to make sure bn is only used once?
if self.bn_node is not None:
parent_name, name = _parent_name(self.bn_node.target)
setattr(quantizer.modules[parent_name], name, torch.nn.Identity())
# relu may be used multiple times, so we don't set relu to identity
return quantizer.fused_graph.node_copy(self.conv_or_linear_node, load_arg)
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv1d))
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv2d))
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv3d))
@ -118,14 +43,25 @@ class ConvOrLinearBNReLUFusion(FuseHandler):
@register_fusion_pattern((torch.nn.ReLU, torch.nn.BatchNorm2d))
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.BatchNorm3d))
@register_fusion_pattern((torch.nn.ReLU, torch.nn.BatchNorm3d))
class ModuleReLUFusion(FuseHandler):
@register_fusion_pattern((torch.nn.BatchNorm1d, torch.nn.Conv1d))
@register_fusion_pattern((torch.nn.BatchNorm2d, torch.nn.Conv2d))
@register_fusion_pattern((torch.nn.BatchNorm3d, torch.nn.Conv3d))
@register_fusion_pattern((torch.nn.BatchNorm1d, torch.nn.Linear))
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm1d, torch.nn.Conv1d)))
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm3d, torch.nn.Conv3d)))
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm1d, torch.nn.Conv1d)))
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm3d, torch.nn.Conv3d)))
class DefaultFuseHandler(FuseHandler):
def __init__(
self,
quantizer: QuantizerCls,
node: Node):
super().__init__(quantizer, node)
def fuse(self, quantizer: QuantizerCls,
def fuse(self,
quantizer: QuantizerCls,
load_arg: Callable,
root_node: Node,
matched_node_pattern: NodePattern,
@ -136,26 +72,45 @@ class ModuleReLUFusion(FuseHandler):
root_module = quantizer.modules[root_node.target]
assert len(additional_fuser_method_mapping) == 0, "Fusion implementation is "
"undergoing changes, additoinal_fuser_method_mapping is not supported currently."
def get_module(n):
if n.op == "call_module":
return quantizer.modules[n.target]
elif n.op == "call_function" and n.target == torch.nn.functional.relu:
relu = torch.nn.ReLU()
relu.training = root_module.training
return relu
return MatchAllNode
def get_modules(pattern, modules):
""" Given a node pattern, extract the corresponding modules
e.g. input: (relu_node, (bn_node, conv_node))
output: (relu_module, (bn_module, conv_module))
"""
if isinstance(pattern, (tuple, list)):
n, *args = pattern
get_modules(n, modules)
arg_modules: List[torch.nn.Module] = []
for a in args:
get_modules(a, arg_modules)
arg_modules = tuple(arg_modules) if len(arg_modules) > 1 else arg_modules[0] # type: ignore[assignment]
modules.append(arg_modules)
else:
n = pattern
if n.op == "call_module":
modules.append(quantizer.modules[n.target])
elif n.op == "call_function" and n.target == torch.nn.functional.relu:
relu = torch.nn.ReLU()
relu.training = root_module.training
modules.append(relu)
else:
modules.append(MatchAllNode)
return tuple(modules)
matched_modules = tuple(map(get_module, matched_node_pattern))
# since relu can be used multiple times, we'll need to create a relu module for each match
matched_modules = get_modules(matched_node_pattern, [])
def get_type(m):
def get_matched_types(m):
if isinstance(m, tuple):
return tuple(map(get_matched_types, m))
return type(m)
matched_module_types = tuple(map(get_type, matched_modules))
matched_module_types = get_matched_types(matched_modules)
module_parent_name, module_name = _parent_name(root_node.target)
fuser_method = get_fuser_method_new(matched_module_types, fuser_method_mapping)
# TODO: change the signature for fuser_method to take matched module patterns
# as input
fused_module = fuser_method(*matched_modules)
# TODO: maybe add a pass to cleanup bn modules?
setattr(quantizer.modules[module_parent_name], module_name, fused_module)
return quantizer.fused_graph.node_copy(root_node, load_arg)

View File

@ -1322,7 +1322,7 @@ def prepare(
# 'linear': Linear(...),
# 'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
# }
modules = dict(model.named_modules())
modules = dict(model.named_modules(remove_duplicate=False))
# fill qconfig_map, a map from node name to qconfig, used in find_matches
equalization_qconfig_map = generate_qconfig_map(model, modules, model.graph, equalization_qconfig_dict, node_name_to_scope)

View File

@ -8,6 +8,5 @@ here.
"""
from torch.ao.quantization.fx.fusion_patterns import (
FuseHandler,
ConvOrLinearBNReLUFusion,
ModuleReLUFusion
DefaultFuseHandler,
)