mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
This diff introduced the supportive of condition statement into ao debug handler generation. Most of code borrowed from ExecuTorch to avoid circle dependency issue. Differential Revision: [D66270691](https://our.internmc.facebook.com/intern/diff/D66270691/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/141516 Approved by: https://github.com/jerryzh168
154 lines
5.3 KiB
Python
154 lines
5.3 KiB
Python
# mypy: allow-untyped-defs
|
|
import itertools
|
|
import operator
|
|
from typing import Any, Callable, List, Optional, OrderedDict, Sequence, Set, Tuple
|
|
|
|
import torch
|
|
from torch.fx import Node
|
|
from torch.fx.passes.utils.source_matcher_utils import (
|
|
check_subgraphs_connected,
|
|
get_source_partitions,
|
|
SourcePartition,
|
|
)
|
|
|
|
|
|
__all__ = [
|
|
"find_sequential_partitions",
|
|
"get_control_flow_submodules",
|
|
"get_equivalent_types",
|
|
"update_equivalent_types_dict",
|
|
]
|
|
|
|
_EQUIVALENT_TYPES: List[Set] = [
|
|
{torch.nn.Conv1d, torch.nn.functional.conv1d},
|
|
{torch.nn.Conv2d, torch.nn.functional.conv2d},
|
|
{torch.nn.AdaptiveAvgPool2d, torch.nn.functional.adaptive_avg_pool2d},
|
|
{torch.nn.ReLU, torch.nn.functional.relu, torch.nn.functional.relu_},
|
|
{torch.nn.BatchNorm2d, torch.nn.functional.batch_norm},
|
|
{torch.nn.Hardtanh, torch.nn.functional.hardtanh, torch.nn.functional.hardtanh_},
|
|
{torch.add, operator.add, operator.iadd, "add", "add_"},
|
|
{torch.mul, operator.mul, operator.imul, "mul", "mul_"},
|
|
]
|
|
|
|
|
|
def _create_equivalent_types_dict():
|
|
_DICT = {}
|
|
for values in _EQUIVALENT_TYPES:
|
|
for v in values:
|
|
_DICT[v] = list(values)
|
|
return _DICT
|
|
|
|
|
|
_EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict()
|
|
|
|
|
|
def get_equivalent_types() -> List[Set]:
|
|
return _EQUIVALENT_TYPES
|
|
|
|
|
|
def update_equivalent_types_dict(customized_equivalent_types=None):
|
|
"""Help function for user who wants to customize the _EQUIVALENT_TYPES and _EQUIVALENT_TYPES_DICT.
|
|
When customized_equivalent_types passes in,
|
|
re-generate _EQUIVALENT_TYPES and _EQUIVALENT_TYPES_DICT.
|
|
"""
|
|
if customized_equivalent_types is None:
|
|
raise ValueError("customized_equivalent_types should not be None")
|
|
global _EQUIVALENT_TYPES
|
|
global _EQUIVALENT_TYPES_DICT
|
|
_EQUIVALENT_TYPES = customized_equivalent_types
|
|
_EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict()
|
|
|
|
|
|
def _partitions_sequential(partitions: Sequence[SourcePartition]):
|
|
prev_partition = None
|
|
for partition in partitions:
|
|
if prev_partition is not None and not check_subgraphs_connected(
|
|
prev_partition, partition
|
|
):
|
|
return False
|
|
prev_partition = partition
|
|
return True
|
|
|
|
|
|
def _get_matching_types(partition_type):
|
|
matching_types = [partition_type]
|
|
if partition_type in _EQUIVALENT_TYPES_DICT:
|
|
matching_types.extend(_EQUIVALENT_TYPES_DICT[partition_type])
|
|
return matching_types
|
|
|
|
|
|
def _valid_type_sequence(partition_types: List[Any]):
|
|
partition_types_set = set() # type: ignore[var-annotated]
|
|
for partition_type in partition_types:
|
|
matching_types = _get_matching_types(partition_type)
|
|
matching_types_set = set(matching_types)
|
|
if len(partition_types_set & matching_types_set) > 0:
|
|
return False
|
|
partition_types_set |= matching_types_set
|
|
return True
|
|
|
|
|
|
def find_sequential_partitions(
|
|
gm: torch.fx.GraphModule,
|
|
partition_types: List[Any],
|
|
include_functional_equivalent=True,
|
|
filter_fn: Optional[Callable[[Node], bool]] = None,
|
|
):
|
|
if not _valid_type_sequence(partition_types):
|
|
raise ValueError(
|
|
f"Invalid partition types: {partition_types}. Each type in the sequence must be unique"
|
|
)
|
|
|
|
typed_partitions: OrderedDict[Any, List[SourcePartition]] = OrderedDict()
|
|
for partition_type in partition_types:
|
|
types_to_match = _get_matching_types(partition_type)
|
|
partitions = get_source_partitions(gm.graph, types_to_match, filter_fn)
|
|
typed_partitions[partition_type] = list(
|
|
itertools.chain.from_iterable(partitions.values())
|
|
)
|
|
|
|
typed_partitions_list = list(typed_partitions.values())
|
|
fusion_candidates = itertools.product(*typed_partitions_list)
|
|
fused_partitions = [
|
|
candidate
|
|
for candidate in fusion_candidates
|
|
if _partitions_sequential(candidate)
|
|
]
|
|
return fused_partitions
|
|
|
|
|
|
def _get_submodule(
|
|
graph_module: torch.fx.GraphModule, node: torch.fx.Node, arg_index: int
|
|
) -> Tuple[str, torch.nn.Module, torch.fx.Node]:
|
|
submod_node = node.args[arg_index]
|
|
assert isinstance(submod_node, torch.fx.Node)
|
|
assert submod_node.op == "get_attr"
|
|
assert isinstance(submod_node.target, str)
|
|
submodule = graph_module.get_submodule(submod_node.target)
|
|
# pyre-ignore
|
|
return submod_node.target, submodule, node
|
|
|
|
|
|
def get_control_flow_submodules(
|
|
graph_module: torch.fx.GraphModule,
|
|
) -> List[Tuple[str, torch.nn.Module, torch.fx.Node]]:
|
|
"""
|
|
Returns a list of submodules used for control flow operations
|
|
(torch.ops.higher_order.cond/map) that are in the given toplevel graph (does not look
|
|
into submodules). Specifically, the returned value is a list containing a
|
|
tuple of (name of the submodule that's stored in the graph module, the
|
|
submodule itself, and the fx node that uses this submodule).
|
|
"""
|
|
control_flow_submodules = []
|
|
for node in graph_module.graph.nodes:
|
|
if node.op != "call_function":
|
|
continue
|
|
|
|
if node.target is torch.ops.higher_order.cond:
|
|
control_flow_submodules.append(_get_submodule(graph_module, node, 1))
|
|
control_flow_submodules.append(_get_submodule(graph_module, node, 2))
|
|
if node.target is torch.ops.higher_order.map_impl:
|
|
control_flow_submodules.append(_get_submodule(graph_module, node, 0))
|
|
|
|
return control_flow_submodules
|