mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-11 22:34:53 +08:00
[pt2e][qat] Support conv-transpose-bn[-relu] QAT fusion (#123652)
Summary: This commit adds support for QAT fusion for the [conv-transpose-bn] and [conv-transpose-bn-relu] patterns. Test Plan: python test/test_quantization.py TestQuantizePT2EQAT_ConvBn1d.test_qat_conv_transpose_bn python test/test_quantization.py TestQuantizePT2EQAT_ConvBn1d.test_qat_conv_transpose_bn_relu python test/test_quantization.py TestQuantizePT2EQAT_ConvBn2d.test_qat_conv_transpose_bn python test/test_quantization.py TestQuantizePT2EQAT_ConvBn2d.test_qat_conv_transpose_bn_relu Reviewers: jerryzh168 Subscribers: jerryzh168, supriyar Tasks: https://github.com/pytorch/pytorch/issues/122224 Differential Revision: [D55930704](https://our.internmc.facebook.com/intern/diff/D55930704) Pull Request resolved: https://github.com/pytorch/pytorch/pull/123652 Approved by: https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
e4c887fbf6
commit
5c0a380bdf
@ -61,8 +61,12 @@ class PT2EQATTestCase(QuantizationTestCase):
|
||||
**conv_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv = conv_class(3, 3, 3, bias=has_conv_bias, **conv_kwargs)
|
||||
self.bn = bn_class(3) if has_bn else None
|
||||
conv_kwargs.setdefault("in_channels", 3)
|
||||
conv_kwargs.setdefault("out_channels", 3)
|
||||
conv_kwargs.setdefault("kernel_size", 3)
|
||||
conv_kwargs.setdefault("bias", has_conv_bias)
|
||||
self.conv = conv_class(**conv_kwargs)
|
||||
self.bn = bn_class(conv_kwargs["out_channels"]) if has_bn else None
|
||||
self.relu = torch.nn.ReLU() if has_relu else None
|
||||
|
||||
def forward(self, x):
|
||||
@ -78,6 +82,7 @@ class PT2EQATTestCase(QuantizationTestCase):
|
||||
has_conv_bias: bool = True,
|
||||
has_bn: bool = True,
|
||||
has_relu: bool = False,
|
||||
transpose: bool = False,
|
||||
**conv_kwargs,
|
||||
):
|
||||
"""
|
||||
@ -86,7 +91,7 @@ class PT2EQATTestCase(QuantizationTestCase):
|
||||
conv-bn model with conv bias.
|
||||
"""
|
||||
return self._BaseConvBnModel(
|
||||
self.conv_class,
|
||||
self.conv_transpose_class if transpose else self.conv_class,
|
||||
self.bn_class,
|
||||
has_conv_bias,
|
||||
has_bn,
|
||||
@ -179,6 +184,8 @@ class PT2EQATTestCase(QuantizationTestCase):
|
||||
has_bias: bool = True,
|
||||
is_cuda: bool = False,
|
||||
expected_conv_literal_args: Optional[Tuple[Any, ...]] = None,
|
||||
# TODO: set this to true by default
|
||||
verify_convert: bool = False,
|
||||
):
|
||||
self._verify_symmetric_xnnpack_qat_graph_helper(
|
||||
m,
|
||||
@ -188,6 +195,7 @@ class PT2EQATTestCase(QuantizationTestCase):
|
||||
has_bias=has_bias,
|
||||
is_cuda=is_cuda,
|
||||
expected_conv_literal_args=expected_conv_literal_args,
|
||||
verify_convert=verify_convert,
|
||||
)
|
||||
self._verify_symmetric_xnnpack_qat_graph_helper(
|
||||
m,
|
||||
@ -197,6 +205,7 @@ class PT2EQATTestCase(QuantizationTestCase):
|
||||
has_bias=has_bias,
|
||||
is_cuda=is_cuda,
|
||||
expected_conv_literal_args=expected_conv_literal_args,
|
||||
verify_convert=verify_convert,
|
||||
)
|
||||
|
||||
def _verify_symmetric_xnnpack_qat_graph_helper(
|
||||
@ -208,6 +217,7 @@ class PT2EQATTestCase(QuantizationTestCase):
|
||||
has_bias: bool = True,
|
||||
is_cuda: bool = False,
|
||||
expected_conv_literal_args: Optional[Tuple[Any, ...]] = None,
|
||||
verify_convert: bool = False,
|
||||
):
|
||||
"""
|
||||
Verify that the graph module matches the fused QAT [conv - bn (- relu)] pattern
|
||||
@ -267,6 +277,7 @@ class PT2EQATTestCase(QuantizationTestCase):
|
||||
else:
|
||||
div_scale_factor_node = bn_node.args[0]
|
||||
(conv_node, scale_factor_reshape_node) = div_scale_factor_node.args
|
||||
conv_op = conv_node.target
|
||||
self.assertEqual(div_scale_factor_node.target, torch.ops.aten.div.Tensor)
|
||||
self.assertTrue(_is_conv_node(conv_node))
|
||||
self.assertEqual(
|
||||
@ -347,12 +358,72 @@ class PT2EQATTestCase(QuantizationTestCase):
|
||||
self.assertTrue("bn_running_var" in bn_running_var_node.target)
|
||||
self.assertEqual(eps, 1e-5)
|
||||
|
||||
# Optionally check the converted graph
|
||||
if verify_convert:
|
||||
m = convert_pt2e(m)
|
||||
m(*example_inputs)
|
||||
|
||||
if is_per_channel:
|
||||
conv_weight_dq_op = (
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default
|
||||
)
|
||||
node_occurrence = {
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
||||
): 2,
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
): 2,
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_channel.default
|
||||
): 1,
|
||||
}
|
||||
else:
|
||||
conv_weight_dq_op = (
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
)
|
||||
node_occurrence = {
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
||||
): 2,
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
): 3,
|
||||
}
|
||||
node_list = [
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
||||
),
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
),
|
||||
ns.call_function(conv_weight_dq_op),
|
||||
ns.call_function(conv_op),
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default
|
||||
),
|
||||
ns.call_function(
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default
|
||||
),
|
||||
]
|
||||
|
||||
self.checkGraphModuleNodes(
|
||||
m,
|
||||
expected_node_list=node_list,
|
||||
expected_node_occurrence=node_occurrence,
|
||||
)
|
||||
|
||||
|
||||
class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
|
||||
"""
|
||||
Base TestCase to be used for all conv-bn[-relu] fusion patterns.
|
||||
"""
|
||||
|
||||
# TODO: how can we avoid adding every new test to dynamo/expected_test_failures?
|
||||
# Otherwise it fails with the following error:
|
||||
# torch._dynamo.exc.InternalTorchDynamoError:
|
||||
# 'QuantizationConfig' object has no attribute '__bool__'
|
||||
|
||||
def setUp(self):
|
||||
# NB: Skip the test if this is a base class, this is to handle the test
|
||||
# discovery logic in buck which finds and runs all tests here including
|
||||
@ -761,13 +832,36 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
|
||||
self.assertEqual(dq_qmax, 2**31 - 1)
|
||||
self.assertEqual(dq_dtype, torch.int32)
|
||||
|
||||
def _do_test_qat_conv_transpose_bn(self, has_relu: bool):
|
||||
# Use different in/out channel sizes to test if conv weight is
|
||||
# properly transposed in QAT pattern
|
||||
m = self._get_conv_bn_model(
|
||||
has_relu=has_relu,
|
||||
transpose=True,
|
||||
in_channels=3,
|
||||
out_channels=5,
|
||||
kernel_size=3,
|
||||
)
|
||||
self._verify_symmetric_xnnpack_qat_graph(
|
||||
m,
|
||||
self.example_inputs,
|
||||
has_relu=has_relu,
|
||||
verify_convert=True,
|
||||
)
|
||||
|
||||
def test_qat_conv_transpose_bn(self):
|
||||
self._do_test_qat_conv_transpose_bn(has_relu=False)
|
||||
|
||||
def test_qat_conv_transpose_bn_relu(self):
|
||||
self._do_test_qat_conv_transpose_bn(has_relu=True)
|
||||
|
||||
|
||||
# TODO: enable this in the next PR
|
||||
@skipIfNoQNNPACK
|
||||
class TestQuantizePT2EQAT_ConvBn1d(TestQuantizePT2EQAT_ConvBn_Base):
|
||||
dim = 1
|
||||
example_inputs = (torch.randn(1, 3, 5),)
|
||||
conv_class = torch.nn.Conv1d
|
||||
conv_transpose_class = torch.nn.ConvTranspose1d
|
||||
bn_class = torch.nn.BatchNorm1d
|
||||
|
||||
|
||||
@ -776,6 +870,7 @@ class TestQuantizePT2EQAT_ConvBn2d(TestQuantizePT2EQAT_ConvBn_Base):
|
||||
dim = 2
|
||||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
conv_class = torch.nn.Conv2d
|
||||
conv_transpose_class = torch.nn.ConvTranspose2d
|
||||
bn_class = torch.nn.BatchNorm2d
|
||||
|
||||
|
||||
@ -783,6 +878,10 @@ def _is_conv_node(n: torch.fx.Node):
|
||||
return n.op == "call_function" and n.target in [
|
||||
torch.ops.aten.conv1d.default,
|
||||
torch.ops.aten.conv2d.default,
|
||||
torch.ops.aten.conv_transpose1d,
|
||||
torch.ops.aten.conv_transpose1d.default,
|
||||
torch.ops.aten.conv_transpose2d,
|
||||
torch.ops.aten.conv_transpose2d.input,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -21,8 +21,9 @@ from torch.ao.quantization.quantizer import (
|
||||
from .utils import (
|
||||
_conv1d_bn_example_inputs,
|
||||
_conv2d_bn_example_inputs,
|
||||
_is_conv_node,
|
||||
_is_bn_node,
|
||||
_is_conv_or_conv_transpose_node,
|
||||
_is_conv_transpose_fn,
|
||||
fold_bn_weights_into_conv_node,
|
||||
_get_aten_graph_module_for_pattern,
|
||||
)
|
||||
@ -113,7 +114,8 @@ def _get_qat_conv_bn_pattern(conv_fn: Callable) -> Callable:
|
||||
running_std = torch.sqrt(bn_running_var + bn_eps)
|
||||
scale_factor = bn_weight / running_std
|
||||
weight_shape = [1] * len(conv_weight.shape)
|
||||
weight_shape[0] = -1
|
||||
weight_in_channel_axis = 1 if _is_conv_transpose_fn(conv_fn) else 0
|
||||
weight_shape[weight_in_channel_axis] = -1
|
||||
bias_shape = [1] * len(conv_weight.shape)
|
||||
bias_shape[1] = -1
|
||||
scaled_weight = conv_weight * scale_factor.reshape(weight_shape)
|
||||
@ -144,7 +146,8 @@ def _get_qat_conv_bn_pattern_no_conv_bias(conv_fn: Callable) -> Callable:
|
||||
running_std = torch.sqrt(bn_running_var + bn_eps)
|
||||
scale_factor = bn_weight / running_std
|
||||
weight_shape = [1] * len(conv_weight.shape)
|
||||
weight_shape[0] = -1
|
||||
weight_in_channel_axis = 1 if _is_conv_transpose_fn(conv_fn) else 0
|
||||
weight_shape[weight_in_channel_axis] = -1
|
||||
bias_shape = [1] * len(conv_weight.shape)
|
||||
bias_shape[1] = -1
|
||||
scaled_weight = conv_weight * scale_factor.reshape(weight_shape)
|
||||
@ -271,7 +274,7 @@ def _has_conv_bias_filter(
|
||||
the original graph has bias.
|
||||
"""
|
||||
for n in match.nodes_map.values():
|
||||
if _is_conv_node(n):
|
||||
if _is_conv_or_conv_transpose_node(n):
|
||||
return len(n.args) > 2 and n.args[2] is not None
|
||||
raise ValueError("Could not find conv node in matched conv + bn pattern")
|
||||
|
||||
@ -325,7 +328,7 @@ def _get_conv_bn_pattern_nodes(r: ReplacedPatterns) -> Dict[str, Tuple[Node, Nod
|
||||
for n in nodes:
|
||||
if n.op != "call_function":
|
||||
continue
|
||||
if _is_conv_node(n):
|
||||
if _is_conv_or_conv_transpose_node(n):
|
||||
assert conv_node is None
|
||||
conv_node = n
|
||||
if _is_bn_node(n):
|
||||
@ -440,8 +443,8 @@ def _copy_over_literal_conv_args(original_node: Node, new_node: Node):
|
||||
Note: Unlike other tensor args like conv weights and biases, literal args are
|
||||
preserved in the original nodes after replacement, so we can access them here.
|
||||
"""
|
||||
assert _is_conv_node(original_node)
|
||||
assert _is_conv_node(new_node)
|
||||
assert _is_conv_or_conv_transpose_node(original_node)
|
||||
assert _is_conv_or_conv_transpose_node(new_node)
|
||||
# x, weight, bias, [stride, padding, dilation, transposed, output_padding, groups]
|
||||
new_args = list(new_node.args)
|
||||
if len(new_args) < 3:
|
||||
@ -457,8 +460,8 @@ def _update_conv_input_qspec_map_after_replacement(original_node: Node, replacem
|
||||
so the keys in the `input_qspec_map` will need to be updated to reflect
|
||||
the corresponding nodes in the replacement graph.
|
||||
"""
|
||||
assert _is_conv_node(original_node)
|
||||
assert _is_conv_node(replacement_node)
|
||||
assert _is_conv_or_conv_transpose_node(original_node)
|
||||
assert _is_conv_or_conv_transpose_node(replacement_node)
|
||||
if "quantization_annotation" not in original_node.meta:
|
||||
return
|
||||
original_input_qspec_map = original_node.meta["quantization_annotation"].input_qspec_map
|
||||
@ -522,11 +525,12 @@ def _fuse_conv_bn_qat(m: GraphModule) -> GraphModule:
|
||||
has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
|
||||
if not has_bn:
|
||||
return m
|
||||
m = _fuse_conv_bn_qat_helper(m, F.conv1d, _conv1d_bn_example_inputs, is_cuda=False)
|
||||
m = _fuse_conv_bn_qat_helper(m, F.conv2d, _conv2d_bn_example_inputs, is_cuda=False)
|
||||
if torch.cuda.is_available():
|
||||
m = _fuse_conv_bn_qat_helper(m, F.conv1d, _conv1d_bn_example_inputs, is_cuda=True)
|
||||
m = _fuse_conv_bn_qat_helper(m, F.conv2d, _conv2d_bn_example_inputs, is_cuda=True)
|
||||
is_cuda_options = [True, False] if torch.cuda.is_available() else [False]
|
||||
for is_cuda in is_cuda_options:
|
||||
m = _fuse_conv_bn_qat_helper(m, F.conv1d, _conv1d_bn_example_inputs, is_cuda=is_cuda)
|
||||
m = _fuse_conv_bn_qat_helper(m, F.conv2d, _conv2d_bn_example_inputs, is_cuda=is_cuda)
|
||||
m = _fuse_conv_bn_qat_helper(m, F.conv_transpose1d, _conv1d_bn_example_inputs, is_cuda=is_cuda)
|
||||
m = _fuse_conv_bn_qat_helper(m, F.conv_transpose2d, _conv2d_bn_example_inputs, is_cuda=is_cuda)
|
||||
return m
|
||||
|
||||
def _fuse_conv_bn_qat_helper(
|
||||
@ -609,7 +613,7 @@ def _fuse_conv_bn_qat_helper(
|
||||
for original_node, replacement_node in _get_conv_bn_pattern_nodes(r).values():
|
||||
# Step (3a): Copy over metadata for all nodes in [conv - bn - getitem]
|
||||
replacement_node.meta = original_node.meta
|
||||
if _is_conv_node(original_node):
|
||||
if _is_conv_or_conv_transpose_node(original_node):
|
||||
# Step (3b): Copy over conv literal args
|
||||
_copy_over_literal_conv_args(original_node, replacement_node)
|
||||
# Step (3c): Update old references in the conv node's input_qspec_map
|
||||
@ -701,11 +705,12 @@ def _fold_conv_bn_qat(m: GraphModule) -> GraphModule:
|
||||
has_bn = any(_is_bn_node(n) for n in m.graph.nodes)
|
||||
if not has_bn:
|
||||
return m
|
||||
m = _fold_conv_bn_qat_helper(m, F.conv1d, _quantized_conv1d_bn_example_inputs, is_cuda=False)
|
||||
m = _fold_conv_bn_qat_helper(m, F.conv2d, _quantized_conv2d_bn_example_inputs, is_cuda=False)
|
||||
if torch.cuda.is_available():
|
||||
m = _fold_conv_bn_qat_helper(m, F.conv1d, _quantized_conv1d_bn_example_inputs, is_cuda=True)
|
||||
m = _fold_conv_bn_qat_helper(m, F.conv2d, _quantized_conv2d_bn_example_inputs, is_cuda=True)
|
||||
is_cuda_options = [True, False] if torch.cuda.is_available() else [False]
|
||||
for is_cuda in is_cuda_options:
|
||||
m = _fold_conv_bn_qat_helper(m, F.conv1d, _quantized_conv1d_bn_example_inputs, is_cuda=is_cuda)
|
||||
m = _fold_conv_bn_qat_helper(m, F.conv2d, _quantized_conv2d_bn_example_inputs, is_cuda=is_cuda)
|
||||
m = _fold_conv_bn_qat_helper(m, F.conv_transpose1d, _quantized_conv1d_bn_example_inputs, is_cuda=is_cuda)
|
||||
m = _fold_conv_bn_qat_helper(m, F.conv_transpose2d, _quantized_conv2d_bn_example_inputs, is_cuda=is_cuda)
|
||||
return m
|
||||
|
||||
def _fold_conv_bn_qat_helper(
|
||||
@ -780,7 +785,7 @@ def _fold_conv_bn_qat_helper(
|
||||
|
||||
# Copy over literal args for conv
|
||||
for original_node in _filter_nodes_map(r.nodes_map).values():
|
||||
if _is_conv_node(original_node):
|
||||
if _is_conv_or_conv_transpose_node(original_node):
|
||||
_copy_over_literal_conv_args(original_node, conv_node)
|
||||
|
||||
m.graph.eliminate_dead_code()
|
||||
|
||||
@ -7,6 +7,7 @@ from torch.fx import (
|
||||
GraphModule,
|
||||
Node,
|
||||
)
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.fusion import fuse_conv_bn_weights
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, List, Union
|
||||
from torch.utils._pytree import LeafSpec
|
||||
@ -169,7 +170,7 @@ def _is_supported_batch_norm_for_training(node: Node):
|
||||
# TODO: move this to torch/ao/quantization/utils.py
|
||||
def _is_conv_node(n: Node):
|
||||
"""
|
||||
Return whether the node refers to an aten conv or conv transpose op.
|
||||
Return whether the node refers to an aten conv op.
|
||||
"""
|
||||
return n.op == "call_function" and n.target in [
|
||||
torch.ops.aten.conv1d.default,
|
||||
@ -182,10 +183,20 @@ def _is_conv_transpose_node(n: Node):
|
||||
"""
|
||||
return n.op == "call_function" and n.target in [
|
||||
torch.ops.aten.conv_transpose1d,
|
||||
torch.ops.aten.conv_transpose1d.default,
|
||||
torch.ops.aten.conv_transpose2d,
|
||||
torch.ops.aten.conv_transpose2d.input,
|
||||
]
|
||||
|
||||
def _is_conv_or_conv_transpose_node(n: Node):
|
||||
"""
|
||||
Return whether the node refers to an aten conv or conv transpose op.
|
||||
"""
|
||||
return _is_conv_node(n) or _is_conv_transpose_node(n)
|
||||
|
||||
def _is_conv_transpose_fn(conv_fn: Callable):
|
||||
return conv_fn in [F.conv_transpose1d, F.conv_transpose2d]
|
||||
|
||||
def _is_bn_node(n: Node):
|
||||
return _is_supported_batch_norm_for_training(n) or n.target == torch.ops.aten._native_batch_norm_legit_no_training.default
|
||||
|
||||
@ -270,7 +281,7 @@ def _fuse_conv_bn_(m: GraphModule) -> None:
|
||||
continue
|
||||
bn_node = n
|
||||
n = bn_node.args[0]
|
||||
if not (_is_conv_node(n) or _is_conv_transpose_node(n)):
|
||||
if not _is_conv_or_conv_transpose_node(n):
|
||||
continue
|
||||
conv_node = n
|
||||
conv_weight_node = conv_node.args[1]
|
||||
|
||||
@ -267,6 +267,8 @@ class XNNPACKQuantizer(Quantizer):
|
||||
STATIC_QAT_ONLY_OPS = [
|
||||
"conv_bn_relu",
|
||||
"conv_bn",
|
||||
"conv_transpose_bn_relu",
|
||||
"conv_transpose_bn",
|
||||
]
|
||||
|
||||
# static quantization ops (both PTQ and QAT)
|
||||
@ -276,6 +278,7 @@ class XNNPACKQuantizer(Quantizer):
|
||||
"linear",
|
||||
"conv_relu",
|
||||
"conv",
|
||||
"conv_transpose_relu",
|
||||
"adaptive_avg_pool2d",
|
||||
# TODO: move this to BoltNNQuantizer?
|
||||
"gru_io_only",
|
||||
|
||||
@ -13,6 +13,8 @@ from torch.ao.quantization.pt2e.utils import (
|
||||
_conv1d_bn_example_inputs,
|
||||
_conv2d_bn_example_inputs,
|
||||
_get_aten_graph_module_for_pattern,
|
||||
_is_conv_node,
|
||||
_is_conv_transpose_node,
|
||||
)
|
||||
from torch.ao.quantization.quantizer import (
|
||||
QuantizationAnnotation,
|
||||
@ -344,22 +346,9 @@ def _do_annotate_conv_relu(
|
||||
continue
|
||||
relu_node = n
|
||||
maybe_conv_node = n.args[0]
|
||||
# TODO: refactor with is_conv_node and is_conv_transpose_node
|
||||
if is_conv_transpose:
|
||||
conv_ops = [
|
||||
torch.ops.aten.conv_transpose1d,
|
||||
torch.ops.aten.conv_transpose2d.input,
|
||||
]
|
||||
else:
|
||||
conv_ops = [
|
||||
torch.ops.aten.conv1d.default,
|
||||
torch.ops.aten.conv2d.default,
|
||||
]
|
||||
if (
|
||||
not isinstance(maybe_conv_node, Node)
|
||||
or maybe_conv_node.op != "call_function"
|
||||
or maybe_conv_node.target not in conv_ops
|
||||
):
|
||||
|
||||
is_conv_node = _is_conv_transpose_node if is_conv_transpose else _is_conv_node
|
||||
if not isinstance(maybe_conv_node, Node) or not is_conv_node(maybe_conv_node):
|
||||
continue
|
||||
conv_node = maybe_conv_node
|
||||
|
||||
|
||||
Reference in New Issue
Block a user