diff --git a/test/fx/test_matcher_utils.py b/test/fx/test_matcher_utils.py index 26caf91485e2..578e0ab07a6a 100644 --- a/test/fx/test_matcher_utils.py +++ b/test/fx/test_matcher_utils.py @@ -173,7 +173,7 @@ class TestMatcher(JitTestCase): torch.randn(3, 3, 3, 3), ) pattern_gm = export_for_training( - WrapperModule(pattern), example_inputs + WrapperModule(pattern), example_inputs, strict=True ).module() before_split_res = pattern_gm(*example_inputs) pattern_gm, _ = _split_to_graph_and_name_node_map(pattern_gm) @@ -204,11 +204,11 @@ class TestMatcher(JitTestCase): torch.randn(3, 3, 3, 3), ) pattern_gm = export_for_training( - WrapperModule(pattern), example_inputs + WrapperModule(pattern), example_inputs, strict=True ).module() matcher = SubgraphMatcherWithNameNodeMap(pattern_gm) target_gm = export_for_training( - WrapperModule(target_graph), example_inputs + WrapperModule(target_graph), example_inputs, strict=True ).module() internal_matches = matcher.match(target_gm.graph) for internal_match in internal_matches: @@ -248,9 +248,11 @@ class TestMatcher(JitTestCase): return linear, {"linear": linear, "x": x} example_inputs = (torch.randn(3, 5),) - pattern_gm = export_for_training(Pattern(), example_inputs).module() + pattern_gm = export_for_training( + Pattern(), example_inputs, strict=True + ).module() matcher = SubgraphMatcherWithNameNodeMap(pattern_gm) - target_gm = export_for_training(M(), example_inputs).module() + target_gm = export_for_training(M(), example_inputs, strict=True).module() internal_matches = matcher.match(target_gm.graph) for internal_match in internal_matches: name_node_map = internal_match.name_node_map diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 501fbb49c2b4..973e720c7eb9 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -1837,7 +1837,9 @@ class AOTInductorTestsTemplate: with config.patch( {"freezing": True, "aot_inductor.force_mmap_weights": True} ), torch.no_grad(): - exported_model = export_for_training(model, example_inputs).module() + exported_model = export_for_training( + model, example_inputs, strict=True + ).module() quantizer = X86InductorQuantizer() quantizer.set_global( xiq.get_default_x86_inductor_quantization_config(reduce_range=True) diff --git a/test/quantization/pt2e/test_duplicate_dq.py b/test/quantization/pt2e/test_duplicate_dq.py index 54456ab37b15..4a5cb6edaeb6 100644 --- a/test/quantization/pt2e/test_duplicate_dq.py +++ b/test/quantization/pt2e/test_duplicate_dq.py @@ -101,10 +101,7 @@ class TestDuplicateDQPass(QuantizationTestCase): # program capture m = copy.deepcopy(m_eager) - m = export_for_training( - m, - example_inputs, - ).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) # Calibrate diff --git a/test/quantization/pt2e/test_metadata_porting.py b/test/quantization/pt2e/test_metadata_porting.py index 4f6eb4f56d3a..96eff3a789f2 100644 --- a/test/quantization/pt2e/test_metadata_porting.py +++ b/test/quantization/pt2e/test_metadata_porting.py @@ -98,10 +98,7 @@ class TestMetaDataPorting(QuantizationTestCase): # program capture m = copy.deepcopy(m_eager) - m = torch.export.export_for_training( - m, - example_inputs, - ).module() + m = torch.export.export_for_training(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) # Calibrate diff --git a/test/quantization/pt2e/test_numeric_debugger.py b/test/quantization/pt2e/test_numeric_debugger.py index b5ada0cc3d59..deff8e4987e5 100644 --- a/test/quantization/pt2e/test_numeric_debugger.py +++ b/test/quantization/pt2e/test_numeric_debugger.py @@ -81,7 +81,7 @@ class TestNumericDebugger(TestCase): def test_simple(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) + ep = export_for_training(m, example_inputs, strict=True) generate_numeric_debug_handle(ep) self._assert_each_node_has_debug_handle(ep) debug_handle_map = self._extract_debug_handles(ep) @@ -91,7 +91,7 @@ class TestNumericDebugger(TestCase): def test_control_flow(self): m = TestHelperModules.ControlFlow() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) + ep = export_for_training(m, example_inputs, strict=True) generate_numeric_debug_handle(ep) self._assert_each_node_has_debug_handle(ep) @@ -102,7 +102,7 @@ class TestNumericDebugger(TestCase): def test_quantize_pt2e_preserve_handle(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) + ep = export_for_training(m, example_inputs, strict=True) generate_numeric_debug_handle(ep) m = ep.module() @@ -162,14 +162,14 @@ class TestNumericDebugger(TestCase): def test_re_export_preserve_handle(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) + ep = export_for_training(m, example_inputs, strict=True) generate_numeric_debug_handle(ep) m = ep.module() self._assert_each_node_has_debug_handle(ep) debug_handle_map_ref = self._extract_debug_handles(ep) - ep_reexport = export_for_training(m, example_inputs) + ep_reexport = export_for_training(m, example_inputs, strict=True) self._assert_each_node_has_debug_handle(ep_reexport) debug_handle_map = self._extract_debug_handles(ep_reexport) @@ -179,7 +179,7 @@ class TestNumericDebugger(TestCase): def test_run_decompositions_same_handle_id(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) + ep = export_for_training(m, example_inputs, strict=True) generate_numeric_debug_handle(ep) self._assert_each_node_has_debug_handle(ep) @@ -204,7 +204,7 @@ class TestNumericDebugger(TestCase): for m in test_models: example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) + ep = export_for_training(m, example_inputs, strict=True) generate_numeric_debug_handle(ep) self._assert_each_node_has_debug_handle(ep) @@ -227,7 +227,7 @@ class TestNumericDebugger(TestCase): def test_prepare_for_propagation_comparison(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) + ep = export_for_training(m, example_inputs, strict=True) generate_numeric_debug_handle(ep) m = ep.module() m_logger = prepare_for_propagation_comparison(m) @@ -244,7 +244,7 @@ class TestNumericDebugger(TestCase): def test_extract_results_from_loggers(self): m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) + ep = export_for_training(m, example_inputs, strict=True) generate_numeric_debug_handle(ep) m = ep.module() m_ref_logger = prepare_for_propagation_comparison(m) @@ -269,7 +269,7 @@ class TestNumericDebugger(TestCase): def test_extract_results_from_loggers_list_output(self): m = TestHelperModules.Conv2dWithSplit() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) + ep = export_for_training(m, example_inputs, strict=True) generate_numeric_debug_handle(ep) m = ep.module() m_ref_logger = prepare_for_propagation_comparison(m) @@ -299,7 +299,7 @@ class TestNumericDebugger(TestCase): def test_added_node_gets_unique_id(self) -> None: m = TestHelperModules.Conv2dThenConv1d() example_inputs = m.example_inputs() - ep = export_for_training(m, example_inputs) + ep = export_for_training(m, example_inputs, strict=True) generate_numeric_debug_handle(ep) ref_handles = self._extract_debug_handles(ep) ref_counter = Counter(ref_handles.values()) diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py index 2bc87f72fc25..08ffecc3aabd 100644 --- a/test/quantization/pt2e/test_quantize_pt2e.py +++ b/test/quantization/pt2e/test_quantize_pt2e.py @@ -767,10 +767,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): example_inputs = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5)) # program capture - m = export_for_training( - m, - example_inputs, - ).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_pt2e(m, BackendAQuantizer()) # make sure the two observers for input are shared conv_output_obs = [] @@ -830,10 +827,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): ) # program capture - m = export_for_training( - m, - example_inputs, - ).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) m(*example_inputs) # make sure the two input observers and output are shared @@ -1152,10 +1146,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): ) # program capture - m = export_for_training( - m, - example_inputs, - ).module() + m = export_for_training(m, example_inputs, strict=True).module() quantizer = BackendAQuantizer() m = prepare_pt2e(m, quantizer) m(*example_inputs) @@ -1305,7 +1296,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): m = M().eval() example_inputs = torch.randn(1, 2, 3, 3) - m = export_for_training(m, (example_inputs,)).module() + m = export_for_training(m, (example_inputs,), strict=True).module() with self.assertRaises(Exception): m = prepare_pt2e(m, BackendAQuantizer()) @@ -1428,10 +1419,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): quantizer.set_global(operator_config) example_inputs = (torch.randn(2, 2),) m = M().eval() - m = export_for_training( - m, - example_inputs, - ).module() + m = export_for_training(m, example_inputs, strict=True).module() weight_meta = None for n in m.graph.nodes: if ( @@ -1518,7 +1506,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): m = M().eval() quantizer = TestQuantizer() example_inputs = (torch.randn(1, 2, 3, 3),) - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) m(*example_inputs) node_occurrence = { @@ -1569,7 +1557,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): torch.randn(1, 2, 3, 3), torch.randn(1, 2, 3, 3), ) - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) m(*example_inputs) node_occurrence = { @@ -1824,7 +1812,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): example_inputs = (torch.randn(1),) m = M().train() - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() if inplace: target = torch.ops.aten.dropout_.default else: @@ -1889,7 +1877,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): m = M().train() example_inputs = (torch.randn(1, 3, 3, 3),) bn_train_op, bn_eval_op = self._get_bn_train_eval_ops() - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() # Assert that batch norm op exists and is in train mode bn_node = self._get_node(m, bn_train_op) @@ -1920,7 +1908,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): m.train() # After export: this is not OK - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() with self.assertRaises(NotImplementedError): m.eval() with self.assertRaises(NotImplementedError): @@ -1961,7 +1949,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): m = M().train() example_inputs = (torch.randn(1, 3, 3, 3),) bn_train_op, bn_eval_op = self._get_bn_train_eval_ops() - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool): targets = [n.target for n in m.graph.nodes] @@ -2027,7 +2015,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): m = M().train() example_inputs = (torch.randn(1, 3, 3, 3),) - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() torch.ao.quantization.allow_exported_model_train_eval(m) # Mock m.recompile() to count how many times it's been called @@ -2059,7 +2047,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): def test_model_is_exported(self): m = TestHelperModules.ConvWithBNRelu(relu=True) example_inputs = (torch.rand(3, 3, 5, 5),) - exported_gm = export_for_training(m, example_inputs).module() + exported_gm = export_for_training(m, example_inputs, strict=True).module() fx_traced_gm = torch.fx.symbolic_trace(m, example_inputs) self.assertTrue( torch.ao.quantization.pt2e.export_utils.model_is_exported(exported_gm) @@ -2077,7 +2065,9 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): quantizer = XNNPACKQuantizer().set_global( get_symmetric_quantization_config(is_per_channel=True, is_qat=True) ) - m.conv_bn_relu = export_for_training(m.conv_bn_relu, example_inputs).module() + m.conv_bn_relu = export_for_training( + m.conv_bn_relu, example_inputs, strict=True + ).module() m.conv_bn_relu = prepare_qat_pt2e(m.conv_bn_relu, quantizer) m(*example_inputs) m.conv_bn_relu = convert_pt2e(m.conv_bn_relu) @@ -2085,7 +2075,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): quantizer = XNNPACKQuantizer().set_module_type( torch.nn.Linear, get_symmetric_quantization_config(is_per_channel=False) ) - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) m = convert_pt2e(m) @@ -2257,7 +2247,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): def dynamic_quantize_pt2e(model, example_inputs): torch._dynamo.reset() - model = export_for_training(model, example_inputs).module() + model = export_for_training(model, example_inputs, strict=True).module() # Per channel quantization for weight # Dynamic quantization for activation # Please read a detail: https://fburl.com/code/30zds51q @@ -2360,7 +2350,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): example_inputs = (torch.randn(1, 3, 5, 5),) m = M() - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() quantizer = XNNPACKQuantizer().set_global( get_symmetric_quantization_config(), ) @@ -2442,7 +2432,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): edge_or_node_to_obs_or_fq[x] = new_observer example_inputs = (torch.rand(1, 32, 16, 16),) - gm = export_for_training(Model().eval(), example_inputs).module() + gm = export_for_training(Model().eval(), example_inputs, strict=True).module() gm = prepare_pt2e(gm, BackendAQuantizer()) gm = convert_pt2e(gm) for n in gm.graph.nodes: @@ -2469,7 +2459,9 @@ class TestQuantizePT2E(PT2EQuantizationTestCase): "ConvWithBNRelu" in node.meta["nn_module_stack"]["L__self__"][1] ) - m.conv_bn_relu = export_for_training(m.conv_bn_relu, example_inputs).module() + m.conv_bn_relu = export_for_training( + m.conv_bn_relu, example_inputs, strict=True + ).module() for node in m.conv_bn_relu.graph.nodes: if node.op not in ["placeholder", "output", "get_attr"]: check_nn_module(node) diff --git a/test/quantization/pt2e/test_quantize_pt2e_qat.py b/test/quantization/pt2e/test_quantize_pt2e_qat.py index abc9849aee82..b52f34c68c5b 100644 --- a/test/quantization/pt2e/test_quantize_pt2e_qat.py +++ b/test/quantization/pt2e/test_quantize_pt2e_qat.py @@ -140,8 +140,7 @@ class PT2EQATTestCase(QuantizationTestCase): ) ) model_pt2e = export_for_training( - model_pt2e, - example_inputs, + model_pt2e, example_inputs, strict=True ).module() model_pt2e = prepare_qat_pt2e(model_pt2e, quantizer) torch.manual_seed(MANUAL_SEED) @@ -229,10 +228,7 @@ class PT2EQATTestCase(QuantizationTestCase): quantizer.set_global( get_symmetric_quantization_config(is_per_channel, is_qat=True) ) - m = export_for_training( - m, - example_inputs, - ).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_qat_pt2e(m, quantizer) m(*example_inputs) @@ -621,7 +617,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase): m = M(self.conv_class, self.bn_class, backbone) quantizer = XNNPACKQuantizer() quantizer.set_global(get_symmetric_quantization_config(is_qat=True)) - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_qat_pt2e(m, quantizer) m(*example_inputs) m = convert_pt2e(m) @@ -679,7 +675,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase): def test_qat_conv_bn_bias_derived_qspec(self): m = self._get_conv_bn_model() example_inputs = self.example_inputs - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() quantizer = ConvBnDerivedBiasQuantizer() m = prepare_qat_pt2e(m, quantizer) m(*example_inputs) @@ -726,7 +722,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase): def test_qat_per_channel_weight_custom_dtype(self): m = self._get_conv_bn_model() example_inputs = self.example_inputs - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() quantizer = ConvBnInt32WeightQuantizer() m = prepare_qat_pt2e(m, quantizer) m(*example_inputs) @@ -780,7 +776,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase): def test_qat_conv_bn_per_channel_weight_bias(self): m = self._get_conv_bn_model() example_inputs = self.example_inputs - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() quantizer = ConvBnDerivedBiasQuantizer(is_per_channel=True) m = prepare_qat_pt2e(m, quantizer) m(*example_inputs) @@ -837,7 +833,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase): it into conv in `convert_pt2e` even in train mode. """ m = self._get_conv_bn_model(has_conv_bias=False, has_bn=True, has_relu=False) - m = export_for_training(m, self.example_inputs).module() + m = export_for_training(m, self.example_inputs, strict=True).module() quantizer = XNNPACKQuantizer() quantizer.set_global( get_symmetric_quantization_config(is_per_channel=False, is_qat=True), @@ -1085,7 +1081,9 @@ class TestQuantizeMixQATAndPTQ(QuantizationTestCase): in_channels = child.linear1.weight.size(1) example_input = (torch.rand((1, in_channels)),) - traced_child = export_for_training(child, example_input).module() + traced_child = export_for_training( + child, example_input, strict=True + ).module() quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config( is_per_channel=True, is_qat=True @@ -1116,10 +1114,7 @@ class TestQuantizeMixQATAndPTQ(QuantizationTestCase): self._convert_qat_linears(model) model(*example_inputs) - model_pt2e = export_for_training( - model, - example_inputs, - ).module() + model_pt2e = export_for_training(model, example_inputs, strict=True).module() quantizer = XNNPACKQuantizer() quantizer.set_module_type(torch.nn.Linear, None) diff --git a/test/quantization/pt2e/test_representation.py b/test/quantization/pt2e/test_representation.py index c6eed1ed8260..3648ac352dc4 100644 --- a/test/quantization/pt2e/test_representation.py +++ b/test/quantization/pt2e/test_representation.py @@ -33,10 +33,7 @@ class TestPT2ERepresentation(QuantizationTestCase): ) -> torch.nn.Module: # resetting dynamo cache torch._dynamo.reset() - model = export_for_training( - model, - example_inputs, - ).module() + model = export_for_training(model, example_inputs, strict=True).module() model_copy = copy.deepcopy(model) model = prepare_pt2e(model, quantizer) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 51b7ce72f74f..1c14ded72fe9 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -665,10 +665,7 @@ class X86InductorQuantTestCase(QuantizationTestCase): # program capture m = copy.deepcopy(m_eager) - m = export_for_training( - m, - example_inputs, - ).module() + m = export_for_training(m, example_inputs, strict=True).module() # QAT Model failed to deepcopy export_model = m if is_qat else copy.deepcopy(m) @@ -2344,7 +2341,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase): ) example_inputs = (torch.randn(2, 2),) m = M().eval() - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) # Use a linear count instead of names because the names might change, but # the order should be the same. diff --git a/test/quantization/pt2e/test_xnnpack_quantizer.py b/test/quantization/pt2e/test_xnnpack_quantizer.py index 36209e5aad10..4e14dfd27ae2 100644 --- a/test/quantization/pt2e/test_xnnpack_quantizer.py +++ b/test/quantization/pt2e/test_xnnpack_quantizer.py @@ -361,7 +361,7 @@ class TestXNNPACKQuantizer(PT2EQuantizationTestCase): ) example_inputs = (torch.randn(2, 2),) m = M().eval() - m = export_for_training(m, example_inputs).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) # Use a linear count instead of names because the names might change, but # the order should be the same. @@ -497,10 +497,7 @@ class TestXNNPACKQuantizer(PT2EQuantizationTestCase): example_inputs = (torch.randn(1, 3, 5, 5),) # program capture - m = export_for_training( - m, - example_inputs, - ).module() + m = export_for_training(m, example_inputs, strict=True).module() m = prepare_pt2e(m, quantizer) m(*example_inputs) @@ -766,8 +763,7 @@ class TestXNNPACKQuantizer(PT2EQuantizationTestCase): with torchdynamo.config.patch(allow_rnn=True): model_graph = export_for_training( - model_graph, - example_inputs, + model_graph, example_inputs, strict=True ).module() quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config( @@ -829,8 +825,7 @@ class TestXNNPACKQuantizer(PT2EQuantizationTestCase): with torchdynamo.config.patch(allow_rnn=True): model_graph = export_for_training( - model_graph, - example_inputs, + model_graph, example_inputs, strict=True ).module() quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config( @@ -1039,10 +1034,7 @@ class TestXNNPACKQuantizerModels(PT2EQuantizationTestCase): m = torchvision.models.resnet18().eval() m_copy = copy.deepcopy(m) # program capture - m = export_for_training( - m, - example_inputs, - ).module() + m = export_for_training(m, example_inputs, strict=True).module() quantizer = XNNPACKQuantizer() quantization_config = get_symmetric_quantization_config(is_per_channel=True) diff --git a/test/test_model_exports_to_core_aten.py b/test/test_model_exports_to_core_aten.py index aae14c28b8d6..3d1c25939ec4 100644 --- a/test/test_model_exports_to_core_aten.py +++ b/test/test_model_exports_to_core_aten.py @@ -27,7 +27,9 @@ class TestQuantizePT2EModels(TestCase): m = m.eval() input_shape = (1, 3, 224, 224) example_inputs = (torch.randn(input_shape),) - m = torch.export.export_for_training(m, copy.deepcopy(example_inputs)).module() + m = torch.export.export_for_training( + m, copy.deepcopy(example_inputs), strict=True + ).module() m(*example_inputs) m = export.export(m, copy.deepcopy(example_inputs)) ops = _get_ops_list(m.graph_module) diff --git a/torch/ao/quantization/pt2e/utils.py b/torch/ao/quantization/pt2e/utils.py index 47e939f7596a..86304247d151 100644 --- a/torch/ao/quantization/pt2e/utils.py +++ b/torch/ao/quantization/pt2e/utils.py @@ -355,6 +355,7 @@ def _get_aten_graph_module_for_pattern( pattern, # type: ignore[arg-type] example_inputs, kwargs, + strict=True, ).module() aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr] diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index 416965e80ba3..4e1b9676d7ca 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -1003,9 +1003,7 @@ class Pipe(torch.nn.Module): logger.info("Tracing model ...") try: ep = torch.export.export_for_training( - mod, - example_args, - example_kwargs, + mod, example_args, example_kwargs, strict=True ) except Exception as e: raise RuntimeError( diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 07e7da55eafc..e114a37b04df 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -4,24 +4,48 @@ r"""Importing this file includes common utility methods and base clases for checking quantization api and properties of resulting modules. """ -from functorch.experimental import control_flow - import torch -import torch.nn as nn -import torch.nn.functional as F import torch.ao.nn.intrinsic.quantized.dynamic as nniqd import torch.ao.nn.quantized as nnq import torch.ao.nn.quantized.dynamic as nnqd -from torch.ao.nn.intrinsic import _FusedModule import torch.distributed as dist -from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM - -from torch.export import export_for_training +import torch.nn as nn +import torch.nn.functional as F +from functorch.experimental import control_flow +from torch.ao.nn.intrinsic import _FusedModule from torch.ao.quantization import ( - QuantType, + convert, default_dynamic_qat_qconfig, + default_dynamic_qconfig, + default_dynamic_quant_observer, default_embedding_qat_qconfig, + default_observer, + default_per_channel_qconfig, + default_qconfig, default_symmetric_qnnpack_qat_qconfig, + default_weight_observer, + DeQuantStub, + float_qparams_weight_only_qconfig, + get_default_qat_qconfig, + get_default_qat_qconfig_mapping, + get_default_qconfig, + get_default_qconfig_mapping, + PerChannelMinMaxObserver, + propagate_qconfig_, + QConfig, + QConfigMapping, + quantize, + quantize_dynamic_jit, + quantize_jit, + QuantStub, + QuantType, + QuantWrapper, +) +from torch.ao.quantization.backend_config import get_executorch_backend_config +from torch.ao.quantization.quantization_mappings import ( + get_default_dynamic_quant_module_mappings, + get_default_qat_module_mappings, + get_default_qconfig_propagation_list, ) from torch.ao.quantization.quantize_pt2e import ( _convert_to_reference_decomposed_fx, @@ -29,83 +53,75 @@ from torch.ao.quantization.quantize_pt2e import ( prepare_pt2e, prepare_qat_pt2e, ) -from torch.ao.quantization.backend_config import ( - get_executorch_backend_config, -) from torch.ao.quantization.quantizer.xnnpack_quantizer import ( - XNNPACKQuantizer, get_symmetric_quantization_config, + XNNPACKQuantizer, ) -from torch.ao.quantization import QuantWrapper, QuantStub, DeQuantStub, \ - default_qconfig, default_dynamic_qconfig, default_per_channel_qconfig, QConfig, default_observer, default_weight_observer, \ - propagate_qconfig_, convert, get_default_qconfig, quantize_dynamic_jit, quantize_jit, float_qparams_weight_only_qconfig, \ - get_default_qat_qconfig, PerChannelMinMaxObserver, default_dynamic_quant_observer, quantize, \ - QConfigMapping, get_default_qconfig_mapping, get_default_qat_qconfig_mapping -from torch.ao.quantization.quantization_mappings import ( - get_default_dynamic_quant_module_mappings, - get_default_qconfig_propagation_list, - get_default_qat_module_mappings, -) -from torch.testing._internal.common_quantized import ( - override_quantized_engine, -) + +from torch.export import export_for_training from torch.jit.mobile import _load_for_lite_interpreter +from torch.testing._internal.common_quantized import override_quantized_engine +from torch.testing._internal.common_utils import TEST_WITH_ROCM, TestCase try: + from torch.ao.ns.fx.ns_types import NSSingleResultValuesType, NSSubgraph + # graph mode quantization based on fx from torch.ao.quantization.quantize_fx import ( - prepare_fx, - prepare_qat_fx, convert_fx, convert_to_reference_fx, + prepare_fx, + prepare_qat_fx, ) - from torch.ao.ns.fx.ns_types import NSSingleResultValuesType, NSSubgraph - from torch.fx.graph import Node from torch.fx import GraphModule + from torch.fx.graph import Node + HAS_FX = True except ImportError: HAS_FX = False +import contextlib import copy -import io import functools +import io import os import unittest +from typing import Any, Callable, Optional, Union + import numpy as np -from torch.testing import FileCheck -from typing import Callable, Any, Union, Optional import torch._dynamo as torchdynamo import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq import torch.ao.quantization.quantizer.xpu_inductor_quantizer as xpuiq from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer from torch.ao.quantization.quantizer.xpu_inductor_quantizer import XPUInductorQuantizer -import contextlib +from torch.testing import FileCheck + class NodeSpec: - ''' Used for checking GraphModule Node - ''' + """Used for checking GraphModule Node""" + def __init__(self, op, target): - ''' + """ op: call_function | call_module target: for call_function, target would be a function for call_module, target would be the type of PyTorch module - ''' + """ self.op = op self.target = target @classmethod def call_function(cls, target): - return NodeSpec('call_function', target) + return NodeSpec("call_function", target) @classmethod def call_method(cls, target): - return NodeSpec('call_method', target) + return NodeSpec("call_method", target) @classmethod def call_module(cls, target): - return NodeSpec('call_module', target) + return NodeSpec("call_module", target) def __hash__(self): return hash((self.op, self.target)) @@ -119,8 +135,12 @@ class NodeSpec: def __repr__(self): return repr(self.op) + " " + repr(self.target) + def get_supported_device_types(): - return ['cpu', 'cuda'] if torch.cuda.is_available() and not TEST_WITH_ROCM else ['cpu'] + return ( + ["cpu", "cuda"] if torch.cuda.is_available() and not TEST_WITH_ROCM else ["cpu"] + ) + def test_only_eval_fn(model, calib_data): r""" @@ -130,7 +150,10 @@ def test_only_eval_fn(model, calib_data): for inp in calib_data: model(*inp) + _default_loss_fn = torch.nn.CrossEntropyLoss() + + def test_only_train_fn(model, train_data, loss_fn=_default_loss_fn): r""" Default train function takes a torch.utils.data.Dataset and train the model @@ -153,9 +176,11 @@ def test_only_train_fn(model, train_data, loss_fn=_default_loss_fn): correct += (predicted == target).sum().item() return train_loss, correct, total + class AverageMeter: """Computes and stores the average and current value""" - def __init__(self, name, fmt=':f'): + + def __init__(self, name, fmt=":f"): self.name = name self.fmt = fmt self.reset() @@ -173,7 +198,7 @@ class AverageMeter: self.avg = self.sum / self.count def __str__(self): - fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" return fmtstr.format(**self.__dict__) @@ -193,10 +218,11 @@ def accuracy(output, target, topk=(1,)): res.append(correct_k.mul_(100.0 / batch_size)) return res + def train_one_epoch(model, criterion, optimizer, data_loader, device, ntrain_batches): model.train() for cnt, (image, target) in enumerate(data_loader, start=1): - print('.', end='') + print(".", end="") image, target = image.to(device), target.to(device) output = model(image) loss = criterion(output, target) @@ -208,16 +234,19 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, ntrain_bat return return + def ddp_setup(rank, world_size): - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = '12355' + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" # initialize the process group dist.init_process_group("gloo", rank=rank, world_size=world_size) + def ddp_cleanup(): dist.destroy_process_group() + def run_ddp(rank, world_size, prepared): ddp_setup(rank, world_size) prepared.cuda() @@ -232,24 +261,42 @@ def run_ddp(rank, world_size, prepared): def convert_dynamic(module): convert(module, get_default_dynamic_quant_module_mappings(), inplace=True) + def prepare_dynamic(model, qconfig_dict=None): propagate_qconfig_(model, qconfig_dict) + def _make_conv_test_input( - batch_size, in_channels_per_group, input_feature_map_size, - out_channels_per_group, groups, kernel_size, X_scale, X_zero_point, W_scale, - W_zero_point, use_bias, use_channelwise, + batch_size, + in_channels_per_group, + input_feature_map_size, + out_channels_per_group, + groups, + kernel_size, + X_scale, + X_zero_point, + W_scale, + W_zero_point, + use_bias, + use_channelwise, ): in_channels = in_channels_per_group * groups out_channels = out_channels_per_group * groups (X_value_min, X_value_max) = (0, 4) X_init = torch.randint( - X_value_min, X_value_max, - (batch_size, in_channels,) + input_feature_map_size) + X_value_min, + X_value_max, + ( + batch_size, + in_channels, + ) + + input_feature_map_size, + ) X = X_scale * (X_init - X_zero_point).float() X_q = torch.quantize_per_tensor( - X, scale=X_scale, zero_point=X_zero_point, dtype=torch.quint8) + X, scale=X_scale, zero_point=X_zero_point, dtype=torch.quint8 + ) W_scale = W_scale * out_channels W_zero_point = W_zero_point * out_channels @@ -266,109 +313,132 @@ def _make_conv_test_input( # The operator expects them in the format # (out_channels, in_channels/groups,) + kernel_size W_init = torch.randint( - W_value_min, W_value_max, - (out_channels, in_channels_per_group,) + kernel_size) + W_value_min, + W_value_max, + ( + out_channels, + in_channels_per_group, + ) + + kernel_size, + ) b_init = torch.randint(0, 10, (out_channels,)) if use_channelwise: W_shape = (-1, 1) + (1,) * len(kernel_size) W_scales_tensor = torch.tensor(W_scale, dtype=torch.float) W_zero_points_tensor = torch.tensor(W_zero_point, dtype=torch.float) - W = W_scales_tensor.reshape(*W_shape) * ( - W_init.float() - W_zero_points_tensor.reshape(*W_shape)).float() + W = ( + W_scales_tensor.reshape(*W_shape) + * (W_init.float() - W_zero_points_tensor.reshape(*W_shape)).float() + ) b = X_scale * W_scales_tensor * b_init.float() W_q = torch.quantize_per_channel( - W, W_scales_tensor.double(), W_zero_points_tensor.long(), 0, - dtype=torch.qint8) + W, + W_scales_tensor.double(), + W_zero_points_tensor.long(), + 0, + dtype=torch.qint8, + ) else: W = W_scale[0] * (W_init - W_zero_point[0]).float() b = X_scale * W_scale[0] * b_init.float() W_q = torch.quantize_per_tensor( - W, scale=W_scale[0], zero_point=W_zero_point[0], dtype=torch.qint8) + W, scale=W_scale[0], zero_point=W_zero_point[0], dtype=torch.qint8 + ) return (X, X_q, W, W_q, b if use_bias else None) + def _make_conv_add_extra_input_tensor(scale, zero_point, sizes): (X_value_min, X_value_max) = (0, 4) X_init = torch.randint( X_value_min, X_value_max, - sizes # Infer the size of tensor to do the add + sizes, # Infer the size of tensor to do the add ) X = scale * (X_init - zero_point).float() X_q = torch.quantize_per_tensor( - X, scale=scale, zero_point=zero_point, dtype=torch.quint8) + X, scale=scale, zero_point=zero_point, dtype=torch.quint8 + ) return X, X_q + def skipIfNoFBGEMM(fn): - reason = 'Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs with instruction set support AVX2 or newer.' + reason = "Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs with instruction set support AVX2 or newer." if isinstance(fn, type): - if 'fbgemm' not in torch.backends.quantized.supported_engines: + if "fbgemm" not in torch.backends.quantized.supported_engines: fn.__unittest_skip__ = True fn.__unittest_skip_why__ = reason return fn @functools.wraps(fn) def wrapper(*args, **kwargs): - if 'fbgemm' not in torch.backends.quantized.supported_engines: + if "fbgemm" not in torch.backends.quantized.supported_engines: raise unittest.SkipTest(reason) else: fn(*args, **kwargs) + return wrapper + def skipIfNoQNNPACK(fn): - reason = 'Quantized operations require QNNPACK.' + reason = "Quantized operations require QNNPACK." if isinstance(fn, type): - if 'qnnpack' not in torch.backends.quantized.supported_engines: + if "qnnpack" not in torch.backends.quantized.supported_engines: fn.__unittest_skip__ = True fn.__unittest_skip_why__ = reason return fn @functools.wraps(fn) def wrapper(*args, **kwargs): - if 'qnnpack' not in torch.backends.quantized.supported_engines: + if "qnnpack" not in torch.backends.quantized.supported_engines: raise unittest.SkipTest(reason) else: fn(*args, **kwargs) + return wrapper + def withQNNPACKBackend(fn): # TODO(future PR): consider combining with skipIfNoQNNPACK, # will require testing of existing callsites - reason = 'Quantized operations require QNNPACK.' + reason = "Quantized operations require QNNPACK." if isinstance(fn, type): - if 'qnnpack' not in torch.backends.quantized.supported_engines: + if "qnnpack" not in torch.backends.quantized.supported_engines: fn.__unittest_skip__ = True fn.__unittest_skip_why__ = reason return fn @functools.wraps(fn) def wrapper(*args, **kwargs): - if 'qnnpack' not in torch.backends.quantized.supported_engines: + if "qnnpack" not in torch.backends.quantized.supported_engines: raise unittest.SkipTest(reason) - with override_quantized_engine('qnnpack'): + with override_quantized_engine("qnnpack"): fn(*args, **kwargs) return wrapper + def skipIfNoONEDNN(fn): - reason = 'Quantized operations require ONEDNN.' + reason = "Quantized operations require ONEDNN." if isinstance(fn, type): - if 'onednn' not in torch.backends.quantized.supported_engines: + if "onednn" not in torch.backends.quantized.supported_engines: fn.__unittest_skip__ = True fn.__unittest_skip_why__ = reason return fn @functools.wraps(fn) def wrapper(*args, **kwargs): - if 'onednn' not in torch.backends.quantized.supported_engines: + if "onednn" not in torch.backends.quantized.supported_engines: raise unittest.SkipTest(reason) else: fn(*args, **kwargs) + return wrapper + def skipIfNoONEDNNBF16(fn): - reason = 'Quantized operations require BF16 support.' + reason = "Quantized operations require BF16 support." if isinstance(fn, type): if not torch.ops.mkldnn._is_mkldnn_bf16_supported(): fn.__unittest_skip__ = True @@ -381,24 +451,28 @@ def skipIfNoONEDNNBF16(fn): raise unittest.SkipTest(reason) else: fn(*args, **kwargs) + return wrapper + def skipIfNoX86(fn): - reason = 'Quantized operations require X86.' + reason = "Quantized operations require X86." if isinstance(fn, type): - if 'x86' not in torch.backends.quantized.supported_engines: + if "x86" not in torch.backends.quantized.supported_engines: fn.__unittest_skip__ = True fn.__unittest_skip_why__ = reason return fn @functools.wraps(fn) def wrapper(*args, **kwargs): - if 'x86' not in torch.backends.quantized.supported_engines: + if "x86" not in torch.backends.quantized.supported_engines: raise unittest.SkipTest(reason) else: fn(*args, **kwargs) + return wrapper + def skipIfNoDynamoSupport(fn): reason = "dynamo doesn't support." if isinstance(fn, type): @@ -413,8 +487,10 @@ def skipIfNoDynamoSupport(fn): raise unittest.SkipTest(reason) else: fn(*args, **kwargs) + return wrapper + def skipIfNoInductorSupport(fn): reason = "inductor doesn't support." if isinstance(fn, type): @@ -429,18 +505,23 @@ def skipIfNoInductorSupport(fn): raise unittest.SkipTest(reason) else: fn(*args, **kwargs) + return wrapper + try: import torchvision # noqa: F401 + HAS_TORCHVISION = True except ImportError: HAS_TORCHVISION = False skip_if_no_torchvision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision") + def get_script_module(model, tracing, data): return torch.jit.trace(model, data) if tracing else torch.jit.script(model) + def lengths_to_offsets(t, offset_type=np.int64, use_begin_offset=True): """ Convert lengths to offsets for embedding_bag @@ -464,7 +545,7 @@ def _group_quantize_tensor(w, n_bit=4, q_group_size=16): max_val = to_quant.amax(dim=1, keepdim=True) min_val = to_quant.amin(dim=1, keepdim=True) - max_int = 2 ** n_bit - 1 + max_int = 2**n_bit - 1 min_int = 0 scales = (max_val - min_val).clamp(min=1e-6) / max_int assert torch.isnan(scales).sum() == 0 @@ -476,7 +557,7 @@ def _group_quantize_tensor(w, n_bit=4, q_group_size=16): assert torch.isnan(out).sum() == 0 out = out.to(dtype=torch.int32).reshape(w.shape) - if out.device != torch.device('cpu'): + if out.device != torch.device("cpu"): out = (out[::, ::2] << 4 | out[::, 1::2]).to(torch.uint8) # Scales and zeros for the same q-group should be contiguous, so we can @@ -490,15 +571,15 @@ def _group_quantize_tensor(w, n_bit=4, q_group_size=16): zeros.reshape(zeros.size(0), zeros.size(1), 1), ], 2, - ).transpose(0, 1).contiguous() + ) + .transpose(0, 1) + .contiguous() ) return out, scales_and_zeros -def _group_quantize_tensor_symmetric( - w, n_bit=4, groupsize=32 -): +def _group_quantize_tensor_symmetric(w, n_bit=4, groupsize=32): # W is of shape [K x N] # We transpose W as Quantization is applied on [N x K] w = w.transpose(0, 1).contiguous() @@ -566,26 +647,47 @@ class QuantizationTestCase(TestCase): def setUp(self): super().setUp() self.calib_data = [[torch.rand(2, 5, dtype=torch.float)] for _ in range(2)] - self.train_data = [[torch.rand(2, 5, dtype=torch.float), torch.randint(0, 1, (2,), dtype=torch.long)] for _ in range(2)] - self.img_data_1d = [[torch.rand(2, 3, 10, dtype=torch.float)] - for _ in range(2)] - self.img_data_2d = [[torch.rand(1, 3, 10, 10, dtype=torch.float)] - for _ in range(2)] - self.img_data_3d = [[torch.rand(1, 3, 5, 5, 5, dtype=torch.float)] - for _ in range(2)] - self.img_data_1d_train = [[torch.rand(2, 3, 10, dtype=torch.float), - torch.randint(0, 1, (1,), dtype=torch.long)] - for _ in range(2)] - self.img_data_2d_train = [[torch.rand(1, 3, 10, 10, dtype=torch.float), - torch.randint(0, 1, (1,), dtype=torch.long)] - for _ in range(2)] - self.img_data_3d_train = [[torch.rand(1, 3, 5, 5, 5, dtype=torch.float), - torch.randint(0, 1, (1,), dtype=torch.long)] - for _ in range(2)] + self.train_data = [ + [ + torch.rand(2, 5, dtype=torch.float), + torch.randint(0, 1, (2,), dtype=torch.long), + ] + for _ in range(2) + ] + self.img_data_1d = [[torch.rand(2, 3, 10, dtype=torch.float)] for _ in range(2)] + self.img_data_2d = [ + [torch.rand(1, 3, 10, 10, dtype=torch.float)] for _ in range(2) + ] + self.img_data_3d = [ + [torch.rand(1, 3, 5, 5, 5, dtype=torch.float)] for _ in range(2) + ] + self.img_data_1d_train = [ + [ + torch.rand(2, 3, 10, dtype=torch.float), + torch.randint(0, 1, (1,), dtype=torch.long), + ] + for _ in range(2) + ] + self.img_data_2d_train = [ + [ + torch.rand(1, 3, 10, 10, dtype=torch.float), + torch.randint(0, 1, (1,), dtype=torch.long), + ] + for _ in range(2) + ] + self.img_data_3d_train = [ + [ + torch.rand(1, 3, 5, 5, 5, dtype=torch.float), + torch.randint(0, 1, (1,), dtype=torch.long), + ] + for _ in range(2) + ] - self.img_data_dict = {1 : self.img_data_1d, - 2 : self.img_data_2d, - 3 : self.img_data_3d} + self.img_data_dict = { + 1: self.img_data_1d, + 2: self.img_data_2d, + 3: self.img_data_3d, + } # Quant types that produce statically quantized ops self.static_quant_types = [QuantType.STATIC, QuantType.QAT] @@ -594,75 +696,92 @@ class QuantizationTestCase(TestCase): def checkNoPrepModules(self, module): r"""Checks the module does not contain child - modules for quantization preparation, e.g. - quant, dequant and observer + modules for quantization preparation, e.g. + quant, dequant and observer """ - self.assertFalse(hasattr(module, 'quant')) - self.assertFalse(hasattr(module, 'dequant')) + self.assertFalse(hasattr(module, "quant")) + self.assertFalse(hasattr(module, "dequant")) def checkNoQconfig(self, module): - r"""Checks the module does not contain qconfig - """ - self.assertFalse(hasattr(module, 'qconfig')) + r"""Checks the module does not contain qconfig""" + self.assertFalse(hasattr(module, "qconfig")) for child in module.children(): self.checkNoQconfig(child) def checkHasPrepModules(self, module): r"""Checks the module contains child - modules for quantization preparation, e.g. - quant, dequant and observer + modules for quantization preparation, e.g. + quant, dequant and observer """ - self.assertTrue(hasattr(module, 'module')) - self.assertTrue(hasattr(module, 'quant')) - self.assertTrue(hasattr(module, 'dequant')) + self.assertTrue(hasattr(module, "module")) + self.assertTrue(hasattr(module, "quant")) + self.assertTrue(hasattr(module, "dequant")) - def checkObservers(self, module, propagate_qconfig_list=None, prepare_custom_config_dict=None): + def checkObservers( + self, module, propagate_qconfig_list=None, prepare_custom_config_dict=None + ): r"""Checks the module or module's leaf descendants - have observers in preparation for quantization + have observers in preparation for quantization """ if propagate_qconfig_list is None: propagate_qconfig_list = get_default_qconfig_propagation_list() if prepare_custom_config_dict is None: prepare_custom_config_dict = {} - float_to_observed_module_class_mapping = prepare_custom_config_dict.get("float_to_observed_custom_module_class", {}) + float_to_observed_module_class_mapping = prepare_custom_config_dict.get( + "float_to_observed_custom_module_class", {} + ) # check if a module is a leaf module, ignoring activation_post_process attribute def is_leaf_module(module): submodule_name_count = 0 for name, _ in module.named_children(): - if name != 'activation_post_process': + if name != "activation_post_process": submodule_name_count += 1 return submodule_name_count == 0 - if hasattr(module, 'qconfig') and module.qconfig is not None and \ - ((is_leaf_module(module) and not isinstance(module, torch.nn.Sequential) - and type(module) in propagate_qconfig_list) or - type(module) in float_to_observed_module_class_mapping.keys()) and \ - not isinstance(module, torch.ao.quantization.DeQuantStub): - self.assertTrue(hasattr(module, 'activation_post_process'), - 'module: ' + str(type(module)) + ' do not have observer') + if ( + hasattr(module, "qconfig") + and module.qconfig is not None + and ( + ( + is_leaf_module(module) + and not isinstance(module, torch.nn.Sequential) + and type(module) in propagate_qconfig_list + ) + or type(module) in float_to_observed_module_class_mapping.keys() + ) + and not isinstance(module, torch.ao.quantization.DeQuantStub) + ): + self.assertTrue( + hasattr(module, "activation_post_process"), + "module: " + str(type(module)) + " do not have observer", + ) # we don't need to check observers for child modules of the # qat modules - if type(module) not in get_default_qat_module_mappings().values() and \ - type(module) not in float_to_observed_module_class_mapping.values() and \ - not isinstance(module, _FusedModule): + if ( + type(module) not in get_default_qat_module_mappings().values() + and type(module) not in float_to_observed_module_class_mapping.values() + and not isinstance(module, _FusedModule) + ): for child in module.children(): if type(child) in [nn.Dropout]: continue - self.checkObservers(child, propagate_qconfig_list, prepare_custom_config_dict) + self.checkObservers( + child, propagate_qconfig_list, prepare_custom_config_dict + ) def checkQuantDequant(self, mod): r"""Checks that mod has nn.Quantize and - nn.DeQuantize submodules inserted + nn.DeQuantize submodules inserted """ self.assertEqual(type(mod.quant), nnq.Quantize) self.assertEqual(type(mod.dequant), nnq.DeQuantize) def checkWrappedQuantizedLinear(self, mod): r"""Checks that mod has been swapped for an nnq.Linear - module, the bias is qint32, and that the module - has Quantize and DeQuantize submodules + module, the bias is qint32, and that the module + has Quantize and DeQuantize submodules """ self.assertEqual(type(mod.module), nnq.Linear) self.checkQuantDequant(mod) @@ -672,14 +791,14 @@ class QuantizationTestCase(TestCase): def checkDynamicQuantizedLinear(self, mod, dtype): r"""Checks that mod has been swapped for an nnqd.Linear - module, the bias is float. + module, the bias is float. """ self.assertEqual(type(mod), nnqd.Linear) self.assertEqual(mod._packed_params.dtype, dtype) def checkDynamicQuantizedLinearRelu(self, mod, dtype): r"""Checks that mod has been swapped for an nnqd.Linear - module, the bias is float. + module, the bias is float. """ self.assertEqual(type(mod), nniqd.LinearReLU) self.assertEqual(mod._packed_params.dtype, dtype) @@ -721,25 +840,35 @@ class QuantizationTestCase(TestCase): def checkDynamicQuantizedLSTM(self, mod, reference_module_type, dtype): r"""Checks that mod has been swapped for an nnqd.LSTM type - module, the bias is float. + module, the bias is float. """ - wt_dtype_map = {torch.qint8: 'quantized_dynamic', torch.float16: 'quantized_fp16'} + wt_dtype_map = { + torch.qint8: "quantized_dynamic", + torch.float16: "quantized_fp16", + } self.assertEqual(type(mod), reference_module_type) for packed_params in mod._all_weight_values: - self.assertEqual(packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype]) + self.assertEqual( + packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype] + ) def checkLinear(self, mod): self.assertEqual(type(mod), torch.nn.Linear) def checkDynamicQuantizedModule(self, mod, reference_module_type, dtype): r"""Checks that mod has been swapped for an nnqd.Linear - module, the bias is float. + module, the bias is float. """ - wt_dtype_map = {torch.qint8: 'quantized_dynamic', torch.float16: 'quantized_fp16'} + wt_dtype_map = { + torch.qint8: "quantized_dynamic", + torch.float16: "quantized_fp16", + } self.assertEqual(type(mod), reference_module_type) - if hasattr(mod, '_all_weight_values'): + if hasattr(mod, "_all_weight_values"): for packed_params in mod._all_weight_values: - self.assertEqual(packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype]) + self.assertEqual( + packed_params.param.__getstate__()[0][0], wt_dtype_map[dtype] + ) def checkScriptable(self, orig_mod, calib_data, check_save_load=False): scripted = torch.jit.script(orig_mod) @@ -770,20 +899,29 @@ class QuantizationTestCase(TestCase): scripted_output = test_mod(*inp) self.assertEqual(scripted_output, ref_output) - - def checkGraphModeOp(self, module, inputs, quantized_op, tracing=False, debug=False, - check=True, eval_mode=True, dynamic=False, qconfig=None): + def checkGraphModeOp( + self, + module, + inputs, + quantized_op, + tracing=False, + debug=False, + check=True, + eval_mode=True, + dynamic=False, + qconfig=None, + ): if debug: - print('Testing:', str(module)) - qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)} + print("Testing:", str(module)) + qconfig_dict = {"": get_default_qconfig(torch.backends.quantized.engine)} if eval_mode: module = module.eval() if dynamic: - qconfig_dict = {'': default_dynamic_qconfig if qconfig is None else qconfig} + qconfig_dict = {"": default_dynamic_qconfig if qconfig is None else qconfig} model = get_script_module(module, tracing, inputs[0]).eval() if debug: - print('input graph:', model.graph) + print("input graph:", model.graph) models = {} outputs = {} for debug in [True, False]: @@ -796,31 +934,37 @@ class QuantizationTestCase(TestCase): # input data staying constant for comparisons inputs_copy = copy.deepcopy(inputs) models[debug] = quantize_jit( - model, qconfig_dict, test_only_eval_fn, [inputs_copy], inplace=False, - debug=debug) + model, + qconfig_dict, + test_only_eval_fn, + [inputs_copy], + inplace=False, + debug=debug, + ) # make sure it runs outputs[debug] = models[debug](*inputs[0]) if debug: - print('debug graph:', models[True].graph) - print('non debug graph:', models[False].graph) + print("debug graph:", models[True].graph) + print("non debug graph:", models[False].graph) if check: # debug and non-debug option should have the same numerics self.assertEqual(outputs[True], outputs[False]) # non debug graph should produce quantized op - FileCheck().check(quantized_op) \ - .run(models[False].graph) + FileCheck().check(quantized_op).run(models[False].graph) return models[False] def checkGraphModuleNodes( - self, graph_module, - expected_node=None, - expected_node_occurrence=None, - expected_node_list=None): - """ Check if GraphModule contains the target node + self, + graph_module, + expected_node=None, + expected_node_occurrence=None, + expected_node_list=None, + ): + """Check if GraphModule contains the target node Args: graph_module: the GraphModule instance we want to check expected_node, expected_node_occurrence, expected_node_list: @@ -831,9 +975,9 @@ class QuantizationTestCase(TestCase): modules = dict(graph_module.named_modules(remove_duplicate=False)) for node in graph_module.graph.nodes: n = None - if node.op == 'call_function' or node.op == 'call_method': + if node.op == "call_function" or node.op == "call_method": n = NodeSpec(node.op, node.target) - elif node.op == 'call_module': + elif node.op == "call_module": n = NodeSpec(node.op, type(modules[node.target])) if n is not None: @@ -844,26 +988,34 @@ class QuantizationTestCase(TestCase): nodes_in_graph[n] = 1 if expected_node is not None: - self.assertTrue(expected_node in nodes_in_graph, 'node:' + str(expected_node) + - ' not found in the graph module') + self.assertTrue( + expected_node in nodes_in_graph, + "node:" + str(expected_node) + " not found in the graph module", + ) if expected_node_occurrence is not None: for expected_node, occurrence in expected_node_occurrence.items(): if occurrence != 0: self.assertTrue( expected_node in nodes_in_graph, - 'Check failed for node:' + str(expected_node) + - ' not found') + "Check failed for node:" + str(expected_node) + " not found", + ) self.assertTrue( nodes_in_graph[expected_node] == occurrence, - 'Check failed for node:' + str(expected_node) + - ' Expected occurrence:' + str(occurrence) + - ' Found occurrence:' + str(nodes_in_graph[expected_node])) + "Check failed for node:" + + str(expected_node) + + " Expected occurrence:" + + str(occurrence) + + " Found occurrence:" + + str(nodes_in_graph[expected_node]), + ) else: self.assertTrue( expected_node not in nodes_in_graph, - 'Check failed for node:' + str(expected_node) + - ' expected no occurrence but found') + "Check failed for node:" + + str(expected_node) + + " expected no occurrence but found", + ) if expected_node_list is not None: cur_index = 0 @@ -874,20 +1026,21 @@ class QuantizationTestCase(TestCase): cur_index += 1 self.assertTrue( cur_index == len(expected_node_list), - "Check failed for graph:" + - self.printGraphModule(graph_module, print_str=False) + - "Expected ordered list:" + - str(expected_node_list)) + "Check failed for graph:" + + self.printGraphModule(graph_module, print_str=False) + + "Expected ordered list:" + + str(expected_node_list), + ) def printGraphModule(self, graph_module, print_str=True): modules = dict(graph_module.named_modules(remove_duplicate=False)) node_infos = [] for n in graph_module.graph.nodes: - node_info = ' '.join(map(repr, [n.op, n.name, n.target, n.args, n.kwargs])) - if n.op == 'call_module': - node_info += ' module type: ' + repr(type(modules[n.target])) + node_info = " ".join(map(repr, [n.op, n.name, n.target, n.args, n.kwargs])) + if n.op == "call_module": + node_info += " module type: " + repr(type(modules[n.target])) node_infos.append(node_info) - str_to_print = '\n'.join(node_infos) + str_to_print = "\n".join(node_infos) if print_str: print(str_to_print) return str_to_print @@ -897,7 +1050,9 @@ class QuantizationTestCase(TestCase): def assert_types_for_matched_subgraph_pairs( self, matched_subgraph_pairs: dict[str, tuple[NSSubgraph, NSSubgraph]], - expected_types: dict[str, tuple[tuple[Callable, Callable], tuple[Callable, Callable]]], + expected_types: dict[ + str, tuple[tuple[Callable, Callable], tuple[Callable, Callable]] + ], gm_a: GraphModule, gm_b: GraphModule, ) -> None: @@ -917,16 +1072,16 @@ class QuantizationTestCase(TestCase): def _get_underlying_op_type( node: Node, gm: GraphModule ) -> Union[Callable, str]: - if node.op == 'call_module': + if node.op == "call_module": mod = getattr(gm, node.target) return type(mod) else: - assert node.op in ('call_function', 'call_method') + assert node.op in ("call_function", "call_method") return node.target self.assertTrue( len(matched_subgraph_pairs) == len(expected_types), - f'Expected length of results to match, but got {len(matched_subgraph_pairs)} and {len(expected_types)}' + f"Expected length of results to match, but got {len(matched_subgraph_pairs)} and {len(expected_types)}", ) for k, v in expected_types.items(): expected_types_a, expected_types_b = v @@ -938,14 +1093,16 @@ class QuantizationTestCase(TestCase): act_type_start_b = _get_underlying_op_type(subgraph_b.start_node, gm_b) act_type_end_a = _get_underlying_op_type(subgraph_a.end_node, gm_a) act_type_end_b = _get_underlying_op_type(subgraph_b.end_node, gm_b) - types_match = (exp_type_start_a is act_type_start_a) and \ - (exp_type_end_a is act_type_end_a) and \ - (exp_type_start_b is act_type_start_b) and \ - (exp_type_end_b is act_type_end_b) + types_match = ( + (exp_type_start_a is act_type_start_a) + and (exp_type_end_a is act_type_end_a) + and (exp_type_start_b is act_type_start_b) + and (exp_type_end_b is act_type_end_b) + ) self.assertTrue( types_match, - f'Type mismatch at {k}: expected {(exp_type_start_a, exp_type_end_a, exp_type_start_b, exp_type_end_b)}, ' - f'got {(act_type_start_a, act_type_end_a, act_type_start_b, act_type_end_b)}' + f"Type mismatch at {k}: expected {(exp_type_start_a, exp_type_end_a, exp_type_start_b, exp_type_end_b)}, " + f"got {(act_type_start_a, act_type_end_a, act_type_start_b, act_type_end_b)}", ) def assert_ns_compare_dict_valid( @@ -962,48 +1119,53 @@ class QuantizationTestCase(TestCase): for result_type, layer_data in result_type_to_data.items(): self.assertTrue( len(layer_data) == 2, - f"Layer {layer_name} does not have exactly two model results.") + f"Layer {layer_name} does not have exactly two model results.", + ) model_name_0, model_name_1 = layer_data.keys() for res_idx in range(len(layer_data[model_name_0])): layer_data_0 = layer_data[model_name_0][res_idx] layer_data_1 = layer_data[model_name_1][res_idx] self.assertTrue( - layer_data_0['type'] == layer_data_0['type'], - f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same type.") + layer_data_0["type"] == layer_data_0["type"], + f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same type.", + ) self.assertTrue( - len(layer_data_0['values']) == - len(layer_data_1['values']), - f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same number of seen Tensors.") + len(layer_data_0["values"]) == len(layer_data_1["values"]), + f"Layer {layer_name}, {model_name_0} and {model_name_1} do not have the same number of seen Tensors.", + ) # F.conv1d weight has rank 3, and toq.conv1d unpacked weight # has rank 4. For now, skip the length check for conv1d only. is_weight_functional_conv1d = ( - result_type == NSSingleResultValuesType.WEIGHT.value and - ( - 'conv1d' in layer_data_0['prev_node_target_type'] or - 'conv1d' in layer_data_1['prev_node_target_type'] + result_type == NSSingleResultValuesType.WEIGHT.value + and ( + "conv1d" in layer_data_0["prev_node_target_type"] + or "conv1d" in layer_data_1["prev_node_target_type"] ) ) if not is_weight_functional_conv1d: - for idx in range(len(layer_data_0['values'])): - values_0 = layer_data_0['values'][idx] - values_1 = layer_data_1['values'][idx] + for idx in range(len(layer_data_0["values"])): + values_0 = layer_data_0["values"][idx] + values_1 = layer_data_1["values"][idx] if isinstance(values_0, torch.Tensor): self.assertTrue( values_0.shape == values_1.shape, - f"Layer {layer_name}, {model_name_0} and {model_name_1} " + - f"have a shape mismatch at idx {idx}.") + f"Layer {layer_name}, {model_name_0} and {model_name_1} " + + f"have a shape mismatch at idx {idx}.", + ) elif isinstance(values_0, list): values_0 = values_0[0] values_1 = values_1[0] self.assertTrue( values_0.shape == values_1.shape, - f"Layer {layer_name}, {model_name_0} and {model_name_1} " + - f"have a shape mismatch at idx {idx}.") + f"Layer {layer_name}, {model_name_0} and {model_name_1} " + + f"have a shape mismatch at idx {idx}.", + ) else: - assert isinstance(values_0, tuple), \ - f"unhandled type {type(values_0)}" + assert isinstance( + values_0, tuple + ), f"unhandled type {type(values_0)}" assert len(values_0) == 2 assert len(values_0[1]) == 2 assert values_0[0].shape == values_1[0].shape @@ -1011,80 +1173,91 @@ class QuantizationTestCase(TestCase): assert values_0[1][1].shape == values_1[1][1].shape # verify that ref_node_name is valid - ref_node_name_0 = layer_data_0['ref_node_name'] - ref_node_name_1 = layer_data_1['ref_node_name'] - prev_node_name_0 = layer_data_0['prev_node_name'] - prev_node_name_1 = layer_data_1['prev_node_name'] - if layer_data_0['type'] == NSSingleResultValuesType.NODE_OUTPUT.value: + ref_node_name_0 = layer_data_0["ref_node_name"] + ref_node_name_1 = layer_data_1["ref_node_name"] + prev_node_name_0 = layer_data_0["prev_node_name"] + prev_node_name_1 = layer_data_1["prev_node_name"] + if ( + layer_data_0["type"] + == NSSingleResultValuesType.NODE_OUTPUT.value + ): self.assertTrue(ref_node_name_0 == prev_node_name_0) self.assertTrue(ref_node_name_1 == prev_node_name_1) - elif layer_data_0['type'] == NSSingleResultValuesType.NODE_INPUT.value: + elif ( + layer_data_0["type"] + == NSSingleResultValuesType.NODE_INPUT.value + ): self.assertTrue(ref_node_name_0 != prev_node_name_0) self.assertTrue(ref_node_name_1 != prev_node_name_1) def checkGraphModeFxOp( - self, - model, - inputs, - quant_type, - expected_node=None, - expected_node_occurrence=None, - expected_node_list=None, - is_reference=False, - print_debug_info=False, - custom_qconfig_dict=None, - prepare_expected_node=None, - prepare_expected_node_occurrence=None, - prepare_expected_node_list=None, - prepare_custom_config=None, - backend_config=None): - """ Quantizes model with graph mode quantization on fx and check if the - quantized model contains the quantized_node + self, + model, + inputs, + quant_type, + expected_node=None, + expected_node_occurrence=None, + expected_node_list=None, + is_reference=False, + print_debug_info=False, + custom_qconfig_dict=None, + prepare_expected_node=None, + prepare_expected_node_occurrence=None, + prepare_expected_node_list=None, + prepare_custom_config=None, + backend_config=None, + ): + """Quantizes model with graph mode quantization on fx and check if the + quantized model contains the quantized_node - Args: - model: floating point torch.nn.Module - inputs: one positional sample input arguments for model - expected_node: NodeSpec - e.g. NodeSpec.call_function(torch.quantize_per_tensor) - expected_node_occurrence: a dict from NodeSpec to - expected number of occurrences (int) - e.g. {NodeSpec.call_function(torch.quantize_per_tensor) : 1, - NodeSpec.call_method('dequantize'): 1} - expected_node_list: a list of NodeSpec, used to check the order - of the occurrence of Node - e.g. [NodeSpec.call_function(torch.quantize_per_tensor), - NodeSpec.call_module(nnq.Conv2d), - NodeSpec.call_function(F.hardtanh_), - NodeSpec.call_method('dequantize')] - is_reference: if True, enables reference mode - print_debug_info: if True, prints debug info - custom_qconfig_dict: overrides default qconfig_dict - prepare_expected_node: same as expected_node, but for prepare - prepare_expected_node_occurrence: same as - expected_node_occurrence, but for prepare - prepare_expected_node_list: same as expected_node_list, but - for prepare + Args: + model: floating point torch.nn.Module + inputs: one positional sample input arguments for model + expected_node: NodeSpec + e.g. NodeSpec.call_function(torch.quantize_per_tensor) + expected_node_occurrence: a dict from NodeSpec to + expected number of occurrences (int) + e.g. {NodeSpec.call_function(torch.quantize_per_tensor) : 1, + NodeSpec.call_method('dequantize'): 1} + expected_node_list: a list of NodeSpec, used to check the order + of the occurrence of Node + e.g. [NodeSpec.call_function(torch.quantize_per_tensor), + NodeSpec.call_module(nnq.Conv2d), + NodeSpec.call_function(F.hardtanh_), + NodeSpec.call_method('dequantize')] + is_reference: if True, enables reference mode + print_debug_info: if True, prints debug info + custom_qconfig_dict: overrides default qconfig_dict + prepare_expected_node: same as expected_node, but for prepare + prepare_expected_node_occurrence: same as + expected_node_occurrence, but for prepare + prepare_expected_node_list: same as expected_node_list, but + for prepare - Returns: - A dictionary with the following structure: - { - "prepared": ..., # the prepared model - "quantized": ..., # the quantized non-reference model - "quantized_reference": ..., # the quantized reference model - "result": ..., # the result for either quantized or - # quantized_reference model depending on the - # is_reference argument - } + Returns: + A dictionary with the following structure: + { + "prepared": ..., # the prepared model + "quantized": ..., # the quantized non-reference model + "quantized_reference": ..., # the quantized reference model + "result": ..., # the result for either quantized or + # quantized_reference model depending on the + # is_reference argument + } """ # TODO: make img_data a single example instead of a list if type(inputs) == list: inputs = inputs[0] if quant_type == QuantType.QAT: - qconfig_mapping = get_default_qat_qconfig_mapping(torch.backends.quantized.engine) + qconfig_mapping = get_default_qat_qconfig_mapping( + torch.backends.quantized.engine + ) model.train() elif quant_type == QuantType.STATIC: - qconfig_mapping = get_default_qconfig_mapping(torch.backends.quantized.engine) + qconfig_mapping = get_default_qconfig_mapping( + torch.backends.quantized.engine + ) model.eval() else: qconfig = default_dynamic_qconfig @@ -1098,30 +1271,37 @@ class QuantizationTestCase(TestCase): # overwrite qconfig_dict with custom_qconfig_dict if custom_qconfig_dict is not None: - assert type(custom_qconfig_dict) in (QConfigMapping, dict), \ - 'custom_qconfig_dict should be a QConfigMapping or a dict' + assert type(custom_qconfig_dict) in ( + QConfigMapping, + dict, + ), "custom_qconfig_dict should be a QConfigMapping or a dict" if isinstance(custom_qconfig_dict, QConfigMapping): qconfig_mapping = custom_qconfig_dict else: qconfig_mapping = QConfigMapping.from_dict(custom_qconfig_dict) prepared = prepare( - model, qconfig_mapping, + model, + qconfig_mapping, example_inputs=inputs, prepare_custom_config=prepare_custom_config, - backend_config=backend_config) + backend_config=backend_config, + ) if not quant_type == QuantType.DYNAMIC: prepared(*inputs) if print_debug_info: print() - print('quant type:\n', quant_type) - print('original model:\n', model) + print("quant type:\n", quant_type) + print("original model:\n", model) print() - print('prepared model:\n', prepared) + print("prepared model:\n", prepared) self.checkGraphModuleNodes( - prepared, prepare_expected_node, - prepare_expected_node_occurrence, prepare_expected_node_list) + prepared, + prepare_expected_node, + prepare_expected_node_occurrence, + prepare_expected_node_list, + ) prepared_copy = copy.deepcopy(prepared) qgraph = convert_fx(copy.deepcopy(prepared)) @@ -1134,20 +1314,34 @@ class QuantizationTestCase(TestCase): qgraph_to_check = qgraph_reference if is_reference else qgraph if print_debug_info: print() - print('quantized model:\n', qgraph_to_check) + print("quantized model:\n", qgraph_to_check) self.printGraphModule(qgraph_to_check) print() self.checkGraphModuleNodes( - qgraph_to_check, expected_node, expected_node_occurrence, expected_node_list) - return {"prepared": prepared_copy, - "quantized": qgraph_copy, - "quantized_reference": qgraph_reference_copy, - "quantized_output": result, - "quantized_reference_output": result_reference} + qgraph_to_check, + expected_node, + expected_node_occurrence, + expected_node_list, + ) + return { + "prepared": prepared_copy, + "quantized": qgraph_copy, + "quantized_reference": qgraph_reference_copy, + "quantized_output": result, + "quantized_reference_output": result_reference, + } - - def checkEmbeddingSerialization(self, qemb, num_embeddings, embedding_dim, indices, offsets, - set_qconfig, is_emb_bag, dtype=torch.quint8): + def checkEmbeddingSerialization( + self, + qemb, + num_embeddings, + embedding_dim, + indices, + offsets, + set_qconfig, + is_emb_bag, + dtype=torch.quint8, + ): # Test serialization of dynamic EmbeddingBag module using state_dict if is_emb_bag: inputs = [indices, offsets] @@ -1169,33 +1363,49 @@ class QuantizationTestCase(TestCase): # Check state dict serialization and torch.save APIs if is_emb_bag: - loaded_qemb = nnq.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim, - include_last_offset=True, mode='sum', dtype=dtype) + loaded_qemb = nnq.EmbeddingBag( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + include_last_offset=True, + mode="sum", + dtype=dtype, + ) else: - loaded_qemb = nnq.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim, dtype=dtype) + loaded_qemb = nnq.Embedding( + num_embeddings=num_embeddings, embedding_dim=embedding_dim, dtype=dtype + ) self.check_eager_serialization(qemb, loaded_qemb, inputs) loaded_qemb.load_state_dict(loaded_dict) - self.assertEqual(embedding_unpack(qemb._packed_params._packed_weight), - embedding_unpack(loaded_qemb._packed_params._packed_weight)) - + self.assertEqual( + embedding_unpack(qemb._packed_params._packed_weight), + embedding_unpack(loaded_qemb._packed_params._packed_weight), + ) # Test JIT serialization self.checkScriptable(qemb, [inputs], check_save_load=True) # Test from_float call if is_emb_bag: - float_embedding = torch.nn.EmbeddingBag(num_embeddings=num_embeddings, embedding_dim=embedding_dim, - include_last_offset=True, scale_grad_by_freq=False, mode='sum') + float_embedding = torch.nn.EmbeddingBag( + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + include_last_offset=True, + scale_grad_by_freq=False, + mode="sum", + ) else: - float_embedding = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim) + float_embedding = torch.nn.Embedding( + num_embeddings=num_embeddings, embedding_dim=embedding_dim + ) if set_qconfig: - float_qparams_observer = PerChannelMinMaxObserver.with_args(dtype=dtype, - qscheme=torch.per_channel_affine_float_qparams, - ch_axis=0) - float_embedding.qconfig = QConfig(activation=default_dynamic_quant_observer, - weight=float_qparams_observer) + float_qparams_observer = PerChannelMinMaxObserver.with_args( + dtype=dtype, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0 + ) + float_embedding.qconfig = QConfig( + activation=default_dynamic_quant_observer, weight=float_qparams_observer + ) prepare_dynamic(float_embedding) @@ -1211,6 +1421,7 @@ class QuantizationTestCase(TestCase): self.assertTrue(expected_name in str(q_embeddingbag)) + class QuantizationLiteTestCase(QuantizationTestCase): def _create_quantized_model(self, model_class: type[torch.nn.Module], **kwargs): # Creates quantized model for testing mobile script modules @@ -1223,9 +1434,7 @@ class QuantizationLiteTestCase(QuantizationTestCase): return model - def _compare_script_and_mobile(self, - model: torch.nn.Module, - input: torch.Tensor): + def _compare_script_and_mobile(self, model: torch.nn.Module, input: torch.Tensor): # Compares the numerical outputs for script and lite modules qengine = "qnnpack" with override_quantized_engine(qengine): @@ -1236,18 +1445,28 @@ class QuantizationLiteTestCase(QuantizationTestCase): for retry in range(1, max_retry + 1): # retries `max_retry` times; breaks iff succeeds else throws exception try: - buffer = io.BytesIO(script_module._save_to_buffer_for_lite_interpreter()) + buffer = io.BytesIO( + script_module._save_to_buffer_for_lite_interpreter() + ) buffer.seek(0) mobile_module = _load_for_lite_interpreter(buffer) mobile_module_result = mobile_module(input) - torch.testing.assert_close(script_module_result, mobile_module_result) + torch.testing.assert_close( + script_module_result, mobile_module_result + ) mobile_module_forward_result = mobile_module.forward(input) - torch.testing.assert_close(script_module_result, mobile_module_forward_result) + torch.testing.assert_close( + script_module_result, mobile_module_forward_result + ) - mobile_module_run_method_result = mobile_module.run_method("forward", input) - torch.testing.assert_close(script_module_result, mobile_module_run_method_result) + mobile_module_run_method_result = mobile_module.run_method( + "forward", input + ) + torch.testing.assert_close( + script_module_result, mobile_module_run_method_result + ) except AssertionError as e: if retry == max_retry: raise e @@ -1260,6 +1479,7 @@ class PT2EQuantizationTestCase(QuantizationTestCase): """ Base QuantizationTestCase for PT2 with some helper methods. """ + _MAP_TO_FX_TRACED_OPS = { torch.ops.quantized_decomposed.quantize_per_tensor: torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor: torch.ops.quantized_decomposed.dequantize_per_tensor.default, @@ -1297,6 +1517,7 @@ class PT2EQuantizationTestCase(QuantizationTestCase): m, example_inputs, dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None, + strict=True, ).module() if is_qat: @@ -1337,6 +1558,7 @@ class PT2EQuantizationTestCase(QuantizationTestCase): m_fx, example_inputs, dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None, + strict=True, ).module() node_occurrence = {} for k, v in PT2EQuantizationTestCase._MAP_TO_FX_TRACED_OPS.items(): @@ -1344,7 +1566,8 @@ class PT2EQuantizationTestCase(QuantizationTestCase): node_occurrence[ns.call_function(v)] = expected_node_occurrence[k] if training_ir_node_occurrence is not None: node_occurrence = { - ns.call_function(k): v for k, v in training_ir_node_occurrence.items() + ns.call_function(k): v + for k, v in training_ir_node_occurrence.items() } self.checkGraphModuleNodes(m_fx, expected_node_occurrence=node_occurrence) fx_quant_output = m_fx(*example_inputs) @@ -1355,10 +1578,7 @@ class PT2EQuantizationTestCase(QuantizationTestCase): # resetting dynamo cache torch._dynamo.reset() - m = export_for_training( - m, - example_inputs, - ).module() + m = export_for_training(m, example_inputs, strict=True).module() if is_qat: m = prepare_qat_pt2e(m, quantizer) else: @@ -1377,14 +1597,18 @@ class PT2EQuantizationTestCase(QuantizationTestCase): return self.linear(x) quantizer = XNNPACKQuantizer() - operator_config = get_symmetric_quantization_config(is_per_channel=is_per_channel) + operator_config = get_symmetric_quantization_config( + is_per_channel=is_per_channel + ) quantizer.set_global(operator_config) example_inputs = (torch.randn(2, 2),) m = M().eval() return self._quantize(m, quantizer, example_inputs) + # Below are a series of toy models to use in testing quantization + class SingleLayerLinearModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1397,8 +1621,9 @@ class SingleLayerLinearModel(torch.nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class AnnotatedSingleLayerLinearModel(torch.nn.Module): - def __init__(self, qengine='fbgemm'): + def __init__(self, qengine="fbgemm"): super().__init__() self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) self.fc1 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float)) @@ -1410,8 +1635,9 @@ class AnnotatedSingleLayerLinearModel(torch.nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class SingleLayerLinearDynamicModel(torch.nn.Module): - def __init__(self, qengine='fbgemm'): + def __init__(self, qengine="fbgemm"): super().__init__() self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float) @@ -1423,6 +1649,7 @@ class SingleLayerLinearDynamicModel(torch.nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class LinearAddModel(nn.Module): def __init__(self) -> None: super().__init__() @@ -1438,38 +1665,41 @@ class LinearAddModel(nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class RNNDynamicModel(torch.nn.Module): def __init__(self, mod_type): super().__init__() self.qconfig = default_dynamic_qconfig - if mod_type == 'GRU': + if mod_type == "GRU": self.mod = torch.nn.GRU(2, 2).to(dtype=torch.float) - if mod_type == 'LSTM': + if mod_type == "LSTM": self.mod = torch.nn.LSTM(2, 2).to(dtype=torch.float) def forward(self, x): x = self.mod(x) return x + class RNNCellDynamicModel(torch.nn.Module): def __init__(self, mod_type): super().__init__() self.qconfig = default_dynamic_qconfig - if mod_type == 'GRUCell': + if mod_type == "GRUCell": self.mod = torch.nn.GRUCell(2, 2).to(dtype=torch.float) - if mod_type == 'LSTMCell': + if mod_type == "LSTMCell": self.mod = torch.nn.LSTMCell(2, 2).to(dtype=torch.float) - if mod_type == 'RNNReLU': - self.mod = torch.nn.RNNCell(2, 2, nonlinearity='relu').to(dtype=torch.float) - if mod_type == 'RNNTanh': - self.mod = torch.nn.RNNCell(2, 2, nonlinearity='tanh').to(dtype=torch.float) + if mod_type == "RNNReLU": + self.mod = torch.nn.RNNCell(2, 2, nonlinearity="relu").to(dtype=torch.float) + if mod_type == "RNNTanh": + self.mod = torch.nn.RNNCell(2, 2, nonlinearity="tanh").to(dtype=torch.float) def forward(self, x): x = self.mod(x) return x + class LSTMwithHiddenDynamicModel(torch.nn.Module): - def __init__(self, qengine='fbgemm'): + def __init__(self, qengine="fbgemm"): super().__init__() self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) self.lstm = torch.nn.LSTM(2, 2).to(dtype=torch.float) @@ -1478,6 +1708,7 @@ class LSTMwithHiddenDynamicModel(torch.nn.Module): x, hid = self.lstm(x, hid) return x, hid + class ConvModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1490,6 +1721,7 @@ class ConvModel(torch.nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class ConvTransposeModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1502,6 +1734,7 @@ class ConvTransposeModel(torch.nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class AnnotatedConvModel(torch.nn.Module): def __init__(self, qengine): super().__init__() @@ -1519,6 +1752,7 @@ class AnnotatedConvModel(torch.nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class AnnotatedConvTransposeModel(torch.nn.Module): def __init__(self, qengine): super().__init__() @@ -1536,6 +1770,7 @@ class AnnotatedConvTransposeModel(torch.nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class ConvBnModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1550,6 +1785,7 @@ class ConvBnModel(torch.nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class AnnotatedConvBnModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1569,6 +1805,7 @@ class AnnotatedConvBnModel(torch.nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class ConvBnReLUModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1585,8 +1822,9 @@ class ConvBnReLUModel(torch.nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class AnnotatedConvBnReLUModel(torch.nn.Module): - def __init__(self, qengine='fbgemm'): + def __init__(self, qengine="fbgemm"): super().__init__() self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) self.conv = torch.nn.Conv2d(3, 5, 3, bias=False).to(dtype=torch.float) @@ -1606,13 +1844,18 @@ class AnnotatedConvBnReLUModel(torch.nn.Module): def fuse_model(self): # TODO: remove this check and define two fuse_modules function on this module if self.training: - torch.ao.quantization.fuse_modules_qat(self, [['conv', 'bn', 'relu']], inplace=True) + torch.ao.quantization.fuse_modules_qat( + self, [["conv", "bn", "relu"]], inplace=True + ) else: - torch.ao.quantization.fuse_modules(self, [['conv', 'bn', 'relu']], inplace=True) + torch.ao.quantization.fuse_modules( + self, [["conv", "bn", "relu"]], inplace=True + ) def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class TwoLayerConvModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1627,6 +1870,7 @@ class TwoLayerConvModel(torch.nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class TwoLayerLinearModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1641,6 +1885,7 @@ class TwoLayerLinearModel(torch.nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class LinearModelWithSubmodule(nn.Module): def __init__(self) -> None: super().__init__() @@ -1655,6 +1900,7 @@ class LinearModelWithSubmodule(nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return self.subm.get_example_inputs() + class AnnotatedTwoLayerLinearModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1670,6 +1916,7 @@ class AnnotatedTwoLayerLinearModel(torch.nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class ActivationsTestModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1686,6 +1933,7 @@ class ActivationsTestModel(torch.nn.Module): x = self.dequant(x) return x + class LinearReluModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1716,6 +1964,7 @@ class LinearReluLinearModel(torch.nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class LinearReluAddModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1734,6 +1983,7 @@ class LinearReluAddModel(torch.nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class LinearBnLeakyReluModel(torch.nn.Module): def __init__(self, with_bn=True): super().__init__() @@ -1752,6 +2002,7 @@ class LinearBnLeakyReluModel(torch.nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class LinearTanhModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1766,13 +2017,16 @@ class LinearTanhModel(torch.nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class ConvBnAddReluModel(torch.nn.Module): - def __init__(self, - with_bn=True, - with_relu=True, - left_conv=True, - two_conv=True, - use_torch_add=True): + def __init__( + self, + with_bn=True, + with_relu=True, + left_conv=True, + two_conv=True, + use_torch_add=True, + ): super().__init__() self.conv = nn.Conv2d(5, 5, (2, 2)) self.conv2 = nn.Conv2d(5, 5, (2, 2)) @@ -1826,6 +2080,7 @@ class ConvBnAddReluModel(torch.nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5, 3, 3), torch.rand(1, 5, 2, 2)) + # TODO: self.fc should be self.conv class ConvReluModel(torch.nn.Module): def __init__(self) -> None: @@ -1840,6 +2095,7 @@ class ConvReluModel(torch.nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + # TODO: self.fc should be self.conv class ConvReluConvModel(torch.nn.Module): def __init__(self) -> None: @@ -1857,6 +2113,7 @@ class ConvReluConvModel(torch.nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + # TODO: self.fc should be self.conv class ConvReluAddModel(torch.nn.Module): def __init__(self) -> None: @@ -1876,6 +2133,7 @@ class ConvReluAddModel(torch.nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class NormalizationTestModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1897,6 +2155,7 @@ class NormalizationTestModel(torch.nn.Module): x = self.instance_norm3d(x.unsqueeze(-1)) return x + class NestedModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1910,6 +2169,7 @@ class NestedModel(torch.nn.Module): x = self.fc3(x) return x + class AnnotatedNestedModel(torch.nn.Module): def __init__(self, qengine): super().__init__() @@ -1918,7 +2178,7 @@ class AnnotatedNestedModel(torch.nn.Module): self.fc3 = QuantWrapper(torch.nn.Linear(5, 5).to(dtype=torch.float)) self.fc3.qconfig = default_qconfig self.sub2.fc1 = QuantWrapper(self.sub2.fc1) - if qengine == 'fbgemm': + if qengine == "fbgemm": self.sub2.fc1.qconfig = default_per_channel_qconfig else: self.sub2.fc1.qconfig = default_qconfig @@ -1929,6 +2189,7 @@ class AnnotatedNestedModel(torch.nn.Module): x = self.fc3(x) return x + class AnnotatedSubNestedModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1944,6 +2205,7 @@ class AnnotatedSubNestedModel(torch.nn.Module): x = self.fc3(x) return x + class AnnotatedCustomConfigNestedModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1953,12 +2215,11 @@ class AnnotatedCustomConfigNestedModel(torch.nn.Module): self.fc3.qconfig = default_qconfig self.sub2.qconfig = default_qconfig - custom_options = { - 'dtype': torch.quint8, - 'qscheme': torch.per_tensor_affine - } - custom_qconfig = QConfig(activation=default_observer.with_args(**custom_options), - weight=default_weight_observer) + custom_options = {"dtype": torch.quint8, "qscheme": torch.per_tensor_affine} + custom_qconfig = QConfig( + activation=default_observer.with_args(**custom_options), + weight=default_weight_observer, + ) self.sub2.fc1.qconfig = custom_qconfig self.sub2.fc1 = QuantWrapper(self.sub2.fc1) @@ -1970,6 +2231,7 @@ class AnnotatedCustomConfigNestedModel(torch.nn.Module): x = self.fc3(x) return x + class QuantSubModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -1985,6 +2247,7 @@ class QuantSubModel(torch.nn.Module): x = self.fc3(x) return x + class InnerModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2004,14 +2267,14 @@ class InnerModule(torch.nn.Module): if idx >= len(named_children) - 1: break if isinstance(named_children[idx + 1][1], torch.nn.ReLU): - fusable_layers.append([current_name, - named_children[idx + 1][0]]) + fusable_layers.append([current_name, named_children[idx + 1][0]]) # TODO: remove this check and define two fuse_modules function on this module if self.training: torch.ao.quantization.fuse_modules_qat(self, fusable_layers, inplace=True) else: torch.ao.quantization.fuse_modules(self, fusable_layers, inplace=True) + class FunctionalLinear(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2024,6 +2287,7 @@ class FunctionalLinear(torch.nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 5),) + class SingleLayerFunctionalLinearModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2036,6 +2300,7 @@ class SingleLayerFunctionalLinearModel(torch.nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return self.linear1.get_example_inputs() + class TwoLayerFunctionalLinearModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2050,6 +2315,7 @@ class TwoLayerFunctionalLinearModel(torch.nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return self.linear1.get_example_inputs() + class FunctionalLinearAddModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2065,6 +2331,7 @@ class FunctionalLinearAddModel(torch.nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return self.linear1.get_example_inputs() + class FunctionalLinearReluModel(nn.Module): def __init__(self) -> None: super().__init__() @@ -2078,6 +2345,7 @@ class FunctionalLinearReluModel(nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return self.linear.get_example_inputs() + class FunctionalLinearReluLinearModel(nn.Module): def __init__(self) -> None: super().__init__() @@ -2094,6 +2362,7 @@ class FunctionalLinearReluLinearModel(nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return self.linear1.get_example_inputs() + class FunctionalConv2d(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2105,11 +2374,20 @@ class FunctionalConv2d(torch.nn.Module): self.groups = 1 def forward(self, x): - return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + return F.conv2d( + x, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) def get_example_inputs(self) -> tuple[Any, ...]: return (torch.rand(1, 3, 5, 5),) + class SingleLayerFunctionalConvModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2122,6 +2400,7 @@ class SingleLayerFunctionalConvModel(torch.nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return self.conv1.get_example_inputs() + class TwoLayerFunctionalConvModel(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2136,6 +2415,7 @@ class TwoLayerFunctionalConvModel(torch.nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return self.conv1.get_example_inputs() + class FunctionalConvReluModel(nn.Module): def __init__(self) -> None: super().__init__() @@ -2149,6 +2429,7 @@ class FunctionalConvReluModel(nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return self.conv.get_example_inputs() + class FunctionalConvReluConvModel(nn.Module): def __init__(self) -> None: super().__init__() @@ -2165,10 +2446,12 @@ class FunctionalConvReluConvModel(nn.Module): def get_example_inputs(self) -> tuple[Any, ...]: return self.conv1.get_example_inputs() + class SkipQuantModel(torch.nn.Module): r"""We can skip quantization by explicitly setting qconfig of a submodule to None """ + def __init__(self) -> None: super().__init__() self.sub = InnerModule() @@ -2180,10 +2463,12 @@ class SkipQuantModel(torch.nn.Module): def fuse_modules(self): self.sub.fuse_modules() + class AnnotatedSkipQuantModel(torch.nn.Module): r"""We can skip quantization by explicitly setting qconfig of a submodule to None """ + def __init__(self, qengine): super().__init__() self.qconfig = torch.ao.quantization.get_default_qconfig(qengine) @@ -2198,9 +2483,10 @@ class AnnotatedSkipQuantModel(torch.nn.Module): def fuse_modules(self): self.sub.module.fuse_modules() + class QuantStubModel(torch.nn.Module): - r"""A Module with manually inserted `QuantStub` and `DeQuantStub` - """ + r"""A Module with manually inserted `QuantStub` and `DeQuantStub`""" + def __init__(self) -> None: super().__init__() self.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack") @@ -2213,9 +2499,10 @@ class QuantStubModel(torch.nn.Module): x = self.fc(x) return self.dequant(x) + class ManualLinearQATModel(torch.nn.Module): - r"""A Module with manually inserted `QuantStub` and `DeQuantStub` - """ + r"""A Module with manually inserted `QuantStub` and `DeQuantStub`""" + def __init__(self, qengine): super().__init__() self.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) @@ -2230,9 +2517,10 @@ class ManualLinearQATModel(torch.nn.Module): x = self.fc2(x) return self.dequant(x) + class ManualDropoutQATModel(torch.nn.Module): - r"""A Module with manually inserted `QuantStub` and `DeQuantStub` - """ + r"""A Module with manually inserted `QuantStub` and `DeQuantStub`""" + def __init__(self, qengine): super().__init__() self.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine) @@ -2247,9 +2535,10 @@ class ManualDropoutQATModel(torch.nn.Module): x = self.dropout(x) return self.dequant(x) + class ManualLinearDynamicQATModel(torch.nn.Module): - r"""A Module that uses a dynamic QAT by default. - """ + r"""A Module that uses a dynamic QAT by default.""" + def __init__(self, qconfig=None): super().__init__() self.qconfig = qconfig or default_dynamic_qat_qconfig @@ -2261,13 +2550,19 @@ class ManualLinearDynamicQATModel(torch.nn.Module): x = self.fc2(x) return x + class ManualConvLinearQATModel(torch.nn.Module): r"""A module with manually inserted `QuantStub` and `DeQuantStub` and contains both linear and conv modules """ + def __init__(self, qconfig=None): super().__init__() - self.qconfig = qconfig if qconfig else torch.ao.quantization.get_default_qat_qconfig("qnnpack") + self.qconfig = ( + qconfig + if qconfig + else torch.ao.quantization.get_default_qat_qconfig("qnnpack") + ) self.quant = QuantStub() self.dequant = DeQuantStub() self.conv = torch.nn.Conv2d(3, 1, kernel_size=3).to(dtype=torch.float) @@ -2282,30 +2577,38 @@ class ManualConvLinearQATModel(torch.nn.Module): x = self.fc2(x) return self.dequant(x) + class ManualConvLinearSymmQATModel(ManualConvLinearQATModel): r"""Same as ManualConvLinearQATModule but with Symmetric Quantization. Supported only with qnnpack. """ + def __init__(self) -> None: super().__init__(default_symmetric_qnnpack_qat_qconfig) + class ManualEmbeddingBagLinear(nn.Module): def __init__(self) -> None: super().__init__() - self.emb = nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, mode='sum') + self.emb = nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, mode="sum") self.emb.qconfig = default_embedding_qat_qconfig self.quant = QuantStub() self.dequant = DeQuantStub() self.linear = nn.Linear(12, 1).to(dtype=torch.float) self.qconfig = get_default_qat_qconfig("qnnpack") - def forward(self, input: torch.Tensor, offsets: Optional[torch.Tensor] = None, - per_sample_weights: Optional[torch.Tensor] = None): + def forward( + self, + input: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + per_sample_weights: Optional[torch.Tensor] = None, + ): x = self.emb(input, offsets, per_sample_weights) x = self.quant(x) x = self.linear(x) return self.dequant(x) + class DeFusedEmbeddingBagLinear(nn.Module): r"""A module to simulate QAT embedding bag with a linear layer, this module uses a separate embedding and bagging op, similar @@ -2313,6 +2616,7 @@ class DeFusedEmbeddingBagLinear(nn.Module): https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html """ + def __init__(self) -> None: super().__init__() self.emb = nn.Embedding(num_embeddings=10, embedding_dim=12) @@ -2329,6 +2633,7 @@ class DeFusedEmbeddingBagLinear(nn.Module): x = self.linear(x) return self.dequant(x) + class SubModelForFusion(nn.Module): def __init__(self) -> None: super().__init__() @@ -2350,6 +2655,7 @@ class SubModelWithoutFusion(nn.Module): def forward(self, x): return self.relu(self.conv(x)) + class ModelForFusion(nn.Module): def __init__(self, qconfig): super().__init__() @@ -2396,14 +2702,14 @@ class ModelForFusion(nn.Module): y = self.dequant(y) return x + class ConvBNReLU(nn.Sequential): def __init__(self) -> None: super().__init__( - nn.Conv2d(3, 3, 1, 1, bias=False), - nn.BatchNorm2d(3), - nn.ReLU(inplace=False) + nn.Conv2d(3, 3, 1, 1, bias=False), nn.BatchNorm2d(3), nn.ReLU(inplace=False) ) + class ModelWithSequentialFusion(nn.Module): def __init__(self) -> None: super().__init__() @@ -2428,6 +2734,7 @@ class ModelWithSequentialFusion(nn.Module): x = self.dequant(x) return x + class ModelForFusionWithBias(nn.Module): def __init__(self) -> None: super().__init__() @@ -2449,6 +2756,7 @@ class ModelForFusionWithBias(nn.Module): x = self.dequant(x) return x + class ModelForLinearBNFusion(nn.Module): def __init__(self) -> None: super().__init__() @@ -2460,6 +2768,7 @@ class ModelForLinearBNFusion(nn.Module): def forward(self, x): return self.bn(self.fc(x)) + class DummyObserver(torch.nn.Module): def calculate_qparams(self): return 1.0, 0 @@ -2543,9 +2852,14 @@ class ResNetBase(torch.nn.Module): def fuse_model(self): # TODO: remove this check and define two fuse_model function on this module if self.training: - torch.ao.quantization.fuse_modules_qat(self, [['conv1', 'bn1', 'relu1']], inplace=True) + torch.ao.quantization.fuse_modules_qat( + self, [["conv1", "bn1", "relu1"]], inplace=True + ) else: - torch.ao.quantization.fuse_modules(self, [['conv1', 'bn1', 'relu1']], inplace=True) + torch.ao.quantization.fuse_modules( + self, [["conv1", "bn1", "relu1"]], inplace=True + ) + class ModelMultipleOps(torch.nn.Module): def __init__(self) -> None: @@ -2578,6 +2892,7 @@ class ModelMultipleOps(torch.nn.Module): out = self.fc(out) return out + # Model to ensure consistency of fake quant with true quant # Average pooling and mean operations are not modelled # accurately with fake-quant so this model does not @@ -2612,15 +2927,22 @@ class ModelMultipleOpsNoAvgPool(torch.nn.Module): out = self.fc(out) return out + class EmbeddingBagModule(torch.nn.Module): def __init__(self) -> None: super().__init__() - self.emb = torch.nn.EmbeddingBag(num_embeddings=10, embedding_dim=12, - include_last_offset=True, scale_grad_by_freq=False, mode='sum') + self.emb = torch.nn.EmbeddingBag( + num_embeddings=10, + embedding_dim=12, + include_last_offset=True, + scale_grad_by_freq=False, + mode="sum", + ) def forward(self, indices, offsets, per_sample_weights): return self.emb(indices, offsets, per_sample_weights) + class EmbeddingModule(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2629,6 +2951,7 @@ class EmbeddingModule(torch.nn.Module): def forward(self, indices): return self.emb(indices) + class EmbeddingWithStaticLinear(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -2647,9 +2970,11 @@ class EmbeddingWithStaticLinear(torch.nn.Module): features = torch.cat([fc] + [emb], dim=1) return features -class DenseTopMLP(nn.Module): - def __init__(self, dense_dim, dense_out, embedding_dim, top_out_in, top_out_out) -> None: +class DenseTopMLP(nn.Module): + def __init__( + self, dense_dim, dense_out, embedding_dim, top_out_in, top_out_out + ) -> None: super().__init__() self.dense_mlp = nn.Sequential( @@ -2671,16 +2996,18 @@ class DenseTopMLP(nn.Module): out = self.top_mlp(features) return out + # thin wrapper around embedding bag, because tracing inside nn.Embedding # bag is not supported at the moment and this is top level class EmbBagWrapper(nn.Module): def __init__(self, num_embeddings, embedding_dim): super().__init__() - self.emb_bag = nn.EmbeddingBag(num_embeddings, embedding_dim, mode='sum') + self.emb_bag = nn.EmbeddingBag(num_embeddings, embedding_dim, mode="sum") def forward(self, indices, offsets): return self.emb_bag(indices, offsets) + class SparseNNModel(nn.Module): _NUM_EMBEDDINGS = 10 _EMBEDDING_DIM = 5 @@ -2695,8 +3022,12 @@ class SparseNNModel(nn.Module): self.model_sparse = EmbBagWrapper(self._NUM_EMBEDDINGS, self._EMBEDDING_DIM) self.dense_top = DenseTopMLP( - self._DENSE_DIM, self._DENSE_OUTPUT, self._EMBEDDING_DIM, self._TOP_OUT_IN, - self._TOP_OUT_OUT) + self._DENSE_DIM, + self._DENSE_OUTPUT, + self._EMBEDDING_DIM, + self._TOP_OUT_IN, + self._TOP_OUT_OUT, + ) def forward( self, @@ -2704,12 +3035,12 @@ class SparseNNModel(nn.Module): sparse_offsets: torch.Tensor, dense: torch.Tensor, ) -> torch.Tensor: - sparse_feature = self.model_sparse(sparse_indices, sparse_offsets) out = self.dense_top(sparse_feature, dense) return out + class TestHelperModules: class ControlFlow(torch.nn.Module): def forward( @@ -2719,7 +3050,6 @@ class TestHelperModules: pred2: torch.Tensor, y: torch.Tensor, ) -> torch.Tensor: - def true_nested(y: torch.Tensor) -> torch.Tensor: y = y + y y = torch.mm(y, y) @@ -2736,7 +3066,10 @@ class TestHelperModules: return x.cos() def map_fn( - x: torch.Tensor, pred1: torch.Tensor, pred2: torch.Tensor, y: torch.Tensor + x: torch.Tensor, + pred1: torch.Tensor, + pred2: torch.Tensor, + y: torch.Tensor, ) -> torch.Tensor: x = x.cos() y = control_flow.cond(pred1, true_fn, false_fn, [y, pred2]) @@ -2747,7 +3080,12 @@ class TestHelperModules: return control_flow.map(map_fn, xs, pred1, pred2, y) def example_inputs(self): - return (torch.ones(2, 2), torch.tensor([False]), torch.tensor([False]), torch.ones(2, 2),) + return ( + torch.ones(2, 2), + torch.tensor([False]), + torch.tensor([False]), + torch.ones(2, 2), + ) class Conv2dPropAnnotaton(torch.nn.Module): def __init__(self) -> None: @@ -3029,16 +3367,20 @@ class TestHelperModules: x = self.relu(self.fc(x)) return x + def _generate_qdq_quantized_model( mod, inputs, is_qat=False, is_dynamic=False, quantizer=None ): - def get_default_quantizer(is_qat, is_dynamic, inputs): - has_xpu = any(isinstance(input, torch.Tensor) and input.device.type == "xpu" - for input in inputs) + has_xpu = any( + isinstance(input, torch.Tensor) and input.device.type == "xpu" + for input in inputs + ) if has_xpu: quantizer = XPUInductorQuantizer() - assert (not is_qat) and (not is_dynamic), "QAT and dynamic quantization is not supported at XPU backend currently" + assert (not is_qat) and ( + not is_dynamic + ), "QAT and dynamic quantization is not supported at XPU backend currently" quantizer.set_global(xpuiq.get_default_xpu_inductor_quantization_config()) else: quantizer = X86InductorQuantizer() @@ -3051,12 +3393,11 @@ def _generate_qdq_quantized_model( maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad() with maybe_no_grad: - export_model = export_for_training( - mod, - inputs, - ).module() + export_model = export_for_training(mod, inputs, strict=True).module() quantizer = ( - quantizer if quantizer else get_default_quantizer(is_qat, is_dynamic, inputs) + quantizer + if quantizer + else get_default_quantizer(is_qat, is_dynamic, inputs) ) prepare_model = ( prepare_qat_pt2e(export_model, quantizer)