mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[quant][pt2e] Add support for conv transpose + bn + {relu} weights fusion in PTQ (#122046)
Summary: also added some utils in xnnpack_quantizer_utils.py * annotate_conv_tranpsose_bn_relu and annotate_conv_transpose_bn -> this is for QAT * annotate_conv_transpose_relu conv_transpose + bn weights fusion is performed automatically and can not be disabled currently we can add support to allow disable this fusion later if needed Test Plan: python test/test_quantization.py -k test_conv_transpose_bn_fusion Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/122046 Approved by: https://github.com/andrewor14
This commit is contained in:
committed by
PyTorch MergeBot
parent
bc1fef113d
commit
901ba2be86
@ -1988,3 +1988,49 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
m = M().eval()
|
||||
example_inputs = (torch.randn(5, 5),)
|
||||
_ = dynamic_quantize_pt2e(m, example_inputs)
|
||||
|
||||
def test_conv_transpose_bn_relu(self):
|
||||
class BackendAQuantizer(Quantizer):
|
||||
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
int8_qspec = QuantizationSpec(
|
||||
dtype=torch.int8,
|
||||
quant_min=-128,
|
||||
quant_max=127,
|
||||
qscheme=torch.per_tensor_symmetric,
|
||||
is_dynamic=False,
|
||||
observer_or_fake_quant_ctr=observer.default_weight_observer,
|
||||
)
|
||||
quantization_config = QuantizationConfig(
|
||||
input_activation=int8_qspec,
|
||||
weight=int8_qspec,
|
||||
bias=None,
|
||||
output_activation=int8_qspec,
|
||||
)
|
||||
# conv_transpose + bn is fused automatically in PTQ (not configurable)
|
||||
# so we just need to annotate conv_transpose + relu for conv_transpose + bn + relu
|
||||
# pattern
|
||||
OP_TO_ANNOTATOR["conv_transpose_relu"](model, quantization_config)
|
||||
|
||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||
pass
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
node_occurrence = {
|
||||
# two for input of the first conv, one for output for the first conv
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default: 2,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
|
||||
torch.ops.aten.conv_transpose2d.input,
|
||||
torch.ops.aten.relu.default,
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
]
|
||||
self._test_quantizer(
|
||||
TestHelperModules.ConvTWithBNRelu(relu=True, bn=True),
|
||||
example_inputs,
|
||||
BackendAQuantizer(),
|
||||
node_occurrence,
|
||||
node_list,
|
||||
)
|
||||
|
||||
@ -21,7 +21,7 @@ from torch.ao.quantization.quantizer import (
|
||||
from .utils import (
|
||||
_conv1d_bn_example_inputs,
|
||||
_conv2d_bn_example_inputs,
|
||||
_is_conv,
|
||||
_is_conv_node,
|
||||
_is_bn_node,
|
||||
fold_bn_weights_into_conv_node,
|
||||
_get_aten_graph_module_for_pattern,
|
||||
@ -271,7 +271,7 @@ def _has_conv_bias_filter(
|
||||
the original graph has bias.
|
||||
"""
|
||||
for n in match.nodes_map.values():
|
||||
if _is_conv(n):
|
||||
if _is_conv_node(n):
|
||||
return len(n.args) > 2 and n.args[2] is not None
|
||||
raise ValueError("Could not find conv node in matched conv + bn pattern")
|
||||
|
||||
@ -325,7 +325,7 @@ def _get_conv_bn_pattern_nodes(r: ReplacedPatterns) -> Dict[str, Tuple[Node, Nod
|
||||
for n in nodes:
|
||||
if n.op != "call_function":
|
||||
continue
|
||||
if _is_conv(n):
|
||||
if _is_conv_node(n):
|
||||
assert conv_node is None
|
||||
conv_node = n
|
||||
if _is_bn_node(n):
|
||||
@ -440,8 +440,8 @@ def _copy_over_literal_conv_args(original_node: Node, new_node: Node):
|
||||
Note: Unlike other tensor args like conv weights and biases, literal args are
|
||||
preserved in the original nodes after replacement, so we can access them here.
|
||||
"""
|
||||
assert _is_conv(original_node)
|
||||
assert _is_conv(new_node)
|
||||
assert _is_conv_node(original_node)
|
||||
assert _is_conv_node(new_node)
|
||||
# x, weight, bias, [stride, padding, dilation, transposed, output_padding, groups]
|
||||
new_args = list(new_node.args)
|
||||
if len(new_args) < 3:
|
||||
@ -457,8 +457,8 @@ def _update_conv_input_qspec_map_after_replacement(original_node: Node, replacem
|
||||
so the keys in the `input_qspec_map` will need to be updated to reflect
|
||||
the corresponding nodes in the replacement graph.
|
||||
"""
|
||||
assert _is_conv(original_node)
|
||||
assert _is_conv(replacement_node)
|
||||
assert _is_conv_node(original_node)
|
||||
assert _is_conv_node(replacement_node)
|
||||
if "quantization_annotation" not in original_node.meta:
|
||||
return
|
||||
original_input_qspec_map = original_node.meta["quantization_annotation"].input_qspec_map
|
||||
@ -609,7 +609,7 @@ def _fuse_conv_bn_qat_helper(
|
||||
for original_node, replacement_node in _get_conv_bn_pattern_nodes(r).values():
|
||||
# Step (3a): Copy over metadata for all nodes in [conv - bn - getitem]
|
||||
replacement_node.meta = original_node.meta
|
||||
if _is_conv(original_node):
|
||||
if _is_conv_node(original_node):
|
||||
# Step (3b): Copy over conv literal args
|
||||
_copy_over_literal_conv_args(original_node, replacement_node)
|
||||
# Step (3c): Update old references in the conv node's input_qspec_map
|
||||
@ -780,7 +780,7 @@ def _fold_conv_bn_qat_helper(
|
||||
|
||||
# Copy over literal args for conv
|
||||
for original_node in _filter_nodes_map(r.nodes_map).values():
|
||||
if _is_conv(original_node):
|
||||
if _is_conv_node(original_node):
|
||||
_copy_over_literal_conv_args(original_node, conv_node)
|
||||
|
||||
m.graph.eliminate_dead_code()
|
||||
|
||||
@ -166,24 +166,24 @@ def _is_supported_batch_norm_for_training(node: Node):
|
||||
]
|
||||
return node.target in supported_ops
|
||||
|
||||
# TODO: rename this to _is_conv_node
|
||||
def _is_conv(n: Node):
|
||||
# TODO: move this to torch/ao/quantization/utils.py
|
||||
def _is_conv_node(n: Node):
|
||||
"""
|
||||
Return whether the node refers to an aten conv op.
|
||||
Return whether the node refers to an aten conv or conv transpose op.
|
||||
"""
|
||||
return n.op == "call_function" and n.target in [
|
||||
torch.ops.aten.conv1d.default,
|
||||
torch.ops.aten.conv2d.default,
|
||||
]
|
||||
|
||||
# TODO: rename this to _is_conv_transpose_node
|
||||
def _is_conv_transpose(n: Node):
|
||||
def _is_conv_transpose_node(n: Node):
|
||||
"""
|
||||
Return whether the node refers to an aten conv_transpose op.
|
||||
"""
|
||||
return n.op == "call_function" and n.target in [
|
||||
torch.ops.aten.conv_transpose1d,
|
||||
torch.ops.aten.conv_transpose2d,
|
||||
torch.ops.aten.conv_transpose2d.input,
|
||||
]
|
||||
|
||||
def _is_bn_node(n: Node):
|
||||
@ -199,7 +199,7 @@ def fold_bn_weights_into_conv_node(
|
||||
# conv args: input, weight, bias, stride, padding, dilation, ...
|
||||
conv_w = _get_tensor_constant_from_node(conv_weight_node, m)
|
||||
conv_b = _get_tensor_constant_from_node(conv_bias_node, m)
|
||||
transpose = _is_conv_transpose(conv_node)
|
||||
transpose = _is_conv_transpose_node(conv_node)
|
||||
|
||||
# eval bn args: input, weight, bias, running mean, running var, momentum, eps
|
||||
# train bn args: input, weight, bias, running mean, running var, training, momentum, eps
|
||||
@ -270,7 +270,7 @@ def _fuse_conv_bn_(m: GraphModule) -> None:
|
||||
continue
|
||||
bn_node = n
|
||||
n = bn_node.args[0]
|
||||
if not _is_conv(n):
|
||||
if not (_is_conv_node(n) or _is_conv_transpose_node(n)):
|
||||
continue
|
||||
conv_node = n
|
||||
conv_weight_node = conv_node.args[1]
|
||||
|
||||
@ -329,12 +329,12 @@ def _annotate_conv(
|
||||
return annotated_partitions
|
||||
|
||||
|
||||
@register_annotator("conv_relu")
|
||||
def _annotate_conv_relu(
|
||||
def _do_annotate_conv_relu(
|
||||
gm: torch.fx.GraphModule,
|
||||
quantization_config: Optional[QuantizationConfig],
|
||||
filter_fn: Optional[Callable[[Node], bool]] = None,
|
||||
) -> Optional[List[List[Node]]]:
|
||||
is_conv_transpose: bool = False,
|
||||
):
|
||||
annotated_partitions = []
|
||||
for n in gm.graph.nodes:
|
||||
if n.op != "call_function" or n.target not in [
|
||||
@ -344,14 +344,21 @@ def _annotate_conv_relu(
|
||||
continue
|
||||
relu_node = n
|
||||
maybe_conv_node = n.args[0]
|
||||
if (
|
||||
not isinstance(maybe_conv_node, Node)
|
||||
or maybe_conv_node.op != "call_function"
|
||||
or maybe_conv_node.target
|
||||
not in [
|
||||
# TODO: refactor with is_conv_node and is_conv_transpose_node
|
||||
if is_conv_transpose:
|
||||
conv_ops = [
|
||||
torch.ops.aten.conv_transpose1d,
|
||||
torch.ops.aten.conv_transpose2d.input,
|
||||
]
|
||||
else:
|
||||
conv_ops = [
|
||||
torch.ops.aten.conv1d.default,
|
||||
torch.ops.aten.conv2d.default,
|
||||
]
|
||||
if (
|
||||
not isinstance(maybe_conv_node, Node)
|
||||
or maybe_conv_node.op != "call_function"
|
||||
or maybe_conv_node.target not in conv_ops
|
||||
):
|
||||
continue
|
||||
conv_node = maybe_conv_node
|
||||
@ -390,6 +397,28 @@ def _annotate_conv_relu(
|
||||
return annotated_partitions
|
||||
|
||||
|
||||
@register_annotator("conv_relu")
|
||||
def _annotate_conv_relu(
|
||||
gm: torch.fx.GraphModule,
|
||||
quantization_config: Optional[QuantizationConfig],
|
||||
filter_fn: Optional[Callable[[Node], bool]] = None,
|
||||
) -> Optional[List[List[Node]]]:
|
||||
return _do_annotate_conv_relu(
|
||||
gm, quantization_config, filter_fn, is_conv_transpose=False
|
||||
)
|
||||
|
||||
|
||||
@register_annotator("conv_transpose_relu")
|
||||
def _annotate_conv_transpose_relu(
|
||||
gm: torch.fx.GraphModule,
|
||||
quantization_config: Optional[QuantizationConfig],
|
||||
filter_fn: Optional[Callable[[Node], bool]] = None,
|
||||
) -> Optional[List[List[Node]]]:
|
||||
return _do_annotate_conv_relu(
|
||||
gm, quantization_config, filter_fn, is_conv_transpose=True
|
||||
)
|
||||
|
||||
|
||||
@register_annotator("conv_bn")
|
||||
def _annotate_conv_bn(
|
||||
gm: torch.fx.GraphModule,
|
||||
@ -416,11 +445,42 @@ def _annotate_conv_bn_relu(
|
||||
return _do_annotate_conv_bn(gm, quantization_config, filter_fn, has_relu=True)
|
||||
|
||||
|
||||
@register_annotator("conv_transpose_bn")
|
||||
def _annotate_conv_transpose_bn(
|
||||
gm: torch.fx.GraphModule,
|
||||
quantization_config: Optional[QuantizationConfig],
|
||||
filter_fn: Optional[Callable[[Node], bool]] = None,
|
||||
) -> Optional[List[List[Node]]]:
|
||||
"""
|
||||
Find conv_transpose + batchnorm parititions
|
||||
Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
|
||||
"""
|
||||
return _do_annotate_conv_bn(
|
||||
gm, quantization_config, filter_fn, has_relu=False, is_conv_transpose=True
|
||||
)
|
||||
|
||||
|
||||
@register_annotator("conv_transpose_bn_relu")
|
||||
def _annotate_conv_transpose_bn_relu(
|
||||
gm: torch.fx.GraphModule,
|
||||
quantization_config: Optional[QuantizationConfig],
|
||||
filter_fn: Optional[Callable[[Node], bool]] = None,
|
||||
) -> Optional[List[List[Node]]]:
|
||||
"""
|
||||
Find conv_transpose + batchnorm + relu parititions
|
||||
Note: This is only used for QAT. In PTQ, batchnorm should already be fused into the conv.
|
||||
"""
|
||||
return _do_annotate_conv_bn(
|
||||
gm, quantization_config, filter_fn, has_relu=True, is_conv_transpose=True
|
||||
)
|
||||
|
||||
|
||||
def _do_annotate_conv_bn(
|
||||
gm: torch.fx.GraphModule,
|
||||
quantization_config: Optional[QuantizationConfig],
|
||||
filter_fn: Optional[Callable[[Node], bool]],
|
||||
has_relu: bool,
|
||||
is_conv_transpose: bool = False,
|
||||
) -> List[List[Node]]:
|
||||
"""
|
||||
Given a function that takes in a `conv_fn` and returns a conv-bn[-relu] pattern,
|
||||
@ -454,22 +514,28 @@ def _do_annotate_conv_bn(
|
||||
gm.recompile()
|
||||
|
||||
matches = []
|
||||
combinations = [
|
||||
(F.conv1d, _conv1d_bn_example_inputs),
|
||||
(F.conv2d, _conv2d_bn_example_inputs),
|
||||
]
|
||||
if is_conv_transpose:
|
||||
combinations = [
|
||||
(F.conv_transpose1d, _conv1d_bn_example_inputs),
|
||||
(F.conv_transpose2d, _conv2d_bn_example_inputs),
|
||||
]
|
||||
else:
|
||||
combinations = [
|
||||
(F.conv1d, _conv1d_bn_example_inputs), # type: ignore[list-item]
|
||||
(F.conv2d, _conv2d_bn_example_inputs), # type: ignore[list-item]
|
||||
]
|
||||
|
||||
# Add `is_cuda` and `relu_is_inplace` dimensions
|
||||
combinations = itertools.product(
|
||||
combinations = itertools.product( # type: ignore[assignment]
|
||||
combinations,
|
||||
[True, False] if torch.cuda.is_available() else [False], # is_cuda
|
||||
[True, False] if has_relu else [False], # relu_is_inplace
|
||||
)
|
||||
|
||||
# Match against all conv dimensions and cuda variants
|
||||
for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations:
|
||||
pattern = get_pattern(conv_fn, relu_is_inplace)
|
||||
pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda)
|
||||
for (conv_fn, example_inputs), is_cuda, relu_is_inplace in combinations: # type: ignore[misc]
|
||||
pattern = get_pattern(conv_fn, relu_is_inplace) # type: ignore[has-type]
|
||||
pattern = _get_aten_graph_module_for_pattern(pattern, example_inputs, is_cuda) # type: ignore[has-type]
|
||||
pattern.graph.eliminate_dead_code()
|
||||
pattern.recompile()
|
||||
matcher = SubgraphMatcherWithNameNodeMap(pattern, ignore_literals=True)
|
||||
|
||||
@ -1175,6 +1175,7 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
|
||||
fx_qconfig_mapping=None,
|
||||
export_with_dynamic_shape=False,
|
||||
is_qat=False,
|
||||
is_debug_mode=False,
|
||||
):
|
||||
# resetting dynamo cache
|
||||
torch._dynamo.reset()
|
||||
@ -1199,6 +1200,8 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
|
||||
# Calibrate
|
||||
m(*example_inputs)
|
||||
m = convert_pt2e(m)
|
||||
if is_debug_mode:
|
||||
print("quantized model", m)
|
||||
|
||||
pt2_quant_output = m(*example_inputs)
|
||||
ns = NodeSpec
|
||||
@ -2708,6 +2711,27 @@ class TestHelperModules:
|
||||
x = self.bn(x)
|
||||
return self.relu(x)
|
||||
|
||||
class ConvTWithBNRelu(torch.nn.Module):
|
||||
def __init__(self, relu, dim=2, bn=True, bias=True):
|
||||
super().__init__()
|
||||
convts = {1: torch.nn.ConvTranspose1d, 2: torch.nn.ConvTranspose2d}
|
||||
bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d}
|
||||
self.convt = convts[dim](3, 3, 3, bias=bias)
|
||||
|
||||
if bn:
|
||||
self.bn = bns[dim](3)
|
||||
else:
|
||||
self.bn = torch.nn.Identity()
|
||||
if relu:
|
||||
self.relu = torch.nn.ReLU()
|
||||
else:
|
||||
self.relu = torch.nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.convt(x)
|
||||
x = self.bn(x)
|
||||
return self.relu(x)
|
||||
|
||||
class Conv2dThenConv1d(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user