Compare commits

...

3 Commits

Author SHA1 Message Date
21aa086ecc [Dynamo][Hierarchical Compile] Flatten tuple inputs for regions
ghstack-source-id: e99eea21f6c2e02a15b0027ae1cedffbf4003231
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158812
2025-08-15 23:45:18 -07:00
c5f23c5cbf [Dynamo][Hierarchical Compile] Flatten tuple outputs in graph dedupe pass
ghstack-source-id: 9b509d723379eee9e38c7ad61ea0c5620ef0d844
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158811
2025-08-15 18:45:06 -07:00
4b146389a4 [Dynamo][Hierarchical Compile] Refactor for tuple flattening
ghstack-source-id: f168b556bb440ea93f5ed3001baa9b36acf929ff
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158810
2025-08-14 16:14:27 -07:00
2 changed files with 291 additions and 38 deletions

View File

@ -4,13 +4,16 @@ import contextlib
import torch import torch
import torch.fx import torch.fx
from torch._dynamo.graph_deduplication import apply_graph_deduplication
from torch._dynamo.graph_utils import _detect_cycles from torch._dynamo.graph_utils import _detect_cycles
from torch._dynamo.output_graph import FakeRootModule
from torch._dynamo.test_case import TestCase from torch._dynamo.test_case import TestCase
from torch._dynamo.testing import ( from torch._dynamo.testing import (
AotEagerAndRecordGraphs, AotEagerAndRecordGraphs,
extract_graph_and_tracker, extract_graph_and_tracker,
normalize_gm, normalize_gm,
) )
from torch.compiler import allow_in_graph
from torch.utils._ordered_set import OrderedSet from torch.utils._ordered_set import OrderedSet
@ -1106,6 +1109,104 @@ def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
""", """,
) )
def test_tuple_return(self):
@allow_in_graph
def tuple_return(x, y):
return x, y
def inner_fn(x, y):
x0 = x + x + 1
y0 = y + y + 1
return tuple_return(x0, y0)
def fn(x0, x1, x2, y0, y1, y2):
x0 = inner_fn(x0, y0)
x1 = inner_fn(x1, y1)
x2 = inner_fn(x2, y2)
return x0, x1, x2
fn_opt = torch.compile(fn, fullgraph=True)
inps = [torch.rand(10, 10) for _ in range(6)]
result_compiled = fn_opt(*inps)
result_eager = fn(*inps)
self.assertEqual(result_compiled, result_eager)
def test_tuple_inputs(self):
with (
torch._dynamo.config.patch("use_graph_deduplication", False),
torch._dynamo.config.patch("track_nodes_for_deduplication", True),
):
def inner(x, y):
x0, x1 = torch.split(x, 5)
return x0 + x1 + y
def fn(x, y):
o1 = inner(x, y)
o2 = inner(x, y)
o3 = inner(x, y)
o4 = inner(x, y)
return o1.sum() + o2.sum() + o3.sum() + o4.sum()
graph, tracker = extract_graph_and_tracker(
fn, torch.rand(10, 10), torch.rand(5, 10)
)
class MockOutputGraph:
def __init__(self):
self.graph = graph
self.region_tracker = tracker
self.nn_modules = FakeRootModule({})
def install_subgraph(self, name, subgraph):
return ""
splits = [
n
for n in graph.nodes
if n.op == "call_function" and n.target == torch.split
]
for split in splits:
tracker.node_to_duplicates.pop(split)
apply_graph_deduplication(MockOutputGraph())
self.assertExpectedInline(
graph,
"""\
graph():
%_unnamed : [num_users=4] = get_attr[target=]
%l_x_ : torch.Tensor [num_users=4] = placeholder[target=L_x_]
%l_y_ : torch.Tensor [num_users=4] = placeholder[target=L_y_]
%split : [num_users=2] = call_function[target=torch.functional.split](args = (%l_x_, 5), kwargs = {})
%x0 : [num_users=1] = call_function[target=operator.getitem](args = (%split, 0), kwargs = {})
%x1 : [num_users=1] = call_function[target=operator.getitem](args = (%split, 1), kwargs = {})
%split_1 : [num_users=2] = call_function[target=torch.functional.split](args = (%l_x_, 5), kwargs = {})
%x0_1 : [num_users=1] = call_function[target=operator.getitem](args = (%split_1, 0), kwargs = {})
%x1_1 : [num_users=1] = call_function[target=operator.getitem](args = (%split_1, 1), kwargs = {})
%split_2 : [num_users=2] = call_function[target=torch.functional.split](args = (%l_x_, 5), kwargs = {})
%x0_2 : [num_users=1] = call_function[target=operator.getitem](args = (%split_2, 0), kwargs = {})
%x1_2 : [num_users=1] = call_function[target=operator.getitem](args = (%split_2, 1), kwargs = {})
%split_3 : [num_users=2] = call_function[target=torch.functional.split](args = (%l_x_, 5), kwargs = {})
%x0_3 : [num_users=1] = call_function[target=operator.getitem](args = (%split_3, 0), kwargs = {})
%x1_3 : [num_users=1] = call_function[target=operator.getitem](args = (%split_3, 1), kwargs = {})
%invoke_subgraph : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%_unnamed, , %x0, %x1, %l_y_), kwargs = {})
%getitem_8 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph, 0), kwargs = {})
%sum_1 : [num_users=1] = call_method[target=sum](args = (%getitem_8,), kwargs = {})
%invoke_subgraph_1 : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%_unnamed, , %x0_1, %x1_1, %l_y_), kwargs = {})
%getitem_9 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph_1, 0), kwargs = {})
%sum_2 : [num_users=1] = call_method[target=sum](args = (%getitem_9,), kwargs = {})
%add_8 : [num_users=1] = call_function[target=operator.add](args = (%sum_1, %sum_2), kwargs = {})
%invoke_subgraph_2 : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%_unnamed, , %x0_2, %x1_2, %l_y_), kwargs = {})
%getitem_10 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph_2, 0), kwargs = {})
%sum_3 : [num_users=1] = call_method[target=sum](args = (%getitem_10,), kwargs = {})
%add_9 : [num_users=1] = call_function[target=operator.add](args = (%add_8, %sum_3), kwargs = {})
%invoke_subgraph_3 : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%_unnamed, , %x0_3, %x1_3, %l_y_), kwargs = {})
%getitem_11 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph_3, 0), kwargs = {})
%sum_4 : [num_users=1] = call_method[target=sum](args = (%getitem_11,), kwargs = {})
%add_10 : [num_users=1] = call_function[target=operator.add](args = (%add_9, %sum_4), kwargs = {})
return (add_10,)""",
)
def test_param_transfer_to_submodule(self): def test_param_transfer_to_submodule(self):
def inner_fn(x, y): def inner_fn(x, y):
return x + y + y + x return x + y + y + x

View File

@ -9,7 +9,7 @@ structures across different parts of the network.
import logging import logging
import operator import operator
from collections import defaultdict from collections import defaultdict, deque
from collections.abc import Generator, Iterable from collections.abc import Generator, Iterable
from typing import Optional from typing import Optional
@ -80,6 +80,8 @@ when they are created in output_graph.
( (
subgraph, subgraph,
external_node_usages, external_node_usages,
node_usage_to_tuple_elems,
ind_to_tuple_spec,
) = _create_subgraph(region, inds_with_external_users) ) = _create_subgraph(region, inds_with_external_users)
# Ignore regions with no args for now, could they possibly be evaluated at compile time? # Ignore regions with no args for now, could they possibly be evaluated at compile time?
@ -100,6 +102,8 @@ when they are created in output_graph.
region, region,
get_subgraph_node, get_subgraph_node,
external_node_usages, external_node_usages,
node_usage_to_tuple_elems,
ind_to_tuple_spec,
inds_with_external_users, inds_with_external_users,
subgraph_name, subgraph_name,
node_to_additional_deps, node_to_additional_deps,
@ -122,14 +126,18 @@ def _replace_region_with_subgraph(
region: Region, region: Region,
get_subgraph_node: Node, get_subgraph_node: Node,
external_node_usages: Iterable[OrderedSet[UsageIndex]], external_node_usages: Iterable[OrderedSet[UsageIndex]],
node_usage_to_tuple_elems: dict[UsageIndex, OrderedSet[int]],
ind_to_tuple_spec: dict[int, dict[tuple[int, ...], int]],
inds_with_external_users: list[int], inds_with_external_users: list[int],
subgraph_name: str, subgraph_name: str,
node_to_additional_deps: dict[Node, OrderedSet[Node]], node_to_additional_deps: dict[Node, OrderedSet[Node]],
node_to_mutated_arg_positions: dict[Node, OrderedSet[int]], node_to_mutated_arg_positions: dict[Node, OrderedSet[int]],
) -> None: ) -> None:
sub_args = [] sub_args = []
flattened_getitem_nodes: OrderedSet[Node] = OrderedSet()
for usages in external_node_usages: for usages in external_node_usages:
node_ind, usage_ind = next(iter(usages)) usage = next(iter(usages))
node_ind, usage_ind = usage
node = region[node_ind] node = region[node_ind]
flattened_args_kwargs = _get_flat_args(node, {}) flattened_args_kwargs = _get_flat_args(node, {})
for user_ind, node_usage_ind in usages: for user_ind, node_usage_ind in usages:
@ -140,12 +148,19 @@ def _replace_region_with_subgraph(
"NYI: Failed to substitute region %s due to mutation", region "NYI: Failed to substitute region %s due to mutation", region
) )
return return
sub_args.append(flattened_args_kwargs[usage_ind]) if usage in node_usage_to_tuple_elems:
tuple_elems = [region[i] for i in node_usage_to_tuple_elems[usage]]
flattened_getitem_nodes.update(tuple_elems)
sub_args.extend(tuple_elems)
else:
sub_args.append(flattened_args_kwargs[usage_ind])
# Input/Output aliasing not supported in HOPs today # Input/Output aliasing not supported in HOPs today
# Note: we should use the nodes in the original graph (the region here) # Note: we should use the nodes in the original graph (the region here)
# because we use the original traced example values for this check # because we use the original traced example values for this check
if _has_aliasing(region, sub_args, inds_with_external_users): if _has_aliasing(
region, sub_args, inds_with_external_users, flattened_getitem_nodes
):
return return
invoke_args = (get_subgraph_node, subgraph_name, *sub_args) invoke_args = (get_subgraph_node, subgraph_name, *sub_args)
@ -156,16 +171,35 @@ def _replace_region_with_subgraph(
invoke_args, # type: ignore[arg-type] invoke_args, # type: ignore[arg-type]
{}, {},
) )
for ind, external_user_ind in enumerate(inds_with_external_users):
ind = 0
flattened_output_nodes: OrderedSet[Node] = OrderedSet()
for external_user_ind in inds_with_external_users:
node = region[external_user_ind] node = region[external_user_ind]
subgraph_output = graph.create_node( if _is_tuple_node(node):
"call_function", operator.getitem, (invoke_subgraph_node, ind), {} tuple_spec = ind_to_tuple_spec[external_user_ind]
) flattened_output_nodes.update(
node.replace_all_uses_with(subgraph_output, propagate_meta=True) _replace_tuple_outputs(
node, ind, tuple_spec, invoke_subgraph_node, graph
)
)
ind += len(tuple_spec)
else:
subgraph_output = graph.create_node(
"call_function", operator.getitem, (invoke_subgraph_node, ind), {}
)
node.replace_all_uses_with(subgraph_output, propagate_meta=True)
ind += 1
# Erase in reverse topological order # Erase in reverse topological order
for node in reversed(region): for node in reversed(region):
graph.erase_node(node) if node in flattened_getitem_nodes:
# Don't erase these, since they will still be used
continue
if node not in flattened_output_nodes:
graph.erase_node(node)
# Remove any nodes with additional deps # Remove any nodes with additional deps
# This is safe; we've guaranteed that there is # This is safe; we've guaranteed that there is
# no input mutation, so all additional deps # no input mutation, so all additional deps
@ -220,15 +254,43 @@ def _get_inds_with_external_users(region: Region, inds_unique: set[int]) -> None
inds_unique.add(ind) inds_unique.add(ind)
def _copy_nodes_and_remap_inputs( def _create_subgraph(
subgraph: torch.fx.Graph, region: Region region: Region,
) -> list[OrderedSet[UsageIndex]]: inds_with_external_users: list[int],
) -> tuple[
torch.fx.Graph,
list[OrderedSet[UsageIndex]],
dict[UsageIndex, OrderedSet[int]],
dict[int, dict[tuple[int, ...], int]],
]:
subgraph: torch.fx.Graph = torch.fx.Graph()
external_input_to_usages = _get_external_inputs(region) external_input_to_usages = _get_external_inputs(region)
external_node_usages = list[OrderedSet[UsageIndex]]() external_node_usages = list[OrderedSet[UsageIndex]]()
region_to_subgraph_node = {} region_to_subgraph_node = {}
flattened_getitem_nodes: OrderedSet[Node] = OrderedSet()
node_usage_to_tuple_elems: dict[UsageIndex, OrderedSet[int]] = {}
for node, usage_indices in external_input_to_usages.items(): for node, usage_indices in external_input_to_usages.items():
placeholder = subgraph.placeholder(f"subgraph_input_{node.name}") # We don't handle tuples as inputs today
region_to_subgraph_node[node] = placeholder if _is_tuple_node(node):
# If a node is a tuple we will possibly create multiple placeholders for them
# and track which nodes we won't copy into the subgraph because they are flattened away
# Later, when replacing each region with this subgraph, we will create a getitem node
# externally which will perform the flattening on the outer nodes.
flattened_node_indices = _get_flattened_node_indices(node, region)
for ind in flattened_node_indices:
placeholder = subgraph.placeholder(
f"supgraph_input_{node.name}_flattened_{ind}"
)
region_to_subgraph_node[region[ind]] = placeholder
flattened_getitem_nodes.add(region[ind])
node_usage_to_tuple_elems[next(iter(usage_indices))] = (
flattened_node_indices
)
else:
placeholder = subgraph.placeholder(f"subgraph_input_{node.name}")
region_to_subgraph_node[node] = placeholder
external_node_usages.append(usage_indices) external_node_usages.append(usage_indices)
def map_arg(node: Node) -> Node: def map_arg(node: Node) -> Node:
@ -237,29 +299,29 @@ def _copy_nodes_and_remap_inputs(
else: else:
return node return node
for node in region: def copy_to_subgraph(node: Node) -> Node:
subgraph_node = subgraph.node_copy(node, lambda old: map_arg(old)) subgraph_node = subgraph.node_copy(node, lambda old: map_arg(old))
region_to_subgraph_node[node] = subgraph_node region_to_subgraph_node[node] = subgraph_node
return subgraph_node
return external_node_usages output_list = []
ind_to_tuple_spec = {}
for ind, node in enumerate(region):
if node not in flattened_getitem_nodes:
subgraph_node = copy_to_subgraph(node)
if ind in inds_with_external_users:
# flatten tuple outputs by generating a getitem node tree
if _is_tuple_node(node):
getitem_nodes, ind_to_tuple_spec[ind] = _create_getitem_nodes(
node, subgraph_node, subgraph
)
output_list.extend(getitem_nodes)
else:
output_list.append(subgraph_node)
subgraph.output(tuple(output_list))
def _create_subgraph_outputs( return subgraph, external_node_usages, node_usage_to_tuple_elems, ind_to_tuple_spec
subgraph: torch.fx.Graph, inds_to_output: list[int]
) -> None:
node_list = [n for n in subgraph.nodes if n.op not in ("placeholder", "output")]
out_tup = tuple(node_list[ind] for ind in inds_to_output)
subgraph.output(out_tup)
def _create_subgraph(
region: Region,
inds_with_external_users: list[int],
) -> tuple[torch.fx.Graph, list[OrderedSet[UsageIndex]]]:
subgraph: torch.fx.Graph = torch.fx.Graph()
external_node_usages = _copy_nodes_and_remap_inputs(subgraph, region)
_create_subgraph_outputs(subgraph, inds_with_external_users)
return subgraph, external_node_usages
def _stable_topological_sort( def _stable_topological_sort(
@ -384,11 +446,15 @@ def _add_mutation_dependencies(
def _has_aliasing( def _has_aliasing(
region: Region, inputs: list[Node], inds_with_external_users: list[int] region: Region,
inputs: list[Node],
inds_with_external_users: list[int],
flattened_getitem_nodes: OrderedSet[Node],
) -> bool: ) -> bool:
input_storages: dict[StorageWeakRef, Node] = dict() input_storages: dict[StorageWeakRef, Node] = dict()
for node in inputs: for node in inputs:
if node in flattened_getitem_nodes:
continue
example_value = node.meta["example_value"] example_value = node.meta["example_value"]
if isinstance(example_value, torch.Tensor): if isinstance(example_value, torch.Tensor):
storage = StorageWeakRef(example_value._typed_storage()) storage = StorageWeakRef(example_value._typed_storage())
@ -402,10 +468,11 @@ def _has_aliasing(
) )
return True return True
input_storages[storage] = node input_storages[storage] = node
output_storages: dict[StorageWeakRef, Node] = dict() output_storages: dict[StorageWeakRef, Node] = dict()
for i in inds_with_external_users: for i in inds_with_external_users:
out_node = region[i] out_node = region[i]
if out_node in flattened_getitem_nodes:
continue
if out_node: if out_node:
example_value = out_node.meta["example_value"] example_value = out_node.meta["example_value"]
assert not isinstance(example_value, list) assert not isinstance(example_value, list)
@ -421,7 +488,6 @@ def _has_aliasing(
) )
return True return True
output_storages[storage] = out_node output_storages[storage] = out_node
intersected_storages = input_storages.keys() & output_storages.keys() intersected_storages = input_storages.keys() & output_storages.keys()
if len(intersected_storages) > 0: if len(intersected_storages) > 0:
# input-output aliasing # input-output aliasing
@ -435,5 +501,91 @@ def _has_aliasing(
aliased, aliased,
) )
return True return True
return False return False
def _is_tuple_node(node: Node) -> bool:
return isinstance(node.meta["example_value"], tuple)
def _get_children_getitems(node: Node) -> Generator[Node, None, None]:
for user in node.users:
if user.target == operator.getitem and isinstance(user.args[1], int):
yield user
def _get_flattened_node_indices(node: Node, region: Region) -> OrderedSet[int]:
"""Returns an ordered set of indices, each representing a node in the region which will be flattened"""
flattened_node_to_ind = {n: i for i, n in enumerate(region)}
node_indices: OrderedSet[int] = OrderedSet()
queue = deque(_get_children_getitems(node))
while queue:
cur_node = queue.popleft()
if any(user in region for user in cur_node.users):
node_indices.add(flattened_node_to_ind[cur_node])
for child in _get_children_getitems(cur_node):
queue.append(child)
return node_indices
def _create_getitem_nodes(
node: Node, subgraph_tuple_node: Node, subgraph: torch.fx.Graph
) -> tuple[list[Node], dict[tuple[int, ...], int]]:
tup = node.meta["example_value"]
assert isinstance(tup, tuple), "_get_getitem_children expects tuple"
getitem_nodes: list[Node] = []
queue = deque([(e, (i,), subgraph_tuple_node) for i, e in enumerate(tup)])
path_to_output_index = {}
while queue:
cur_elem, path, parent = queue.popleft()
with subgraph.inserting_after(parent):
new_getitem_node = subgraph.create_node(
"call_function", operator.getitem, (parent, path[-1]), {}
)
new_getitem_node.meta["example_value"] = cur_elem
path_to_output_index[path] = len(getitem_nodes)
getitem_nodes.append(new_getitem_node)
if isinstance(cur_elem, tuple):
queue.extend(
[(e, path + (i,), new_getitem_node) for i, e in enumerate(cur_elem)] # type: ignore[arg-type,misc]
)
return getitem_nodes, path_to_output_index # type: ignore[return-value]
def _replace_tuple_outputs(
node: Node,
output_index: int,
tuple_spec: dict[tuple[int, ...], int],
invoke_subgraph_node: Node,
graph: torch.fx.Graph,
) -> OrderedSet[Node]:
assert _is_tuple_node(node), "_replace_tuple_outputs expects a tuple node"
queue = deque((c, (c.args[1],)) for c in _get_children_getitems(node))
erased_nodes: OrderedSet[Node] = OrderedSet()
while queue:
cur_node, path = queue.pop()
for c in _get_children_getitems(cur_node):
queue.append((c, path + (c.args[1],))) # type: ignore[return-value, arg-type]
with graph.inserting_after(invoke_subgraph_node):
subgraph_output = graph.create_node(
"call_function",
operator.getitem,
(invoke_subgraph_node, output_index + tuple_spec[path]), # type: ignore[index]
{},
)
cur_node.replace_all_uses_with(subgraph_output, propagate_meta=True)
graph.erase_node(cur_node)
erased_nodes.add(cur_node)
graph.erase_node(node)
erased_nodes.add(node)
return erased_nodes