mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add numeric_debugger top level APIs (#130643)
Summary: Add three top level APIs for numeric debugger in pt2e flow that can log intermediate output in the model and calculate summary for metric comparisons between nodes in two graphs * `prepare_for_propagation_comparison` * `extract_results_from_loggers` * `compare_results` Test Plan: python test/test_quantization.py -k test_prepare_for_propagation_comparison python test/test_quantization.py -k test_extract_results_from_loggers Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/130643 Approved by: https://github.com/dulinriley, https://github.com/tarun292
This commit is contained in:
committed by
PyTorch MergeBot
parent
726b9268d2
commit
793b17ebcb
@ -49,6 +49,7 @@ Utility functions
|
|||||||
propagate_qconfig_
|
propagate_qconfig_
|
||||||
default_eval_fn
|
default_eval_fn
|
||||||
|
|
||||||
|
|
||||||
torch.ao.quantization.quantize_fx
|
torch.ao.quantization.quantize_fx
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
@ -143,6 +144,21 @@ torch.ao.quantization.pt2e.export_utils
|
|||||||
|
|
||||||
model_is_exported
|
model_is_exported
|
||||||
|
|
||||||
|
.. currentmodule:: torch.ao.quantization
|
||||||
|
|
||||||
|
PT2 Export (pt2e) Numeric Debugger
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
:template: classtemplate.rst
|
||||||
|
|
||||||
|
generate_numeric_debug_handle
|
||||||
|
NUMERIC_DEBUG_HANDLE_KEY
|
||||||
|
prepare_for_propagation_comparison
|
||||||
|
extract_results_from_loggers
|
||||||
|
compare_results
|
||||||
|
|
||||||
torch (quantization related functions)
|
torch (quantization related functions)
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
@ -635,4 +651,3 @@ the `custom operator mechanism <https://pytorch.org/tutorials/advanced/torch_scr
|
|||||||
.. automodule:: torch.nn.quantized.dynamic.modules
|
.. automodule:: torch.nn.quantized.dynamic.modules
|
||||||
.. automodule:: torch.quantization
|
.. automodule:: torch.quantization
|
||||||
.. automodule:: torch.nn.intrinsic.modules
|
.. automodule:: torch.nn.intrinsic.modules
|
||||||
.. automodule:: torch.ao.quantization.pt2e.numeric_debugger
|
|
||||||
|
@ -8,8 +8,11 @@ from typing import Dict
|
|||||||
import torch
|
import torch
|
||||||
from torch._export import capture_pre_autograd_graph
|
from torch._export import capture_pre_autograd_graph
|
||||||
from torch.ao.quantization import (
|
from torch.ao.quantization import (
|
||||||
|
compare_results,
|
||||||
|
extract_results_from_loggers,
|
||||||
generate_numeric_debug_handle,
|
generate_numeric_debug_handle,
|
||||||
NUMERIC_DEBUG_HANDLE_KEY,
|
NUMERIC_DEBUG_HANDLE_KEY,
|
||||||
|
prepare_for_propagation_comparison,
|
||||||
)
|
)
|
||||||
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
||||||
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
||||||
@ -121,3 +124,43 @@ class TestNumericDebugger(TestCase):
|
|||||||
debug_handle_map = _extract_debug_handles(m_export)
|
debug_handle_map = _extract_debug_handles(m_export)
|
||||||
|
|
||||||
self.assertEqual(debug_handle_map, debug_handle_map_ref)
|
self.assertEqual(debug_handle_map, debug_handle_map_ref)
|
||||||
|
|
||||||
|
def test_prepare_for_propagation_comparison(self):
|
||||||
|
m = TestHelperModules.Conv2dThenConv1d()
|
||||||
|
example_inputs = m.example_inputs()
|
||||||
|
m = capture_pre_autograd_graph(m, example_inputs)
|
||||||
|
generate_numeric_debug_handle(m)
|
||||||
|
m_logger = prepare_for_propagation_comparison(m)
|
||||||
|
ref = m(*example_inputs)
|
||||||
|
res = m_logger(*example_inputs)
|
||||||
|
|
||||||
|
from torch.ao.quantization.pt2e._numeric_debugger import OutputLogger
|
||||||
|
|
||||||
|
loggers = [m for m in m_logger.modules() if isinstance(m, OutputLogger)]
|
||||||
|
self.assertEqual(len(loggers), 8)
|
||||||
|
self.assertTrue("conv2d" in [logger.node_name for logger in loggers])
|
||||||
|
self.assertEqual(res, ref)
|
||||||
|
|
||||||
|
def test_extract_results_from_loggers(self):
|
||||||
|
m = TestHelperModules.Conv2dThenConv1d()
|
||||||
|
example_inputs = m.example_inputs()
|
||||||
|
m = capture_pre_autograd_graph(m, example_inputs)
|
||||||
|
generate_numeric_debug_handle(m)
|
||||||
|
m_ref_logger = prepare_for_propagation_comparison(m)
|
||||||
|
|
||||||
|
quantizer = XNNPACKQuantizer().set_global(
|
||||||
|
get_symmetric_quantization_config(is_per_channel=False)
|
||||||
|
)
|
||||||
|
m = prepare_pt2e(m, quantizer)
|
||||||
|
m(*example_inputs)
|
||||||
|
m = convert_pt2e(m)
|
||||||
|
m_quant_logger = prepare_for_propagation_comparison(m)
|
||||||
|
|
||||||
|
m_ref_logger(*example_inputs)
|
||||||
|
m_quant_logger(*example_inputs)
|
||||||
|
ref_results = extract_results_from_loggers(m_ref_logger)
|
||||||
|
quant_results = extract_results_from_loggers(m_quant_logger)
|
||||||
|
comparison_results = compare_results(ref_results, quant_results)
|
||||||
|
for node_summary in comparison_results.values():
|
||||||
|
if len(node_summary.results) > 0:
|
||||||
|
self.assertGreaterEqual(node_summary.results[0].sqnr, 35)
|
||||||
|
@ -16,8 +16,11 @@ from .stubs import * # noqa: F403
|
|||||||
from .pt2e.export_utils import _move_exported_model_to_eval as move_exported_model_to_eval
|
from .pt2e.export_utils import _move_exported_model_to_eval as move_exported_model_to_eval
|
||||||
from .pt2e.export_utils import _move_exported_model_to_train as move_exported_model_to_train
|
from .pt2e.export_utils import _move_exported_model_to_train as move_exported_model_to_train
|
||||||
from .pt2e.export_utils import _allow_exported_model_train_eval as allow_exported_model_train_eval
|
from .pt2e.export_utils import _allow_exported_model_train_eval as allow_exported_model_train_eval
|
||||||
from .pt2e.numeric_debugger import generate_numeric_debug_handle # noqa: F401
|
from .pt2e._numeric_debugger import generate_numeric_debug_handle # noqa: F401
|
||||||
from .pt2e.numeric_debugger import NUMERIC_DEBUG_HANDLE_KEY # noqa: F401
|
from .pt2e._numeric_debugger import NUMERIC_DEBUG_HANDLE_KEY # noqa: F401
|
||||||
|
from .pt2e._numeric_debugger import prepare_for_propagation_comparison # noqa: F401
|
||||||
|
from .pt2e._numeric_debugger import extract_results_from_loggers # noqa: F401
|
||||||
|
from .pt2e._numeric_debugger import compare_results # noqa: F401
|
||||||
from typing import Union, List, Callable, Tuple, Optional
|
from typing import Union, List, Callable, Tuple, Optional
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
import torch
|
import torch
|
||||||
@ -147,6 +150,9 @@ __all__ = [
|
|||||||
"weight_observer_range_neg_127_to_127",
|
"weight_observer_range_neg_127_to_127",
|
||||||
"generate_numeric_debug_handle",
|
"generate_numeric_debug_handle",
|
||||||
"NUMERIC_DEBUG_HANDLE_KEY",
|
"NUMERIC_DEBUG_HANDLE_KEY",
|
||||||
|
"prepare_for_propagation_comparison",
|
||||||
|
"extract_results_from_loggers",
|
||||||
|
"compare_results",
|
||||||
]
|
]
|
||||||
|
|
||||||
def default_eval_fn(model, calib_data):
|
def default_eval_fn(model, calib_data):
|
||||||
|
@ -75,7 +75,7 @@ from .custom_config import (
|
|||||||
from .lower_to_fbgemm import lower_to_fbgemm
|
from .lower_to_fbgemm import lower_to_fbgemm
|
||||||
# importing the lib so that the quantized_decomposed ops are registered
|
# importing the lib so that the quantized_decomposed ops are registered
|
||||||
from ._decomposed import quantized_decomposed_lib # noqa: F401
|
from ._decomposed import quantized_decomposed_lib # noqa: F401
|
||||||
from torch.ao.quantization.pt2e.numeric_debugger import NUMERIC_DEBUG_HANDLE_KEY
|
from torch.ao.quantization import NUMERIC_DEBUG_HANDLE_KEY
|
||||||
import operator
|
import operator
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
225
torch/ao/quantization/pt2e/_numeric_debugger.py
Normal file
225
torch/ao/quantization/pt2e/_numeric_debugger.py
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
import copy
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.ao.ns.fx.utils import compute_sqnr
|
||||||
|
from torch.fx import GraphModule, Node
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
NUMERIC_DEBUG_HANDLE_KEY = "_numeric_debug_handle"
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_numeric_debug_handle(graph_module: GraphModule) -> None:
|
||||||
|
"""Attach numeric_debug_handle_id for all nodes in the model except for placeholder node
|
||||||
|
The graph nodes of input model is modified inplace.
|
||||||
|
"""
|
||||||
|
unique_id = 0
|
||||||
|
for node in graph_module.graph.nodes:
|
||||||
|
if node.op != "placeholder" and NUMERIC_DEBUG_HANDLE_KEY not in node.meta:
|
||||||
|
node.meta[NUMERIC_DEBUG_HANDLE_KEY] = unique_id
|
||||||
|
unique_id += 1
|
||||||
|
|
||||||
|
|
||||||
|
class OutputLogger(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Base class for capturing output values for nodes in a GraphModule, it only captures
|
||||||
|
Tensor output currently, but we can extend it to work for other types of inputs later if needed
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Mark as impure so that calls to it will not be removed during DCE.
|
||||||
|
_is_impure = True
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
debug_handle: int,
|
||||||
|
node_name: Optional[str] = None,
|
||||||
|
nn_module_stack: Optional[object] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.node_name = node_name
|
||||||
|
self.nn_module_stack = nn_module_stack
|
||||||
|
self.debug_handle = debug_handle
|
||||||
|
self.stats: List[torch.Tensor] = []
|
||||||
|
|
||||||
|
def forward(self, x: object) -> object:
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
self.stats.append(x.detach())
|
||||||
|
return x
|
||||||
|
|
||||||
|
def __extra_repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"debug_handle={self.debug_handle}, node_name={self.node_name}, "
|
||||||
|
"nn_module_stack={self.nn_module_stack}, num_stats={len(self.stats)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _insert_logger(model: GraphModule, node: Node, debug_handle: int) -> Node:
|
||||||
|
"""For a given node, adds an OutputLogger that observes the output of that node,
|
||||||
|
and all its users use the OutputLogger output instead.
|
||||||
|
The OutputLogger will contain the debug_handle which can be used to compare
|
||||||
|
graphs after transforms"""
|
||||||
|
|
||||||
|
# to avoid circular dep
|
||||||
|
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
|
||||||
|
|
||||||
|
# add a logger after the node
|
||||||
|
with model.graph.inserting_after(node):
|
||||||
|
get_new_attr_name = get_new_attr_name_with_prefix(f"{node.name}_logger")
|
||||||
|
logger_name = get_new_attr_name(model)
|
||||||
|
setattr(
|
||||||
|
model,
|
||||||
|
logger_name,
|
||||||
|
OutputLogger(debug_handle, node.name, node.meta.get("nn_module_stack")),
|
||||||
|
)
|
||||||
|
logger_node = model.graph.call_module(logger_name, (node,), {})
|
||||||
|
|
||||||
|
orig_users = list(node.users.keys())
|
||||||
|
for user_node in orig_users:
|
||||||
|
if user_node is logger_node:
|
||||||
|
continue
|
||||||
|
user_node.replace_input_with(node, logger_node)
|
||||||
|
|
||||||
|
return logger_node
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_for_propagation_comparison(model: GraphModule) -> GraphModule:
|
||||||
|
"""Add output loggers to node that has numeric_debug_handle
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (GraphModule): original model
|
||||||
|
Returns:
|
||||||
|
a model with output loggers for all nodes that has numeric_debug_handle_id
|
||||||
|
"""
|
||||||
|
# don't change the original model
|
||||||
|
model = copy.deepcopy(model)
|
||||||
|
for n in model.graph.nodes:
|
||||||
|
if NUMERIC_DEBUG_HANDLE_KEY not in n.meta:
|
||||||
|
continue
|
||||||
|
numeric_debug_handle = n.meta[NUMERIC_DEBUG_HANDLE_KEY]
|
||||||
|
_insert_logger(model, n, numeric_debug_handle)
|
||||||
|
|
||||||
|
model.recompile()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class QuantizationComparisonResult:
|
||||||
|
actual: torch.Tensor
|
||||||
|
ref: torch.Tensor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mse_loss(self) -> torch.Tensor:
|
||||||
|
return F.mse_loss(
|
||||||
|
self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32)
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sqnr(self) -> torch.Tensor:
|
||||||
|
return compute_sqnr(
|
||||||
|
self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
# Don't include the tensors themselves as they are quite large to print
|
||||||
|
# out.
|
||||||
|
return (
|
||||||
|
f"QuantizationComparisonResult(mse_loss={self.mse_loss}, sqnr={self.sqnr})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
if not isinstance(self.actual, torch.Tensor):
|
||||||
|
raise ValueError(
|
||||||
|
f"`self.actual` value must be a Tensor, got: {self.actual}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(self.ref, torch.Tensor):
|
||||||
|
raise ValueError(f"`self.ref` value must be a Tensor, got: {self.ref}")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class NodeAccuracySummary:
|
||||||
|
handle: int
|
||||||
|
actual_node_name: str
|
||||||
|
actual_module_stack: str
|
||||||
|
ref_node_name: str
|
||||||
|
ref_module_stack: str
|
||||||
|
results: Sequence[QuantizationComparisonResult]
|
||||||
|
|
||||||
|
|
||||||
|
def _module_stack_to_str(module_stack: object) -> str:
|
||||||
|
"""Simplifies the stack from ("mod", "mod.foo", "mod.foo.0", "mod.foo.0.linear")
|
||||||
|
to "mod.foo.0.linear"
|
||||||
|
"""
|
||||||
|
if not isinstance(module_stack, dict):
|
||||||
|
return str(module_stack)
|
||||||
|
module_values_list = list(module_stack.values())
|
||||||
|
if len(module_values_list) > 0:
|
||||||
|
owning_module = module_values_list[-1][0]
|
||||||
|
return str(owning_module)
|
||||||
|
else:
|
||||||
|
return str(module_stack)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_results_from_loggers(
|
||||||
|
model: GraphModule,
|
||||||
|
) -> Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]]:
|
||||||
|
"""For a given model, extract the tensors stats and related information for each debug handle.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict is keyed by the debug_handle id and the values are a list of Tensors recorded
|
||||||
|
in loggers"""
|
||||||
|
# Results maps debug handle to a tensor list for each model being compared.
|
||||||
|
handles: Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]] = {}
|
||||||
|
for _name, module in model.named_children():
|
||||||
|
if isinstance(module, OutputLogger) and len(module.stats) > 0:
|
||||||
|
handles[module.debug_handle] = (
|
||||||
|
module.node_name,
|
||||||
|
module.nn_module_stack,
|
||||||
|
module.stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
return handles
|
||||||
|
|
||||||
|
|
||||||
|
def compare_results(
|
||||||
|
ref_results: Dict[int, Tuple[str, object, List[torch.Tensor]]],
|
||||||
|
actual_results: Dict[int, Tuple[str, object, List[torch.Tensor]]],
|
||||||
|
) -> Dict[int, NodeAccuracySummary]:
|
||||||
|
"""Given two dict mapping from `debug_handle_id` (int) to list of tensors
|
||||||
|
return a map from `debug_handle_id` to `NodeAccuracySummary` that contains
|
||||||
|
comparison information like SQNR, MSE etc.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ref_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): reference results for each debug_handle_id
|
||||||
|
actual_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): actual results for each debug_handle_id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[int, NodeAccuracySummary]
|
||||||
|
"""
|
||||||
|
comparisons = {}
|
||||||
|
for debug_handle, (ref_name, ref_stack, ref_stats) in ref_results.items():
|
||||||
|
if debug_handle not in actual_results:
|
||||||
|
log.debug(
|
||||||
|
"Cannot compare for handle %s because it wasn't found in the transformed model",
|
||||||
|
debug_handle,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
actual_name, actual_stack, actual_stats = actual_results[debug_handle]
|
||||||
|
comparisons[debug_handle] = NodeAccuracySummary(
|
||||||
|
handle=debug_handle,
|
||||||
|
actual_node_name=actual_name,
|
||||||
|
actual_module_stack=_module_stack_to_str(actual_stack),
|
||||||
|
ref_node_name=ref_name,
|
||||||
|
ref_module_stack=_module_stack_to_str(ref_stack),
|
||||||
|
results=[
|
||||||
|
QuantizationComparisonResult(actual=a, ref=b)
|
||||||
|
for a, b in zip(actual_stats, ref_stats)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
return comparisons
|
@ -1,13 +0,0 @@
|
|||||||
from torch.fx import GraphModule
|
|
||||||
|
|
||||||
__all__ = ["generate_numeric_debug_handle", "NUMERIC_DEBUG_HANDLE_KEY"]
|
|
||||||
|
|
||||||
NUMERIC_DEBUG_HANDLE_KEY = "_numeric_debug_handle"
|
|
||||||
|
|
||||||
|
|
||||||
def generate_numeric_debug_handle(graph_module: GraphModule) -> None:
|
|
||||||
unique_id = 0
|
|
||||||
for node in graph_module.graph.nodes:
|
|
||||||
if node.op != "placeholder" and NUMERIC_DEBUG_HANDLE_KEY not in node.meta:
|
|
||||||
node.meta[NUMERIC_DEBUG_HANDLE_KEY] = unique_id
|
|
||||||
unique_id += 1
|
|
@ -24,7 +24,7 @@ from torch.ao.quantization.quantizer import (
|
|||||||
QuantizationSpecBase,
|
QuantizationSpecBase,
|
||||||
)
|
)
|
||||||
from torch.ao.quantization import ObserverOrFakeQuantize
|
from torch.ao.quantization import ObserverOrFakeQuantize
|
||||||
from torch.ao.quantization.pt2e.numeric_debugger import NUMERIC_DEBUG_HANDLE_KEY
|
from torch.ao.quantization import NUMERIC_DEBUG_HANDLE_KEY
|
||||||
|
|
||||||
# TODO: make pt2e folder private?
|
# TODO: make pt2e folder private?
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
Reference in New Issue
Block a user