From 1b51d29b6630e2321804ab103ddec8bf29e475af Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 26 Sep 2023 19:09:30 -0700 Subject: [PATCH] [quant][pt2e] Enable constant folding for quantize ops (#109343) Summary: This PR added constant folding for quantize ops so that instead of storing fp32 weight in the quantized model, we'll get int8/int16 etc. weight Test Plan: python test/test_quantization.py TestQuantizePT2E.test_fold_quantize also will verify in executorch later Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D49399210](https://our.internmc.facebook.com/intern/diff/D49399210) Pull Request resolved: https://github.com/pytorch/pytorch/pull/109343 Approved by: https://github.com/kimishpatel, https://github.com/jgong5 --- test/quantization/pt2e/test_duplicate_dq.py | 27 +- .../pt2e/test_metadata_porting.py | 50 ++- test/quantization/pt2e/test_quantize_pt2e.py | 285 ++++++++++++------ test/quantization/pt2e/test_representation.py | 13 +- .../pt2e/test_x86inductor_quantizer.py | 54 ++-- test/test_quantization.py | 2 + torch/_inductor/constant_folding.py | 19 +- .../quantization/pt2e/port_metadata_pass.py | 13 + torch/ao/quantization/pt2e/utils.py | 2 +- torch/ao/quantization/quantize_pt2e.py | 37 +++ .../quantizer/xnnpack_quantizer.py | 5 +- 11 files changed, 359 insertions(+), 148 deletions(-) diff --git a/test/quantization/pt2e/test_duplicate_dq.py b/test/quantization/pt2e/test_duplicate_dq.py index bf6be1c9ace8..d51a9b4e1e91 100644 --- a/test/quantization/pt2e/test_duplicate_dq.py +++ b/test/quantization/pt2e/test_duplicate_dq.py @@ -1,5 +1,6 @@ # Owner(s): ["oncall: quantization"] import copy +import unittest import torch import torch._export as export @@ -9,7 +10,6 @@ from torch.ao.quantization.observer import ( MinMaxObserver, PlaceholderObserver, ) -from torch.ao.quantization.pt2e.utils import _find_q_dq_node_for_user from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.ao.quantization.quantizer import ( QuantizationAnnotation, @@ -26,6 +26,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import ( ) from torch.testing._internal.common_quantization import QuantizationTestCase +from torch.testing._internal.common_utils import IS_WINDOWS class TestHelperModules: @@ -81,6 +82,14 @@ class TestHelperModules: return w, add_output, extra_output +_DEQUANTIZE_OPS = [ + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_channel.default, +] + + +@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile") class TestDuplicateDQPass(QuantizationTestCase): def _test_duplicate_dq( self, @@ -106,19 +115,9 @@ class TestDuplicateDQPass(QuantizationTestCase): for n in m.graph.nodes: annotation = n.meta.get("quantization_annotation", None) if annotation is not None: - input_qspec_map = annotation.input_qspec_map - for input_node, qspec in input_qspec_map.items(): - if ( - qspec is not None - and hasattr(qspec, "dtype") - and qspec.dtype != torch.float - ): - q_node, dq_node = _find_q_dq_node_for_user(input_node, n) - if dq_node is None: - raise ValueError( - f"No dq node found for {n}, even though {n} annotated for quantization." - ) - self.assertEqual(len(dq_node.users.keys()), 1) + for arg in n.args: + if isinstance(arg, torch.fx.Node) and arg.target in _DEQUANTIZE_OPS: + self.assertEqual(len(arg.users.keys()), 1) def test_no_need_for_duplicate_dq(self): """ diff --git a/test/quantization/pt2e/test_metadata_porting.py b/test/quantization/pt2e/test_metadata_porting.py index 900e60733f43..c3641047626e 100644 --- a/test/quantization/pt2e/test_metadata_porting.py +++ b/test/quantization/pt2e/test_metadata_porting.py @@ -16,6 +16,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTA from torch.fx import Node from torch.testing._internal.common_quantization import QuantizationTestCase +from torch.testing._internal.common_utils import IS_WINDOWS class TestHelperModules: @@ -57,6 +58,8 @@ _QUANT_OPS = { } +# TODO: rename to TestPortMetadataPass to align with the util name? +@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile") class TestMetaDataPorting(QuantizationTestCase): def _test_metadata_porting( self, @@ -77,23 +80,33 @@ class TestMetaDataPorting(QuantizationTestCase): m = prepare_pt2e(m, quantizer) # Calibrate m(*example_inputs) - m = convert_pt2e(m) + m = convert_pt2e(m, fold_quantize=True) pt2_quant_output = m(*example_inputs) recorded_node_tags = {} for n in m.graph.nodes: + if "quantization_tag" not in n.meta: + continue + if n.op == "call_function" and n.target in _QUANT_OPS: + key = n.target + elif n.op == "get_attr": + key = "get_attr" + else: + continue + + if key not in recorded_node_tags: + recorded_node_tags[key] = set() + if ( n.op == "call_function" - and n.target in _QUANT_OPS - and "quantization_tag" in n.meta + and n.meta["quantization_tag"] in recorded_node_tags[key] ): - if n.target not in recorded_node_tags: - recorded_node_tags[n.target] = set() - if n.meta["quantization_tag"] in recorded_node_tags[n.target]: - raise ValueError( - f"{n} has tag {n.meta['quantization_tag']} that is associated with another node of the same type" - ) - recorded_node_tags[n.target].add(n.meta["quantization_tag"]) + raise ValueError( + f"{key} {n.format_node()} has tag {n.meta['quantization_tag']} that " + "is associated with another node of the same type" + ) + recorded_node_tags[key].add(n.meta["quantization_tag"]) + self.assertEqual(set(recorded_node_tags.keys()), set(node_tags.keys())) for k, v in recorded_node_tags.items(): self.assertEqual(v, node_tags[k]) @@ -130,6 +143,10 @@ class TestMetaDataPorting(QuantizationTestCase): pass example_inputs = (torch.randn(1, 3, 5, 5),) + get_attr_tags = { + "BackendA_conv2d_0", + "BackendA_linear_0", + } quantize_per_tensor_tags = { "BackendA_conv2d_0", "BackendA_adaptive_avg_pool2d_0", @@ -142,6 +159,7 @@ class TestMetaDataPorting(QuantizationTestCase): } dequantize_per_channel_tags = {"BackendA_conv2d_0", "BackendA_linear_0"} node_tags = { + "get_attr": get_attr_tags, torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags, torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags, torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags, @@ -180,10 +198,12 @@ class TestMetaDataPorting(QuantizationTestCase): pass example_inputs = (torch.randn(1, 3, 5, 5),) + get_attr_tags = {"BackendA_conv2d_0", "BackendA_linear_0"} quantize_per_tensor_tags = {"BackendA_conv2d_0", "BackendA_linear_0"} dequantize_per_tensor_tags = {"BackendA_conv2d_0", "BackendA_linear_0"} dequantize_per_channel_tags = {"BackendA_conv2d_0", "BackendA_linear_0"} node_tags = { + "get_attr": get_attr_tags, torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags, torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags, torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags, @@ -236,6 +256,8 @@ class TestMetaDataPorting(QuantizationTestCase): pass example_inputs = (torch.randn(1, 3, 5, 5),) + # TODO: add get_attr_tags when the test is re-enabled + get_attr_tags = {} quantize_per_tensor_tags = { "BackendA_conv2d_0", "BackendA_adaptive_avg_pool2d_0", @@ -252,6 +274,7 @@ class TestMetaDataPorting(QuantizationTestCase): "BackendA_linear_dynamic_0", } node_tags = { + "get_attr": get_attr_tags, torch.ops.quantized_decomposed.quantize_per_tensor.default: quantize_per_tensor_tags, torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags, torch.ops.quantized_decomposed.dequantize_per_tensor.default: dequantize_per_tensor_tags, @@ -295,6 +318,10 @@ class TestMetaDataPorting(QuantizationTestCase): pass example_inputs = (torch.randn(1, 3, 5, 5),) + get_attr_tags = { + "BackendA_conv2d_dynamic_0", + "BackendA_linear_dynamic_0", + } choose_qparams_tensor_tags = { "BackendA_conv2d_dynamic_0", "BackendA_linear_dynamic_0", @@ -312,6 +339,7 @@ class TestMetaDataPorting(QuantizationTestCase): "BackendA_linear_dynamic_0", } node_tags = { + "get_attr": get_attr_tags, torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags, torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags, @@ -349,11 +377,13 @@ class TestMetaDataPorting(QuantizationTestCase): pass example_inputs = (torch.randn(1, 3, 5, 5),) + get_attr_tags = {"BackendA_linear_dynamic_0"} choose_qparams_tensor_tags = {"BackendA_linear_dynamic_0"} quantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"} dequantize_per_tensor_tensor_tags = {"BackendA_linear_dynamic_0"} dequantize_per_channel_tags = {"BackendA_linear_dynamic_0"} node_tags = { + "get_attr": get_attr_tags, torch.ops.quantized_decomposed.quantize_per_tensor.tensor: quantize_per_tensor_tensor_tags, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: dequantize_per_tensor_tensor_tags, torch.ops.quantized_decomposed.dequantize_per_channel.default: dequantize_per_channel_tags, diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 59baeaf162cb..d526ea3a750d 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -91,6 +91,31 @@ class PT2EQuantizationTestCase(QuantizationTestCase): torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, } + def _quantize(self, m, quantizer, example_inputs): + m = capture_pre_autograd_graph( + m, + example_inputs, + ) + m = prepare_pt2e(m, quantizer) + m(*example_inputs) + m = convert_pt2e(m, fold_quantize=True) + return m + + def _get_pt2e_quantized_linear(self, is_per_channel=False) -> torch.fx.GraphModule: + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(2, 2) + + def forward(self, x): + return self.linear(x) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config(is_per_channel=is_per_channel) + quantizer.set_global(operator_config) + example_inputs = (torch.randn(2, 2),) + m = M().eval() + return self._quantize(m, quantizer, example_inputs) def _test_quantizer( self, @@ -118,7 +143,7 @@ class PT2EQuantizationTestCase(QuantizationTestCase): m = prepare_pt2e(m, quantizer) # Calibrate m(*example_inputs) - m = convert_pt2e(m) + m = convert_pt2e(m, fold_quantize=True) pt2_quant_output = m(*example_inputs) node_occurrence = { @@ -209,7 +234,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): example_inputs = (torch.randn(1, 3, 5, 5),) node_occurrence = { # two for input of the first conv, one for output for the first conv - torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, } node_list = [ @@ -277,18 +302,11 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): m = torch.nn.Conv2d(2, 2, 1) x = torch.rand(1, 2, 14, 14) example_inputs = (x,) - # program capture - m = capture_pre_autograd_graph( - m, - example_inputs, - ) - m = prepare_pt2e(m, BackendAQuantizer()) - m(*example_inputs) - m = convert_pt2e(m) + m = self._quantize(m, BackendAQuantizer(), example_inputs) # Ensure the conv has no observer inserted at output node_occurrence = { # two for input of conv - ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2, + ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 1, ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2, } node_list = [ @@ -369,19 +387,12 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): m = TestHelperModules.ConvMaxPool2d() x = torch.rand(1, 2, 14, 14) example_inputs = (x,) - # program capture - m = capture_pre_autograd_graph( - m, - example_inputs, - ) - m = prepare_pt2e(m, BackendAQuantizer()) - m(*example_inputs) - m = convert_pt2e(m) + m = self._quantize(m, BackendAQuantizer(), example_inputs) node_occurrence = { # two for input of conv # one for input of maxpool # one for output of maxpool - ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 4, + ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 3, ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 4, } node_list = [ @@ -469,19 +480,13 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): m = TestHelperModules.ConvWithBNRelu(relu=False, bn=False).eval() example_inputs = (torch.randn(1, 3, 5, 5),) - # program capture - m = capture_pre_autograd_graph( - m, - example_inputs, - ) - m = prepare_pt2e(m, BackendAQuantizer()) - m(*example_inputs) - m = convert_pt2e(m) + m = self._quantize(m, BackendAQuantizer(), example_inputs) node_occurrence = { # input, weight, bias, output for the conv + # note: quantize op for weight and bias are const propagated ns.call_function( torch.ops.quantized_decomposed.quantize_per_tensor.default - ): 4, + ): 2, ns.call_function( torch.ops.quantized_decomposed.dequantize_per_tensor.default ): 4, @@ -575,15 +580,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): m = TestHelperModules.ConvWithBNRelu(relu=False, bn=False).eval() example_inputs = (torch.randn(1, 3, 5, 5),) - # program capture - m = capture_pre_autograd_graph( - m, - example_inputs, - ) - m = prepare_pt2e(m, BackendAQuantizer()) - m(*example_inputs) - m = convert_pt2e(m) - m(*example_inputs) + m = self._quantize(m, BackendAQuantizer(), example_inputs) node_occurrence = { # input, output for the conv @@ -594,9 +591,10 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): torch.ops.quantized_decomposed.dequantize_per_tensor.default ): 2, # weight and bias for conv + # note: quantize op for weight and bias are const propagated ns.call_function( torch.ops.quantized_decomposed.quantize_per_channel.default - ): 2, + ): 0, ns.call_function( torch.ops.quantized_decomposed.dequantize_per_channel.default ): 2, @@ -653,14 +651,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): m = M().eval() example_inputs = (torch.randn(1, 3, 5, 5),) - # program capture - m = capture_pre_autograd_graph( - m, - example_inputs, - ) - m = prepare_pt2e(m, BackendAQuantizer()) - m(*example_inputs) - m = convert_pt2e(m) + m = self._quantize(m, BackendAQuantizer(), example_inputs) fixed_scale = 1.0 / 256.0 fixed_zero_point = 0 for n in m.graph.nodes: @@ -805,13 +796,13 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): assert conv_output_obs[0] == conv_output_obs[1] m(*example_inputs) - m = convert_pt2e(m) + m = convert_pt2e(m, fold_quantize=True) node_occurrence = { # two for input of the first conv, one for output for the first conv ns.call_function( torch.ops.quantized_decomposed.quantize_per_tensor.default - ): 7, + ): 5, ns.call_function( torch.ops.quantized_decomposed.dequantize_per_tensor.default ): 7, @@ -873,8 +864,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): quantizer = Int16ActQuantizer() node_occurrence = { - # two for input of the first conv, one for output for the first conv - torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, + # one for input of the first conv, one for output for the first conv + torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, } node_list = [ @@ -892,6 +883,127 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): node_list, ) + def test_fold_quantize(self): + """Test to make sure the quantized model gets quantized weight (quantize_per_tensor op is folded) + """ + m = self._get_pt2e_quantized_linear() + node_occurrence = { + # quantize op for weight node is folded + ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2, + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 3, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + + def test_fold_quantize_per_channel(self): + """Test to make sure the quantized model gets quantized weight (quantize_per_channel op is folded) + """ + m = self._get_pt2e_quantized_linear(is_per_channel=True) + node_occurrence = { + # quantize op for weight node is folded + ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2, + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_channel.default): 1, + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 2, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + + def test_dont_fold_other_constant(self): + """Make sure the constant propagation does not apply to things unrelated to + quantization + """ + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(2, 2) + self.dont_fold_me = torch.nn.Parameter(torch.randn(2, 2)) + + def forward(self, x): + t = self.dont_fold_me.t() + return self.linear(x) + t + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config(is_per_channel=False) + # only quantize linear, so add is not quantized and the constant Tensor + # should not be folded + quantizer.set_module_type(torch.nn.Linear, operator_config) + example_inputs = (torch.randn(2, 2),) + m = M().eval() + m = self._quantize(m, quantizer, example_inputs) + node_occurrence = { + # quantize op for weight node is folded + ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2, + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 3, + # transpose op not folded + ns.call_function(torch.ops.aten.t.default): 1, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + + def test_fold_all_ops_before_quantize(self): + """Test folding all ops that's before quantized operator: + Before: + get_attr(weight) -> transpose -> quantize -> dequantize + After: + get_attr(folded_weight) -> dequantize + """ + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(2, 2) + + def forward(self, x): + t = self.weight.t() + return torch.nn.functional.linear(x, t) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config(is_per_channel=False) + quantizer.set_global(operator_config) + example_inputs = (torch.randn(2, 2),) + m = M().eval() + m = self._quantize(m, quantizer, example_inputs) + node_occurrence = { + # quantize op for weight node is folded + ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor.default): 2, + ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor.default): 3, + } + self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + + def test_constant_prop_preserve_metadata(self): + """Test to make sure the get_attr node for const propagated weight Tensor gets the correct + metadata (from original get_attr node from weight) + """ + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(2, 2) + + def forward(self, x): + return self.linear(x) + + quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config() + quantizer.set_global(operator_config) + example_inputs = (torch.randn(2, 2),) + m = M().eval() + m = capture_pre_autograd_graph( + m, + example_inputs, + ) + weight_meta = None + for n in m.graph.nodes: + if n.op == "get_attr" and list(n.users)[0].target == torch.ops.aten.linear.default: + weight_meta = n.meta + break + assert weight_meta is not None, "Expect to find metadata for weight node" + + m = prepare_pt2e(m, quantizer) + m(*example_inputs) + m = convert_pt2e(m, fold_quantize=True) + + for n in m.graph.nodes: + if n.op == "get_attr" and "frozen_param" in n.target: + self.assertIn("stack_trace", n.meta) + for key in n.meta: + self.assertEqual(n.meta[key], weight_meta[key]) + def test_add_and_inplace_add(self): quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config(is_per_channel=True) @@ -922,27 +1034,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): def test_save_load(self): """Test save/load a quantized model """ - class M(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(2, 2) - - def forward(self, x): - return self.linear(x) - - quantizer = XNNPACKQuantizer() - operator_config = get_symmetric_quantization_config(is_per_channel=False) - quantizer.set_global(operator_config) + m = self._get_pt2e_quantized_linear() example_inputs = (torch.randn(2, 2),) - m = M().eval() - m = capture_pre_autograd_graph( - m, - example_inputs, - ) - m = prepare_pt2e(m, quantizer) - # Calibrate - m(*example_inputs) - m = convert_pt2e(m) ref_res = m(*example_inputs) with TemporaryFileName() as fname: @@ -991,7 +1084,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): # input and output are using quantize_per_tensor and weight is using quantize_per_channel torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, - torch.ops.quantized_decomposed.quantize_per_channel.default: 1, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ @@ -1021,7 +1115,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): # input and output are using quantize_per_tensor and weight is using quantize_per_channel torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, - torch.ops.quantized_decomposed.quantize_per_channel.default: 2, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } qconfig = default_per_channel_symmetric_qnnpack_qconfig @@ -1045,7 +1140,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): # input and output are using quantize_per_tensor and weight is using quantize_per_channel torch.ops.quantized_decomposed.quantize_per_tensor.default: 5, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5, - torch.ops.quantized_decomposed.quantize_per_channel.default: 3, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, } qconfig = default_per_channel_symmetric_qnnpack_qconfig @@ -1072,7 +1168,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): node_occurrence = { torch.ops.quantized_decomposed.quantize_per_tensor.default: 5, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5, - torch.ops.quantized_decomposed.quantize_per_channel.default: 3, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, } qconfig = default_per_channel_symmetric_qnnpack_qconfig @@ -1099,7 +1196,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): # input and output are using quantize_per_tensor and weight is using quantize_per_channel torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, - torch.ops.quantized_decomposed.quantize_per_channel.default: 2, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } qconfig = default_per_channel_symmetric_qnnpack_qconfig @@ -1125,7 +1223,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): # input and output are using quantize_per_tensor and weight is using quantize_per_channel torch.ops.quantized_decomposed.quantize_per_tensor.default: 5, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 5, - torch.ops.quantized_decomposed.quantize_per_channel.default: 1, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ @@ -1252,7 +1351,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): self.assertEqual( id(m.activation_post_process_3), id(m.activation_post_process_4) ) - m = convert_pt2e(m) + m = convert_pt2e(m, fold_quantize=True) node_occurrence = { # input and output are using quantize_per_tensor and weight is using quantize_per_channel ns.call_function( @@ -1261,9 +1360,10 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): ns.call_function( torch.ops.quantized_decomposed.dequantize_per_tensor.default ): 5, + # note: quantize op for weights are const propagated ns.call_function( torch.ops.quantized_decomposed.quantize_per_channel.default - ): 2, + ): 0, ns.call_function( torch.ops.quantized_decomposed.dequantize_per_channel.default ): 2, @@ -1282,7 +1382,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): # input and output are using quantize_per_tensor and weight is using quantize_per_channel torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 2, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 2, - torch.ops.quantized_decomposed.quantize_per_channel.default: 2, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } act_affine_quant_obs = observer.PlaceholderObserver.with_args( @@ -1324,7 +1425,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): # input and output are using quantize_per_tensor and weight is using quantize_per_channel torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, - torch.ops.quantized_decomposed.quantize_per_tensor.default: 1, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1, } act_affine_quant_obs = observer.PlaceholderObserver.with_args( @@ -1372,9 +1474,11 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): node_occurrence = { torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, - torch.ops.quantized_decomposed.quantize_per_tensor.default: 4, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, - torch.ops.quantized_decomposed.quantize_per_channel.default: 1, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } act_affine_quant_obs = observer.PlaceholderObserver.with_args( @@ -1473,11 +1577,11 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): quantizer = EmbeddingQuantizer() node_occurrence = { - torch.ops.quantized_decomposed.quantize_per_channel.default: 1, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ - torch.ops.quantized_decomposed.quantize_per_channel.default, torch.ops.quantized_decomposed.dequantize_per_channel.default, torch.ops.aten.embedding.default, ] @@ -1577,7 +1681,8 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, torch.ops.quantized_decomposed.quantize_per_tensor.tensor: 1, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor: 1, - torch.ops.quantized_decomposed.quantize_per_channel.default: 3, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, } self._test_quantizer( @@ -1644,7 +1749,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): m.train() # After convert: still not OK - m = convert_pt2e(m) + m = convert_pt2e(m, fold_quantize=True) with self.assertRaises(NotImplementedError): m.eval() with self.assertRaises(NotImplementedError): @@ -1705,7 +1810,7 @@ class TestQuantizePT2EOps(QuantizationTestCase): quantizer.set_global(quantization_config) model_graph = prepare_pt2e(model_graph, quantizer) model_graph(*example_inputs) - model_graph = convert_pt2e(model_graph) + model_graph = convert_pt2e(model_graph, fold_quantize=True) self.assertEqual(model_fx(*example_inputs), model_graph(*example_inputs)) @@ -1767,7 +1872,7 @@ class TestQuantizePT2EOps(QuantizationTestCase): quantizer.set_global(quantization_config) model_graph = prepare_pt2e(model_graph, quantizer) model_graph(*example_inputs) - model_graph = convert_pt2e(model_graph) + model_graph = convert_pt2e(model_graph, fold_quantize=True) self.assertEqual(model_fx(*example_inputs), model_graph(*example_inputs)) @@ -1798,7 +1903,7 @@ class TestQuantizePT2EModels(PT2EQuantizationTestCase): id(m.activation_post_process_3), id(m.activation_post_process_2) ) after_prepare_result = m(*example_inputs) - m = convert_pt2e(m) + m = convert_pt2e(m, fold_quantize=True) after_quant_result = m(*example_inputs) diff --git a/test/quantization/pt2e/test_representation.py b/test/quantization/pt2e/test_representation.py index 142edbfee4e6..198146a321e8 100644 --- a/test/quantization/pt2e/test_representation.py +++ b/test/quantization/pt2e/test_representation.py @@ -29,7 +29,7 @@ class TestPT2ERepresentation(QuantizationTestCase): ref_node_occurrence: Dict[ns, int], non_ref_node_occurrence: Dict[ns, int], fixed_output_tol: float = None, - output_scale_idx: int = 3, + output_scale_idx: int = 2, ) -> torch.nn.Module: # resetting dynamo cache torch._dynamo.reset() @@ -42,7 +42,9 @@ class TestPT2ERepresentation(QuantizationTestCase): model = prepare_pt2e(model, quantizer) # Calibrate model(*example_inputs) - model = convert_pt2e(model, use_reference_representation=True) + model = convert_pt2e( + model, use_reference_representation=True, fold_quantize=True + ) self.checkGraphModuleNodes(model, expected_node_occurrence=ref_node_occurrence) # make sure it runs pt2e_quant_output = model(*example_inputs) @@ -52,7 +54,9 @@ class TestPT2ERepresentation(QuantizationTestCase): model_copy = prepare_pt2e(model_copy, quantizer) # Calibrate model_copy(*example_inputs) - model_copy = convert_pt2e(model_copy, use_reference_representation=False) + model_copy = convert_pt2e( + model_copy, use_reference_representation=False, fold_quantize=True + ) self.checkGraphModuleNodes( model_copy, expected_node_occurrence=non_ref_node_occurrence ) @@ -253,9 +257,10 @@ class TestPT2ERepresentation(QuantizationTestCase): ): 0, } non_ref_node_occurrence = { + # quantize_per_channel is folded ns.call_function( torch.ops.quantized_decomposed.quantize_per_channel.default - ): 1, + ): 0, ns.call_function( torch.ops.quantized_decomposed.dequantize_per_channel.default ): 1, diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 004e2dd0814a..fd596dcdcc97 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -253,7 +253,7 @@ class X86InductorQuantTestCase(QuantizationTestCase): # Calibrate m(*example_inputs) prepare_model = copy.deepcopy(m) - m = convert_pt2e(m) + m = convert_pt2e(m, fold_quantize=True) convert_model = copy.deepcopy(m) pt2_quant_output = m(*example_inputs) node_occurrence = { @@ -284,7 +284,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase): # one for input and weight of the conv, one for output for the conv torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, - torch.ops.quantized_decomposed.quantize_per_channel.default: 1, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ @@ -321,7 +322,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase): # one for input and weight of the conv, one for output for the relu torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, - torch.ops.quantized_decomposed.quantize_per_channel.default: 1, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ @@ -361,19 +363,21 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase): # one for extra input node of add torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, - torch.ops.quantized_decomposed.quantize_per_channel.default: 1, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } else: node_occurrence = { - # one for input and weight of the conv - # one for input and weight of another conv + # one for input of the conv + # one for input of another conv # one for output for the add # 2 conv will share same input quant/dequant # one for extra input node of add torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, - torch.ops.quantized_decomposed.quantize_per_channel.default: 2, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } node_list = [ @@ -410,24 +414,26 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase): ).eval() if conv2d_type != Conv2DType.both: node_occurrence = { - # one for input and weight of the conv + # one for input for conv # one for output for the relu # one for extra input node of add torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, - torch.ops.quantized_decomposed.quantize_per_channel.default: 1, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } else: node_occurrence = { - # one for input and weight of the conv - # one for input and weight of another conv + # one for input of the conv + # one for input of another conv # one for output for the relu # 2 conv will share same input quant/dequant # one for extra input node of add torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4, - torch.ops.quantized_decomposed.quantize_per_channel.default: 2, + # note: quantize op for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } node_list = [ @@ -458,7 +464,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase): node_occurrence = { torch.ops.quantized_decomposed.quantize_per_tensor.default: 5, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 7, - torch.ops.quantized_decomposed.quantize_per_channel.default: 4, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 4, } node_list = [ @@ -498,7 +505,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase): # one for input and weight of the conv, two for input/output for the maxpool2d torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, - torch.ops.quantized_decomposed.quantize_per_channel.default: 1, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ @@ -558,7 +566,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase): node_occurrence = { torch.ops.quantized_decomposed.quantize_per_tensor.default: 7, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 7, - torch.ops.quantized_decomposed.quantize_per_channel.default: 3, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 3, } node_list = [ @@ -632,7 +641,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase): node_occurrence = { torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, - torch.ops.quantized_decomposed.quantize_per_channel.default: 1, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ @@ -688,7 +698,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase): node_occurrence = { torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, - torch.ops.quantized_decomposed.quantize_per_channel.default: 1, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ @@ -739,7 +750,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase): node_occurrence = { torch.ops.quantized_decomposed.quantize_per_tensor.default: 3, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 3, - torch.ops.quantized_decomposed.quantize_per_channel.default: 1, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ @@ -799,7 +811,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase): # one for input and weight, one for output torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, - torch.ops.quantized_decomposed.quantize_per_channel.default: 1, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ @@ -841,7 +854,8 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase): # one for input and weight of the conv, one for output for the relu torch.ops.quantized_decomposed.quantize_per_tensor.default: 2, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 2, - torch.ops.quantized_decomposed.quantize_per_channel.default: 1, + # quantize_per_channel for weights are const propagated + torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, } node_list = [ diff --git a/test/test_quantization.py b/test/test_quantization.py index 0e7d62dfa666..a60727f0db06 100644 --- a/test/test_quantization.py +++ b/test/test_quantization.py @@ -83,6 +83,8 @@ except ImportError as e: try: # To be moved to compiler side later from quantization.pt2e.test_graph_utils import TestGraphUtils # noqa: F401 + from quantization.pt2e.test_duplicate_dq import TestDuplicateDQPass # noqa: F401 + from quantization.pt2e.test_metadata_porting import TestMetaDataPorting # noqa: F401 from quantization.pt2e.test_quantize_pt2e import TestQuantizePT2E # noqa: F401 from quantization.pt2e.test_quantize_pt2e import TestQuantizePT2EOps # noqa: F401 from quantization.pt2e.test_quantize_pt2e import TestQuantizePT2EModels # noqa: F401 diff --git a/torch/_inductor/constant_folding.py b/torch/_inductor/constant_folding.py index 278b97bd7b27..05d16f1c7afe 100644 --- a/torch/_inductor/constant_folding.py +++ b/torch/_inductor/constant_folding.py @@ -1,5 +1,5 @@ import collections -from typing import Any, Dict +from typing import Any, Callable, Dict, Optional import torch import torch.utils._pytree as pytree @@ -51,11 +51,14 @@ class ConstantFolder(torch.fx.Interpreter): self.user_to_last_uses = self.node_to_last_non_output_use() def is_impure(self, node: torch.fx.node.Node): - if node.target == torch.ops.quantized_decomposed.dequantize_per_channel.default: - # For the pattern fp32_weight -> quantized_decomposed.quantize_per_channel.default - # -> quantized_decomposed.dequantize_per_channel.default - # We only folding fp32_weight -> quantized_decomposed.quantize_per_channel.default into - # int8_weight and leave quantized_decomposed.dequantize_per_channel.default in graph to be fused + if node.target in [ + torch.ops.quantized_decomposed.dequantize_per_channel.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + ]: + # For the pattern fp32_weight -> q -> dq + # We only folding fp32_weight -> q + # int8_weight and leave dq in graph to be fused return True return False @@ -163,11 +166,13 @@ class ConstantFolder(torch.fx.Interpreter): @torch.utils._python_dispatch._disable_current_modes() -def constant_fold(gm): +def constant_fold(gm, constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None): cf = ConstantFolder(gm, skip_constructors=True) cf.run() for node, constant in cf.node_replacements.items(): + if constraint_fn is not None and not constraint_fn(node): + continue replace_node_with_constant(gm, node, constant) erased_params = [] diff --git a/torch/ao/quantization/pt2e/port_metadata_pass.py b/torch/ao/quantization/pt2e/port_metadata_pass.py index 6829d8ef389f..3e5420b3c433 100644 --- a/torch/ao/quantization/pt2e/port_metadata_pass.py +++ b/torch/ao/quantization/pt2e/port_metadata_pass.py @@ -98,6 +98,19 @@ def _port_metadata_for_input_quant_nodes( q_node, dq_node = _find_q_dq_node_for_user(input_node, node) if q_node is None or dq_node is None: return + # add metadata for all the node between q_node and get_attr node + # if the q_node can be traced back to get_attr node + q_to_get_attr_nodes = [q_node] + q_node_input = q_node.args[0] + while isinstance(q_node_input, torch.fx.Node) and q_node_input.op not in [ + "placeholder", + "get_attr", + ]: + q_to_get_attr_nodes.append(q_node_input) + q_node_input = q_node_input.args[0] + if isinstance(q_node_input, torch.fx.Node) and q_node_input.op == "get_attr": + for n in q_to_get_attr_nodes: + _add_metadata(n, q_node_input) _add_metadata(dq_node, node) diff --git a/torch/ao/quantization/pt2e/utils.py b/torch/ao/quantization/pt2e/utils.py index a228629a657e..b545ec121cd9 100644 --- a/torch/ao/quantization/pt2e/utils.py +++ b/torch/ao/quantization/pt2e/utils.py @@ -56,7 +56,7 @@ def _find_q_dq_node_for_user( produer: torch.fx.Node, user: torch.fx.Node ) -> Tuple[Any, Any]: """ - Find d, dq pair corresponding to [producer ... -> q -> dq -> user] + Find q, dq pair corresponding to [producer -> q -> dq -> user] Utils works by finding dq arg of user and ensuring it is connected to producer """ diff --git a/torch/ao/quantization/quantize_pt2e.py b/torch/ao/quantization/quantize_pt2e.py index 1c4baa11632f..f5dbd5910aad 100644 --- a/torch/ao/quantization/quantize_pt2e.py +++ b/torch/ao/quantization/quantize_pt2e.py @@ -1,4 +1,6 @@ +import torch from torch.fx import GraphModule +from torch.fx import Node from .pt2e.prepare import prepare from .pt2e.qat_utils import ( @@ -24,6 +26,7 @@ from torch.ao.quantization.quantizer import ( # noqa: F401 from torch.fx.passes.infra.pass_manager import PassManager from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ +from torch._inductor.constant_folding import constant_fold __all__ = [ "prepare_pt2e", @@ -66,10 +69,40 @@ def prepare_qat_pt2e( model = _disallow_eval_train(model) return model +_QUANT_OPS = [ + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, +] +def _quant_node_constraint(n: Node) -> bool: + """If there is any pure ops between get_attr and quantize op they will be const propagated + e.g. get_attr(weight) -> transpose -> quantize -> dequantize* + (Note: dequantize op is not going to be constant propagated) + + This filter is added because we don't want to constant fold the things that are not + related to quantization + """ + return n.op == "call_function" and n.target in _QUANT_OPS + def convert_pt2e( model: GraphModule, use_reference_representation: bool = False, + fold_quantize: bool = False, ) -> GraphModule: + """Convert a calibrated/trained model to a quantized model + + Args: + model: calibrated/trained model + use_reference_representation: boolean flag to indicate whether to produce referece representation or not + fold_quantize: boolean flag to indicate whether fold the quantize op or not + + Note: please set `fold_quantize` to True whenever you can, we'll deprecate this flag and + make True the default option in the future, to make sure the change doesn't break BC for you, it's + better to set the flag to True now. + + Returns: + quantized model, either in q/dq representation or reference representation + """ original_graph_meta = model.meta model = _convert_to_reference_decomposed_fx(model) model = _fold_conv_bn_qat(model) @@ -78,6 +111,10 @@ def convert_pt2e( pm = PassManager([PortNodeMetaForQDQ()]) model = pm(model).graph_module + + if fold_quantize: + constant_fold(model, _quant_node_constraint) + if use_reference_representation: model = reference_representation_rewrite(model) diff --git a/torch/ao/quantization/quantizer/xnnpack_quantizer.py b/torch/ao/quantization/quantizer/xnnpack_quantizer.py index f1b9937d0fcd..96c2f4c07750 100644 --- a/torch/ao/quantization/quantizer/xnnpack_quantizer.py +++ b/torch/ao/quantization/quantizer/xnnpack_quantizer.py @@ -201,7 +201,8 @@ def _get_module_name_filter(module_name: str): # 'L__self___sub': ("L['self'].sub", ), # 'L__self___sub_linear': ("L['self'].sub.linear", ) # } - nn_module_stack = n.meta["nn_module_stack"] + # get_attr nodes doesn't have nn_module_stack? + nn_module_stack = n.meta.get("nn_module_stack", {}) names = [ n[len("L__self___") :].replace("_", ".") for n in nn_module_stack.keys() ] @@ -228,7 +229,7 @@ def _get_module_type_filter(tp: Callable): # 'L__self___sub': ("L['self'].sub", ), # 'L__self___sub_linear': ("L['self'].sub.linear", ) # } - nn_module_stack = n.meta["nn_module_stack"] + nn_module_stack = n.meta.get("nn_module_stack", {}) types = [t for _, t in nn_module_stack.values()] return tp in types