mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
support condition branch in ao debug handler (#141516)
This diff introduced the supportive of condition statement into ao debug handler generation. Most of code borrowed from ExecuTorch to avoid circle dependency issue. Differential Revision: [D66270691](https://our.internmc.facebook.com/intern/diff/D66270691/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/141516 Approved by: https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
75530885ba
commit
ff059587c6
@ -14,6 +14,7 @@ from torch.ao.quantization import (
|
|||||||
NUMERIC_DEBUG_HANDLE_KEY,
|
NUMERIC_DEBUG_HANDLE_KEY,
|
||||||
prepare_for_propagation_comparison,
|
prepare_for_propagation_comparison,
|
||||||
)
|
)
|
||||||
|
from torch.ao.quantization.pt2e.graph_utils import get_control_flow_submodules
|
||||||
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
||||||
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
||||||
get_symmetric_quantization_config,
|
get_symmetric_quantization_config,
|
||||||
@ -27,14 +28,18 @@ from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef, Tes
|
|||||||
def _extract_debug_handles(model) -> Dict[str, int]:
|
def _extract_debug_handles(model) -> Dict[str, int]:
|
||||||
debug_handle_map: Dict[str, int] = {}
|
debug_handle_map: Dict[str, int] = {}
|
||||||
|
|
||||||
for node in model.graph.nodes:
|
m_queue = [model]
|
||||||
if (
|
|
||||||
CUSTOM_KEY in node.meta
|
while m_queue:
|
||||||
and NUMERIC_DEBUG_HANDLE_KEY in node.meta[CUSTOM_KEY]
|
cur_m = m_queue.pop(0)
|
||||||
):
|
for n in cur_m.graph.nodes:
|
||||||
debug_handle_map[str(node)] = node.meta[CUSTOM_KEY][
|
if CUSTOM_KEY in n.meta and NUMERIC_DEBUG_HANDLE_KEY in n.meta[CUSTOM_KEY]:
|
||||||
NUMERIC_DEBUG_HANDLE_KEY
|
debug_handle_map[str(n)] = n.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY]
|
||||||
]
|
|
||||||
|
control_flow_submodules = [
|
||||||
|
submodule for _, submodule, _ in get_control_flow_submodules(cur_m)
|
||||||
|
]
|
||||||
|
m_queue.extend(control_flow_submodules)
|
||||||
|
|
||||||
return debug_handle_map
|
return debug_handle_map
|
||||||
|
|
||||||
@ -44,15 +49,21 @@ class TestNumericDebugger(TestCase):
|
|||||||
def test_simple(self):
|
def test_simple(self):
|
||||||
m = TestHelperModules.Conv2dThenConv1d()
|
m = TestHelperModules.Conv2dThenConv1d()
|
||||||
example_inputs = m.example_inputs()
|
example_inputs = m.example_inputs()
|
||||||
ep = torch.export.export(m, example_inputs)
|
ep = export_for_training(m, example_inputs)
|
||||||
generate_numeric_debug_handle(ep)
|
generate_numeric_debug_handle(ep)
|
||||||
unique_ids = set()
|
debug_handle_map = _extract_debug_handles(ep.module())
|
||||||
count = 0
|
|
||||||
for n in ep.graph.nodes:
|
self.assertEqual(len(set(debug_handle_map.values())), len(debug_handle_map))
|
||||||
if CUSTOM_KEY in n.meta and NUMERIC_DEBUG_HANDLE_KEY in n.meta[CUSTOM_KEY]:
|
|
||||||
unique_ids.add(n.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY])
|
def test_control_flow(self):
|
||||||
count += 1
|
m = TestHelperModules.ControlFlow()
|
||||||
self.assertEqual(len(unique_ids), count)
|
example_inputs = m.example_inputs()
|
||||||
|
ep = export_for_training(m, example_inputs)
|
||||||
|
generate_numeric_debug_handle(ep)
|
||||||
|
|
||||||
|
debug_handle_map = _extract_debug_handles(ep.module())
|
||||||
|
|
||||||
|
self.assertEqual(len(set(debug_handle_map.values())), len(debug_handle_map))
|
||||||
|
|
||||||
def test_quantize_pt2e_preserve_handle(self):
|
def test_quantize_pt2e_preserve_handle(self):
|
||||||
m = TestHelperModules.Conv2dThenConv1d()
|
m = TestHelperModules.Conv2dThenConv1d()
|
||||||
|
@ -5,6 +5,7 @@ from typing import Callable, Dict, List, Optional, Sequence, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.ao.ns.fx.utils import compute_sqnr
|
from torch.ao.ns.fx.utils import compute_sqnr
|
||||||
|
from torch.ao.quantization.pt2e.graph_utils import get_control_flow_submodules
|
||||||
from torch.export import ExportedProgram
|
from torch.export import ExportedProgram
|
||||||
from torch.fx import GraphModule, Node
|
from torch.fx import GraphModule, Node
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
@ -42,19 +43,32 @@ def generate_numeric_debug_handle(ep: ExportedProgram) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
unique_id = 0
|
unique_id = 0
|
||||||
# Find the max ID that exists in the graph first, in case part of the graph
|
|
||||||
# has already been annotated. This way we guarantee there are no duplicate
|
def _bfs_trace_graph_with_node_process(node_op: Callable) -> None:
|
||||||
# handle IDs.
|
nonlocal ep
|
||||||
for node in ep.graph.nodes:
|
queue = [ep.graph_module]
|
||||||
|
while queue:
|
||||||
|
current_graph_module = queue.pop(0)
|
||||||
|
for node in current_graph_module.graph.nodes:
|
||||||
|
if node.op in ["output", "placeholder"]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
node_op(node)
|
||||||
|
|
||||||
|
control_flow_submodules = [
|
||||||
|
submodule
|
||||||
|
for _, submodule, _ in get_control_flow_submodules(current_graph_module)
|
||||||
|
]
|
||||||
|
queue.extend(control_flow_submodules)
|
||||||
|
|
||||||
|
def _find_max_id(node: torch.fx.Node) -> None:
|
||||||
|
nonlocal unique_id
|
||||||
unique_id = max(
|
unique_id = max(
|
||||||
unique_id, node.meta.get(CUSTOM_KEY, {}).get(NUMERIC_DEBUG_HANDLE_KEY, 0)
|
unique_id, node.meta.get(CUSTOM_KEY, {}).get(NUMERIC_DEBUG_HANDLE_KEY, 0)
|
||||||
)
|
)
|
||||||
unique_id += 1
|
|
||||||
|
|
||||||
for node in ep.graph.nodes:
|
|
||||||
if node.op in ["output", "placeholder"]:
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
def _assign_debug_handle(node: torch.fx.Node) -> None:
|
||||||
|
nonlocal unique_id
|
||||||
if CUSTOM_KEY not in node.meta:
|
if CUSTOM_KEY not in node.meta:
|
||||||
node.meta[CUSTOM_KEY] = {}
|
node.meta[CUSTOM_KEY] = {}
|
||||||
|
|
||||||
@ -62,6 +76,17 @@ def generate_numeric_debug_handle(ep: ExportedProgram) -> None:
|
|||||||
node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = unique_id
|
node.meta[CUSTOM_KEY][NUMERIC_DEBUG_HANDLE_KEY] = unique_id
|
||||||
unique_id += 1
|
unique_id += 1
|
||||||
|
|
||||||
|
# Find the max ID that exists in the graph first, in case part of the graph
|
||||||
|
# has already been annotated. This way we guarantee there are no duplicate
|
||||||
|
# handle IDs.
|
||||||
|
_bfs_trace_graph_with_node_process(_find_max_id)
|
||||||
|
|
||||||
|
unique_id += 1
|
||||||
|
|
||||||
|
# Assign debug handles to all nodes in the graph that don't have one based on the
|
||||||
|
# max ID found in the previous step.
|
||||||
|
_bfs_trace_graph_with_node_process(_assign_debug_handle)
|
||||||
|
|
||||||
|
|
||||||
class OutputLogger(torch.nn.Module):
|
class OutputLogger(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import itertools
|
import itertools
|
||||||
import operator
|
import operator
|
||||||
from typing import Any, Callable, List, Optional, OrderedDict, Sequence, Set
|
from typing import Any, Callable, List, Optional, OrderedDict, Sequence, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.fx import Node
|
from torch.fx import Node
|
||||||
@ -14,6 +14,7 @@ from torch.fx.passes.utils.source_matcher_utils import (
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"find_sequential_partitions",
|
"find_sequential_partitions",
|
||||||
|
"get_control_flow_submodules",
|
||||||
"get_equivalent_types",
|
"get_equivalent_types",
|
||||||
"update_equivalent_types_dict",
|
"update_equivalent_types_dict",
|
||||||
]
|
]
|
||||||
@ -114,3 +115,39 @@ def find_sequential_partitions(
|
|||||||
if _partitions_sequential(candidate)
|
if _partitions_sequential(candidate)
|
||||||
]
|
]
|
||||||
return fused_partitions
|
return fused_partitions
|
||||||
|
|
||||||
|
|
||||||
|
def _get_submodule(
|
||||||
|
graph_module: torch.fx.GraphModule, node: torch.fx.Node, arg_index: int
|
||||||
|
) -> Tuple[str, torch.nn.Module, torch.fx.Node]:
|
||||||
|
submod_node = node.args[arg_index]
|
||||||
|
assert isinstance(submod_node, torch.fx.Node)
|
||||||
|
assert submod_node.op == "get_attr"
|
||||||
|
assert isinstance(submod_node.target, str)
|
||||||
|
submodule = graph_module.get_submodule(submod_node.target)
|
||||||
|
# pyre-ignore
|
||||||
|
return submod_node.target, submodule, node
|
||||||
|
|
||||||
|
|
||||||
|
def get_control_flow_submodules(
|
||||||
|
graph_module: torch.fx.GraphModule,
|
||||||
|
) -> List[Tuple[str, torch.nn.Module, torch.fx.Node]]:
|
||||||
|
"""
|
||||||
|
Returns a list of submodules used for control flow operations
|
||||||
|
(torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look
|
||||||
|
into submodules). Specifically, the returned value is a list containing a
|
||||||
|
tuple of (name of the submodule that's stored in the graph module, the
|
||||||
|
submodule itself, and the fx node that uses this submodule).
|
||||||
|
"""
|
||||||
|
control_flow_submodules = []
|
||||||
|
for node in graph_module.graph.nodes:
|
||||||
|
if node.op != "call_function":
|
||||||
|
continue
|
||||||
|
|
||||||
|
if node.target is torch.ops.higher_order.cond:
|
||||||
|
control_flow_submodules.append(_get_submodule(graph_module, node, 1))
|
||||||
|
control_flow_submodules.append(_get_submodule(graph_module, node, 2))
|
||||||
|
if node.target is torch.ops.higher_order.map_impl:
|
||||||
|
control_flow_submodules.append(_get_submodule(graph_module, node, 0))
|
||||||
|
|
||||||
|
return control_flow_submodules
|
||||||
|
@ -4,6 +4,8 @@ r"""Importing this file includes common utility methods and base clases for
|
|||||||
checking quantization api and properties of resulting modules.
|
checking quantization api and properties of resulting modules.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from functorch.experimental import control_flow
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -2677,6 +2679,44 @@ class SparseNNModel(nn.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
class TestHelperModules:
|
class TestHelperModules:
|
||||||
|
class ControlFlow(torch.nn.Module):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
xs: torch.Tensor,
|
||||||
|
pred1: torch.Tensor,
|
||||||
|
pred2: torch.Tensor,
|
||||||
|
y: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
def true_nested(y: torch.Tensor) -> torch.Tensor:
|
||||||
|
y = y + y
|
||||||
|
y = torch.mm(y, y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
def false_nested(y: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.mm(y, y)
|
||||||
|
|
||||||
|
def true_fn(x: torch.Tensor, pred2: torch.Tensor) -> torch.Tensor:
|
||||||
|
z = control_flow.cond(pred2, true_nested, false_nested, [x])
|
||||||
|
return x + z
|
||||||
|
|
||||||
|
def false_fn(x: torch.Tensor, _) -> torch.Tensor:
|
||||||
|
return x.cos()
|
||||||
|
|
||||||
|
def map_fn(
|
||||||
|
x: torch.Tensor, pred1: torch.Tensor, pred2: torch.Tensor, y: torch.Tensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
x = x.cos()
|
||||||
|
y = control_flow.cond(pred1, true_fn, false_fn, [y, pred2])
|
||||||
|
x = x + y
|
||||||
|
return x.sin()
|
||||||
|
|
||||||
|
y = torch.mm(y, y)
|
||||||
|
return control_flow.map(map_fn, xs, pred1, pred2, y)
|
||||||
|
|
||||||
|
def example_inputs(self):
|
||||||
|
return (torch.ones(2, 2), torch.tensor([False]), torch.tensor([False]), torch.ones(2, 2),)
|
||||||
|
|
||||||
class Conv2dPropAnnotaton(torch.nn.Module):
|
class Conv2dPropAnnotaton(torch.nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
Reference in New Issue
Block a user