[quant][pt2e] Add support for conv transpose + bn + {relu} weights fusion in PTQ (#122046)

Summary:

also added some utils in xnnpack_quantizer_utils.py
* annotate_conv_tranpsose_bn_relu and annotate_conv_transpose_bn -> this is for QAT
* annotate_conv_transpose_relu

conv_transpose + bn weights fusion is performed automatically and can not be disabled currently
we can add support to allow disable this fusion later if needed

Test Plan:
python test/test_quantization.py -k test_conv_transpose_bn_fusion

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122046
Approved by: https://github.com/andrewor14
This commit is contained in:
Jerry Zhang
2024-03-18 21:44:40 -07:00
committed by PyTorch MergeBot
parent bc1fef113d
commit 901ba2be86
5 changed files with 168 additions and 32 deletions

View File

@ -1988,3 +1988,49 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
m = M().eval()
example_inputs = (torch.randn(5, 5),)
_ = dynamic_quantize_pt2e(m, example_inputs)
def test_conv_transpose_bn_relu(self):
class BackendAQuantizer(Quantizer):
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
int8_qspec = QuantizationSpec(
dtype=torch.int8,
quant_min=-128,
quant_max=127,
qscheme=torch.per_tensor_symmetric,
is_dynamic=False,
observer_or_fake_quant_ctr=observer.default_weight_observer,
)
quantization_config = QuantizationConfig(
input_activation=int8_qspec,
weight=int8_qspec,
bias=None,
output_activation=int8_qspec,
)
# conv_transpose + bn is fused automatically in PTQ (not configurable)
# so we just need to annotate conv_transpose + relu for conv_transpose + bn + relu
# pattern
OP_TO_ANNOTATOR["conv_transpose_relu"](model, quantization_config)
def validate(self, model: torch.fx.GraphModule) -> None:
pass
example_inputs = (torch.randn(1, 3, 5, 5),)
node_occurrence = {
# two for input of the first conv, one for output for the first conv
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
}
node_list = [
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
torch.ops.aten.conv_transpose2d.input,
torch.ops.aten.relu.default,
torch.ops.quantized_decomposed.quantize_per_tensor.default,
]
self._test_quantizer(
TestHelperModules.ConvTWithBNRelu(relu=True, bn=True),
example_inputs,
BackendAQuantizer(),
node_occurrence,
node_list,
)

View File

@ -21,7 +21,7 @@ from torch.ao.quantization.quantizer import (
from .utils import (
_conv1d_bn_example_inputs,
_conv2d_bn_example_inputs,
_is_conv,
_is_conv_node,
_is_bn_node,
fold_bn_weights_into_conv_node,
_get_aten_graph_module_for_pattern,
@ -271,7 +271,7 @@ def _has_conv_bias_filter(
the original graph has bias.
"""
for n in match.nodes_map.values():
if _is_conv(n):
if _is_conv_node(n):
return len(n.args) > 2 and n.args[2] is not None
raise ValueError("Could not find conv node in matched conv + bn pattern")
@ -325,7 +325,7 @@ def _get_conv_bn_pattern_nodes(r: ReplacedPatterns) -> Dict[str, Tuple[Node, Nod
for n in nodes:
if n.op != "call_function":
continue
if _is_conv(n):
if _is_conv_node(n):
assert conv_node is None
conv_node = n
if _is_bn_node(n):
@ -440,8 +440,8 @@ def _copy_over_literal_conv_args(original_node: Node, new_node: Node):
Note: Unlike other tensor args like conv weights and biases, literal args are
preserved in the original nodes after replacement, so we can access them here.
"""
assert _is_conv(original_node)
assert _is_conv(new_node)
assert _is_conv_node(original_node)
assert _is_conv_node(new_node)
# x, weight, bias, [stride, padding, dilation, transposed, output_padding, groups]
new_args = list(new_node.args)
if len(new_args) < 3:
@ -457,8 +457,8 @@ def _update_conv_input_qspec_map_after_replacement(original_node: Node, replacem
so the keys in the `input_qspec_map` will need to be updated to reflect
the corresponding nodes in the replacement graph.
"""
assert _is_conv(original_node)
assert _is_conv(replacement_node)
assert _is_conv_node(original_node)
assert _is_conv_node(replacement_node)
if "quantization_annotation" not in original_node.meta:
return
original_input_qspec_map = original_node.meta["quantization_annotation"].input_qspec_map
@ -609,7 +609,7 @@ def _fuse_conv_bn_qat_helper(
for original_node, replacement_node in _get_conv_bn_pattern_nodes(r).values():
# Step (3a): Copy over metadata for all nodes in [conv - bn - getitem]
replacement_node.meta = original_node.meta
if _is_conv(original_node):
if _is_conv_node(original_node):
# Step (3b): Copy over conv literal args
_copy_over_literal_conv_args(original_node, replacement_node)
# Step (3c): Update old references in the conv node's input_qspec_map
@ -780,7 +780,7 @@ def _fold_conv_bn_qat_helper(
# Copy over literal args for conv
for original_node in _filter_nodes_map(r.nodes_map).values():
if _is_conv(original_node):
if _is_conv_node(original_node):
_copy_over_literal_conv_args(original_node, conv_node)
m.graph.eliminate_dead_code()

View File

@ -166,24 +166,24 @@ def _is_supported_batch_norm_for_training(node: Node):
]
return node.target in supported_ops
# TODO: rename this to _is_conv_node
def _is_conv(n: Node):
# TODO: move this to torch/ao/quantization/utils.py
def _is_conv_node(n: Node):
"""
Return whether the node refers to an aten conv op.
Return whether the node refers to an aten conv or conv transpose op.
"""
return n.op == "call_function" and n.target in [
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
]
# TODO: rename this to _is_conv_transpose_node
def _is_conv_transpose(n: Node):
def _is_conv_transpose_node(n: Node):
"""
Return whether the node refers to an aten conv_transpose op.
"""
return n.op == "call_function" and n.target in [
torch.ops.aten.conv_transpose1d,
torch.ops.aten.conv_transpose2d,
torch.ops.aten.conv_transpose2d.input,
]
def _is_bn_node(n: Node):
@ -199,7 +199,7 @@ def fold_bn_weights_into_conv_node(
# conv args: input, weight, bias, stride, padding, dilation, ...
conv_w = _get_tensor_constant_from_node(conv_weight_node, m)
conv_b = _get_tensor_constant_from_node(conv_bias_node, m)
transpose = _is_conv_transpose(conv_node)
transpose = _is_conv_transpose_node(conv_node)
# eval bn args: input, weight, bias, running mean, running var, momentum, eps
# train bn args: input, weight, bias, running mean, running var, training, momentum, eps
@ -270,7 +270,7 @@ def _fuse_conv_bn_(m: GraphModule) -> None:
continue
bn_node = n
n = bn_node.args[0]
if not _is_conv(n):
if not (_is_conv_node(n) or _is_conv_transpose_node(n)):
continue
conv_node = n
conv_weight_node = conv_node.args[1]

View File

@ -329,12 +329,12 @@ def _annotate_conv(
return annotated_partitions
@register_annotator("conv_relu")
def _annotate_conv_relu(
def _do_annotate_conv_relu(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
is_conv_transpose: bool = False,
):
annotated_partitions = []
for n in gm.graph.nodes:
if n.op != "call_function" or n.target not in [
@ -344,14 +344,21 @@ def _annotate_conv_relu(
continue
relu_node = n
maybe_conv_node = n.args[0]
if (
not isinstance(maybe_conv_node, Node)
or maybe_conv_node.op != "call_function"
or maybe_conv_node.target
not in [
# TODO: refactor with is_conv_node and is_conv_transpose_node
if is_conv_transpose:
conv_ops = [
torch.ops.aten.conv_transpose1d,
torch.ops.aten.conv_transpose2d.input,
]
else:
conv_ops = [
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
]
if (
not isinstance(maybe_conv_node, Node)
or maybe_conv_node.op != "call_function"
or maybe_conv_node.target not in conv_ops
):
continue
conv_node = maybe_conv_node
@ -390,6 +397,28 @@ def _annotate_conv_relu(
return annotated_partitions
@register_annotator("conv_relu")
def _annotate_conv_relu(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
return _do_annotate_conv_relu(
gm, quantization_config, filter_fn, is_conv_transpose=False
)
@register_annotator("conv_transpose_relu")
def _annotate_conv_transpose_relu(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
return _do_annotate_conv_relu(
gm, quantization_config, filter_fn, is_conv_transpose=True
)
@register_annotator("conv_bn")
def _annotate_conv_bn(
gm: torch.fx.GraphModule,
@ -416,11 +445,42 @@ def _annotate_conv_bn_relu(
return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=True)
@register_annotator("conv_transpose_bn")
def _annotate_conv_transpose_bn(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
"""
Find conv_transpose + batchnorm parititions
Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
"""
return _do_annotate_conv_bn(
gm, quantization_config, filter_fn, has_relu=False, is_conv_transpose=True
)
@register_annotator("conv_transpose_bn_relu")
def _annotate_conv_transpose_bn_relu(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
"""
Find conv_transpose + batchnorm + relu parititions
Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
"""
return _do_annotate_conv_bn(
gm, quantization_config, filter_fn, has_relu=True, is_conv_transpose=True
)
def _do_annotate_conv_bn(
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[Callable[[Node], bool]],
has_relu: bool,
is_conv_transpose: bool = False,
) -> List[List[Node]]:
"""
Given a function that takes in a `conv_fn` and returns a conv-bn[-relu] pattern,
@ -454,22 +514,28 @@ def _do_annotate_conv_bn(
gm.recompile()
matches = []
combinations = [
(F.conv1d, _conv1d_bn_example_inputs),
(F.conv2d, _conv2d_bn_example_inputs),
]
if is_conv_transpose:
combinations = [
(F.conv_transpose1d, _conv1d_bn_example_inputs),
(F.conv_transpose2d, _conv2d_bn_example_inputs),
]
else:
combinations = [
(F.conv1d, _conv1d_bn_example_inputs), # type: ignore[list-item]
(F.conv2d, _conv2d_bn_example_inputs), # type: ignore[list-item]
]
# Add `is_cuda` and `relu_is_inplace` dimensions
combinations = itertools.product(
combinations = itertools.product( # type: ignore[assignment]
combinations,
[True, False] if torch.cuda.is_available() else [False], # is_cuda
[True, False] if has_relu else [False], # relu_is_inplace
)
# Match against all conv dimensions and cuda variants
for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations:
pattern = get_pattern(conv_fn, relu_is_inplace)
pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda)
for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations: # type: ignore[misc]
pattern = get_pattern(conv_fn, relu_is_inplace) # type: ignore[has-type]
pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda) # type: ignore[has-type]
pattern.graph.eliminate_dead_code()
pattern.recompile()
matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True)

View File

@ -1175,6 +1175,7 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
fx_qconfig_mapping=None,
export_with_dynamic_shape=False,
is_qat=False,
is_debug_mode=False,
):
# resetting dynamo cache
torch._dynamo.reset()
@ -1199,6 +1200,8 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
# Calibrate
m(*example_inputs)
m = convert_pt2e(m)
if is_debug_mode:
print("quantized model", m)
pt2_quant_output = m(*example_inputs)
ns = NodeSpec
@ -2708,6 +2711,27 @@ class TestHelperModules:
x = self.bn(x)
return self.relu(x)
class ConvTWithBNRelu(torch.nn.Module):
def __init__(self, relu, dim=2, bn=True, bias=True):
super().__init__()
convts = {1: torch.nn.ConvTranspose1d, 2: torch.nn.ConvTranspose2d}
bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d}
self.convt = convts[dim](3, 3, 3, bias=bias)
if bn:
self.bn = bns[dim](3)
else:
self.bn = torch.nn.Identity()
if relu:
self.relu = torch.nn.ReLU()
else:
self.relu = torch.nn.Identity()
def forward(self, x):
x = self.convt(x)
x = self.bn(x)
return self.relu(x)
class Conv2dThenConv1d(torch.nn.Module):
def __init__(self):
super().__init__()