mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[reland][quant][graphmode][fx] Enable fuse handler for sequence of 3 ops (#70006)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70006 reland: fixing some mypy errors that was missed before This PR enables fuse handler for sequence of three ops, and merges all fuse handlers into one TODO: we can also move this to backend_config_dict folder Test Plan: regression fusion test ``` python test/test_quantization.py TestFuseFx ``` Imported from OSS Imported from OSS Reviewed By: supriyar Differential Revision: D33144606 fbshipit-source-id: ca34f282018a0fb4d04c7e35119eaf2d64258e78
This commit is contained in:
committed by
Facebook GitHub Bot
parent
fa582045fc
commit
a73c6a45b6
@ -166,8 +166,7 @@ class TestAOMigrationQuantizationFx(AOMigrationTestCase):
|
||||
def test_function_import_fx_fusion_patterns(self):
|
||||
function_list = [
|
||||
'FuseHandler',
|
||||
'ConvOrLinearBNReLUFusion',
|
||||
'ModuleReLUFusion'
|
||||
'DefaultFuseHandler'
|
||||
]
|
||||
self._test_function_import('fx.fusion_patterns', function_list)
|
||||
|
||||
|
||||
@ -148,7 +148,7 @@ DEFAULT_PATTERN_TO_FUSER_METHOD: Dict[Pattern, Union[nn.Sequential, Callable]] =
|
||||
(nn.ReLU, (nn.BatchNorm1d, nn.Conv1d)): reverse3(fuse_conv_bn_relu),
|
||||
(nn.BatchNorm2d, nn.Conv2d): reverse2(fuse_conv_bn),
|
||||
(nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)): reverse3(fuse_conv_bn_relu),
|
||||
(nn.BatchNorm3d, nn.Conv2d): reverse2(fuse_conv_bn),
|
||||
(nn.BatchNorm3d, nn.Conv3d): reverse2(fuse_conv_bn),
|
||||
(nn.ReLU, (nn.BatchNorm3d, nn.Conv3d)): reverse3(fuse_conv_bn_relu),
|
||||
(nn.ReLU, nn.Conv1d): reverse2(nni.ConvReLU1d),
|
||||
(nn.ReLU, nn.Conv2d): reverse2(nni.ConvReLU2d),
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from ..fusion_patterns import ModuleReLUFusion
|
||||
from ..fusion_patterns import DefaultFuseHandler
|
||||
|
||||
# TODO: move ModuleReLUFusion here
|
||||
def get_fuse_handler_cls():
|
||||
return ModuleReLUFusion
|
||||
return DefaultFuseHandler
|
||||
|
||||
@ -56,15 +56,20 @@ class Fuser:
|
||||
def load_arg(a):
|
||||
return map_arg(a, lambda node: env[node.name])
|
||||
|
||||
def get_root_node(node_pattern):
|
||||
while not isinstance(node_pattern[-1], Node):
|
||||
node_pattern = node_pattern[-1]
|
||||
return node_pattern[-1]
|
||||
|
||||
for node in input_graph.nodes:
|
||||
maybe_last_node, pattern, matched_node_pattern, obj = \
|
||||
fusion_pairs.get(node.name, (None, None, None, None))
|
||||
if maybe_last_node is node:
|
||||
assert obj is not None
|
||||
# TODO: currently we hard code the root node, which only works for
|
||||
# a tuple of two nodes, we want to make this more general to
|
||||
# support more complex patterns
|
||||
root_node = matched_node_pattern[-1] # type: ignore[index]
|
||||
# a sequence of ops and assume the root node is the last node,
|
||||
# we want to make this more general to support more complex patterns
|
||||
root_node = get_root_node(matched_node_pattern) # type: ignore[index]
|
||||
env[node.name] = obj.fuse(
|
||||
self, load_arg, root_node, matched_node_pattern, # type: ignore[arg-type]
|
||||
fuse_custom_config_dict, fuser_method_mapping)
|
||||
|
||||
@ -5,10 +5,9 @@ from .pattern_utils import (
|
||||
)
|
||||
from .utils import _parent_name
|
||||
from .quantization_types import QuantizerCls, NodePattern, Pattern
|
||||
from ..fuser_method_mappings import get_fuser_method
|
||||
from ..fuser_method_mappings import get_fuser_method_new
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
from typing import Any, Callable, Dict, Optional, Union, List
|
||||
from .match_utils import MatchAllNode
|
||||
|
||||
# ----------------------------
|
||||
@ -32,80 +31,6 @@ class FuseHandler(ABC):
|
||||
fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]]) -> Node:
|
||||
pass
|
||||
|
||||
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm1d, torch.nn.Conv1d)))
|
||||
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
|
||||
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm3d, torch.nn.Conv3d)))
|
||||
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm1d, torch.nn.Conv1d)))
|
||||
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
|
||||
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm3d, torch.nn.Conv3d)))
|
||||
@register_fusion_pattern((torch.nn.BatchNorm1d, torch.nn.Conv1d))
|
||||
@register_fusion_pattern((torch.nn.BatchNorm2d, torch.nn.Conv2d))
|
||||
@register_fusion_pattern((torch.nn.BatchNorm3d, torch.nn.Conv3d))
|
||||
@register_fusion_pattern((torch.nn.BatchNorm1d, torch.nn.Linear))
|
||||
class ConvOrLinearBNReLUFusion(FuseHandler):
|
||||
def __init__(self, quantizer: QuantizerCls, node: Node):
|
||||
super().__init__(quantizer, node)
|
||||
self.relu_node = None
|
||||
self.bn_node = None
|
||||
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
|
||||
(node.op == 'call_module' and type(quantizer.modules[node.target]) == torch.nn.ReLU):
|
||||
self.relu_node = node
|
||||
assert isinstance(node.args[0], Node)
|
||||
node = node.args[0]
|
||||
assert node.op == 'call_module'
|
||||
if type(quantizer.modules[node.target]) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]:
|
||||
self.bn_node = node
|
||||
self.bn = quantizer.modules[self.bn_node.target]
|
||||
assert isinstance(node.args[0], Node)
|
||||
node = node.args[0]
|
||||
assert node.op == 'call_module'
|
||||
self.conv_or_linear_node = node
|
||||
self.conv_or_linear = quantizer.modules[self.conv_or_linear_node.target]
|
||||
|
||||
def fuse(self,
|
||||
quantizer: QuantizerCls,
|
||||
load_arg: Callable,
|
||||
root_node: Node,
|
||||
matched_node_pattern: NodePattern,
|
||||
fuse_custom_config_dict: Dict[str, Any],
|
||||
fuser_method_mapping: Optional[Dict[Pattern, Union[torch.nn.Sequential, Callable]]]) -> Node:
|
||||
additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
|
||||
op_list = []
|
||||
if self.relu_node is not None:
|
||||
# since relu can be used multiple times, we'll need to create a relu module for each match
|
||||
if self.relu_node.op == 'call_module':
|
||||
relu = torch.nn.ReLU(quantizer.modules[self.relu_node.target].inplace)
|
||||
else:
|
||||
# TODO: get inplace argument from functional
|
||||
relu = torch.nn.ReLU()
|
||||
op_list.append(relu)
|
||||
relu.training = self.conv_or_linear.training
|
||||
if self.bn_node is not None:
|
||||
op_list.append(self.bn)
|
||||
op_list.append(self.conv_or_linear)
|
||||
else:
|
||||
assert self.bn_node is not None
|
||||
op_list.append(self.bn)
|
||||
op_list.append(self.conv_or_linear)
|
||||
|
||||
# the modules are added in order of relu - bn - conv_or_linear
|
||||
# so we need to correct it
|
||||
op_list.reverse()
|
||||
op_type_list = tuple(type(m) for m in op_list)
|
||||
conv_or_linear_parent_name, conv_or_linear_name = _parent_name(self.conv_or_linear_node.target)
|
||||
fuser_method = get_fuser_method(op_type_list, additional_fuser_method_mapping)
|
||||
if fuser_method is None:
|
||||
raise NotImplementedError("Cannot fuse modules: {}".format(op_type_list))
|
||||
fused = fuser_method(*op_list)
|
||||
setattr(quantizer.modules[conv_or_linear_parent_name], conv_or_linear_name, fused)
|
||||
|
||||
# TODO: do we need to make sure bn is only used once?
|
||||
if self.bn_node is not None:
|
||||
parent_name, name = _parent_name(self.bn_node.target)
|
||||
setattr(quantizer.modules[parent_name], name, torch.nn.Identity())
|
||||
# relu may be used multiple times, so we don't set relu to identity
|
||||
return quantizer.fused_graph.node_copy(self.conv_or_linear_node, load_arg)
|
||||
|
||||
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv1d))
|
||||
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv2d))
|
||||
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv3d))
|
||||
@ -118,14 +43,25 @@ class ConvOrLinearBNReLUFusion(FuseHandler):
|
||||
@register_fusion_pattern((torch.nn.ReLU, torch.nn.BatchNorm2d))
|
||||
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.BatchNorm3d))
|
||||
@register_fusion_pattern((torch.nn.ReLU, torch.nn.BatchNorm3d))
|
||||
class ModuleReLUFusion(FuseHandler):
|
||||
@register_fusion_pattern((torch.nn.BatchNorm1d, torch.nn.Conv1d))
|
||||
@register_fusion_pattern((torch.nn.BatchNorm2d, torch.nn.Conv2d))
|
||||
@register_fusion_pattern((torch.nn.BatchNorm3d, torch.nn.Conv3d))
|
||||
@register_fusion_pattern((torch.nn.BatchNorm1d, torch.nn.Linear))
|
||||
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm1d, torch.nn.Conv1d)))
|
||||
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
|
||||
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm3d, torch.nn.Conv3d)))
|
||||
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm1d, torch.nn.Conv1d)))
|
||||
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
|
||||
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm3d, torch.nn.Conv3d)))
|
||||
class DefaultFuseHandler(FuseHandler):
|
||||
def __init__(
|
||||
self,
|
||||
quantizer: QuantizerCls,
|
||||
node: Node):
|
||||
super().__init__(quantizer, node)
|
||||
|
||||
def fuse(self, quantizer: QuantizerCls,
|
||||
def fuse(self,
|
||||
quantizer: QuantizerCls,
|
||||
load_arg: Callable,
|
||||
root_node: Node,
|
||||
matched_node_pattern: NodePattern,
|
||||
@ -136,26 +72,45 @@ class ModuleReLUFusion(FuseHandler):
|
||||
root_module = quantizer.modules[root_node.target]
|
||||
assert len(additional_fuser_method_mapping) == 0, "Fusion implementation is "
|
||||
"undergoing changes, additoinal_fuser_method_mapping is not supported currently."
|
||||
def get_module(n):
|
||||
if n.op == "call_module":
|
||||
return quantizer.modules[n.target]
|
||||
elif n.op == "call_function" and n.target == torch.nn.functional.relu:
|
||||
relu = torch.nn.ReLU()
|
||||
relu.training = root_module.training
|
||||
return relu
|
||||
return MatchAllNode
|
||||
def get_modules(pattern, modules):
|
||||
""" Given a node pattern, extract the corresponding modules
|
||||
e.g. input: (relu_node, (bn_node, conv_node))
|
||||
output: (relu_module, (bn_module, conv_module))
|
||||
"""
|
||||
if isinstance(pattern, (tuple, list)):
|
||||
n, *args = pattern
|
||||
get_modules(n, modules)
|
||||
arg_modules: List[torch.nn.Module] = []
|
||||
for a in args:
|
||||
get_modules(a, arg_modules)
|
||||
arg_modules = tuple(arg_modules) if len(arg_modules) > 1 else arg_modules[0] # type: ignore[assignment]
|
||||
modules.append(arg_modules)
|
||||
else:
|
||||
n = pattern
|
||||
if n.op == "call_module":
|
||||
modules.append(quantizer.modules[n.target])
|
||||
elif n.op == "call_function" and n.target == torch.nn.functional.relu:
|
||||
relu = torch.nn.ReLU()
|
||||
relu.training = root_module.training
|
||||
modules.append(relu)
|
||||
else:
|
||||
modules.append(MatchAllNode)
|
||||
return tuple(modules)
|
||||
|
||||
matched_modules = tuple(map(get_module, matched_node_pattern))
|
||||
# since relu can be used multiple times, we'll need to create a relu module for each match
|
||||
matched_modules = get_modules(matched_node_pattern, [])
|
||||
|
||||
def get_type(m):
|
||||
def get_matched_types(m):
|
||||
if isinstance(m, tuple):
|
||||
return tuple(map(get_matched_types, m))
|
||||
return type(m)
|
||||
|
||||
matched_module_types = tuple(map(get_type, matched_modules))
|
||||
matched_module_types = get_matched_types(matched_modules)
|
||||
module_parent_name, module_name = _parent_name(root_node.target)
|
||||
fuser_method = get_fuser_method_new(matched_module_types, fuser_method_mapping)
|
||||
# TODO: change the signature for fuser_method to take matched module patterns
|
||||
# as input
|
||||
fused_module = fuser_method(*matched_modules)
|
||||
# TODO: maybe add a pass to cleanup bn modules?
|
||||
setattr(quantizer.modules[module_parent_name], module_name, fused_module)
|
||||
return quantizer.fused_graph.node_copy(root_node, load_arg)
|
||||
|
||||
@ -1322,7 +1322,7 @@ def prepare(
|
||||
# 'linear': Linear(...),
|
||||
# 'linear.weight_fake_quant': PerChannelMinMaxObserver(...),
|
||||
# }
|
||||
modules = dict(model.named_modules())
|
||||
modules = dict(model.named_modules(remove_duplicate=False))
|
||||
|
||||
# fill qconfig_map, a map from node name to qconfig, used in find_matches
|
||||
equalization_qconfig_map = generate_qconfig_map(model, modules, model.graph, equalization_qconfig_dict, node_name_to_scope)
|
||||
|
||||
@ -8,6 +8,5 @@ here.
|
||||
"""
|
||||
from torch.ao.quantization.fx.fusion_patterns import (
|
||||
FuseHandler,
|
||||
ConvOrLinearBNReLUFusion,
|
||||
ModuleReLUFusion
|
||||
DefaultFuseHandler,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user