[Dynamo][Hierarchical Compile] Flatten tuple inputs for regions (#158812)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158812
Approved by: https://github.com/anijain2305
ghstack dependencies: #158810, #158811
This commit is contained in:
Michael Lazos
2025-08-15 23:45:18 -07:00
committed by PyTorch MergeBot
parent 664005662a
commit 450517f346
2 changed files with 140 additions and 12 deletions

View File

@ -4,7 +4,9 @@ import contextlib
import torch
import torch.fx
from torch._dynamo.graph_deduplication import apply_graph_deduplication
from torch._dynamo.graph_utils import _detect_cycles
from torch._dynamo.output_graph import FakeRootModule
from torch._dynamo.test_case import TestCase
from torch._dynamo.testing import (
AotEagerAndRecordGraphs,
@ -1129,6 +1131,82 @@ def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
result_eager = fn(*inps)
self.assertEqual(result_compiled, result_eager)
def test_tuple_inputs(self):
with (
torch._dynamo.config.patch("use_graph_deduplication", False),
torch._dynamo.config.patch("track_nodes_for_deduplication", True),
):
def inner(x, y):
x0, x1 = torch.split(x, 5)
return x0 + x1 + y
def fn(x, y):
o1 = inner(x, y)
o2 = inner(x, y)
o3 = inner(x, y)
o4 = inner(x, y)
return o1.sum() + o2.sum() + o3.sum() + o4.sum()
graph, tracker = extract_graph_and_tracker(
fn, torch.rand(10, 10), torch.rand(5, 10)
)
class MockOutputGraph:
def __init__(self):
self.graph = graph
self.region_tracker = tracker
self.nn_modules = FakeRootModule({})
def install_subgraph(self, name, subgraph):
return ""
splits = [
n
for n in graph.nodes
if n.op == "call_function" and n.target == torch.split
]
for split in splits:
tracker.node_to_duplicates.pop(split)
apply_graph_deduplication(MockOutputGraph())
self.assertExpectedInline(
graph,
"""\
graph():
%_unnamed : [num_users=4] = get_attr[target=]
%l_x_ : torch.Tensor [num_users=4] = placeholder[target=L_x_]
%l_y_ : torch.Tensor [num_users=4] = placeholder[target=L_y_]
%split : [num_users=2] = call_function[target=torch.functional.split](args = (%l_x_, 5), kwargs = {})
%x0 : [num_users=1] = call_function[target=operator.getitem](args = (%split, 0), kwargs = {})
%x1 : [num_users=1] = call_function[target=operator.getitem](args = (%split, 1), kwargs = {})
%split_1 : [num_users=2] = call_function[target=torch.functional.split](args = (%l_x_, 5), kwargs = {})
%x0_1 : [num_users=1] = call_function[target=operator.getitem](args = (%split_1, 0), kwargs = {})
%x1_1 : [num_users=1] = call_function[target=operator.getitem](args = (%split_1, 1), kwargs = {})
%split_2 : [num_users=2] = call_function[target=torch.functional.split](args = (%l_x_, 5), kwargs = {})
%x0_2 : [num_users=1] = call_function[target=operator.getitem](args = (%split_2, 0), kwargs = {})
%x1_2 : [num_users=1] = call_function[target=operator.getitem](args = (%split_2, 1), kwargs = {})
%split_3 : [num_users=2] = call_function[target=torch.functional.split](args = (%l_x_, 5), kwargs = {})
%x0_3 : [num_users=1] = call_function[target=operator.getitem](args = (%split_3, 0), kwargs = {})
%x1_3 : [num_users=1] = call_function[target=operator.getitem](args = (%split_3, 1), kwargs = {})
%invoke_subgraph : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%_unnamed, , %x0, %x1, %l_y_), kwargs = {})
%getitem_8 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph, 0), kwargs = {})
%sum_1 : [num_users=1] = call_method[target=sum](args = (%getitem_8,), kwargs = {})
%invoke_subgraph_1 : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%_unnamed, , %x0_1, %x1_1, %l_y_), kwargs = {})
%getitem_9 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph_1, 0), kwargs = {})
%sum_2 : [num_users=1] = call_method[target=sum](args = (%getitem_9,), kwargs = {})
%add_8 : [num_users=1] = call_function[target=operator.add](args = (%sum_1, %sum_2), kwargs = {})
%invoke_subgraph_2 : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%_unnamed, , %x0_2, %x1_2, %l_y_), kwargs = {})
%getitem_10 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph_2, 0), kwargs = {})
%sum_3 : [num_users=1] = call_method[target=sum](args = (%getitem_10,), kwargs = {})
%add_9 : [num_users=1] = call_function[target=operator.add](args = (%add_8, %sum_3), kwargs = {})
%invoke_subgraph_3 : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%_unnamed, , %x0_3, %x1_3, %l_y_), kwargs = {})
%getitem_11 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph_3, 0), kwargs = {})
%sum_4 : [num_users=1] = call_method[target=sum](args = (%getitem_11,), kwargs = {})
%add_10 : [num_users=1] = call_function[target=operator.add](args = (%add_9, %sum_4), kwargs = {})
return (add_10,)""",
)
def test_param_transfer_to_submodule(self):
def inner_fn(x, y):
return x + y + y + x

View File

@ -80,6 +80,7 @@ when they are created in output_graph.
(
subgraph,
external_node_usages,
node_usage_to_tuple_elems,
ind_to_tuple_spec,
) = _create_subgraph(region, inds_with_external_users)
@ -101,6 +102,7 @@ when they are created in output_graph.
region,
get_subgraph_node,
external_node_usages,
node_usage_to_tuple_elems,
ind_to_tuple_spec,
inds_with_external_users,
subgraph_name,
@ -124,6 +126,7 @@ def _replace_region_with_subgraph(
region: Region,
get_subgraph_node: Node,
external_node_usages: Iterable[OrderedSet[UsageIndex]],
node_usage_to_tuple_elems: dict[UsageIndex, OrderedSet[int]],
ind_to_tuple_spec: dict[int, dict[tuple[int, ...], int]],
inds_with_external_users: list[int],
subgraph_name: str,
@ -131,6 +134,7 @@ def _replace_region_with_subgraph(
node_to_mutated_arg_positions: dict[Node, OrderedSet[int]],
) -> None:
sub_args = []
flattened_getitem_nodes: OrderedSet[Node] = OrderedSet()
for usages in external_node_usages:
usage = next(iter(usages))
node_ind, usage_ind = usage
@ -144,13 +148,19 @@ def _replace_region_with_subgraph(
"NYI: Failed to substitute region %s due to mutation", region
)
return
sub_args.append(flattened_args_kwargs[usage_ind])
if usage in node_usage_to_tuple_elems:
tuple_elems = [region[i] for i in node_usage_to_tuple_elems[usage]]
flattened_getitem_nodes.update(tuple_elems)
sub_args.extend(tuple_elems)
else:
sub_args.append(flattened_args_kwargs[usage_ind])
# Input/Output aliasing not supported in HOPs today
# Note: we should use the nodes in the original graph (the region here)
# because we use the original traced example values for this check
if _has_aliasing(region, sub_args, inds_with_external_users):
if _has_aliasing(
region, sub_args, inds_with_external_users, flattened_getitem_nodes
):
return
invoke_args = (get_subgraph_node, subgraph_name, *sub_args)
@ -183,6 +193,10 @@ def _replace_region_with_subgraph(
# Erase in reverse topological order
for node in reversed(region):
if node in flattened_getitem_nodes:
# Don't erase these, since they will still be used
continue
if node not in flattened_output_nodes:
graph.erase_node(node)
@ -244,17 +258,39 @@ def _create_subgraph(
region: Region,
inds_with_external_users: list[int],
) -> tuple[
torch.fx.Graph, list[OrderedSet[UsageIndex]], dict[int, dict[tuple[int, ...], int]]
torch.fx.Graph,
list[OrderedSet[UsageIndex]],
dict[UsageIndex, OrderedSet[int]],
dict[int, dict[tuple[int, ...], int]],
]:
subgraph: torch.fx.Graph = torch.fx.Graph()
external_input_to_usages = _get_external_inputs(region)
external_node_usages = list[OrderedSet[UsageIndex]]()
region_to_subgraph_node = {}
flattened_getitem_nodes: OrderedSet[Node] = OrderedSet()
node_usage_to_tuple_elems: dict[UsageIndex, OrderedSet[int]] = {}
for node, usage_indices in external_input_to_usages.items():
placeholder = subgraph.placeholder(f"subgraph_input_{node.name}")
region_to_subgraph_node[node] = placeholder
# We don't handle tuples as inputs today
if _is_tuple_node(node):
# If a node is a tuple we will possibly create multiple placeholders for them
# and track which nodes we won't copy into the subgraph because they are flattened away
# Later, when replacing each region with this subgraph, we will create a getitem node
# externally which will perform the flattening on the outer nodes.
flattened_node_indices = _get_flattened_node_indices(node, region)
for ind in flattened_node_indices:
placeholder = subgraph.placeholder(
f"supgraph_input_{node.name}_flattened_{ind}"
)
region_to_subgraph_node[region[ind]] = placeholder
flattened_getitem_nodes.add(region[ind])
node_usage_to_tuple_elems[next(iter(usage_indices))] = (
flattened_node_indices
)
else:
placeholder = subgraph.placeholder(f"subgraph_input_{node.name}")
region_to_subgraph_node[node] = placeholder
external_node_usages.append(usage_indices)
def map_arg(node: Node) -> Node:
@ -285,7 +321,7 @@ def _create_subgraph(
subgraph.output(tuple(output_list))
return subgraph, external_node_usages, ind_to_tuple_spec
return subgraph, external_node_usages, node_usage_to_tuple_elems, ind_to_tuple_spec
def _stable_topological_sort(
@ -413,10 +449,12 @@ def _has_aliasing(
region: Region,
inputs: list[Node],
inds_with_external_users: list[int],
flattened_getitem_nodes: OrderedSet[Node],
) -> bool:
input_storages: dict[StorageWeakRef, Node] = dict()
for node in inputs:
if node in flattened_getitem_nodes:
continue
example_value = node.meta["example_value"]
if isinstance(example_value, torch.Tensor):
storage = StorageWeakRef(example_value._typed_storage())
@ -430,11 +468,11 @@ def _has_aliasing(
)
return True
input_storages[storage] = node
output_storages: dict[StorageWeakRef, Node] = dict()
for i in inds_with_external_users:
out_node = region[i]
if out_node in flattened_getitem_nodes:
continue
if out_node:
example_value = out_node.meta["example_value"]
assert not isinstance(example_value, list)
@ -450,7 +488,6 @@ def _has_aliasing(
)
return True
output_storages[storage] = out_node
intersected_storages = input_storages.keys() & output_storages.keys()
if len(intersected_storages) > 0:
# input-output aliasing
@ -464,7 +501,6 @@ def _has_aliasing(
aliased,
)
return True
return False
@ -478,6 +514,20 @@ def _get_children_getitems(node: Node) -> Generator[Node, None, None]:
yield user
def _get_flattened_node_indices(node: Node, region: Region) -> OrderedSet[int]:
"""Returns an ordered set of indices, each representing a node in the region which will be flattened"""
flattened_node_to_ind = {n: i for i, n in enumerate(region)}
node_indices: OrderedSet[int] = OrderedSet()
queue = deque(_get_children_getitems(node))
while queue:
cur_node = queue.popleft()
if any(user in region for user in cur_node.users):
node_indices.add(flattened_node_to_ind[cur_node])
for child in _get_children_getitems(cur_node):
queue.append(child)
return node_indices
def _create_getitem_nodes(
node: Node, subgraph_tuple_node: Node, subgraph: torch.fx.Graph
) -> tuple[list[Node], dict[tuple[int, ...], int]]: