Compare commits

...

3 Commits

Author SHA1 Message Date
21aa086ecc [Dynamo][Hierarchical Compile] Flatten tuple inputs for regions
ghstack-source-id: e99eea21f6c2e02a15b0027ae1cedffbf4003231
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158812
2025-08-15 23:45:18 -07:00
c5f23c5cbf [Dynamo][Hierarchical Compile] Flatten tuple outputs in graph dedupe pass
ghstack-source-id: 9b509d723379eee9e38c7ad61ea0c5620ef0d844
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158811
2025-08-15 18:45:06 -07:00
4b146389a4 [Dynamo][Hierarchical Compile] Refactor for tuple flattening
ghstack-source-id: f168b556bb440ea93f5ed3001baa9b36acf929ff
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158810
2025-08-14 16:14:27 -07:00
2 changed files with 291 additions and 38 deletions

View File

@ -4,13 +4,16 @@ import contextlib
import torch
import torch.fx
from torch._dynamo.graph_deduplication import apply_graph_deduplication
from torch._dynamo.graph_utils import _detect_cycles
from torch._dynamo.output_graph import FakeRootModule
from torch._dynamo.test_case import TestCase
from torch._dynamo.testing import (
AotEagerAndRecordGraphs,
extract_graph_and_tracker,
normalize_gm,
)
from torch.compiler import allow_in_graph
from torch.utils._ordered_set import OrderedSet
@ -1106,6 +1109,104 @@ def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
""",
)
def test_tuple_return(self):
@allow_in_graph
def tuple_return(x, y):
return x, y
def inner_fn(x, y):
x0 = x + x + 1
y0 = y + y + 1
return tuple_return(x0, y0)
def fn(x0, x1, x2, y0, y1, y2):
x0 = inner_fn(x0, y0)
x1 = inner_fn(x1, y1)
x2 = inner_fn(x2, y2)
return x0, x1, x2
fn_opt = torch.compile(fn, fullgraph=True)
inps = [torch.rand(10, 10) for _ in range(6)]
result_compiled = fn_opt(*inps)
result_eager = fn(*inps)
self.assertEqual(result_compiled, result_eager)
def test_tuple_inputs(self):
with (
torch._dynamo.config.patch("use_graph_deduplication", False),
torch._dynamo.config.patch("track_nodes_for_deduplication", True),
):
def inner(x, y):
x0, x1 = torch.split(x, 5)
return x0 + x1 + y
def fn(x, y):
o1 = inner(x, y)
o2 = inner(x, y)
o3 = inner(x, y)
o4 = inner(x, y)
return o1.sum() + o2.sum() + o3.sum() + o4.sum()
graph, tracker = extract_graph_and_tracker(
fn, torch.rand(10, 10), torch.rand(5, 10)
)
class MockOutputGraph:
def __init__(self):
self.graph = graph
self.region_tracker = tracker
self.nn_modules = FakeRootModule({})
def install_subgraph(self, name, subgraph):
return ""
splits = [
n
for n in graph.nodes
if n.op == "call_function" and n.target == torch.split
]
for split in splits:
tracker.node_to_duplicates.pop(split)
apply_graph_deduplication(MockOutputGraph())
self.assertExpectedInline(
graph,
"""\
graph():
%_unnamed : [num_users=4] = get_attr[target=]
%l_x_ : torch.Tensor [num_users=4] = placeholder[target=L_x_]
%l_y_ : torch.Tensor [num_users=4] = placeholder[target=L_y_]
%split : [num_users=2] = call_function[target=torch.functional.split](args = (%l_x_, 5), kwargs = {})
%x0 : [num_users=1] = call_function[target=operator.getitem](args = (%split, 0), kwargs = {})
%x1 : [num_users=1] = call_function[target=operator.getitem](args = (%split, 1), kwargs = {})
%split_1 : [num_users=2] = call_function[target=torch.functional.split](args = (%l_x_, 5), kwargs = {})
%x0_1 : [num_users=1] = call_function[target=operator.getitem](args = (%split_1, 0), kwargs = {})
%x1_1 : [num_users=1] = call_function[target=operator.getitem](args = (%split_1, 1), kwargs = {})
%split_2 : [num_users=2] = call_function[target=torch.functional.split](args = (%l_x_, 5), kwargs = {})
%x0_2 : [num_users=1] = call_function[target=operator.getitem](args = (%split_2, 0), kwargs = {})
%x1_2 : [num_users=1] = call_function[target=operator.getitem](args = (%split_2, 1), kwargs = {})
%split_3 : [num_users=2] = call_function[target=torch.functional.split](args = (%l_x_, 5), kwargs = {})
%x0_3 : [num_users=1] = call_function[target=operator.getitem](args = (%split_3, 0), kwargs = {})
%x1_3 : [num_users=1] = call_function[target=operator.getitem](args = (%split_3, 1), kwargs = {})
%invoke_subgraph : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%_unnamed, , %x0, %x1, %l_y_), kwargs = {})
%getitem_8 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph, 0), kwargs = {})
%sum_1 : [num_users=1] = call_method[target=sum](args = (%getitem_8,), kwargs = {})
%invoke_subgraph_1 : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%_unnamed, , %x0_1, %x1_1, %l_y_), kwargs = {})
%getitem_9 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph_1, 0), kwargs = {})
%sum_2 : [num_users=1] = call_method[target=sum](args = (%getitem_9,), kwargs = {})
%add_8 : [num_users=1] = call_function[target=operator.add](args = (%sum_1, %sum_2), kwargs = {})
%invoke_subgraph_2 : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%_unnamed, , %x0_2, %x1_2, %l_y_), kwargs = {})
%getitem_10 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph_2, 0), kwargs = {})
%sum_3 : [num_users=1] = call_method[target=sum](args = (%getitem_10,), kwargs = {})
%add_9 : [num_users=1] = call_function[target=operator.add](args = (%add_8, %sum_3), kwargs = {})
%invoke_subgraph_3 : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%_unnamed, , %x0_3, %x1_3, %l_y_), kwargs = {})
%getitem_11 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph_3, 0), kwargs = {})
%sum_4 : [num_users=1] = call_method[target=sum](args = (%getitem_11,), kwargs = {})
%add_10 : [num_users=1] = call_function[target=operator.add](args = (%add_9, %sum_4), kwargs = {})
return (add_10,)""",
)
def test_param_transfer_to_submodule(self):
def inner_fn(x, y):
return x + y + y + x

View File

@ -9,7 +9,7 @@ structures across different parts of the network.
import logging
import operator
from collections import defaultdict
from collections import defaultdict, deque
from collections.abc import Generator, Iterable
from typing import Optional
@ -80,6 +80,8 @@ when they are created in output_graph.
(
subgraph,
external_node_usages,
node_usage_to_tuple_elems,
ind_to_tuple_spec,
) = _create_subgraph(region, inds_with_external_users)
# Ignore regions with no args for now, could they possibly be evaluated at compile time?
@ -100,6 +102,8 @@ when they are created in output_graph.
region,
get_subgraph_node,
external_node_usages,
node_usage_to_tuple_elems,
ind_to_tuple_spec,
inds_with_external_users,
subgraph_name,
node_to_additional_deps,
@ -122,14 +126,18 @@ def _replace_region_with_subgraph(
region: Region,
get_subgraph_node: Node,
external_node_usages: Iterable[OrderedSet[UsageIndex]],
node_usage_to_tuple_elems: dict[UsageIndex, OrderedSet[int]],
ind_to_tuple_spec: dict[int, dict[tuple[int, ...], int]],
inds_with_external_users: list[int],
subgraph_name: str,
node_to_additional_deps: dict[Node, OrderedSet[Node]],
node_to_mutated_arg_positions: dict[Node, OrderedSet[int]],
) -> None:
sub_args = []
flattened_getitem_nodes: OrderedSet[Node] = OrderedSet()
for usages in external_node_usages:
node_ind, usage_ind = next(iter(usages))
usage = next(iter(usages))
node_ind, usage_ind = usage
node = region[node_ind]
flattened_args_kwargs = _get_flat_args(node, {})
for user_ind, node_usage_ind in usages:
@ -140,12 +148,19 @@ def _replace_region_with_subgraph(
"NYI: Failed to substitute region %s due to mutation", region
)
return
sub_args.append(flattened_args_kwargs[usage_ind])
if usage in node_usage_to_tuple_elems:
tuple_elems = [region[i] for i in node_usage_to_tuple_elems[usage]]
flattened_getitem_nodes.update(tuple_elems)
sub_args.extend(tuple_elems)
else:
sub_args.append(flattened_args_kwargs[usage_ind])
# Input/Output aliasing not supported in HOPs today
# Note: we should use the nodes in the original graph (the region here)
# because we use the original traced example values for this check
if _has_aliasing(region, sub_args, inds_with_external_users):
if _has_aliasing(
region, sub_args, inds_with_external_users, flattened_getitem_nodes
):
return
invoke_args = (get_subgraph_node, subgraph_name, *sub_args)
@ -156,16 +171,35 @@ def _replace_region_with_subgraph(
invoke_args, # type: ignore[arg-type]
{},
)
for ind, external_user_ind in enumerate(inds_with_external_users):
ind = 0
flattened_output_nodes: OrderedSet[Node] = OrderedSet()
for external_user_ind in inds_with_external_users:
node = region[external_user_ind]
subgraph_output = graph.create_node(
"call_function", operator.getitem, (invoke_subgraph_node, ind), {}
)
node.replace_all_uses_with(subgraph_output, propagate_meta=True)
if _is_tuple_node(node):
tuple_spec = ind_to_tuple_spec[external_user_ind]
flattened_output_nodes.update(
_replace_tuple_outputs(
node, ind, tuple_spec, invoke_subgraph_node, graph
)
)
ind += len(tuple_spec)
else:
subgraph_output = graph.create_node(
"call_function", operator.getitem, (invoke_subgraph_node, ind), {}
)
node.replace_all_uses_with(subgraph_output, propagate_meta=True)
ind += 1
# Erase in reverse topological order
for node in reversed(region):
graph.erase_node(node)
if node in flattened_getitem_nodes:
# Don't erase these, since they will still be used
continue
if node not in flattened_output_nodes:
graph.erase_node(node)
# Remove any nodes with additional deps
# This is safe; we've guaranteed that there is
# no input mutation, so all additional deps
@ -220,15 +254,43 @@ def _get_inds_with_external_users(region: Region, inds_unique: set[int]) -> None
inds_unique.add(ind)
def _copy_nodes_and_remap_inputs(
subgraph: torch.fx.Graph, region: Region
) -> list[OrderedSet[UsageIndex]]:
def _create_subgraph(
region: Region,
inds_with_external_users: list[int],
) -> tuple[
torch.fx.Graph,
list[OrderedSet[UsageIndex]],
dict[UsageIndex, OrderedSet[int]],
dict[int, dict[tuple[int, ...], int]],
]:
subgraph: torch.fx.Graph = torch.fx.Graph()
external_input_to_usages = _get_external_inputs(region)
external_node_usages = list[OrderedSet[UsageIndex]]()
region_to_subgraph_node = {}
flattened_getitem_nodes: OrderedSet[Node] = OrderedSet()
node_usage_to_tuple_elems: dict[UsageIndex, OrderedSet[int]] = {}
for node, usage_indices in external_input_to_usages.items():
placeholder = subgraph.placeholder(f"subgraph_input_{node.name}")
region_to_subgraph_node[node] = placeholder
# We don't handle tuples as inputs today
if _is_tuple_node(node):
# If a node is a tuple we will possibly create multiple placeholders for them
# and track which nodes we won't copy into the subgraph because they are flattened away
# Later, when replacing each region with this subgraph, we will create a getitem node
# externally which will perform the flattening on the outer nodes.
flattened_node_indices = _get_flattened_node_indices(node, region)
for ind in flattened_node_indices:
placeholder = subgraph.placeholder(
f"supgraph_input_{node.name}_flattened_{ind}"
)
region_to_subgraph_node[region[ind]] = placeholder
flattened_getitem_nodes.add(region[ind])
node_usage_to_tuple_elems[next(iter(usage_indices))] = (
flattened_node_indices
)
else:
placeholder = subgraph.placeholder(f"subgraph_input_{node.name}")
region_to_subgraph_node[node] = placeholder
external_node_usages.append(usage_indices)
def map_arg(node: Node) -> Node:
@ -237,29 +299,29 @@ def _copy_nodes_and_remap_inputs(
else:
return node
for node in region:
def copy_to_subgraph(node: Node) -> Node:
subgraph_node = subgraph.node_copy(node, lambda old: map_arg(old))
region_to_subgraph_node[node] = subgraph_node
return subgraph_node
return external_node_usages
output_list = []
ind_to_tuple_spec = {}
for ind, node in enumerate(region):
if node not in flattened_getitem_nodes:
subgraph_node = copy_to_subgraph(node)
if ind in inds_with_external_users:
# flatten tuple outputs by generating a getitem node tree
if _is_tuple_node(node):
getitem_nodes, ind_to_tuple_spec[ind] = _create_getitem_nodes(
node, subgraph_node, subgraph
)
output_list.extend(getitem_nodes)
else:
output_list.append(subgraph_node)
subgraph.output(tuple(output_list))
def _create_subgraph_outputs(
subgraph: torch.fx.Graph, inds_to_output: list[int]
) -> None:
node_list = [n for n in subgraph.nodes if n.op not in ("placeholder", "output")]
out_tup = tuple(node_list[ind] for ind in inds_to_output)
subgraph.output(out_tup)
def _create_subgraph(
region: Region,
inds_with_external_users: list[int],
) -> tuple[torch.fx.Graph, list[OrderedSet[UsageIndex]]]:
subgraph: torch.fx.Graph = torch.fx.Graph()
external_node_usages = _copy_nodes_and_remap_inputs(subgraph, region)
_create_subgraph_outputs(subgraph, inds_with_external_users)
return subgraph, external_node_usages
return subgraph, external_node_usages, node_usage_to_tuple_elems, ind_to_tuple_spec
def _stable_topological_sort(
@ -384,11 +446,15 @@ def _add_mutation_dependencies(
def _has_aliasing(
region: Region, inputs: list[Node], inds_with_external_users: list[int]
region: Region,
inputs: list[Node],
inds_with_external_users: list[int],
flattened_getitem_nodes: OrderedSet[Node],
) -> bool:
input_storages: dict[StorageWeakRef, Node] = dict()
for node in inputs:
if node in flattened_getitem_nodes:
continue
example_value = node.meta["example_value"]
if isinstance(example_value, torch.Tensor):
storage = StorageWeakRef(example_value._typed_storage())
@ -402,10 +468,11 @@ def _has_aliasing(
)
return True
input_storages[storage] = node
output_storages: dict[StorageWeakRef, Node] = dict()
for i in inds_with_external_users:
out_node = region[i]
if out_node in flattened_getitem_nodes:
continue
if out_node:
example_value = out_node.meta["example_value"]
assert not isinstance(example_value, list)
@ -421,7 +488,6 @@ def _has_aliasing(
)
return True
output_storages[storage] = out_node
intersected_storages = input_storages.keys() & output_storages.keys()
if len(intersected_storages) > 0:
# input-output aliasing
@ -435,5 +501,91 @@ def _has_aliasing(
aliased,
)
return True
return False
def _is_tuple_node(node: Node) -> bool:
return isinstance(node.meta["example_value"], tuple)
def _get_children_getitems(node: Node) -> Generator[Node, None, None]:
for user in node.users:
if user.target == operator.getitem and isinstance(user.args[1], int):
yield user
def _get_flattened_node_indices(node: Node, region: Region) -> OrderedSet[int]:
"""Returns an ordered set of indices, each representing a node in the region which will be flattened"""
flattened_node_to_ind = {n: i for i, n in enumerate(region)}
node_indices: OrderedSet[int] = OrderedSet()
queue = deque(_get_children_getitems(node))
while queue:
cur_node = queue.popleft()
if any(user in region for user in cur_node.users):
node_indices.add(flattened_node_to_ind[cur_node])
for child in _get_children_getitems(cur_node):
queue.append(child)
return node_indices
def _create_getitem_nodes(
node: Node, subgraph_tuple_node: Node, subgraph: torch.fx.Graph
) -> tuple[list[Node], dict[tuple[int, ...], int]]:
tup = node.meta["example_value"]
assert isinstance(tup, tuple), "_get_getitem_children expects tuple"
getitem_nodes: list[Node] = []
queue = deque([(e, (i,), subgraph_tuple_node) for i, e in enumerate(tup)])
path_to_output_index = {}
while queue:
cur_elem, path, parent = queue.popleft()
with subgraph.inserting_after(parent):
new_getitem_node = subgraph.create_node(
"call_function", operator.getitem, (parent, path[-1]), {}
)
new_getitem_node.meta["example_value"] = cur_elem
path_to_output_index[path] = len(getitem_nodes)
getitem_nodes.append(new_getitem_node)
if isinstance(cur_elem, tuple):
queue.extend(
[(e, path + (i,), new_getitem_node) for i, e in enumerate(cur_elem)] # type: ignore[arg-type,misc]
)
return getitem_nodes, path_to_output_index # type: ignore[return-value]
def _replace_tuple_outputs(
node: Node,
output_index: int,
tuple_spec: dict[tuple[int, ...], int],
invoke_subgraph_node: Node,
graph: torch.fx.Graph,
) -> OrderedSet[Node]:
assert _is_tuple_node(node), "_replace_tuple_outputs expects a tuple node"
queue = deque((c, (c.args[1],)) for c in _get_children_getitems(node))
erased_nodes: OrderedSet[Node] = OrderedSet()
while queue:
cur_node, path = queue.pop()
for c in _get_children_getitems(cur_node):
queue.append((c, path + (c.args[1],))) # type: ignore[return-value, arg-type]
with graph.inserting_after(invoke_subgraph_node):
subgraph_output = graph.create_node(
"call_function",
operator.getitem,
(invoke_subgraph_node, output_index + tuple_spec[path]), # type: ignore[index]
{},
)
cur_node.replace_all_uses_with(subgraph_output, propagate_meta=True)
graph.erase_node(cur_node)
erased_nodes.add(cur_node)
graph.erase_node(node)
erased_nodes.add(node)
return erased_nodes