[pt2e][qat] Support conv-transpose-bn[-relu] QAT fusion (#123652)

Summary: This commit adds support for QAT fusion for the
[conv-transpose-bn] and [conv-transpose-bn-relu] patterns.

Test Plan:
python test/test_quantization.py TestQuantizePT2EQAT_ConvBn1d.test_qat_conv_transpose_bn
python test/test_quantization.py TestQuantizePT2EQAT_ConvBn1d.test_qat_conv_transpose_bn_relu
python test/test_quantization.py TestQuantizePT2EQAT_ConvBn2d.test_qat_conv_transpose_bn
python test/test_quantization.py TestQuantizePT2EQAT_ConvBn2d.test_qat_conv_transpose_bn_relu

Reviewers: jerryzh168

Subscribers: jerryzh168, supriyar

Tasks: https://github.com/pytorch/pytorch/issues/122224

Differential Revision: [D55930704](https://our.internmc.facebook.com/intern/diff/D55930704)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123652
Approved by: https://github.com/jerryzh168
This commit is contained in:
andrewor14
2024-04-11 14:47:40 -07:00
committed by PyTorch MergeBot
parent e4c887fbf6
commit 5c0a380bdf
9 changed files with 150 additions and 43 deletions

View File

@ -61,8 +61,12 @@ class PT2EQATTestCase(QuantizationTestCase):
**conv_kwargs,
):
super().__init__()
self.conv = conv_class(3, 3, 3, bias=has_conv_bias, **conv_kwargs)
self.bn = bn_class(3) if has_bn else None
conv_kwargs.setdefault("in_channels", 3)
conv_kwargs.setdefault("out_channels", 3)
conv_kwargs.setdefault("kernel_size", 3)
conv_kwargs.setdefault("bias", has_conv_bias)
self.conv = conv_class(**conv_kwargs)
self.bn = bn_class(conv_kwargs["out_channels"]) if has_bn else None
self.relu = torch.nn.ReLU() if has_relu else None
def forward(self, x):
@ -78,6 +82,7 @@ class PT2EQATTestCase(QuantizationTestCase):
has_conv_bias: bool = True,
has_bn: bool = True,
has_relu: bool = False,
transpose: bool = False,
**conv_kwargs,
):
"""
@ -86,7 +91,7 @@ class PT2EQATTestCase(QuantizationTestCase):
conv-bn model with conv bias.
"""
return self._BaseConvBnModel(
self.conv_class,
self.conv_transpose_class if transpose else self.conv_class,
self.bn_class,
has_conv_bias,
has_bn,
@ -179,6 +184,8 @@ class PT2EQATTestCase(QuantizationTestCase):
has_bias: bool = True,
is_cuda: bool = False,
expected_conv_literal_args: Optional[Tuple[Any, ...]] = None,
# TODO: set this to true by default
verify_convert: bool = False,
):
self._verify_symmetric_xnnpack_qat_graph_helper(
m,
@ -188,6 +195,7 @@ class PT2EQATTestCase(QuantizationTestCase):
has_bias=has_bias,
is_cuda=is_cuda,
expected_conv_literal_args=expected_conv_literal_args,
verify_convert=verify_convert,
)
self._verify_symmetric_xnnpack_qat_graph_helper(
m,
@ -197,6 +205,7 @@ class PT2EQATTestCase(QuantizationTestCase):
has_bias=has_bias,
is_cuda=is_cuda,
expected_conv_literal_args=expected_conv_literal_args,
verify_convert=verify_convert,
)
def _verify_symmetric_xnnpack_qat_graph_helper(
@ -208,6 +217,7 @@ class PT2EQATTestCase(QuantizationTestCase):
has_bias: bool = True,
is_cuda: bool = False,
expected_conv_literal_args: Optional[Tuple[Any, ...]] = None,
verify_convert: bool = False,
):
"""
Verify that the graph module matches the fused QAT [conv - bn (- relu)] pattern
@ -267,6 +277,7 @@ class PT2EQATTestCase(QuantizationTestCase):
else:
div_scale_factor_node = bn_node.args[0]
(conv_node, scale_factor_reshape_node) = div_scale_factor_node.args
conv_op = conv_node.target
self.assertEqual(div_scale_factor_node.target, torch.ops.aten.div.Tensor)
self.assertTrue(_is_conv_node(conv_node))
self.assertEqual(
@ -347,12 +358,72 @@ class PT2EQATTestCase(QuantizationTestCase):
self.assertTrue("bn_running_var" in bn_running_var_node.target)
self.assertEqual(eps, 1e-5)
# Optionally check the converted graph
if verify_convert:
m = convert_pt2e(m)
m(*example_inputs)
if is_per_channel:
conv_weight_dq_op = (
torch.ops.quantized_decomposed.dequantize_per_channel.default
)
node_occurrence = {
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor.default
): 2,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
): 2,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_channel.default
): 1,
}
else:
conv_weight_dq_op = (
torch.ops.quantized_decomposed.dequantize_per_tensor.default
)
node_occurrence = {
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor.default
): 2,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
): 3,
}
node_list = [
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor.default
),
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
),
ns.call_function(conv_weight_dq_op),
ns.call_function(conv_op),
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor.default
),
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
),
]
self.checkGraphModuleNodes(
m,
expected_node_list=node_list,
expected_node_occurrence=node_occurrence,
)
class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
"""
Base TestCase to be used for all conv-bn[-relu] fusion patterns.
"""
# TODO: how can we avoid adding every new test to dynamo/expected_test_failures?
# Otherwise it fails with the following error:
# torch._dynamo.exc.InternalTorchDynamoError:
# 'QuantizationConfig' object has no attribute '__bool__'
def setUp(self):
# NB: Skip the test if this is a base class, this is to handle the test
# discovery logic in buck which finds and runs all tests here including
@ -761,13 +832,36 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
self.assertEqual(dq_qmax, 2**31 - 1)
self.assertEqual(dq_dtype, torch.int32)
def _do_test_qat_conv_transpose_bn(self, has_relu: bool):
# Use different in/out channel sizes to test if conv weight is
# properly transposed in QAT pattern
m = self._get_conv_bn_model(
has_relu=has_relu,
transpose=True,
in_channels=3,
out_channels=5,
kernel_size=3,
)
self._verify_symmetric_xnnpack_qat_graph(
m,
self.example_inputs,
has_relu=has_relu,
verify_convert=True,
)
def test_qat_conv_transpose_bn(self):
self._do_test_qat_conv_transpose_bn(has_relu=False)
def test_qat_conv_transpose_bn_relu(self):
self._do_test_qat_conv_transpose_bn(has_relu=True)
# TODO: enable this in the next PR
@skipIfNoQNNPACK
class TestQuantizePT2EQAT_ConvBn1d(TestQuantizePT2EQAT_ConvBn_Base):
dim = 1
example_inputs = (torch.randn(1, 3, 5),)
conv_class = torch.nn.Conv1d
conv_transpose_class = torch.nn.ConvTranspose1d
bn_class = torch.nn.BatchNorm1d
@ -776,6 +870,7 @@ class TestQuantizePT2EQAT_ConvBn2d(TestQuantizePT2EQAT_ConvBn_Base):
dim = 2
example_inputs = (torch.randn(1, 3, 5, 5),)
conv_class = torch.nn.Conv2d
conv_transpose_class = torch.nn.ConvTranspose2d
bn_class = torch.nn.BatchNorm2d
@ -783,6 +878,10 @@ def _is_conv_node(n: torch.fx.Node):
return n.op == "call_function" and n.target in [
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.conv_transpose1d,
torch.ops.aten.conv_transpose1d.default,
torch.ops.aten.conv_transpose2d,
torch.ops.aten.conv_transpose2d.input,
]

View File

@ -21,8 +21,9 @@ from torch.ao.quantization.quantizer import (
from .utils import (
_conv1d_bn_example_inputs,
_conv2d_bn_example_inputs,
_is_conv_node,
_is_bn_node,
_is_conv_or_conv_transpose_node,
_is_conv_transpose_fn,
fold_bn_weights_into_conv_node,
_get_aten_graph_module_for_pattern,
)
@ -113,7 +114,8 @@ def _get_qat_conv_bn_pattern(conv_fn: Callable) -> Callable:
running_std = torch.sqrt(bn_running_var + bn_eps)
scale_factor = bn_weight / running_std
weight_shape = [1] * len(conv_weight.shape)
weight_shape[0] = -1
weight_in_channel_axis = 1 if _is_conv_transpose_fn(conv_fn) else 0
weight_shape[weight_in_channel_axis] = -1
bias_shape = [1] * len(conv_weight.shape)
bias_shape[1] = -1
scaled_weight = conv_weight * scale_factor.reshape(weight_shape)
@ -144,7 +146,8 @@ def _get_qat_conv_bn_pattern_no_conv_bias(conv_fn: Callable) -> Callable:
running_std = torch.sqrt(bn_running_var + bn_eps)
scale_factor = bn_weight / running_std
weight_shape = [1] * len(conv_weight.shape)
weight_shape[0] = -1
weight_in_channel_axis = 1 if _is_conv_transpose_fn(conv_fn) else 0
weight_shape[weight_in_channel_axis] = -1
bias_shape = [1] * len(conv_weight.shape)
bias_shape[1] = -1
scaled_weight = conv_weight * scale_factor.reshape(weight_shape)
@ -271,7 +274,7 @@ def _has_conv_bias_filter(
the original graph has bias.
"""
for n in match.nodes_map.values():
if _is_conv_node(n):
if _is_conv_or_conv_transpose_node(n):
return len(n.args) > 2 and n.args[2] is not None
raise ValueError("Could not find conv node in matched conv + bn pattern")
@ -325,7 +328,7 @@ def _get_conv_bn_pattern_nodes(r: ReplacedPatterns) -> Dict[str, Tuple[Node, Nod
for n in nodes:
if n.op != "call_function":
continue
if _is_conv_node(n):
if _is_conv_or_conv_transpose_node(n):
assert conv_node is None
conv_node = n
if _is_bn_node(n):
@ -440,8 +443,8 @@ def _copy_over_literal_conv_args(original_node: Node, new_node: Node):
Note: Unlike other tensor args like conv weights and biases, literal args are
preserved in the original nodes after replacement, so we can access them here.
"""
assert _is_conv_node(original_node)
assert _is_conv_node(new_node)
assert _is_conv_or_conv_transpose_node(original_node)
assert _is_conv_or_conv_transpose_node(new_node)
# x, weight, bias, [stride, padding, dilation, transposed, output_padding, groups]
new_args = list(new_node.args)
if len(new_args) < 3:
@ -457,8 +460,8 @@ def _update_conv_input_qspec_map_after_replacement(original_node: Node, replacem
so the keys in the `input_qspec_map` will need to be updated to reflect
the corresponding nodes in the replacement graph.
"""
assert _is_conv_node(original_node)
assert _is_conv_node(replacement_node)
assert _is_conv_or_conv_transpose_node(original_node)
assert _is_conv_or_conv_transpose_node(replacement_node)
if "quantization_annotation" not in original_node.meta:
return
original_input_qspec_map = original_node.meta["quantization_annotation"].input_qspec_map
@ -522,11 +525,12 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
if not has_bn:
return m
m = _fuse_conv_bn_qat_helper(m, F.conv1d, _conv1d_bn_example_inputs, is_cuda=False)
m = _fuse_conv_bn_qat_helper(m, F.conv2d, _conv2d_bn_example_inputs, is_cuda=False)
if torch.cuda.is_available():
m = _fuse_conv_bn_qat_helper(m, F.conv1d, _conv1d_bn_example_inputs, is_cuda=True)
m = _fuse_conv_bn_qat_helper(m, F.conv2d, _conv2d_bn_example_inputs, is_cuda=True)
is_cuda_options = [True, False] if torch.cuda.is_available() else [False]
for is_cuda in is_cuda_options:
m = _fuse_conv_bn_qat_helper(m, F.conv1d, _conv1d_bn_example_inputs, is_cuda=is_cuda)
m = _fuse_conv_bn_qat_helper(m, F.conv2d, _conv2d_bn_example_inputs, is_cuda=is_cuda)
m = _fuse_conv_bn_qat_helper(m, F.conv_transpose1d, _conv1d_bn_example_inputs, is_cuda=is_cuda)
m = _fuse_conv_bn_qat_helper(m, F.conv_transpose2d, _conv2d_bn_example_inputs, is_cuda=is_cuda)
return m
def _fuse_conv_bn_qat_helper(
@ -609,7 +613,7 @@ def _fuse_conv_bn_qat_helper(
for original_node, replacement_node in _get_conv_bn_pattern_nodes(r).values():
# Step (3a): Copy over metadata for all nodes in [conv - bn - getitem]
replacement_node.meta = original_node.meta
if _is_conv_node(original_node):
if _is_conv_or_conv_transpose_node(original_node):
# Step (3b): Copy over conv literal args
_copy_over_literal_conv_args(original_node, replacement_node)
# Step (3c): Update old references in the conv node's input_qspec_map
@ -701,11 +705,12 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
if not has_bn:
return m
m = _fold_conv_bn_qat_helper(m, F.conv1d, _quantized_conv1d_bn_example_inputs, is_cuda=False)
m = _fold_conv_bn_qat_helper(m, F.conv2d, _quantized_conv2d_bn_example_inputs, is_cuda=False)
if torch.cuda.is_available():
m = _fold_conv_bn_qat_helper(m, F.conv1d, _quantized_conv1d_bn_example_inputs, is_cuda=True)
m = _fold_conv_bn_qat_helper(m, F.conv2d, _quantized_conv2d_bn_example_inputs, is_cuda=True)
is_cuda_options = [True, False] if torch.cuda.is_available() else [False]
for is_cuda in is_cuda_options:
m = _fold_conv_bn_qat_helper(m, F.conv1d, _quantized_conv1d_bn_example_inputs, is_cuda=is_cuda)
m = _fold_conv_bn_qat_helper(m, F.conv2d, _quantized_conv2d_bn_example_inputs, is_cuda=is_cuda)
m = _fold_conv_bn_qat_helper(m, F.conv_transpose1d, _quantized_conv1d_bn_example_inputs, is_cuda=is_cuda)
m = _fold_conv_bn_qat_helper(m, F.conv_transpose2d, _quantized_conv2d_bn_example_inputs, is_cuda=is_cuda)
return m
def _fold_conv_bn_qat_helper(
@ -780,7 +785,7 @@ def _fold_conv_bn_qat_helper(
# Copy over literal args for conv
for original_node in _filter_nodes_map(r.nodes_map).values():
if _is_conv_node(original_node):
if _is_conv_or_conv_transpose_node(original_node):
_copy_over_literal_conv_args(original_node, conv_node)
m.graph.eliminate_dead_code()

View File

@ -7,6 +7,7 @@ from torch.fx import (
GraphModule,
Node,
)
import torch.nn.functional as F
from torch.nn.utils.fusion import fuse_conv_bn_weights
from typing import Any, Callable, Dict, Optional, Tuple, List, Union
from torch.utils._pytree import LeafSpec
@ -169,7 +170,7 @@ def _is_supported_batch_norm_for_training(node: Node):
# TODO: move this to torch/ao/quantization/utils.py
def _is_conv_node(n: Node):
"""
Return whether the node refers to an aten conv or conv transpose op.
Return whether the node refers to an aten conv op.
"""
return n.op == "call_function" and n.target in [
torch.ops.aten.conv1d.default,
@ -182,10 +183,20 @@ def _is_conv_transpose_node(n: Node):
"""
return n.op == "call_function" and n.target in [
torch.ops.aten.conv_transpose1d,
torch.ops.aten.conv_transpose1d.default,
torch.ops.aten.conv_transpose2d,
torch.ops.aten.conv_transpose2d.input,
]
def _is_conv_or_conv_transpose_node(n: Node):
"""
Return whether the node refers to an aten conv or conv transpose op.
"""
return _is_conv_node(n) or _is_conv_transpose_node(n)
def _is_conv_transpose_fn(conv_fn: Callable):
return conv_fn in [F.conv_transpose1d, F.conv_transpose2d]
def _is_bn_node(n: Node):
return _is_supported_batch_norm_for_training(n) or n.target == torch.ops.aten._native_batch_norm_legit_no_training.default
@ -270,7 +281,7 @@ def _fuse_conv_bn_(m: GraphModule) -> None:
continue
bn_node = n
n = bn_node.args[0]
if not (_is_conv_node(n) or _is_conv_transpose_node(n)):
if not _is_conv_or_conv_transpose_node(n):
continue
conv_node = n
conv_weight_node = conv_node.args[1]

View File

@ -267,6 +267,8 @@ class XNNPACKQuantizer(Quantizer):
STATIC_QAT_ONLY_OPS = [
"conv_bn_relu",
"conv_bn",
"conv_transpose_bn_relu",
"conv_transpose_bn",
]
# static quantization ops (both PTQ and QAT)
@ -276,6 +278,7 @@ class XNNPACKQuantizer(Quantizer):
"linear",
"conv_relu",
"conv",
"conv_transpose_relu",
"adaptive_avg_pool2d",
# TODO: move this to BoltNNQuantizer?
"gru_io_only",

View File

@ -13,6 +13,8 @@ from torch.ao.quantization.pt2e.utils import (
_conv1d_bn_example_inputs,
_conv2d_bn_example_inputs,
_get_aten_graph_module_for_pattern,
_is_conv_node,
_is_conv_transpose_node,
)
from torch.ao.quantization.quantizer import (
QuantizationAnnotation,
@ -344,22 +346,9 @@ def _do_annotate_conv_relu(
continue
relu_node = n
maybe_conv_node = n.args[0]
# TODO: refactor with is_conv_node and is_conv_transpose_node
if is_conv_transpose:
conv_ops = [
torch.ops.aten.conv_transpose1d,
torch.ops.aten.conv_transpose2d.input,
]
else:
conv_ops = [
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
]
if (
not isinstance(maybe_conv_node, Node)
or maybe_conv_node.op != "call_function"
or maybe_conv_node.target not in conv_ops
):
is_conv_node = _is_conv_transpose_node if is_conv_transpose else _is_conv_node
if not isinstance(maybe_conv_node, Node) or not is_conv_node(maybe_conv_node):
continue
conv_node = maybe_conv_node