mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
In comm-compute overlap we will have a graph with: ``` def foo(...): ag = all_gather(...) hiding_compute = mm(...) wait(ag) ``` There is no explicit dependency between the hiding compute and the collectives, but we want to add implicit dependencies from wait->hiding_compute, and from hiding_compute->all_gather to preserve overlap. Additionally, while bucketing, we will merge collective starts and collective waits together. In this case, we will want to treat the two nodes as a single subgraph - each node in the merged set will have the union of all deps in the set. This pr adds `AugmentedGraphHelper` that adds the apis, and allows querying for dependency with this augmented graph. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163959 Approved by: https://github.com/v0i0, https://github.com/IvanKobzarev ghstack dependencies: #163215, #163754
609 lines
23 KiB
Python
609 lines
23 KiB
Python
"""
|
|
This module implements graph deduplication functionality for TorchDynamo's optimization pipeline.
|
|
Graph deduplication identifies identical subgraphs in the computational graph and merges them
|
|
to reduce redundancy and improve performance. The process involves analyzing regions of the graph,
|
|
identifying structurally equivalent regions, and replacing them with a single shared implementation.
|
|
This optimization is particularly effective for models with repeated patterns or similar computational
|
|
structures across different parts of the network.
|
|
"""
|
|
|
|
import logging
|
|
import operator
|
|
from collections import defaultdict, deque
|
|
from collections.abc import Generator, Iterable
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.fx
|
|
from torch._dynamo import config
|
|
from torch.multiprocessing.reductions import StorageWeakRef
|
|
from torch.utils._ordered_set import OrderedSet
|
|
|
|
from .graph_region_tracker import Node, Region
|
|
from .graph_utils import _detect_cycles, _get_flat_args, _get_flat_args_unique
|
|
|
|
|
|
# Represents an index into the region
|
|
# to select a node and then
|
|
# an index into that node's
|
|
# flattened arguments
|
|
UsageIndex = tuple[int, int]
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
last_node_to_additional_deps: Optional[dict[Node, OrderedSet[Node]]] = None
|
|
|
|
|
|
def apply_graph_deduplication(output_graph) -> dict[str, torch.fx.GraphModule]: # type: ignore[no-untyped-def]
|
|
"""
|
|
This is the main entry point for applying the graph deduplication pass. \
|
|
Deduplication occurs in two phases:
|
|
1. Subgraph creation:
|
|
Subgraph creation works by taking one representative region from each region \
|
|
group and creating a subgraph from it, which will then be used to replace all regions \
|
|
in the group. This is implemented by first copying all nodes of the region to the new \
|
|
subgraph and then finding all inputs which are not within the region and creating placeholders \
|
|
for them. For the outputs, all regions in a region group need to be scanned to ensure the \
|
|
largest set of outputs is found, and then an output node is created which returns \
|
|
a tuple of all outputs.
|
|
|
|
2. Graph replacement:
|
|
To replace each region with the extracted subgraph, the node index in the region \
|
|
and argument index within the node's flattened args and kwargs are recorded once during \
|
|
subgraph creation. This allows us to determine which (external to the region) nodes and \
|
|
in which order these nodes are passed as inputs. For the outputs, getitem nodes are created \
|
|
for each output, and all nodes in the region with external outputs are replaced by the proper \
|
|
getitem node. Finally, all original nodes are erased (there should be no uses of these \
|
|
left in the graph).
|
|
|
|
The deduplication mutates the output_graph argument in place.
|
|
|
|
Returns a mapping of nodes to their subgraph output replacement node to remap outputs
|
|
when they are created in output_graph.
|
|
"""
|
|
|
|
duplicated_region_groups = output_graph.region_tracker.get_identical_regions(
|
|
output_graph.graph
|
|
)
|
|
node_to_mutated_arg_positions = (
|
|
output_graph.region_tracker.node_to_mutated_arg_positions
|
|
)
|
|
node_to_additional_deps = _populate_additional_deps(
|
|
output_graph.graph, output_graph.region_tracker.node_to_mutated_arg_positions
|
|
)
|
|
|
|
sub_gms: dict[str, torch.fx.GraphModule] = {}
|
|
|
|
for region_group in duplicated_region_groups:
|
|
inds_with_external_users = _get_all_output_indices(region_group)
|
|
region = region_group[0]
|
|
(
|
|
subgraph,
|
|
external_node_usages,
|
|
node_usage_to_tuple_elems,
|
|
ind_to_tuple_spec,
|
|
) = _create_subgraph(region, inds_with_external_users)
|
|
|
|
# Ignore regions with no args for now, could they possibly be evaluated at compile time?
|
|
if not list(external_node_usages):
|
|
continue
|
|
|
|
sub_gm = torch.fx.GraphModule(output_graph.nn_modules, subgraph)
|
|
subgraph_name = output_graph.install_subgraph("subgraph", sub_gm)
|
|
sub_gms[subgraph_name] = sub_gm
|
|
with output_graph.graph.inserting_before():
|
|
get_subgraph_node = output_graph.graph.create_node(
|
|
"get_attr", subgraph_name, (), {}
|
|
)
|
|
|
|
for region in region_group:
|
|
_replace_region_with_subgraph(
|
|
output_graph.graph,
|
|
region,
|
|
get_subgraph_node,
|
|
external_node_usages,
|
|
node_usage_to_tuple_elems,
|
|
ind_to_tuple_spec,
|
|
inds_with_external_users,
|
|
subgraph_name,
|
|
node_to_additional_deps,
|
|
node_to_mutated_arg_positions,
|
|
)
|
|
|
|
# This is to expose the updated node_to_additional_deps to tests
|
|
global last_node_to_additional_deps
|
|
last_node_to_additional_deps = node_to_additional_deps
|
|
|
|
_stable_topological_sort(
|
|
output_graph.graph,
|
|
node_to_additional_deps,
|
|
)
|
|
return sub_gms
|
|
|
|
|
|
def _replace_region_with_subgraph(
|
|
graph: torch.fx.Graph,
|
|
region: Region,
|
|
get_subgraph_node: Node,
|
|
external_node_usages: Iterable[OrderedSet[UsageIndex]],
|
|
node_usage_to_tuple_elems: dict[UsageIndex, OrderedSet[int]],
|
|
ind_to_tuple_spec: dict[int, dict[tuple[int, ...], int]],
|
|
inds_with_external_users: list[int],
|
|
subgraph_name: str,
|
|
node_to_additional_deps: dict[Node, OrderedSet[Node]],
|
|
node_to_mutated_arg_positions: dict[Node, OrderedSet[int]],
|
|
) -> None:
|
|
sub_args = []
|
|
flattened_getitem_nodes: OrderedSet[Node] = OrderedSet()
|
|
for usages in external_node_usages:
|
|
usage = next(iter(usages))
|
|
node_ind, usage_ind = usage
|
|
node = region[node_ind]
|
|
flattened_args_kwargs = _get_flat_args(node, {})
|
|
for user_ind, node_usage_ind in usages:
|
|
user = region[user_ind]
|
|
if user in node_to_mutated_arg_positions:
|
|
if node_usage_ind in node_to_mutated_arg_positions[user]:
|
|
log.debug(
|
|
"NYI: Failed to substitute region %s due to mutation", region
|
|
)
|
|
return
|
|
if usage in node_usage_to_tuple_elems:
|
|
tuple_elems = [region[i] for i in node_usage_to_tuple_elems[usage]]
|
|
flattened_getitem_nodes.update(tuple_elems)
|
|
sub_args.extend(tuple_elems)
|
|
else:
|
|
sub_args.append(flattened_args_kwargs[usage_ind])
|
|
|
|
# Input/Output aliasing not supported in HOPs today
|
|
# Note: we should use the nodes in the original graph (the region here)
|
|
# because we use the original traced example values for this check
|
|
if _has_aliasing(
|
|
region, sub_args, inds_with_external_users, flattened_getitem_nodes
|
|
):
|
|
return
|
|
|
|
invoke_args = (get_subgraph_node, subgraph_name, *sub_args)
|
|
|
|
invoke_subgraph_node = graph.create_node(
|
|
"call_function",
|
|
torch.ops.higher_order.invoke_subgraph,
|
|
invoke_args, # type: ignore[arg-type]
|
|
{},
|
|
)
|
|
|
|
ind = 0
|
|
flattened_output_nodes: OrderedSet[Node] = OrderedSet()
|
|
for external_user_ind in inds_with_external_users:
|
|
node = region[external_user_ind]
|
|
if _is_tuple_node(node):
|
|
tuple_spec = ind_to_tuple_spec[external_user_ind]
|
|
flattened_output_nodes.update(
|
|
_replace_tuple_outputs(
|
|
node, ind, tuple_spec, invoke_subgraph_node, graph
|
|
)
|
|
)
|
|
ind += len(tuple_spec)
|
|
else:
|
|
subgraph_output = graph.create_node(
|
|
"call_function", operator.getitem, (invoke_subgraph_node, ind), {}
|
|
)
|
|
node.replace_all_uses_with(subgraph_output, propagate_meta=True)
|
|
ind += 1
|
|
|
|
# Erase in reverse topological order
|
|
for node in reversed(region):
|
|
if node in flattened_getitem_nodes:
|
|
# Don't erase these, since they will still be used
|
|
continue
|
|
|
|
if node not in flattened_output_nodes:
|
|
graph.erase_node(node)
|
|
|
|
# Remove any nodes with additional deps
|
|
# This is safe; we've guaranteed that there is
|
|
# no input mutation, so all additional deps
|
|
# will be internal to the subgraph
|
|
node_to_additional_deps.pop(node, None)
|
|
for deps in node_to_additional_deps.values():
|
|
try:
|
|
deps.remove(node)
|
|
deps.add(invoke_subgraph_node)
|
|
except KeyError:
|
|
pass
|
|
|
|
if config.graph_deduplication_lint:
|
|
print(_detect_cycles(graph, node_to_additional_deps))
|
|
_stable_topological_sort(graph, node_to_additional_deps)
|
|
graph.lint()
|
|
|
|
|
|
def _get_external_inputs(
|
|
region: Region,
|
|
) -> dict[Node, OrderedSet[UsageIndex]]:
|
|
external_node_to_usages = defaultdict[Node, OrderedSet[UsageIndex]](OrderedSet)
|
|
region_unique = set(region)
|
|
for node_ind, node in enumerate(region):
|
|
flattened_args_kwargs = _get_flat_args(node, {})
|
|
for arg_ind, in_node in enumerate(flattened_args_kwargs):
|
|
if isinstance(in_node, Node) and in_node not in region_unique:
|
|
# in_node may occur in multiple nodes' flat_args
|
|
# track this so we can check if the arg is mutated
|
|
# Previously, we only needed to track one occurrence
|
|
# to be able to map that node to a placeholder
|
|
external_node_to_usages[in_node].add((node_ind, arg_ind))
|
|
|
|
return external_node_to_usages
|
|
|
|
|
|
def _get_all_output_indices(regions: list[Region]) -> list[int]:
|
|
# Scan all regions to get the set of all possible output nodes indices in the region
|
|
# perhaps we can record this information during region creation for more efficiency?
|
|
inds_with_external_users: set[int] = set()
|
|
for region in regions:
|
|
_get_inds_with_external_users(region, inds_with_external_users)
|
|
|
|
return sorted(inds_with_external_users)
|
|
|
|
|
|
def _get_inds_with_external_users(region: Region, inds_unique: set[int]) -> None:
|
|
for ind, node in enumerate(region):
|
|
for user in node.users:
|
|
if user not in region:
|
|
if ind not in inds_unique:
|
|
inds_unique.add(ind)
|
|
|
|
|
|
def _create_subgraph(
|
|
region: Region,
|
|
inds_with_external_users: list[int],
|
|
) -> tuple[
|
|
torch.fx.Graph,
|
|
list[OrderedSet[UsageIndex]],
|
|
dict[UsageIndex, OrderedSet[int]],
|
|
dict[int, dict[tuple[int, ...], int]],
|
|
]:
|
|
subgraph: torch.fx.Graph = torch.fx.Graph()
|
|
external_input_to_usages = _get_external_inputs(region)
|
|
external_node_usages = list[OrderedSet[UsageIndex]]()
|
|
region_to_subgraph_node = {}
|
|
flattened_getitem_nodes: OrderedSet[Node] = OrderedSet()
|
|
node_usage_to_tuple_elems: dict[UsageIndex, OrderedSet[int]] = {}
|
|
|
|
for node, usage_indices in external_input_to_usages.items():
|
|
# We don't handle tuples as inputs today
|
|
if _is_tuple_node(node):
|
|
# If a node is a tuple we will possibly create multiple placeholders for them
|
|
# and track which nodes we won't copy into the subgraph because they are flattened away
|
|
# Later, when replacing each region with this subgraph, we will create a getitem node
|
|
# externally which will perform the flattening on the outer nodes.
|
|
flattened_node_indices = _get_flattened_node_indices(node, region)
|
|
for ind in flattened_node_indices:
|
|
placeholder = subgraph.placeholder(
|
|
f"supgraph_input_{node.name}_flattened_{ind}"
|
|
)
|
|
region_to_subgraph_node[region[ind]] = placeholder
|
|
flattened_getitem_nodes.add(region[ind])
|
|
node_usage_to_tuple_elems[next(iter(usage_indices))] = (
|
|
flattened_node_indices
|
|
)
|
|
else:
|
|
placeholder = subgraph.placeholder(f"subgraph_input_{node.name}")
|
|
region_to_subgraph_node[node] = placeholder
|
|
|
|
external_node_usages.append(usage_indices)
|
|
|
|
def map_arg(node: Node) -> Node:
|
|
if node in region_to_subgraph_node:
|
|
return region_to_subgraph_node[node]
|
|
else:
|
|
return node
|
|
|
|
def copy_to_subgraph(node: Node) -> Node:
|
|
subgraph_node = subgraph.node_copy(node, lambda old: map_arg(old))
|
|
region_to_subgraph_node[node] = subgraph_node
|
|
return subgraph_node
|
|
|
|
output_list = []
|
|
ind_to_tuple_spec = {}
|
|
for ind, node in enumerate(region):
|
|
if node not in flattened_getitem_nodes:
|
|
subgraph_node = copy_to_subgraph(node)
|
|
if ind in inds_with_external_users:
|
|
# flatten tuple outputs by generating a getitem node tree
|
|
if _is_tuple_node(node):
|
|
getitem_nodes, ind_to_tuple_spec[ind] = _create_getitem_nodes(
|
|
node, subgraph_node, subgraph
|
|
)
|
|
output_list.extend(getitem_nodes)
|
|
else:
|
|
output_list.append(subgraph_node)
|
|
|
|
subgraph.output(tuple(output_list))
|
|
|
|
return subgraph, external_node_usages, node_usage_to_tuple_elems, ind_to_tuple_spec
|
|
|
|
|
|
def _stable_topological_sort_impl(
|
|
graph: torch.fx.Graph,
|
|
node_to_additional_deps: dict[Node, OrderedSet[Node]],
|
|
do_sort: bool = True,
|
|
) -> bool:
|
|
# Nodes are in exactly one of these four collections:
|
|
|
|
# - Nodes in `pending` are waiting to be processed (in reverse order):
|
|
pending = list(reversed(graph.nodes))
|
|
|
|
# - Nodes in `ready` have been processed and are already in the correct
|
|
# order.
|
|
ready = OrderedSet[Node]()
|
|
|
|
# - `waiting` is a mapping from a dependency to nodes which depend on that
|
|
# dependency.
|
|
waiting = defaultdict(list)
|
|
|
|
# - `outputs` are always at the end of the graph
|
|
outputs = OrderedSet[Node]()
|
|
|
|
# The cursor indicates the last processed node so we can add new nodes
|
|
# after it.
|
|
cursor = None
|
|
while pending:
|
|
node = pending.pop()
|
|
|
|
if node.target == "output":
|
|
outputs.add(node)
|
|
assert not node.users, "output nodes should have no users"
|
|
continue
|
|
|
|
waiting_for = [
|
|
x
|
|
for x in _get_flat_args_unique(node, node_to_additional_deps)
|
|
if x not in ready
|
|
]
|
|
if waiting_for:
|
|
# We have unprocessed input nodes. Might as well wait for the last
|
|
# arg so an already sorted list will only recheck this node once.
|
|
waiting[waiting_for[-1]].append(node)
|
|
else:
|
|
ready.add(node)
|
|
if cursor and cursor.next is not node and do_sort:
|
|
cursor.append(node)
|
|
cursor = node
|
|
# Mark the nodes that have been waiting for this node to finish as
|
|
# ready to check again.
|
|
pending.extend(reversed(waiting.pop(node, ())))
|
|
|
|
ready.update(outputs)
|
|
return not waiting and len(ready) == len(graph.nodes)
|
|
|
|
|
|
def _stable_topological_sort(
|
|
graph: torch.fx.Graph,
|
|
node_to_additional_deps: dict[Node, OrderedSet[Node]],
|
|
) -> None:
|
|
assert _stable_topological_sort_impl(graph, node_to_additional_deps)
|
|
|
|
|
|
def _has_cycle(
|
|
graph: torch.fx.Graph,
|
|
node_to_additional_deps: dict[Node, OrderedSet[Node]],
|
|
) -> bool:
|
|
return not _stable_topological_sort_impl(
|
|
graph, node_to_additional_deps, do_sort=False
|
|
)
|
|
|
|
|
|
def _populate_additional_deps(
|
|
graph: torch.fx.Graph, node_to_mutated_arg_positions: dict[Node, OrderedSet[int]]
|
|
) -> dict[Node, OrderedSet[Node]]:
|
|
node_to_additional_deps: dict[Node, OrderedSet[Node]] = defaultdict(OrderedSet)
|
|
_add_mutation_dependencies(node_to_mutated_arg_positions, node_to_additional_deps)
|
|
_add_global_state_dependencies(graph, node_to_additional_deps)
|
|
return node_to_additional_deps
|
|
|
|
|
|
def _add_global_state_dependencies(
|
|
graph: torch.fx.Graph, node_to_additional_deps: dict[Node, OrderedSet[Node]]
|
|
) -> None:
|
|
import torch.amp
|
|
|
|
all_nodes = list(graph.nodes)
|
|
|
|
# These are targets of the nodes which need to stay in the same relative place in the graph
|
|
global_state_targets = {torch.amp._enter_autocast, torch.amp._exit_autocast}
|
|
all_nodes_dep_on: list[Node] = []
|
|
|
|
def prev_cur_nodes(
|
|
all_nodes: list[Node],
|
|
) -> Generator[tuple[list[Node], Node], None, None]:
|
|
prev_nodes: list[Node] = []
|
|
next_nodes = list(reversed(all_nodes))
|
|
|
|
while next_nodes:
|
|
cur_node = next_nodes.pop()
|
|
yield prev_nodes, cur_node
|
|
prev_nodes.append(cur_node)
|
|
|
|
for prev_nodes, cur_node in prev_cur_nodes(all_nodes):
|
|
args_unique = _get_flat_args_unique(cur_node, {})
|
|
new_deps = [n for n in all_nodes_dep_on if n not in args_unique]
|
|
|
|
if new_deps:
|
|
additional_deps = node_to_additional_deps[cur_node]
|
|
additional_deps.update(new_deps)
|
|
|
|
if cur_node.target in global_state_targets:
|
|
additional_deps = node_to_additional_deps[cur_node]
|
|
additional_deps.update(n for n in prev_nodes if n not in args_unique)
|
|
all_nodes_dep_on.append(cur_node)
|
|
|
|
|
|
def _add_mutation_dependencies(
|
|
node_to_mutated_arg_positions: dict[Node, OrderedSet[int]],
|
|
node_to_additional_deps: dict[Node, OrderedSet[Node]],
|
|
) -> None:
|
|
for node, indices in node_to_mutated_arg_positions.items():
|
|
flat_args_kwargs = _get_flat_args(node, {})
|
|
|
|
# for all mutated args,
|
|
# add dependency on usages which occur after node to ensure
|
|
# node will always be ordered before them
|
|
# also add node as a dependency on usages which
|
|
# occur before node to ensure node is ordered after them
|
|
for index in indices:
|
|
mutated_arg = flat_args_kwargs[index]
|
|
for user in mutated_arg.users:
|
|
if user is node:
|
|
continue
|
|
elif user < node:
|
|
node_to_additional_deps[node].add(user)
|
|
elif user > node:
|
|
node_to_additional_deps[user].add(node)
|
|
|
|
|
|
def _has_aliasing(
|
|
region: Region,
|
|
inputs: list[Node],
|
|
inds_with_external_users: list[int],
|
|
flattened_getitem_nodes: OrderedSet[Node],
|
|
) -> bool:
|
|
input_storages: dict[StorageWeakRef, Node] = dict()
|
|
for node in inputs:
|
|
if node in flattened_getitem_nodes:
|
|
continue
|
|
example_value = node.meta["example_value"]
|
|
if isinstance(example_value, torch.Tensor):
|
|
storage = StorageWeakRef(example_value._typed_storage())
|
|
if storage in input_storages:
|
|
# input-input aliasing
|
|
log.debug(
|
|
"NYI: Failed to substitute region %s due to input-output aliasing detected at nodes %s, %s",
|
|
region,
|
|
input_storages[storage],
|
|
node,
|
|
)
|
|
return True
|
|
input_storages[storage] = node
|
|
output_storages: dict[StorageWeakRef, Node] = dict()
|
|
for i in inds_with_external_users:
|
|
out_node = region[i]
|
|
if out_node in flattened_getitem_nodes:
|
|
continue
|
|
if out_node:
|
|
example_value = out_node.meta["example_value"]
|
|
assert not isinstance(example_value, list)
|
|
if isinstance(example_value, torch.Tensor):
|
|
storage = StorageWeakRef(example_value._typed_storage())
|
|
if storage in output_storages:
|
|
# output-output aliasing
|
|
log.debug(
|
|
"NYI: Failed to substitute region %s due to output-output aliasing detected at nodes %s, %s",
|
|
region,
|
|
output_storages[storage],
|
|
out_node,
|
|
)
|
|
return True
|
|
output_storages[storage] = out_node
|
|
intersected_storages = input_storages.keys() & output_storages.keys()
|
|
if len(intersected_storages) > 0:
|
|
# input-output aliasing
|
|
aliased = [
|
|
(input_storages[s], output_storages[s]) for s in intersected_storages
|
|
]
|
|
aliased = ", ".join([f"{i} and {o}" for i, o in aliased])
|
|
log.debug(
|
|
"NYI: Failed to substitute region %s due to input-output aliasing detected at nodes %s",
|
|
region,
|
|
aliased,
|
|
)
|
|
return True
|
|
return False
|
|
|
|
|
|
def _is_tuple_node(node: Node) -> bool:
|
|
return isinstance(node.meta["example_value"], tuple)
|
|
|
|
|
|
def _get_children_getitems(node: Node) -> Generator[Node, None, None]:
|
|
for user in node.users:
|
|
if user.target == operator.getitem and isinstance(user.args[1], int):
|
|
yield user
|
|
|
|
|
|
def _get_flattened_node_indices(node: Node, region: Region) -> OrderedSet[int]:
|
|
"""Returns an ordered set of indices, each representing a node in the region which will be flattened"""
|
|
flattened_node_to_ind = {n: i for i, n in enumerate(region)}
|
|
node_indices: OrderedSet[int] = OrderedSet()
|
|
queue = deque(_get_children_getitems(node))
|
|
while queue:
|
|
cur_node = queue.popleft()
|
|
if any(user in region for user in cur_node.users):
|
|
node_indices.add(flattened_node_to_ind[cur_node])
|
|
for child in _get_children_getitems(cur_node):
|
|
queue.append(child)
|
|
return node_indices
|
|
|
|
|
|
def _create_getitem_nodes(
|
|
node: Node, subgraph_tuple_node: Node, subgraph: torch.fx.Graph
|
|
) -> tuple[list[Node], dict[tuple[int, ...], int]]:
|
|
tup = node.meta["example_value"]
|
|
assert isinstance(tup, tuple), "_get_getitem_children expects tuple"
|
|
|
|
getitem_nodes: list[Node] = []
|
|
queue = deque([(e, (i,), subgraph_tuple_node) for i, e in enumerate(tup)])
|
|
path_to_output_index = {}
|
|
|
|
while queue:
|
|
cur_elem, path, parent = queue.popleft()
|
|
|
|
with subgraph.inserting_after(parent):
|
|
new_getitem_node = subgraph.create_node(
|
|
"call_function", operator.getitem, (parent, path[-1]), {}
|
|
)
|
|
new_getitem_node.meta["example_value"] = cur_elem
|
|
|
|
path_to_output_index[path] = len(getitem_nodes)
|
|
getitem_nodes.append(new_getitem_node)
|
|
|
|
if isinstance(cur_elem, tuple):
|
|
queue.extend(
|
|
[(e, path + (i,), new_getitem_node) for i, e in enumerate(cur_elem)] # type: ignore[arg-type,misc]
|
|
)
|
|
|
|
return getitem_nodes, path_to_output_index # type: ignore[return-value]
|
|
|
|
|
|
def _replace_tuple_outputs(
|
|
node: Node,
|
|
output_index: int,
|
|
tuple_spec: dict[tuple[int, ...], int],
|
|
invoke_subgraph_node: Node,
|
|
graph: torch.fx.Graph,
|
|
) -> OrderedSet[Node]:
|
|
assert _is_tuple_node(node), "_replace_tuple_outputs expects a tuple node"
|
|
|
|
queue = deque((c, (c.args[1],)) for c in _get_children_getitems(node))
|
|
erased_nodes: OrderedSet[Node] = OrderedSet()
|
|
while queue:
|
|
cur_node, path = queue.pop()
|
|
|
|
for c in _get_children_getitems(cur_node):
|
|
queue.append((c, path + (c.args[1],))) # type: ignore[return-value, arg-type]
|
|
|
|
with graph.inserting_after(invoke_subgraph_node):
|
|
subgraph_output = graph.create_node(
|
|
"call_function",
|
|
operator.getitem,
|
|
(invoke_subgraph_node, output_index + tuple_spec[path]), # type: ignore[index]
|
|
{},
|
|
)
|
|
cur_node.replace_all_uses_with(subgraph_output, propagate_meta=True)
|
|
graph.erase_node(cur_node)
|
|
erased_nodes.add(cur_node)
|
|
|
|
graph.erase_node(node)
|
|
erased_nodes.add(node)
|
|
return erased_nodes
|