support condition branch in ao debug handler (#141516)

This diff introduced the supportive of condition statement into ao debug handler generation.

Most of code borrowed from ExecuTorch to avoid circle dependency issue.

Differential Revision: [D66270691](https://our.internmc.facebook.com/intern/diff/D66270691/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141516
Approved by: https://github.com/jerryzh168
This commit is contained in:
gasoonjia
2024-12-09 20:44:43 -08:00
committed by PyTorch MergeBot
parent 75530885ba
commit ff059587c6
4 changed files with 139 additions and 26 deletions

View File

@ -14,6 +14,7 @@ from torch.ao.quantization import (
NUMERIC_DEBUG_HANDLE_KEY,
prepare_for_propagation_comparison,
)
from torch.ao.quantization.pt2e.graph_utils import get_control_flow_submodules
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
@ -27,14 +28,18 @@ from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef, Tes
def _extract_debug_handles(model) -> Dict[str, int]:
debug_handle_map: Dict[str, int] = {}
for node in model.graph.nodes:
if (
CUSTOM_KEY in node.meta
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY]
):
debug_handle_map[str(node)] = node.meta[CUSTOM_KEY][
NUMERIC_DEBUG_HANDLE_KEY
m_queue = [model]
while m_queue:
cur_m = m_queue.pop(0)
for n in cur_m.graph.nodes:
if CUSTOM_KEY in n.meta and NUMERIC_DEBUG_HANDLE_KEY in n.meta[CUSTOM_KEY]:
debug_handle_map[str(n)] = n.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY]
control_flow_submodules = [
submodule for _, submodule, _ in get_control_flow_submodules(cur_m)
]
m_queue.extend(control_flow_submodules)
return debug_handle_map
@ -44,15 +49,21 @@ class TestNumericDebugger(TestCase):
def test_simple(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = torch.export.export(m, example_inputs)
ep = export_for_training(m, example_inputs)
generate_numeric_debug_handle(ep)
unique_ids = set()
count = 0
for n in ep.graph.nodes:
if CUSTOM_KEY in n.meta and NUMERIC_DEBUG_HANDLE_KEY in n.meta[CUSTOM_KEY]:
unique_ids.add(n.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY])
count += 1
self.assertEqual(len(unique_ids), count)
debug_handle_map = _extract_debug_handles(ep.module())
self.assertEqual(len(set(debug_handle_map.values())), len(debug_handle_map))
def test_control_flow(self):
m = TestHelperModules.ControlFlow()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs)
generate_numeric_debug_handle(ep)
debug_handle_map = _extract_debug_handles(ep.module())
self.assertEqual(len(set(debug_handle_map.values())), len(debug_handle_map))
def test_quantize_pt2e_preserve_handle(self):
m = TestHelperModules.Conv2dThenConv1d()

View File

@ -5,6 +5,7 @@ from typing import Callable, Dict, List, Optional, Sequence, Tuple
import torch
from torch.ao.ns.fx.utils import compute_sqnr
from torch.ao.quantization.pt2e.graph_utils import get_control_flow_submodules
from torch.export import ExportedProgram
from torch.fx import GraphModule, Node
from torch.nn import functional as F
@ -42,19 +43,32 @@ def generate_numeric_debug_handle(ep: ExportedProgram) -> None:
)
unique_id = 0
# Find the max ID that exists in the graph first, in case part of the graph
# has already been annotated. This way we guarantee there are no duplicate
# handle IDs.
for node in ep.graph.nodes:
unique_id = max(
unique_id, node.meta.get(CUSTOM_KEY, {}).get(NUMERIC_DEBUG_HANDLE_KEY, 0)
)
unique_id += 1
for node in ep.graph.nodes:
def _bfs_trace_graph_with_node_process(node_op: Callable) -> None:
nonlocal ep
queue = [ep.graph_module]
while queue:
current_graph_module = queue.pop(0)
for node in current_graph_module.graph.nodes:
if node.op in ["output", "placeholder"]:
continue
node_op(node)
control_flow_submodules = [
submodule
for _, submodule, _ in get_control_flow_submodules(current_graph_module)
]
queue.extend(control_flow_submodules)
def _find_max_id(node: torch.fx.Node) -> None:
nonlocal unique_id
unique_id = max(
unique_id, node.meta.get(CUSTOM_KEY, {}).get(NUMERIC_DEBUG_HANDLE_KEY, 0)
)
def _assign_debug_handle(node: torch.fx.Node) -> None:
nonlocal unique_id
if CUSTOM_KEY not in node.meta:
node.meta[CUSTOM_KEY] = {}
@ -62,6 +76,17 @@ def generate_numeric_debug_handle(ep: ExportedProgram) -> None:
node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = unique_id
unique_id += 1
# Find the max ID that exists in the graph first, in case part of the graph
# has already been annotated. This way we guarantee there are no duplicate
# handle IDs.
_bfs_trace_graph_with_node_process(_find_max_id)
unique_id += 1
# Assign debug handles to all nodes in the graph that don't have one based on the
# max ID found in the previous step.
_bfs_trace_graph_with_node_process(_assign_debug_handle)
class OutputLogger(torch.nn.Module):
"""

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import itertools
import operator
from typing import Any, Callable, List, Optional, OrderedDict, Sequence, Set
from typing import Any, Callable, List, Optional, OrderedDict, Sequence, Set, Tuple
import torch
from torch.fx import Node
@ -14,6 +14,7 @@ from torch.fx.passes.utils.source_matcher_utils import (
__all__ = [
"find_sequential_partitions",
"get_control_flow_submodules",
"get_equivalent_types",
"update_equivalent_types_dict",
]
@ -114,3 +115,39 @@ def find_sequential_partitions(
if _partitions_sequential(candidate)
]
return fused_partitions
def _get_submodule(
graph_module: torch.fx.GraphModule, node: torch.fx.Node, arg_index: int
) -> Tuple[str, torch.nn.Module, torch.fx.Node]:
submod_node = node.args[arg_index]
assert isinstance(submod_node, torch.fx.Node)
assert submod_node.op == "get_attr"
assert isinstance(submod_node.target, str)
submodule = graph_module.get_submodule(submod_node.target)
# pyre-ignore
return submod_node.target, submodule, node
def get_control_flow_submodules(
graph_module: torch.fx.GraphModule,
) -> List[Tuple[str, torch.nn.Module, torch.fx.Node]]:
"""
Returns a list of submodules used for control flow operations
(torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look
into submodules). Specifically, the returned value is a list containing a
tuple of (name of the submodule that's stored in the graph module, the
submodule itself, and the fx node that uses this submodule).
"""
control_flow_submodules = []
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue
if node.target is torch.ops.higher_order.cond:
control_flow_submodules.append(_get_submodule(graph_module, node, 1))
control_flow_submodules.append(_get_submodule(graph_module, node, 2))
if node.target is torch.ops.higher_order.map_impl:
control_flow_submodules.append(_get_submodule(graph_module, node, 0))
return control_flow_submodules

View File

@ -4,6 +4,8 @@ r"""Importing this file includes common utility methods and base clases for
checking quantization api and properties of resulting modules.
"""
from functorch.experimental import control_flow
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -2677,6 +2679,44 @@ class SparseNNModel(nn.Module):
return out
class TestHelperModules:
class ControlFlow(torch.nn.Module):
def forward(
self,
xs: torch.Tensor,
pred1: torch.Tensor,
pred2: torch.Tensor,
y: torch.Tensor,
) -> torch.Tensor:
def true_nested(y: torch.Tensor) -> torch.Tensor:
y = y + y
y = torch.mm(y, y)
return y
def false_nested(y: torch.Tensor) -> torch.Tensor:
return torch.mm(y, y)
def true_fn(x: torch.Tensor, pred2: torch.Tensor) -> torch.Tensor:
z = control_flow.cond(pred2, true_nested, false_nested, [x])
return x + z
def false_fn(x: torch.Tensor, _) -> torch.Tensor:
return x.cos()
def map_fn(
x: torch.Tensor, pred1: torch.Tensor, pred2: torch.Tensor, y: torch.Tensor
) -> torch.Tensor:
x = x.cos()
y = control_flow.cond(pred1, true_fn, false_fn, [y, pred2])
x = x + y
return x.sin()
y = torch.mm(y, y)
return control_flow.map(map_fn, xs, pred1, pred2, y)
def example_inputs(self):
return (torch.ones(2, 2), torch.tensor([False]), torch.tensor([False]), torch.ones(2, 2),)
class Conv2dPropAnnotaton(torch.nn.Module):
def __init__(self) -> None:
super().__init__()