mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Previously numeric debugger only supports torch.Tensor, this PR adds support for list, tuple and dict as well Test Plan: python test/test_quantization.py -k test_extract_results_from_loggers_list_output Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D67660049](https://our.internmc.facebook.com/intern/diff/D67660049) Pull Request resolved: https://github.com/pytorch/pytorch/pull/143882 Approved by: https://github.com/dulinriley
350 lines
14 KiB
Python
350 lines
14 KiB
Python
# Owner(s): ["oncall: quantization"]
|
|
|
|
import copy
|
|
import unittest
|
|
from collections import Counter
|
|
from typing import Dict
|
|
|
|
import torch
|
|
from torch.ao.quantization import (
|
|
compare_results,
|
|
CUSTOM_KEY,
|
|
extract_results_from_loggers,
|
|
generate_numeric_debug_handle,
|
|
NUMERIC_DEBUG_HANDLE_KEY,
|
|
prepare_for_propagation_comparison,
|
|
)
|
|
from torch.ao.quantization.pt2e.graph_utils import bfs_trace_with_node_process
|
|
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
|
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
|
get_symmetric_quantization_config,
|
|
XNNPACKQuantizer,
|
|
)
|
|
from torch.export import export_for_training
|
|
from torch.testing._internal.common_quantization import TestHelperModules
|
|
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef, TestCase
|
|
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile")
|
|
class TestNumericDebugger(TestCase):
|
|
def _assert_each_node_has_debug_handle(self, model) -> None:
|
|
def _assert_node_has_debug_handle(node):
|
|
self.assertTrue(
|
|
CUSTOM_KEY in node.meta
|
|
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY],
|
|
f"Node {node} doesn't have debug handle",
|
|
)
|
|
|
|
bfs_trace_with_node_process(model, _assert_node_has_debug_handle)
|
|
|
|
def _extract_debug_handles(self, model) -> Dict[str, int]:
|
|
debug_handle_map: Dict[str, int] = {}
|
|
|
|
def _extract_debug_handles_from_node(node):
|
|
nonlocal debug_handle_map
|
|
if (
|
|
CUSTOM_KEY in node.meta
|
|
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY]
|
|
):
|
|
debug_handle_map[str(node)] = node.meta[CUSTOM_KEY][
|
|
NUMERIC_DEBUG_HANDLE_KEY
|
|
]
|
|
|
|
bfs_trace_with_node_process(model, _extract_debug_handles_from_node)
|
|
|
|
return debug_handle_map
|
|
|
|
def _extract_debug_handles_with_prev_decomp_op(self, model) -> Dict[str, int]:
|
|
prev_decomp_op_to_debug_handle_map: Dict[str, int] = {}
|
|
|
|
def _extract_debug_handles_with_prev_decomp_op_from_node(node):
|
|
nonlocal prev_decomp_op_to_debug_handle_map
|
|
if (
|
|
CUSTOM_KEY in node.meta
|
|
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY]
|
|
):
|
|
prev_decomp_op = str(node.meta.get("nn_module_stack"))
|
|
debug_handle = node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY]
|
|
if prev_decomp_op not in prev_decomp_op_to_debug_handle_map:
|
|
prev_decomp_op_to_debug_handle_map[prev_decomp_op] = debug_handle
|
|
else:
|
|
assert (
|
|
prev_decomp_op_to_debug_handle_map[prev_decomp_op]
|
|
== debug_handle
|
|
), f"Node {node} has different debug handle {debug_handle}"
|
|
"than previous node sharing the same decomp op {prev_decomp_op}"
|
|
|
|
bfs_trace_with_node_process(
|
|
model, _extract_debug_handles_with_prev_decomp_op_from_node
|
|
)
|
|
return prev_decomp_op_to_debug_handle_map
|
|
|
|
def test_simple(self):
|
|
m = TestHelperModules.Conv2dThenConv1d()
|
|
example_inputs = m.example_inputs()
|
|
ep = export_for_training(m, example_inputs)
|
|
generate_numeric_debug_handle(ep)
|
|
self._assert_each_node_has_debug_handle(ep)
|
|
debug_handle_map = self._extract_debug_handles(ep)
|
|
|
|
self.assertEqual(len(set(debug_handle_map.values())), len(debug_handle_map))
|
|
|
|
def test_control_flow(self):
|
|
m = TestHelperModules.ControlFlow()
|
|
example_inputs = m.example_inputs()
|
|
ep = export_for_training(m, example_inputs)
|
|
generate_numeric_debug_handle(ep)
|
|
|
|
self._assert_each_node_has_debug_handle(ep)
|
|
debug_handle_map = self._extract_debug_handles(ep)
|
|
|
|
self.assertEqual(len(set(debug_handle_map.values())), len(debug_handle_map))
|
|
|
|
def test_quantize_pt2e_preserve_handle(self):
|
|
m = TestHelperModules.Conv2dThenConv1d()
|
|
example_inputs = m.example_inputs()
|
|
ep = export_for_training(m, example_inputs)
|
|
generate_numeric_debug_handle(ep)
|
|
m = ep.module()
|
|
|
|
quantizer = XNNPACKQuantizer().set_global(
|
|
get_symmetric_quantization_config(is_per_channel=False)
|
|
)
|
|
m = prepare_pt2e(m, quantizer)
|
|
debug_handle_map = self._extract_debug_handles(m)
|
|
res_counter = Counter(debug_handle_map.values())
|
|
repeated_debug_handle_ids = [1, 2, 3]
|
|
# 3 ids were repeated because we copy over the id from node to its output observer
|
|
# torch.ops.aten.conv2d.default, torch.ops.aten.squeeze.dim and torch.ops.aten.conv1d.default
|
|
for dh_id in repeated_debug_handle_ids:
|
|
self.assertEqual(res_counter[dh_id], 2)
|
|
|
|
m(*example_inputs)
|
|
m = convert_pt2e(m)
|
|
self._assert_each_node_has_debug_handle(ep)
|
|
debug_handle_map = self._extract_debug_handles(m)
|
|
res_counter = Counter(debug_handle_map.values())
|
|
# same set of ids where repeated, because we copy over the id from observer/fake_quant to
|
|
# dequantize node
|
|
repeated_debug_handle_ids = [1, 2, 3]
|
|
for dh_id in repeated_debug_handle_ids:
|
|
self.assertEqual(res_counter[dh_id], 2)
|
|
|
|
def test_copy_preserve_handle(self):
|
|
m = TestHelperModules.Conv2dThenConv1d()
|
|
example_inputs = m.example_inputs()
|
|
ep = torch.export.export(m, example_inputs, strict=True)
|
|
generate_numeric_debug_handle(ep)
|
|
|
|
self._assert_each_node_has_debug_handle(ep)
|
|
debug_handle_map_ref = self._extract_debug_handles(ep)
|
|
|
|
ep_copy = copy.copy(ep)
|
|
debug_handle_map = self._extract_debug_handles(ep_copy)
|
|
|
|
self._assert_each_node_has_debug_handle(ep)
|
|
self.assertEqual(debug_handle_map, debug_handle_map_ref)
|
|
|
|
def test_deepcopy_preserve_handle(self):
|
|
m = TestHelperModules.Conv2dThenConv1d()
|
|
example_inputs = m.example_inputs()
|
|
ep = torch.export.export(m, example_inputs, strict=True)
|
|
generate_numeric_debug_handle(ep)
|
|
|
|
debug_handle_map_ref = self._extract_debug_handles(ep)
|
|
ep_copy = copy.deepcopy(ep)
|
|
debug_handle_map = self._extract_debug_handles(ep_copy)
|
|
|
|
self._assert_each_node_has_debug_handle(ep)
|
|
self.assertEqual(debug_handle_map, debug_handle_map_ref)
|
|
|
|
@skipIfCrossRef # mlazos: retracing FX graph with torch function mode doesn't propagate metadata, because the stack
|
|
# trace of the mode torch function impl doesn't match the traced graph stored lineno.
|
|
def test_re_export_preserve_handle(self):
|
|
m = TestHelperModules.Conv2dThenConv1d()
|
|
example_inputs = m.example_inputs()
|
|
ep = export_for_training(m, example_inputs)
|
|
generate_numeric_debug_handle(ep)
|
|
m = ep.module()
|
|
|
|
self._assert_each_node_has_debug_handle(ep)
|
|
debug_handle_map_ref = self._extract_debug_handles(ep)
|
|
|
|
ep_reexport = export_for_training(m, example_inputs)
|
|
|
|
self._assert_each_node_has_debug_handle(ep_reexport)
|
|
debug_handle_map = self._extract_debug_handles(ep_reexport)
|
|
|
|
self.assertEqual(debug_handle_map, debug_handle_map_ref)
|
|
|
|
def test_run_decompositions_same_handle_id(self):
|
|
m = TestHelperModules.Conv2dThenConv1d()
|
|
example_inputs = m.example_inputs()
|
|
ep = export_for_training(m, example_inputs)
|
|
generate_numeric_debug_handle(ep)
|
|
|
|
self._assert_each_node_has_debug_handle(ep)
|
|
debug_handle_map_ref = self._extract_debug_handles(ep)
|
|
|
|
ep_copy = copy.copy(ep)
|
|
ep_copy = ep_copy.run_decompositions()
|
|
|
|
self._assert_each_node_has_debug_handle(ep_copy)
|
|
debug_handle_map = self._extract_debug_handles(ep_copy)
|
|
|
|
# checking the map still has the same ids, the node may change
|
|
self.assertEqual(
|
|
set(debug_handle_map.values()), set(debug_handle_map_ref.values())
|
|
)
|
|
|
|
def test_run_decompositions_map_handle_to_new_nodes(self):
|
|
test_models = [
|
|
TestHelperModules.TwoLinearModule(),
|
|
TestHelperModules.Conv2dThenConv1d(),
|
|
]
|
|
|
|
for m in test_models:
|
|
example_inputs = m.example_inputs()
|
|
ep = export_for_training(m, example_inputs)
|
|
generate_numeric_debug_handle(ep)
|
|
|
|
self._assert_each_node_has_debug_handle(ep)
|
|
pre_decomp_to_debug_handle_map_ref = (
|
|
self._extract_debug_handles_with_prev_decomp_op(ep)
|
|
)
|
|
|
|
ep_copy = copy.copy(ep)
|
|
ep_copy = ep_copy.run_decompositions()
|
|
self._assert_each_node_has_debug_handle(ep_copy)
|
|
pre_decomp_to_debug_handle_map = (
|
|
self._extract_debug_handles_with_prev_decomp_op(ep_copy)
|
|
)
|
|
|
|
# checking the map still has the same ids, the node may change
|
|
self.assertEqual(
|
|
pre_decomp_to_debug_handle_map, pre_decomp_to_debug_handle_map_ref
|
|
)
|
|
|
|
def test_prepare_for_propagation_comparison(self):
|
|
m = TestHelperModules.Conv2dThenConv1d()
|
|
example_inputs = m.example_inputs()
|
|
ep = export_for_training(m, example_inputs)
|
|
generate_numeric_debug_handle(ep)
|
|
m = ep.module()
|
|
m_logger = prepare_for_propagation_comparison(m)
|
|
ref = m(*example_inputs)
|
|
res = m_logger(*example_inputs)
|
|
|
|
from torch.ao.quantization.pt2e._numeric_debugger import OutputLogger
|
|
|
|
loggers = [m for m in m_logger.modules() if isinstance(m, OutputLogger)]
|
|
self.assertEqual(len(loggers), 3)
|
|
self.assertTrue("conv2d" in [logger.node_name for logger in loggers])
|
|
self.assertEqual(res, ref)
|
|
|
|
def test_extract_results_from_loggers(self):
|
|
m = TestHelperModules.Conv2dThenConv1d()
|
|
example_inputs = m.example_inputs()
|
|
ep = export_for_training(m, example_inputs)
|
|
generate_numeric_debug_handle(ep)
|
|
m = ep.module()
|
|
m_ref_logger = prepare_for_propagation_comparison(m)
|
|
|
|
quantizer = XNNPACKQuantizer().set_global(
|
|
get_symmetric_quantization_config(is_per_channel=False)
|
|
)
|
|
m = prepare_pt2e(m, quantizer)
|
|
m(*example_inputs)
|
|
m = convert_pt2e(m)
|
|
m_quant_logger = prepare_for_propagation_comparison(m)
|
|
|
|
m_ref_logger(*example_inputs)
|
|
m_quant_logger(*example_inputs)
|
|
ref_results = extract_results_from_loggers(m_ref_logger)
|
|
quant_results = extract_results_from_loggers(m_quant_logger)
|
|
comparison_results = compare_results(ref_results, quant_results)
|
|
for node_summary in comparison_results.values():
|
|
if len(node_summary.results) > 0:
|
|
self.assertGreaterEqual(node_summary.results[0].sqnr, 35)
|
|
|
|
def test_extract_results_from_loggers_list_output(self):
|
|
m = TestHelperModules.Conv2dWithSplit()
|
|
example_inputs = m.example_inputs()
|
|
ep = export_for_training(m, example_inputs)
|
|
generate_numeric_debug_handle(ep)
|
|
m = ep.module()
|
|
m_ref_logger = prepare_for_propagation_comparison(m)
|
|
|
|
quantizer = XNNPACKQuantizer().set_global(
|
|
get_symmetric_quantization_config(is_per_channel=False)
|
|
)
|
|
m = prepare_pt2e(m, quantizer)
|
|
m(*example_inputs)
|
|
m = convert_pt2e(m)
|
|
m_quant_logger = prepare_for_propagation_comparison(m)
|
|
|
|
m_ref_logger(*example_inputs)
|
|
m_quant_logger(*example_inputs)
|
|
ref_results = extract_results_from_loggers(m_ref_logger)
|
|
quant_results = extract_results_from_loggers(m_quant_logger)
|
|
comparison_results = compare_results(ref_results, quant_results)
|
|
for node_summary in comparison_results.values():
|
|
if len(node_summary.results) > 0:
|
|
sqnr = node_summary.results[0].sqnr
|
|
if isinstance(sqnr, list):
|
|
for sqnr_i in sqnr:
|
|
self.assertGreaterEqual(sqnr_i, 35)
|
|
else:
|
|
self.assertGreaterEqual(sqnr, 35)
|
|
|
|
def test_added_node_gets_unique_id(self) -> None:
|
|
m = TestHelperModules.Conv2dThenConv1d()
|
|
example_inputs = m.example_inputs()
|
|
ep = export_for_training(m, example_inputs)
|
|
generate_numeric_debug_handle(ep)
|
|
ref_handles = self._extract_debug_handles(ep)
|
|
ref_counter = Counter(ref_handles.values())
|
|
for k, v in ref_counter.items():
|
|
self.assertEqual(
|
|
v,
|
|
1,
|
|
msg=f"For handle {k}, there were {v} nodes with that handle, but expected only 1",
|
|
)
|
|
|
|
# Now that we have unique ids, add a new node into the graph and re-generate
|
|
# to make sure that the new node gets a unique id.
|
|
last_node = next(iter(reversed(ep.graph.nodes)))
|
|
with ep.graph.inserting_before(last_node):
|
|
arg = last_node.args[0]
|
|
self.assertIsInstance(arg, (list, tuple))
|
|
arg = arg[0]
|
|
# Add a function that only requires a single tensor input.
|
|
n = ep.graph.call_function(torch.ops.aten.relu.default, args=(arg,))
|
|
arg.replace_all_uses_with(n, lambda x: x != n)
|
|
ep.graph_module.recompile()
|
|
|
|
# Regenerate handles, make sure only the new relu node has a new id, and
|
|
# it doesn't clash with any of the existing ids.
|
|
generate_numeric_debug_handle(ep)
|
|
|
|
self._assert_each_node_has_debug_handle(ep)
|
|
handles_after_modification = self._extract_debug_handles(ep)
|
|
handles_counter = Counter(handles_after_modification.values())
|
|
for name, handle in ref_handles.items():
|
|
self.assertIn(name, handles_after_modification)
|
|
# Check that handle was unchanged.
|
|
self.assertEqual(handles_after_modification[name], handle)
|
|
# Check that total count was unchanged.
|
|
ref_count = ref_counter[handle]
|
|
after_count = handles_counter[handle]
|
|
self.assertEqual(
|
|
after_count,
|
|
ref_count,
|
|
msg=f"For handle {handle}, there were {after_count} nodes with that handle, but expected only {ref_count}",
|
|
)
|
|
|
|
# Check for relu specifically. Avoid hardcoding the handle id since it
|
|
# may change with future node ordering changes.
|
|
self.assertNotEqual(handles_after_modification["relu_default"], 0)
|
|
self.assertEqual(handles_counter[handles_after_modification["relu_default"]], 1)
|