Files
pytorch/torch/ao/quantization/pt2e/_numeric_debugger.py
Jerry Zhang 793b17ebcb Add numeric_debugger top level APIs (#130643)
Summary:
Add three top level APIs for numeric debugger in pt2e flow that can log intermediate output in the model
and calculate summary for metric comparisons between nodes in two graphs

* `prepare_for_propagation_comparison`
* `extract_results_from_loggers`
* `compare_results`

Test Plan:
python test/test_quantization.py -k test_prepare_for_propagation_comparison
python test/test_quantization.py -k test_extract_results_from_loggers

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130643
Approved by: https://github.com/dulinriley, https://github.com/tarun292
2024-07-18 20:54:18 +00:00

226 lines
7.6 KiB
Python

import copy
import logging
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple
import torch
from torch.ao.ns.fx.utils import compute_sqnr
from torch.fx import GraphModule, Node
from torch.nn import functional as F
NUMERIC_DEBUG_HANDLE_KEY = "_numeric_debug_handle"
log = logging.getLogger(__name__)
def generate_numeric_debug_handle(graph_module: GraphModule) -> None:
"""Attach numeric_debug_handle_id for all nodes in the model except for placeholder node
The graph nodes of input model is modified inplace.
"""
unique_id = 0
for node in graph_module.graph.nodes:
if node.op != "placeholder" and NUMERIC_DEBUG_HANDLE_KEY not in node.meta:
node.meta[NUMERIC_DEBUG_HANDLE_KEY] = unique_id
unique_id += 1
class OutputLogger(torch.nn.Module):
"""
Base class for capturing output values for nodes in a GraphModule, it only captures
Tensor output currently, but we can extend it to work for other types of inputs later if needed
"""
# Mark as impure so that calls to it will not be removed during DCE.
_is_impure = True
def __init__(
self,
debug_handle: int,
node_name: Optional[str] = None,
nn_module_stack: Optional[object] = None,
) -> None:
super().__init__()
self.node_name = node_name
self.nn_module_stack = nn_module_stack
self.debug_handle = debug_handle
self.stats: List[torch.Tensor] = []
def forward(self, x: object) -> object:
if isinstance(x, torch.Tensor):
self.stats.append(x.detach())
return x
def __extra_repr__(self) -> str:
return (
f"debug_handle={self.debug_handle}, node_name={self.node_name}, "
"nn_module_stack={self.nn_module_stack}, num_stats={len(self.stats)})"
)
def _insert_logger(model: GraphModule, node: Node, debug_handle: int) -> Node:
"""For a given node, adds an OutputLogger that observes the output of that node,
and all its users use the OutputLogger output instead.
The OutputLogger will contain the debug_handle which can be used to compare
graphs after transforms"""
# to avoid circular dep
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
# add a logger after the node
with model.graph.inserting_after(node):
get_new_attr_name = get_new_attr_name_with_prefix(f"{node.name}_logger")
logger_name = get_new_attr_name(model)
setattr(
model,
logger_name,
OutputLogger(debug_handle, node.name, node.meta.get("nn_module_stack")),
)
logger_node = model.graph.call_module(logger_name, (node,), {})
orig_users = list(node.users.keys())
for user_node in orig_users:
if user_node is logger_node:
continue
user_node.replace_input_with(node, logger_node)
return logger_node
def prepare_for_propagation_comparison(model: GraphModule) -> GraphModule:
"""Add output loggers to node that has numeric_debug_handle
Args:
model (GraphModule): original model
Returns:
a model with output loggers for all nodes that has numeric_debug_handle_id
"""
# don't change the original model
model = copy.deepcopy(model)
for n in model.graph.nodes:
if NUMERIC_DEBUG_HANDLE_KEY not in n.meta:
continue
numeric_debug_handle = n.meta[NUMERIC_DEBUG_HANDLE_KEY]
_insert_logger(model, n, numeric_debug_handle)
model.recompile()
return model
@dataclass(frozen=True)
class QuantizationComparisonResult:
actual: torch.Tensor
ref: torch.Tensor
@property
def mse_loss(self) -> torch.Tensor:
return F.mse_loss(
self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32)
)
@property
def sqnr(self) -> torch.Tensor:
return compute_sqnr(
self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32)
)
def __repr__(self) -> str:
# Don't include the tensors themselves as they are quite large to print
# out.
return (
f"QuantizationComparisonResult(mse_loss={self.mse_loss}, sqnr={self.sqnr})"
)
def __post_init__(self) -> None:
if not isinstance(self.actual, torch.Tensor):
raise ValueError(
f"`self.actual` value must be a Tensor, got: {self.actual}"
)
if not isinstance(self.ref, torch.Tensor):
raise ValueError(f"`self.ref` value must be a Tensor, got: {self.ref}")
@dataclass(frozen=True)
class NodeAccuracySummary:
handle: int
actual_node_name: str
actual_module_stack: str
ref_node_name: str
ref_module_stack: str
results: Sequence[QuantizationComparisonResult]
def _module_stack_to_str(module_stack: object) -> str:
"""Simplifies the stack from ("mod", "mod.foo", "mod.foo.0", "mod.foo.0.linear")
to "mod.foo.0.linear"
"""
if not isinstance(module_stack, dict):
return str(module_stack)
module_values_list = list(module_stack.values())
if len(module_values_list) > 0:
owning_module = module_values_list[-1][0]
return str(owning_module)
else:
return str(module_stack)
def extract_results_from_loggers(
model: GraphModule,
) -> Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]]:
"""For a given model, extract the tensors stats and related information for each debug handle.
Returns:
A dict is keyed by the debug_handle id and the values are a list of Tensors recorded
in loggers"""
# Results maps debug handle to a tensor list for each model being compared.
handles: Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]] = {}
for _name, module in model.named_children():
if isinstance(module, OutputLogger) and len(module.stats) > 0:
handles[module.debug_handle] = (
module.node_name,
module.nn_module_stack,
module.stats,
)
return handles
def compare_results(
ref_results: Dict[int, Tuple[str, object, List[torch.Tensor]]],
actual_results: Dict[int, Tuple[str, object, List[torch.Tensor]]],
) -> Dict[int, NodeAccuracySummary]:
"""Given two dict mapping from `debug_handle_id` (int) to list of tensors
return a map from `debug_handle_id` to `NodeAccuracySummary` that contains
comparison information like SQNR, MSE etc.
Args:
ref_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): reference results for each debug_handle_id
actual_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): actual results for each debug_handle_id
Returns:
Dict[int, NodeAccuracySummary]
"""
comparisons = {}
for debug_handle, (ref_name, ref_stack, ref_stats) in ref_results.items():
if debug_handle not in actual_results:
log.debug(
"Cannot compare for handle %s because it wasn't found in the transformed model",
debug_handle,
)
continue
actual_name, actual_stack, actual_stats = actual_results[debug_handle]
comparisons[debug_handle] = NodeAccuracySummary(
handle=debug_handle,
actual_node_name=actual_name,
actual_module_stack=_module_stack_to_str(actual_stack),
ref_node_name=ref_name,
ref_module_stack=_module_stack_to_str(ref_stack),
results=[
QuantizationComparisonResult(actual=a, ref=b)
for a, b in zip(actual_stats, ref_stats)
],
)
return comparisons