Add support for list, tuple and dict in numeric debugger (#143882)

Summary:
Previously numeric debugger only supports torch.Tensor, this PR adds support for list, tuple and dict as well

Test Plan:
python test/test_quantization.py -k test_extract_results_from_loggers_list_output

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D67660049](https://our.internmc.facebook.com/intern/diff/D67660049)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143882
Approved by: https://github.com/dulinriley
This commit is contained in:
Jerry Zhang
2024-12-27 10:12:34 -08:00
committed by PyTorch MergeBot
parent c3c27aef34
commit ad78edee8e
3 changed files with 118 additions and 29 deletions

View File

@ -267,6 +267,36 @@ class TestNumericDebugger(TestCase):
if len(node_summary.results) > 0:
self.assertGreaterEqual(node_summary.results[0].sqnr, 35)
def test_extract_results_from_loggers_list_output(self):
m = TestHelperModules.Conv2dWithSplit()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs)
generate_numeric_debug_handle(ep)
m = ep.module()
m_ref_logger = prepare_for_propagation_comparison(m)
quantizer = XNNPACKQuantizer().set_global(
get_symmetric_quantization_config(is_per_channel=False)
)
m = prepare_pt2e(m, quantizer)
m(*example_inputs)
m = convert_pt2e(m)
m_quant_logger = prepare_for_propagation_comparison(m)
m_ref_logger(*example_inputs)
m_quant_logger(*example_inputs)
ref_results = extract_results_from_loggers(m_ref_logger)
quant_results = extract_results_from_loggers(m_quant_logger)
comparison_results = compare_results(ref_results, quant_results)
for node_summary in comparison_results.values():
if len(node_summary.results) > 0:
sqnr = node_summary.results[0].sqnr
if isinstance(sqnr, list):
for sqnr_i in sqnr:
self.assertGreaterEqual(sqnr_i, 35)
else:
self.assertGreaterEqual(sqnr, 35)
def test_added_node_gets_unique_id(self) -> None:
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()