mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
support condition branch in ao debug handler (#141516)
This diff introduced the supportive of condition statement into ao debug handler generation. Most of code borrowed from ExecuTorch to avoid circle dependency issue. Differential Revision: [D66270691](https://our.internmc.facebook.com/intern/diff/D66270691/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/141516 Approved by: https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
75530885ba
commit
ff059587c6
@ -14,6 +14,7 @@ from torch.ao.quantization import (
|
||||
NUMERIC_DEBUG_HANDLE_KEY,
|
||||
prepare_for_propagation_comparison,
|
||||
)
|
||||
from torch.ao.quantization.pt2e.graph_utils import get_control_flow_submodules
|
||||
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
||||
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
||||
get_symmetric_quantization_config,
|
||||
@ -27,14 +28,18 @@ from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef, Tes
|
||||
def _extract_debug_handles(model) -> Dict[str, int]:
|
||||
debug_handle_map: Dict[str, int] = {}
|
||||
|
||||
for node in model.graph.nodes:
|
||||
if (
|
||||
CUSTOM_KEY in node.meta
|
||||
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY]
|
||||
):
|
||||
debug_handle_map[str(node)] = node.meta[CUSTOM_KEY][
|
||||
NUMERIC_DEBUG_HANDLE_KEY
|
||||
m_queue = [model]
|
||||
|
||||
while m_queue:
|
||||
cur_m = m_queue.pop(0)
|
||||
for n in cur_m.graph.nodes:
|
||||
if CUSTOM_KEY in n.meta and NUMERIC_DEBUG_HANDLE_KEY in n.meta[CUSTOM_KEY]:
|
||||
debug_handle_map[str(n)] = n.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY]
|
||||
|
||||
control_flow_submodules = [
|
||||
submodule for _, submodule, _ in get_control_flow_submodules(cur_m)
|
||||
]
|
||||
m_queue.extend(control_flow_submodules)
|
||||
|
||||
return debug_handle_map
|
||||
|
||||
@ -44,15 +49,21 @@ class TestNumericDebugger(TestCase):
|
||||
def test_simple(self):
|
||||
m = TestHelperModules.Conv2dThenConv1d()
|
||||
example_inputs = m.example_inputs()
|
||||
ep = torch.export.export(m, example_inputs)
|
||||
ep = export_for_training(m, example_inputs)
|
||||
generate_numeric_debug_handle(ep)
|
||||
unique_ids = set()
|
||||
count = 0
|
||||
for n in ep.graph.nodes:
|
||||
if CUSTOM_KEY in n.meta and NUMERIC_DEBUG_HANDLE_KEY in n.meta[CUSTOM_KEY]:
|
||||
unique_ids.add(n.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY])
|
||||
count += 1
|
||||
self.assertEqual(len(unique_ids), count)
|
||||
debug_handle_map = _extract_debug_handles(ep.module())
|
||||
|
||||
self.assertEqual(len(set(debug_handle_map.values())), len(debug_handle_map))
|
||||
|
||||
def test_control_flow(self):
|
||||
m = TestHelperModules.ControlFlow()
|
||||
example_inputs = m.example_inputs()
|
||||
ep = export_for_training(m, example_inputs)
|
||||
generate_numeric_debug_handle(ep)
|
||||
|
||||
debug_handle_map = _extract_debug_handles(ep.module())
|
||||
|
||||
self.assertEqual(len(set(debug_handle_map.values())), len(debug_handle_map))
|
||||
|
||||
def test_quantize_pt2e_preserve_handle(self):
|
||||
m = TestHelperModules.Conv2dThenConv1d()
|
||||
|
@ -5,6 +5,7 @@ from typing import Callable, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from torch.ao.ns.fx.utils import compute_sqnr
|
||||
from torch.ao.quantization.pt2e.graph_utils import get_control_flow_submodules
|
||||
from torch.export import ExportedProgram
|
||||
from torch.fx import GraphModule, Node
|
||||
from torch.nn import functional as F
|
||||
@ -42,19 +43,32 @@ def generate_numeric_debug_handle(ep: ExportedProgram) -> None:
|
||||
)
|
||||
|
||||
unique_id = 0
|
||||
# Find the max ID that exists in the graph first, in case part of the graph
|
||||
# has already been annotated. This way we guarantee there are no duplicate
|
||||
# handle IDs.
|
||||
for node in ep.graph.nodes:
|
||||
unique_id = max(
|
||||
unique_id, node.meta.get(CUSTOM_KEY, {}).get(NUMERIC_DEBUG_HANDLE_KEY, 0)
|
||||
)
|
||||
unique_id += 1
|
||||
|
||||
for node in ep.graph.nodes:
|
||||
def _bfs_trace_graph_with_node_process(node_op: Callable) -> None:
|
||||
nonlocal ep
|
||||
queue = [ep.graph_module]
|
||||
while queue:
|
||||
current_graph_module = queue.pop(0)
|
||||
for node in current_graph_module.graph.nodes:
|
||||
if node.op in ["output", "placeholder"]:
|
||||
continue
|
||||
|
||||
node_op(node)
|
||||
|
||||
control_flow_submodules = [
|
||||
submodule
|
||||
for _, submodule, _ in get_control_flow_submodules(current_graph_module)
|
||||
]
|
||||
queue.extend(control_flow_submodules)
|
||||
|
||||
def _find_max_id(node: torch.fx.Node) -> None:
|
||||
nonlocal unique_id
|
||||
unique_id = max(
|
||||
unique_id, node.meta.get(CUSTOM_KEY, {}).get(NUMERIC_DEBUG_HANDLE_KEY, 0)
|
||||
)
|
||||
|
||||
def _assign_debug_handle(node: torch.fx.Node) -> None:
|
||||
nonlocal unique_id
|
||||
if CUSTOM_KEY not in node.meta:
|
||||
node.meta[CUSTOM_KEY] = {}
|
||||
|
||||
@ -62,6 +76,17 @@ def generate_numeric_debug_handle(ep: ExportedProgram) -> None:
|
||||
node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = unique_id
|
||||
unique_id += 1
|
||||
|
||||
# Find the max ID that exists in the graph first, in case part of the graph
|
||||
# has already been annotated. This way we guarantee there are no duplicate
|
||||
# handle IDs.
|
||||
_bfs_trace_graph_with_node_process(_find_max_id)
|
||||
|
||||
unique_id += 1
|
||||
|
||||
# Assign debug handles to all nodes in the graph that don't have one based on the
|
||||
# max ID found in the previous step.
|
||||
_bfs_trace_graph_with_node_process(_assign_debug_handle)
|
||||
|
||||
|
||||
class OutputLogger(torch.nn.Module):
|
||||
"""
|
||||
|
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import itertools
|
||||
import operator
|
||||
from typing import Any, Callable, List, Optional, OrderedDict, Sequence, Set
|
||||
from typing import Any, Callable, List, Optional, OrderedDict, Sequence, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx import Node
|
||||
@ -14,6 +14,7 @@ from torch.fx.passes.utils.source_matcher_utils import (
|
||||
|
||||
__all__ = [
|
||||
"find_sequential_partitions",
|
||||
"get_control_flow_submodules",
|
||||
"get_equivalent_types",
|
||||
"update_equivalent_types_dict",
|
||||
]
|
||||
@ -114,3 +115,39 @@ def find_sequential_partitions(
|
||||
if _partitions_sequential(candidate)
|
||||
]
|
||||
return fused_partitions
|
||||
|
||||
|
||||
def _get_submodule(
|
||||
graph_module: torch.fx.GraphModule, node: torch.fx.Node, arg_index: int
|
||||
) -> Tuple[str, torch.nn.Module, torch.fx.Node]:
|
||||
submod_node = node.args[arg_index]
|
||||
assert isinstance(submod_node, torch.fx.Node)
|
||||
assert submod_node.op == "get_attr"
|
||||
assert isinstance(submod_node.target, str)
|
||||
submodule = graph_module.get_submodule(submod_node.target)
|
||||
# pyre-ignore
|
||||
return submod_node.target, submodule, node
|
||||
|
||||
|
||||
def get_control_flow_submodules(
|
||||
graph_module: torch.fx.GraphModule,
|
||||
) -> List[Tuple[str, torch.nn.Module, torch.fx.Node]]:
|
||||
"""
|
||||
Returns a list of submodules used for control flow operations
|
||||
(torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look
|
||||
into submodules). Specifically, the returned value is a list containing a
|
||||
tuple of (name of the submodule that's stored in the graph module, the
|
||||
submodule itself, and the fx node that uses this submodule).
|
||||
"""
|
||||
control_flow_submodules = []
|
||||
for node in graph_module.graph.nodes:
|
||||
if node.op != "call_function":
|
||||
continue
|
||||
|
||||
if node.target is torch.ops.higher_order.cond:
|
||||
control_flow_submodules.append(_get_submodule(graph_module, node, 1))
|
||||
control_flow_submodules.append(_get_submodule(graph_module, node, 2))
|
||||
if node.target is torch.ops.higher_order.map_impl:
|
||||
control_flow_submodules.append(_get_submodule(graph_module, node, 0))
|
||||
|
||||
return control_flow_submodules
|
||||
|
@ -4,6 +4,8 @@ r"""Importing this file includes common utility methods and base clases for
|
||||
checking quantization api and properties of resulting modules.
|
||||
"""
|
||||
|
||||
from functorch.experimental import control_flow
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@ -2677,6 +2679,44 @@ class SparseNNModel(nn.Module):
|
||||
return out
|
||||
|
||||
class TestHelperModules:
|
||||
class ControlFlow(torch.nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
xs: torch.Tensor,
|
||||
pred1: torch.Tensor,
|
||||
pred2: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
def true_nested(y: torch.Tensor) -> torch.Tensor:
|
||||
y = y + y
|
||||
y = torch.mm(y, y)
|
||||
return y
|
||||
|
||||
def false_nested(y: torch.Tensor) -> torch.Tensor:
|
||||
return torch.mm(y, y)
|
||||
|
||||
def true_fn(x: torch.Tensor, pred2: torch.Tensor) -> torch.Tensor:
|
||||
z = control_flow.cond(pred2, true_nested, false_nested, [x])
|
||||
return x + z
|
||||
|
||||
def false_fn(x: torch.Tensor, _) -> torch.Tensor:
|
||||
return x.cos()
|
||||
|
||||
def map_fn(
|
||||
x: torch.Tensor, pred1: torch.Tensor, pred2: torch.Tensor, y: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
x = x.cos()
|
||||
y = control_flow.cond(pred1, true_fn, false_fn, [y, pred2])
|
||||
x = x + y
|
||||
return x.sin()
|
||||
|
||||
y = torch.mm(y, y)
|
||||
return control_flow.map(map_fn, xs, pred1, pred2, y)
|
||||
|
||||
def example_inputs(self):
|
||||
return (torch.ones(2, 2), torch.tensor([False]), torch.tensor([False]), torch.ones(2, 2),)
|
||||
|
||||
class Conv2dPropAnnotaton(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
Reference in New Issue
Block a user