mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
fx quant: enable linear-bn1d fusion for PTQ (#66484)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/66484 https://github.com/pytorch/pytorch/pull/50748 added linear - bn1d fusion in Eager mode, for PTQ only. This PR also enables this in FX graph mode. We reuse the existing conv-bn-relu fusion handler, renaming `conv` to `conv_or_linear` for readability. The QAT version is saved for a future PR, for both eager and FX graph. Test Plan: ``` python test/test_quantization.py TestFuseFx.test_fuse_linear_bn_eval ``` Imported from OSS Reviewed By: bdhirsh Differential Revision: D31575392 fbshipit-source-id: f69d80ef37c98cbc070099170e335e250bcdf913
This commit is contained in:
committed by
Facebook GitHub Bot
parent
9d287d0b63
commit
d549c8de78
@ -164,7 +164,7 @@ class TestAOMigrationQuantizationFx(AOMigrationTestCase):
|
||||
def test_function_import_fx_fusion_patterns(self):
|
||||
function_list = [
|
||||
'FuseHandler',
|
||||
'ConvBNReLUFusion',
|
||||
'ConvOrLinearBNReLUFusion',
|
||||
'ModuleReLUFusion'
|
||||
]
|
||||
self._test_function_import('fx.fusion_patterns', function_list)
|
||||
|
@ -250,6 +250,34 @@ class TestFuseFx(QuantizationTestCase):
|
||||
expected_node_list=expected_nodes,
|
||||
expected_node_occurrence=expected_occurrence)
|
||||
|
||||
def test_fuse_linear_bn_eval(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(1, 1)
|
||||
self.bn1d = nn.BatchNorm1d(1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
x = self.bn1d(x)
|
||||
return x
|
||||
|
||||
# test eval mode
|
||||
m = M().eval()
|
||||
from torch.ao.quantization.quantize_fx import fuse_fx
|
||||
# fuse_fx is a top level api and only supports eval mode
|
||||
m = fuse_fx(m)
|
||||
expected_nodes = [
|
||||
ns.call_module(nn.Linear),
|
||||
]
|
||||
expected_occurrence = {
|
||||
ns.call_module(nn.BatchNorm1d): 0,
|
||||
}
|
||||
self.checkGraphModuleNodes(
|
||||
m,
|
||||
expected_node_list=expected_nodes,
|
||||
expected_node_occurrence=expected_occurrence)
|
||||
|
||||
def test_fuse_module_relu(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -40,7 +40,8 @@ class FuseHandler(ABC):
|
||||
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm1d, torch.nn.Conv1d)))
|
||||
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
|
||||
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm3d, torch.nn.Conv3d)))
|
||||
class ConvBNReLUFusion(FuseHandler):
|
||||
@register_fusion_pattern((torch.nn.BatchNorm1d, torch.nn.Linear))
|
||||
class ConvOrLinearBNReLUFusion(FuseHandler):
|
||||
def __init__(self, quantizer: QuantizerCls, node: Node):
|
||||
super().__init__(quantizer, node)
|
||||
self.relu_node = None
|
||||
@ -57,8 +58,8 @@ class ConvBNReLUFusion(FuseHandler):
|
||||
assert isinstance(node.args[0], Node)
|
||||
node = node.args[0]
|
||||
assert node.op == 'call_module'
|
||||
self.conv_node = node
|
||||
self.conv = quantizer.modules[self.conv_node.target]
|
||||
self.conv_or_linear_node = node
|
||||
self.conv_or_linear = quantizer.modules[self.conv_or_linear_node.target]
|
||||
|
||||
def fuse(self, quantizer: QuantizerCls, load_arg: Callable,
|
||||
fuse_custom_config_dict: Dict[str, Any] = None) -> Node:
|
||||
@ -74,32 +75,32 @@ class ConvBNReLUFusion(FuseHandler):
|
||||
# TODO: get inplace argument from functional
|
||||
relu = torch.nn.ReLU()
|
||||
op_list.append(relu)
|
||||
relu.training = self.conv.training
|
||||
relu.training = self.conv_or_linear.training
|
||||
if self.bn_node is not None:
|
||||
op_list.append(self.bn)
|
||||
op_list.append(self.conv)
|
||||
op_list.append(self.conv_or_linear)
|
||||
else:
|
||||
assert self.bn_node is not None
|
||||
op_list.append(self.bn)
|
||||
op_list.append(self.conv)
|
||||
op_list.append(self.conv_or_linear)
|
||||
|
||||
# the modules are added in order of relu - bn - conv
|
||||
# the modules are added in order of relu - bn - conv_or_linear
|
||||
# so we need to correct it
|
||||
op_list.reverse()
|
||||
op_type_list = tuple(type(m) for m in op_list)
|
||||
conv_parent_name, conv_name = _parent_name(self.conv_node.target)
|
||||
conv_or_linear_parent_name, conv_or_linear_name = _parent_name(self.conv_or_linear_node.target)
|
||||
fuser_method = get_fuser_method(op_type_list, additional_fuser_method_mapping)
|
||||
if fuser_method is None:
|
||||
raise NotImplementedError("Cannot fuse modules: {}".format(op_type_list))
|
||||
fused = fuser_method(*op_list)
|
||||
setattr(quantizer.modules[conv_parent_name], conv_name, fused)
|
||||
setattr(quantizer.modules[conv_or_linear_parent_name], conv_or_linear_name, fused)
|
||||
|
||||
# TODO: do we need to make sure bn is only used once?
|
||||
if self.bn_node is not None:
|
||||
parent_name, name = _parent_name(self.bn_node.target)
|
||||
setattr(quantizer.modules[parent_name], name, torch.nn.Identity())
|
||||
# relu may be used multiple times, so we don't set relu to identity
|
||||
return quantizer.fused_graph.node_copy(self.conv_node, load_arg)
|
||||
return quantizer.fused_graph.node_copy(self.conv_or_linear_node, load_arg)
|
||||
|
||||
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Linear))
|
||||
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Linear))
|
||||
|
@ -53,7 +53,7 @@ def get_default_output_activation_post_process_map() -> Dict[Pattern, ObserverBa
|
||||
|
||||
# Example use of register pattern function:
|
||||
# @register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
|
||||
# class ConvBNReLUFusion():
|
||||
# class ConvOrLinearBNReLUFusion():
|
||||
# def __init__(...):
|
||||
# ...
|
||||
#
|
||||
|
@ -8,6 +8,6 @@ here.
|
||||
"""
|
||||
from torch.ao.quantization.fx.fusion_patterns import (
|
||||
FuseHandler,
|
||||
ConvBNReLUFusion,
|
||||
ConvOrLinearBNReLUFusion,
|
||||
ModuleReLUFusion
|
||||
)
|
||||
|
Reference in New Issue
Block a user