fx quant: enable linear-bn1d fusion for PTQ (#66484)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66484

https://github.com/pytorch/pytorch/pull/50748 added linear - bn1d fusion
in Eager mode, for PTQ only. This PR also enables this in FX graph mode.

We reuse the existing conv-bn-relu fusion handler, renaming `conv` to
`conv_or_linear` for readability.

The QAT version is saved for a future PR, for both eager and FX graph.

Test Plan:
```
python test/test_quantization.py TestFuseFx.test_fuse_linear_bn_eval
```

Imported from OSS

Reviewed By: bdhirsh

Differential Revision: D31575392

fbshipit-source-id: f69d80ef37c98cbc070099170e335e250bcdf913
This commit is contained in:
Vasiliy Kuznetsov
2021-10-18 10:04:47 -07:00
committed by Facebook GitHub Bot
parent 9d287d0b63
commit d549c8de78
5 changed files with 42 additions and 13 deletions

View File

@ -164,7 +164,7 @@ class TestAOMigrationQuantizationFx(AOMigrationTestCase):
def test_function_import_fx_fusion_patterns(self):
function_list = [
'FuseHandler',
'ConvBNReLUFusion',
'ConvOrLinearBNReLUFusion',
'ModuleReLUFusion'
]
self._test_function_import('fx.fusion_patterns', function_list)

View File

@ -250,6 +250,34 @@ class TestFuseFx(QuantizationTestCase):
expected_node_list=expected_nodes,
expected_node_occurrence=expected_occurrence)
def test_fuse_linear_bn_eval(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1, 1)
self.bn1d = nn.BatchNorm1d(1)
def forward(self, x):
x = self.linear(x)
x = self.bn1d(x)
return x
# test eval mode
m = M().eval()
from torch.ao.quantization.quantize_fx import fuse_fx
# fuse_fx is a top level api and only supports eval mode
m = fuse_fx(m)
expected_nodes = [
ns.call_module(nn.Linear),
]
expected_occurrence = {
ns.call_module(nn.BatchNorm1d): 0,
}
self.checkGraphModuleNodes(
m,
expected_node_list=expected_nodes,
expected_node_occurrence=expected_occurrence)
def test_fuse_module_relu(self):
class M(torch.nn.Module):
def __init__(self):

View File

@ -40,7 +40,8 @@ class FuseHandler(ABC):
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm1d, torch.nn.Conv1d)))
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm3d, torch.nn.Conv3d)))
class ConvBNReLUFusion(FuseHandler):
@register_fusion_pattern((torch.nn.BatchNorm1d, torch.nn.Linear))
class ConvOrLinearBNReLUFusion(FuseHandler):
def __init__(self, quantizer: QuantizerCls, node: Node):
super().__init__(quantizer, node)
self.relu_node = None
@ -57,8 +58,8 @@ class ConvBNReLUFusion(FuseHandler):
assert isinstance(node.args[0], Node)
node = node.args[0]
assert node.op == 'call_module'
self.conv_node = node
self.conv = quantizer.modules[self.conv_node.target]
self.conv_or_linear_node = node
self.conv_or_linear = quantizer.modules[self.conv_or_linear_node.target]
def fuse(self, quantizer: QuantizerCls, load_arg: Callable,
fuse_custom_config_dict: Dict[str, Any] = None) -> Node:
@ -74,32 +75,32 @@ class ConvBNReLUFusion(FuseHandler):
# TODO: get inplace argument from functional
relu = torch.nn.ReLU()
op_list.append(relu)
relu.training = self.conv.training
relu.training = self.conv_or_linear.training
if self.bn_node is not None:
op_list.append(self.bn)
op_list.append(self.conv)
op_list.append(self.conv_or_linear)
else:
assert self.bn_node is not None
op_list.append(self.bn)
op_list.append(self.conv)
op_list.append(self.conv_or_linear)
# the modules are added in order of relu - bn - conv
# the modules are added in order of relu - bn - conv_or_linear
# so we need to correct it
op_list.reverse()
op_type_list = tuple(type(m) for m in op_list)
conv_parent_name, conv_name = _parent_name(self.conv_node.target)
conv_or_linear_parent_name, conv_or_linear_name = _parent_name(self.conv_or_linear_node.target)
fuser_method = get_fuser_method(op_type_list, additional_fuser_method_mapping)
if fuser_method is None:
raise NotImplementedError("Cannot fuse modules: {}".format(op_type_list))
fused = fuser_method(*op_list)
setattr(quantizer.modules[conv_parent_name], conv_name, fused)
setattr(quantizer.modules[conv_or_linear_parent_name], conv_or_linear_name, fused)
# TODO: do we need to make sure bn is only used once?
if self.bn_node is not None:
parent_name, name = _parent_name(self.bn_node.target)
setattr(quantizer.modules[parent_name], name, torch.nn.Identity())
# relu may be used multiple times, so we don't set relu to identity
return quantizer.fused_graph.node_copy(self.conv_node, load_arg)
return quantizer.fused_graph.node_copy(self.conv_or_linear_node, load_arg)
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Linear))
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Linear))

View File

@ -53,7 +53,7 @@ def get_default_output_activation_post_process_map() -> Dict[Pattern, ObserverBa
# Example use of register pattern function:
# @register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
# class ConvBNReLUFusion():
# class ConvOrLinearBNReLUFusion():
# def __init__(...):
# ...
#

View File

@ -8,6 +8,6 @@ here.
"""
from torch.ao.quantization.fx.fusion_patterns import (
FuseHandler,
ConvBNReLUFusion,
ConvOrLinearBNReLUFusion,
ModuleReLUFusion
)