support condition branch in ao debug handler (#141516)

This diff introduced the supportive of condition statement into ao debug handler generation.

Most of code borrowed from ExecuTorch to avoid circle dependency issue.

Differential Revision: [D66270691](https://our.internmc.facebook.com/intern/diff/D66270691/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141516
Approved by: https://github.com/jerryzh168
This commit is contained in:
gasoonjia
2024-12-09 20:44:43 -08:00
committed by PyTorch MergeBot
parent 75530885ba
commit ff059587c6
4 changed files with 139 additions and 26 deletions

View File

@ -14,6 +14,7 @@ from torch.ao.quantization import (
NUMERIC_DEBUG_HANDLE_KEY, NUMERIC_DEBUG_HANDLE_KEY,
prepare_for_propagation_comparison, prepare_for_propagation_comparison,
) )
from torch.ao.quantization.pt2e.graph_utils import get_control_flow_submodules
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import ( from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config, get_symmetric_quantization_config,
@ -27,14 +28,18 @@ from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef, Tes
def _extract_debug_handles(model) -> Dict[str, int]: def _extract_debug_handles(model) -> Dict[str, int]:
debug_handle_map: Dict[str, int] = {} debug_handle_map: Dict[str, int] = {}
for node in model.graph.nodes: m_queue = [model]
if (
CUSTOM_KEY in node.meta while m_queue:
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY] cur_m = m_queue.pop(0)
): for n in cur_m.graph.nodes:
debug_handle_map[str(node)] = node.meta[CUSTOM_KEY][ if CUSTOM_KEY in n.meta and NUMERIC_DEBUG_HANDLE_KEY in n.meta[CUSTOM_KEY]:
NUMERIC_DEBUG_HANDLE_KEY debug_handle_map[str(n)] = n.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY]
]
control_flow_submodules = [
submodule for _, submodule, _ in get_control_flow_submodules(cur_m)
]
m_queue.extend(control_flow_submodules)
return debug_handle_map return debug_handle_map
@ -44,15 +49,21 @@ class TestNumericDebugger(TestCase):
def test_simple(self): def test_simple(self):
m = TestHelperModules.Conv2dThenConv1d() m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs() example_inputs = m.example_inputs()
ep = torch.export.export(m, example_inputs) ep = export_for_training(m, example_inputs)
generate_numeric_debug_handle(ep) generate_numeric_debug_handle(ep)
unique_ids = set() debug_handle_map = _extract_debug_handles(ep.module())
count = 0
for n in ep.graph.nodes: self.assertEqual(len(set(debug_handle_map.values())), len(debug_handle_map))
if CUSTOM_KEY in n.meta and NUMERIC_DEBUG_HANDLE_KEY in n.meta[CUSTOM_KEY]:
unique_ids.add(n.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY]) def test_control_flow(self):
count += 1 m = TestHelperModules.ControlFlow()
self.assertEqual(len(unique_ids), count) example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs)
generate_numeric_debug_handle(ep)
debug_handle_map = _extract_debug_handles(ep.module())
self.assertEqual(len(set(debug_handle_map.values())), len(debug_handle_map))
def test_quantize_pt2e_preserve_handle(self): def test_quantize_pt2e_preserve_handle(self):
m = TestHelperModules.Conv2dThenConv1d() m = TestHelperModules.Conv2dThenConv1d()

View File

@ -5,6 +5,7 @@ from typing import Callable, Dict, List, Optional, Sequence, Tuple
import torch import torch
from torch.ao.ns.fx.utils import compute_sqnr from torch.ao.ns.fx.utils import compute_sqnr
from torch.ao.quantization.pt2e.graph_utils import get_control_flow_submodules
from torch.export import ExportedProgram from torch.export import ExportedProgram
from torch.fx import GraphModule, Node from torch.fx import GraphModule, Node
from torch.nn import functional as F from torch.nn import functional as F
@ -42,19 +43,32 @@ def generate_numeric_debug_handle(ep: ExportedProgram) -> None:
) )
unique_id = 0 unique_id = 0
# Find the max ID that exists in the graph first, in case part of the graph
# has already been annotated. This way we guarantee there are no duplicate def _bfs_trace_graph_with_node_process(node_op: Callable) -> None:
# handle IDs. nonlocal ep
for node in ep.graph.nodes: queue = [ep.graph_module]
while queue:
current_graph_module = queue.pop(0)
for node in current_graph_module.graph.nodes:
if node.op in ["output", "placeholder"]:
continue
node_op(node)
control_flow_submodules = [
submodule
for _, submodule, _ in get_control_flow_submodules(current_graph_module)
]
queue.extend(control_flow_submodules)
def _find_max_id(node: torch.fx.Node) -> None:
nonlocal unique_id
unique_id = max( unique_id = max(
unique_id, node.meta.get(CUSTOM_KEY, {}).get(NUMERIC_DEBUG_HANDLE_KEY, 0) unique_id, node.meta.get(CUSTOM_KEY, {}).get(NUMERIC_DEBUG_HANDLE_KEY, 0)
) )
unique_id += 1
for node in ep.graph.nodes:
if node.op in ["output", "placeholder"]:
continue
def _assign_debug_handle(node: torch.fx.Node) -> None:
nonlocal unique_id
if CUSTOM_KEY not in node.meta: if CUSTOM_KEY not in node.meta:
node.meta[CUSTOM_KEY] = {} node.meta[CUSTOM_KEY] = {}
@ -62,6 +76,17 @@ def generate_numeric_debug_handle(ep: ExportedProgram) -> None:
node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = unique_id node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = unique_id
unique_id += 1 unique_id += 1
# Find the max ID that exists in the graph first, in case part of the graph
# has already been annotated. This way we guarantee there are no duplicate
# handle IDs.
_bfs_trace_graph_with_node_process(_find_max_id)
unique_id += 1
# Assign debug handles to all nodes in the graph that don't have one based on the
# max ID found in the previous step.
_bfs_trace_graph_with_node_process(_assign_debug_handle)
class OutputLogger(torch.nn.Module): class OutputLogger(torch.nn.Module):
""" """

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import itertools import itertools
import operator import operator
from typing import Any, Callable, List, Optional, OrderedDict, Sequence, Set from typing import Any, Callable, List, Optional, OrderedDict, Sequence, Set, Tuple
import torch import torch
from torch.fx import Node from torch.fx import Node
@ -14,6 +14,7 @@ from torch.fx.passes.utils.source_matcher_utils import (
__all__ = [ __all__ = [
"find_sequential_partitions", "find_sequential_partitions",
"get_control_flow_submodules",
"get_equivalent_types", "get_equivalent_types",
"update_equivalent_types_dict", "update_equivalent_types_dict",
] ]
@ -114,3 +115,39 @@ def find_sequential_partitions(
if _partitions_sequential(candidate) if _partitions_sequential(candidate)
] ]
return fused_partitions return fused_partitions
def _get_submodule(
graph_module: torch.fx.GraphModule, node: torch.fx.Node, arg_index: int
) -> Tuple[str, torch.nn.Module, torch.fx.Node]:
submod_node = node.args[arg_index]
assert isinstance(submod_node, torch.fx.Node)
assert submod_node.op == "get_attr"
assert isinstance(submod_node.target, str)
submodule = graph_module.get_submodule(submod_node.target)
# pyre-ignore
return submod_node.target, submodule, node
def get_control_flow_submodules(
graph_module: torch.fx.GraphModule,
) -> List[Tuple[str, torch.nn.Module, torch.fx.Node]]:
"""
Returns a list of submodules used for control flow operations
(torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look
into submodules). Specifically, the returned value is a list containing a
tuple of (name of the submodule that's stored in the graph module, the
submodule itself, and the fx node that uses this submodule).
"""
control_flow_submodules = []
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue
if node.target is torch.ops.higher_order.cond:
control_flow_submodules.append(_get_submodule(graph_module, node, 1))
control_flow_submodules.append(_get_submodule(graph_module, node, 2))
if node.target is torch.ops.higher_order.map_impl:
control_flow_submodules.append(_get_submodule(graph_module, node, 0))
return control_flow_submodules

View File

@ -4,6 +4,8 @@ r"""Importing this file includes common utility methods and base clases for
checking quantization api and properties of resulting modules. checking quantization api and properties of resulting modules.
""" """
from functorch.experimental import control_flow
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -2677,6 +2679,44 @@ class SparseNNModel(nn.Module):
return out return out
class TestHelperModules: class TestHelperModules:
class ControlFlow(torch.nn.Module):
def forward(
self,
xs: torch.Tensor,
pred1: torch.Tensor,
pred2: torch.Tensor,
y: torch.Tensor,
) -> torch.Tensor:
def true_nested(y: torch.Tensor) -> torch.Tensor:
y = y + y
y = torch.mm(y, y)
return y
def false_nested(y: torch.Tensor) -> torch.Tensor:
return torch.mm(y, y)
def true_fn(x: torch.Tensor, pred2: torch.Tensor) -> torch.Tensor:
z = control_flow.cond(pred2, true_nested, false_nested, [x])
return x + z
def false_fn(x: torch.Tensor, _) -> torch.Tensor:
return x.cos()
def map_fn(
x: torch.Tensor, pred1: torch.Tensor, pred2: torch.Tensor, y: torch.Tensor
) -> torch.Tensor:
x = x.cos()
y = control_flow.cond(pred1, true_fn, false_fn, [y, pred2])
x = x + y
return x.sin()
y = torch.mm(y, y)
return control_flow.map(map_fn, xs, pred1, pred2, y)
def example_inputs(self):
return (torch.ones(2, 2), torch.tensor([False]), torch.tensor([False]), torch.ones(2, 2),)
class Conv2dPropAnnotaton(torch.nn.Module): class Conv2dPropAnnotaton(torch.nn.Module):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()