mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add support for list, tuple and dict in numeric debugger (#143882)
Summary: Previously numeric debugger only supports torch.Tensor, this PR adds support for list, tuple and dict as well Test Plan: python test/test_quantization.py -k test_extract_results_from_loggers_list_output Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D67660049](https://our.internmc.facebook.com/intern/diff/D67660049) Pull Request resolved: https://github.com/pytorch/pytorch/pull/143882 Approved by: https://github.com/dulinriley
This commit is contained in:
committed by
PyTorch MergeBot
parent
c3c27aef34
commit
ad78edee8e
@ -267,6 +267,36 @@ class TestNumericDebugger(TestCase):
|
||||
if len(node_summary.results) > 0:
|
||||
self.assertGreaterEqual(node_summary.results[0].sqnr, 35)
|
||||
|
||||
def test_extract_results_from_loggers_list_output(self):
|
||||
m = TestHelperModules.Conv2dWithSplit()
|
||||
example_inputs = m.example_inputs()
|
||||
ep = export_for_training(m, example_inputs)
|
||||
generate_numeric_debug_handle(ep)
|
||||
m = ep.module()
|
||||
m_ref_logger = prepare_for_propagation_comparison(m)
|
||||
|
||||
quantizer = XNNPACKQuantizer().set_global(
|
||||
get_symmetric_quantization_config(is_per_channel=False)
|
||||
)
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
m(*example_inputs)
|
||||
m = convert_pt2e(m)
|
||||
m_quant_logger = prepare_for_propagation_comparison(m)
|
||||
|
||||
m_ref_logger(*example_inputs)
|
||||
m_quant_logger(*example_inputs)
|
||||
ref_results = extract_results_from_loggers(m_ref_logger)
|
||||
quant_results = extract_results_from_loggers(m_quant_logger)
|
||||
comparison_results = compare_results(ref_results, quant_results)
|
||||
for node_summary in comparison_results.values():
|
||||
if len(node_summary.results) > 0:
|
||||
sqnr = node_summary.results[0].sqnr
|
||||
if isinstance(sqnr, list):
|
||||
for sqnr_i in sqnr:
|
||||
self.assertGreaterEqual(sqnr_i, 35)
|
||||
else:
|
||||
self.assertGreaterEqual(sqnr, 35)
|
||||
|
||||
def test_added_node_gets_unique_id(self) -> None:
|
||||
m = TestHelperModules.Conv2dThenConv1d()
|
||||
example_inputs = m.example_inputs()
|
||||
|
Reference in New Issue
Block a user