diff --git a/docs/source/quantization-support.rst b/docs/source/quantization-support.rst index 2f4827f32793..44597d867b49 100644 --- a/docs/source/quantization-support.rst +++ b/docs/source/quantization-support.rst @@ -49,6 +49,7 @@ Utility functions propagate_qconfig_ default_eval_fn + torch.ao.quantization.quantize_fx ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -143,6 +144,21 @@ torch.ao.quantization.pt2e.export_utils model_is_exported +.. currentmodule:: torch.ao.quantization + +PT2 Export (pt2e) Numeric Debugger +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. autosummary:: + :toctree: generated + :nosignatures: + :template: classtemplate.rst + + generate_numeric_debug_handle + NUMERIC_DEBUG_HANDLE_KEY + prepare_for_propagation_comparison + extract_results_from_loggers + compare_results + torch (quantization related functions) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -635,4 +651,3 @@ the `custom operator mechanism 0: + self.assertGreaterEqual(node_summary.results[0].sqnr, 35) diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py index c9e3f6afeaa1..629159b5b2f3 100644 --- a/torch/ao/quantization/__init__.py +++ b/torch/ao/quantization/__init__.py @@ -16,8 +16,11 @@ from .stubs import * # noqa: F403 from .pt2e.export_utils import _move_exported_model_to_eval as move_exported_model_to_eval from .pt2e.export_utils import _move_exported_model_to_train as move_exported_model_to_train from .pt2e.export_utils import _allow_exported_model_train_eval as allow_exported_model_train_eval -from .pt2e.numeric_debugger import generate_numeric_debug_handle # noqa: F401 -from .pt2e.numeric_debugger import NUMERIC_DEBUG_HANDLE_KEY # noqa: F401 +from .pt2e._numeric_debugger import generate_numeric_debug_handle # noqa: F401 +from .pt2e._numeric_debugger import NUMERIC_DEBUG_HANDLE_KEY # noqa: F401 +from .pt2e._numeric_debugger import prepare_for_propagation_comparison # noqa: F401 +from .pt2e._numeric_debugger import extract_results_from_loggers # noqa: F401 +from .pt2e._numeric_debugger import compare_results # noqa: F401 from typing import Union, List, Callable, Tuple, Optional from torch import Tensor import torch @@ -147,6 +150,9 @@ __all__ = [ "weight_observer_range_neg_127_to_127", "generate_numeric_debug_handle", "NUMERIC_DEBUG_HANDLE_KEY", + "prepare_for_propagation_comparison", + "extract_results_from_loggers", + "compare_results", ] def default_eval_fn(model, calib_data): diff --git a/torch/ao/quantization/fx/convert.py b/torch/ao/quantization/fx/convert.py index 68a64c18de03..3f65b6a652e6 100644 --- a/torch/ao/quantization/fx/convert.py +++ b/torch/ao/quantization/fx/convert.py @@ -75,7 +75,7 @@ from .custom_config import ( from .lower_to_fbgemm import lower_to_fbgemm # importing the lib so that the quantized_decomposed ops are registered from ._decomposed import quantized_decomposed_lib # noqa: F401 -from torch.ao.quantization.pt2e.numeric_debugger import NUMERIC_DEBUG_HANDLE_KEY +from torch.ao.quantization import NUMERIC_DEBUG_HANDLE_KEY import operator __all__ = [ diff --git a/torch/ao/quantization/pt2e/_numeric_debugger.py b/torch/ao/quantization/pt2e/_numeric_debugger.py new file mode 100644 index 000000000000..1bcc442ca441 --- /dev/null +++ b/torch/ao/quantization/pt2e/_numeric_debugger.py @@ -0,0 +1,225 @@ +import copy +import logging +from dataclasses import dataclass +from typing import Dict, List, Optional, Sequence, Tuple + +import torch +from torch.ao.ns.fx.utils import compute_sqnr +from torch.fx import GraphModule, Node +from torch.nn import functional as F + + +NUMERIC_DEBUG_HANDLE_KEY = "_numeric_debug_handle" + +log = logging.getLogger(__name__) + + +def generate_numeric_debug_handle(graph_module: GraphModule) -> None: + """Attach numeric_debug_handle_id for all nodes in the model except for placeholder node + The graph nodes of input model is modified inplace. + """ + unique_id = 0 + for node in graph_module.graph.nodes: + if node.op != "placeholder" and NUMERIC_DEBUG_HANDLE_KEY not in node.meta: + node.meta[NUMERIC_DEBUG_HANDLE_KEY] = unique_id + unique_id += 1 + + +class OutputLogger(torch.nn.Module): + """ + Base class for capturing output values for nodes in a GraphModule, it only captures + Tensor output currently, but we can extend it to work for other types of inputs later if needed + """ + + # Mark as impure so that calls to it will not be removed during DCE. + _is_impure = True + + def __init__( + self, + debug_handle: int, + node_name: Optional[str] = None, + nn_module_stack: Optional[object] = None, + ) -> None: + super().__init__() + self.node_name = node_name + self.nn_module_stack = nn_module_stack + self.debug_handle = debug_handle + self.stats: List[torch.Tensor] = [] + + def forward(self, x: object) -> object: + if isinstance(x, torch.Tensor): + self.stats.append(x.detach()) + return x + + def __extra_repr__(self) -> str: + return ( + f"debug_handle={self.debug_handle}, node_name={self.node_name}, " + "nn_module_stack={self.nn_module_stack}, num_stats={len(self.stats)})" + ) + + +def _insert_logger(model: GraphModule, node: Node, debug_handle: int) -> Node: + """For a given node, adds an OutputLogger that observes the output of that node, + and all its users use the OutputLogger output instead. + The OutputLogger will contain the debug_handle which can be used to compare + graphs after transforms""" + + # to avoid circular dep + from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix + + # add a logger after the node + with model.graph.inserting_after(node): + get_new_attr_name = get_new_attr_name_with_prefix(f"{node.name}_logger") + logger_name = get_new_attr_name(model) + setattr( + model, + logger_name, + OutputLogger(debug_handle, node.name, node.meta.get("nn_module_stack")), + ) + logger_node = model.graph.call_module(logger_name, (node,), {}) + + orig_users = list(node.users.keys()) + for user_node in orig_users: + if user_node is logger_node: + continue + user_node.replace_input_with(node, logger_node) + + return logger_node + + +def prepare_for_propagation_comparison(model: GraphModule) -> GraphModule: + """Add output loggers to node that has numeric_debug_handle + + Args: + model (GraphModule): original model + Returns: + a model with output loggers for all nodes that has numeric_debug_handle_id + """ + # don't change the original model + model = copy.deepcopy(model) + for n in model.graph.nodes: + if NUMERIC_DEBUG_HANDLE_KEY not in n.meta: + continue + numeric_debug_handle = n.meta[NUMERIC_DEBUG_HANDLE_KEY] + _insert_logger(model, n, numeric_debug_handle) + + model.recompile() + return model + + +@dataclass(frozen=True) +class QuantizationComparisonResult: + actual: torch.Tensor + ref: torch.Tensor + + @property + def mse_loss(self) -> torch.Tensor: + return F.mse_loss( + self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32) + ) + + @property + def sqnr(self) -> torch.Tensor: + return compute_sqnr( + self.actual.to(dtype=torch.float32), self.ref.to(dtype=torch.float32) + ) + + def __repr__(self) -> str: + # Don't include the tensors themselves as they are quite large to print + # out. + return ( + f"QuantizationComparisonResult(mse_loss={self.mse_loss}, sqnr={self.sqnr})" + ) + + def __post_init__(self) -> None: + if not isinstance(self.actual, torch.Tensor): + raise ValueError( + f"`self.actual` value must be a Tensor, got: {self.actual}" + ) + + if not isinstance(self.ref, torch.Tensor): + raise ValueError(f"`self.ref` value must be a Tensor, got: {self.ref}") + + +@dataclass(frozen=True) +class NodeAccuracySummary: + handle: int + actual_node_name: str + actual_module_stack: str + ref_node_name: str + ref_module_stack: str + results: Sequence[QuantizationComparisonResult] + + +def _module_stack_to_str(module_stack: object) -> str: + """Simplifies the stack from ("mod", "mod.foo", "mod.foo.0", "mod.foo.0.linear") + to "mod.foo.0.linear" + """ + if not isinstance(module_stack, dict): + return str(module_stack) + module_values_list = list(module_stack.values()) + if len(module_values_list) > 0: + owning_module = module_values_list[-1][0] + return str(owning_module) + else: + return str(module_stack) + + +def extract_results_from_loggers( + model: GraphModule, +) -> Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]]: + """For a given model, extract the tensors stats and related information for each debug handle. + + Returns: + A dict is keyed by the debug_handle id and the values are a list of Tensors recorded + in loggers""" + # Results maps debug handle to a tensor list for each model being compared. + handles: Dict[int, Tuple[Optional[str], object, List[torch.Tensor]]] = {} + for _name, module in model.named_children(): + if isinstance(module, OutputLogger) and len(module.stats) > 0: + handles[module.debug_handle] = ( + module.node_name, + module.nn_module_stack, + module.stats, + ) + + return handles + + +def compare_results( + ref_results: Dict[int, Tuple[str, object, List[torch.Tensor]]], + actual_results: Dict[int, Tuple[str, object, List[torch.Tensor]]], +) -> Dict[int, NodeAccuracySummary]: + """Given two dict mapping from `debug_handle_id` (int) to list of tensors + return a map from `debug_handle_id` to `NodeAccuracySummary` that contains + comparison information like SQNR, MSE etc. + + Args: + ref_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): reference results for each debug_handle_id + actual_results (Dict[int, Tuple[str, object, List[torch.Tensor]]]): actual results for each debug_handle_id + + Returns: + Dict[int, NodeAccuracySummary] + """ + comparisons = {} + for debug_handle, (ref_name, ref_stack, ref_stats) in ref_results.items(): + if debug_handle not in actual_results: + log.debug( + "Cannot compare for handle %s because it wasn't found in the transformed model", + debug_handle, + ) + continue + actual_name, actual_stack, actual_stats = actual_results[debug_handle] + comparisons[debug_handle] = NodeAccuracySummary( + handle=debug_handle, + actual_node_name=actual_name, + actual_module_stack=_module_stack_to_str(actual_stack), + ref_node_name=ref_name, + ref_module_stack=_module_stack_to_str(ref_stack), + results=[ + QuantizationComparisonResult(actual=a, ref=b) + for a, b in zip(actual_stats, ref_stats) + ], + ) + + return comparisons diff --git a/torch/ao/quantization/pt2e/numeric_debugger.py b/torch/ao/quantization/pt2e/numeric_debugger.py deleted file mode 100644 index 06a17673d2af..000000000000 --- a/torch/ao/quantization/pt2e/numeric_debugger.py +++ /dev/null @@ -1,13 +0,0 @@ -from torch.fx import GraphModule - -__all__ = ["generate_numeric_debug_handle", "NUMERIC_DEBUG_HANDLE_KEY"] - -NUMERIC_DEBUG_HANDLE_KEY = "_numeric_debug_handle" - - -def generate_numeric_debug_handle(graph_module: GraphModule) -> None: - unique_id = 0 - for node in graph_module.graph.nodes: - if node.op != "placeholder" and NUMERIC_DEBUG_HANDLE_KEY not in node.meta: - node.meta[NUMERIC_DEBUG_HANDLE_KEY] = unique_id - unique_id += 1 diff --git a/torch/ao/quantization/pt2e/prepare.py b/torch/ao/quantization/pt2e/prepare.py index 26acb24f798f..34c6757589f7 100644 --- a/torch/ao/quantization/pt2e/prepare.py +++ b/torch/ao/quantization/pt2e/prepare.py @@ -24,7 +24,7 @@ from torch.ao.quantization.quantizer import ( QuantizationSpecBase, ) from torch.ao.quantization import ObserverOrFakeQuantize -from torch.ao.quantization.pt2e.numeric_debugger import NUMERIC_DEBUG_HANDLE_KEY +from torch.ao.quantization import NUMERIC_DEBUG_HANDLE_KEY # TODO: make pt2e folder private? __all__ = [