Files
pytorch/torch/ao/quantization/pt2e/graph_utils.py
gasoonjia ff059587c6 support condition branch in ao debug handler (#141516)
This diff introduced the supportive of condition statement into ao debug handler generation.

Most of code borrowed from ExecuTorch to avoid circle dependency issue.

Differential Revision: [D66270691](https://our.internmc.facebook.com/intern/diff/D66270691/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141516
Approved by: https://github.com/jerryzh168
2024-12-10 14:05:12 +00:00

154 lines
5.3 KiB
Python

# mypy: allow-untyped-defs
import itertools
import operator
from typing import Any, Callable, List, Optional, OrderedDict, Sequence, Set, Tuple
import torch
from torch.fx import Node
from torch.fx.passes.utils.source_matcher_utils import (
check_subgraphs_connected,
get_source_partitions,
SourcePartition,
)
__all__ = [
"find_sequential_partitions",
"get_control_flow_submodules",
"get_equivalent_types",
"update_equivalent_types_dict",
]
_EQUIVALENT_TYPES: List[Set] = [
{torch.nn.Conv1d, torch.nn.functional.conv1d},
{torch.nn.Conv2d, torch.nn.functional.conv2d},
{torch.nn.AdaptiveAvgPool2d, torch.nn.functional.adaptive_avg_pool2d},
{torch.nn.ReLU, torch.nn.functional.relu, torch.nn.functional.relu_},
{torch.nn.BatchNorm2d, torch.nn.functional.batch_norm},
{torch.nn.Hardtanh, torch.nn.functional.hardtanh, torch.nn.functional.hardtanh_},
{torch.add, operator.add, operator.iadd, "add", "add_"},
{torch.mul, operator.mul, operator.imul, "mul", "mul_"},
]
def _create_equivalent_types_dict():
_DICT = {}
for values in _EQUIVALENT_TYPES:
for v in values:
_DICT[v] = list(values)
return _DICT
_EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict()
def get_equivalent_types() -> List[Set]:
return _EQUIVALENT_TYPES
def update_equivalent_types_dict(customized_equivalent_types=None):
"""Help function for user who wants to customize the _EQUIVALENT_TYPES and _EQUIVALENT_TYPES_DICT.
When customized_equivalent_types passes in,
re-generate _EQUIVALENT_TYPES and _EQUIVALENT_TYPES_DICT.
"""
if customized_equivalent_types is None:
raise ValueError("customized_equivalent_types should not be None")
global _EQUIVALENT_TYPES
global _EQUIVALENT_TYPES_DICT
_EQUIVALENT_TYPES = customized_equivalent_types
_EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict()
def _partitions_sequential(partitions: Sequence[SourcePartition]):
prev_partition = None
for partition in partitions:
if prev_partition is not None and not check_subgraphs_connected(
prev_partition, partition
):
return False
prev_partition = partition
return True
def _get_matching_types(partition_type):
matching_types = [partition_type]
if partition_type in _EQUIVALENT_TYPES_DICT:
matching_types.extend(_EQUIVALENT_TYPES_DICT[partition_type])
return matching_types
def _valid_type_sequence(partition_types: List[Any]):
partition_types_set = set() # type: ignore[var-annotated]
for partition_type in partition_types:
matching_types = _get_matching_types(partition_type)
matching_types_set = set(matching_types)
if len(partition_types_set & matching_types_set) > 0:
return False
partition_types_set |= matching_types_set
return True
def find_sequential_partitions(
gm: torch.fx.GraphModule,
partition_types: List[Any],
include_functional_equivalent=True,
filter_fn: Optional[Callable[[Node], bool]] = None,
):
if not _valid_type_sequence(partition_types):
raise ValueError(
f"Invalid partition types: {partition_types}. Each type in the sequence must be unique"
)
typed_partitions: OrderedDict[Any, List[SourcePartition]] = OrderedDict()
for partition_type in partition_types:
types_to_match = _get_matching_types(partition_type)
partitions = get_source_partitions(gm.graph, types_to_match, filter_fn)
typed_partitions[partition_type] = list(
itertools.chain.from_iterable(partitions.values())
)
typed_partitions_list = list(typed_partitions.values())
fusion_candidates = itertools.product(*typed_partitions_list)
fused_partitions = [
candidate
for candidate in fusion_candidates
if _partitions_sequential(candidate)
]
return fused_partitions
def _get_submodule(
graph_module: torch.fx.GraphModule, node: torch.fx.Node, arg_index: int
) -> Tuple[str, torch.nn.Module, torch.fx.Node]:
submod_node = node.args[arg_index]
assert isinstance(submod_node, torch.fx.Node)
assert submod_node.op == "get_attr"
assert isinstance(submod_node.target, str)
submodule = graph_module.get_submodule(submod_node.target)
# pyre-ignore
return submod_node.target, submodule, node
def get_control_flow_submodules(
graph_module: torch.fx.GraphModule,
) -> List[Tuple[str, torch.nn.Module, torch.fx.Node]]:
"""
Returns a list of submodules used for control flow operations
(torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look
into submodules). Specifically, the returned value is a list containing a
tuple of (name of the submodule that's stored in the graph module, the
submodule itself, and the fx node that uses this submodule).
"""
control_flow_submodules = []
for node in graph_module.graph.nodes:
if node.op != "call_function":
continue
if node.target is torch.ops.higher_order.cond:
control_flow_submodules.append(_get_submodule(graph_module, node, 1))
control_flow_submodules.append(_get_submodule(graph_module, node, 2))
if node.target is torch.ops.higher_order.map_impl:
control_flow_submodules.append(_get_submodule(graph_module, node, 0))
return control_flow_submodules