mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[quant][pt2e] fix placeholder typo and related quantization tests (#135379)
A previous typo on "placeholder" and related tests in quantization are fixed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135379 Approved by: https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
e6a0221fc6
commit
c92227c41a
@ -75,7 +75,7 @@ class TestNumericDebugger(TestCase):
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
debug_handle_map = _extract_debug_handles(m)
|
||||
res_counter = Counter(debug_handle_map.values())
|
||||
repeated_debug_handle_ids = [3, 4, 7]
|
||||
repeated_debug_handle_ids = [2, 3, 6]
|
||||
# 3 ids were repeated because we copy over the id from node to its output observer
|
||||
# torch.ops.aten.conv2d.default, torch.ops.aten.squeeze.dim and torch.ops.aten.conv1d.default
|
||||
for dh_id in repeated_debug_handle_ids:
|
||||
@ -87,7 +87,7 @@ class TestNumericDebugger(TestCase):
|
||||
res_counter = Counter(debug_handle_map.values())
|
||||
# same set of ids where repeated, because we copy over the id from observer/fake_quant to
|
||||
# dequantize node
|
||||
repeated_debug_handle_ids = [3, 4, 7]
|
||||
repeated_debug_handle_ids = [2, 3, 6]
|
||||
for dh_id in repeated_debug_handle_ids:
|
||||
self.assertEqual(res_counter[dh_id], 2)
|
||||
|
||||
@ -161,7 +161,7 @@ class TestNumericDebugger(TestCase):
|
||||
from torch.ao.quantization.pt2e._numeric_debugger import OutputLogger
|
||||
|
||||
loggers = [m for m in m_logger.modules() if isinstance(m, OutputLogger)]
|
||||
self.assertEqual(len(loggers), 8)
|
||||
self.assertEqual(len(loggers), 7)
|
||||
self.assertTrue("conv2d" in [logger.node_name for logger in loggers])
|
||||
self.assertEqual(res, ref)
|
||||
|
||||
|
Reference in New Issue
Block a user