mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Dynamo][Hierarchical Compile] Flatten tuple inputs for regions (#158812)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158812 Approved by: https://github.com/anijain2305 ghstack dependencies: #158810, #158811
This commit is contained in:
committed by
PyTorch MergeBot
parent
664005662a
commit
450517f346
@ -4,7 +4,9 @@ import contextlib
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch._dynamo.graph_deduplication import apply_graph_deduplication
|
||||
from torch._dynamo.graph_utils import _detect_cycles
|
||||
from torch._dynamo.output_graph import FakeRootModule
|
||||
from torch._dynamo.test_case import TestCase
|
||||
from torch._dynamo.testing import (
|
||||
AotEagerAndRecordGraphs,
|
||||
@ -1129,6 +1131,82 @@ def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
|
||||
result_eager = fn(*inps)
|
||||
self.assertEqual(result_compiled, result_eager)
|
||||
|
||||
def test_tuple_inputs(self):
|
||||
with (
|
||||
torch._dynamo.config.patch("use_graph_deduplication", False),
|
||||
torch._dynamo.config.patch("track_nodes_for_deduplication", True),
|
||||
):
|
||||
|
||||
def inner(x, y):
|
||||
x0, x1 = torch.split(x, 5)
|
||||
return x0 + x1 + y
|
||||
|
||||
def fn(x, y):
|
||||
o1 = inner(x, y)
|
||||
o2 = inner(x, y)
|
||||
o3 = inner(x, y)
|
||||
o4 = inner(x, y)
|
||||
return o1.sum() + o2.sum() + o3.sum() + o4.sum()
|
||||
|
||||
graph, tracker = extract_graph_and_tracker(
|
||||
fn, torch.rand(10, 10), torch.rand(5, 10)
|
||||
)
|
||||
|
||||
class MockOutputGraph:
|
||||
def __init__(self):
|
||||
self.graph = graph
|
||||
self.region_tracker = tracker
|
||||
self.nn_modules = FakeRootModule({})
|
||||
|
||||
def install_subgraph(self, name, subgraph):
|
||||
return ""
|
||||
|
||||
splits = [
|
||||
n
|
||||
for n in graph.nodes
|
||||
if n.op == "call_function" and n.target == torch.split
|
||||
]
|
||||
for split in splits:
|
||||
tracker.node_to_duplicates.pop(split)
|
||||
|
||||
apply_graph_deduplication(MockOutputGraph())
|
||||
self.assertExpectedInline(
|
||||
graph,
|
||||
"""\
|
||||
graph():
|
||||
%_unnamed : [num_users=4] = get_attr[target=]
|
||||
%l_x_ : torch.Tensor [num_users=4] = placeholder[target=L_x_]
|
||||
%l_y_ : torch.Tensor [num_users=4] = placeholder[target=L_y_]
|
||||
%split : [num_users=2] = call_function[target=torch.functional.split](args = (%l_x_, 5), kwargs = {})
|
||||
%x0 : [num_users=1] = call_function[target=operator.getitem](args = (%split, 0), kwargs = {})
|
||||
%x1 : [num_users=1] = call_function[target=operator.getitem](args = (%split, 1), kwargs = {})
|
||||
%split_1 : [num_users=2] = call_function[target=torch.functional.split](args = (%l_x_, 5), kwargs = {})
|
||||
%x0_1 : [num_users=1] = call_function[target=operator.getitem](args = (%split_1, 0), kwargs = {})
|
||||
%x1_1 : [num_users=1] = call_function[target=operator.getitem](args = (%split_1, 1), kwargs = {})
|
||||
%split_2 : [num_users=2] = call_function[target=torch.functional.split](args = (%l_x_, 5), kwargs = {})
|
||||
%x0_2 : [num_users=1] = call_function[target=operator.getitem](args = (%split_2, 0), kwargs = {})
|
||||
%x1_2 : [num_users=1] = call_function[target=operator.getitem](args = (%split_2, 1), kwargs = {})
|
||||
%split_3 : [num_users=2] = call_function[target=torch.functional.split](args = (%l_x_, 5), kwargs = {})
|
||||
%x0_3 : [num_users=1] = call_function[target=operator.getitem](args = (%split_3, 0), kwargs = {})
|
||||
%x1_3 : [num_users=1] = call_function[target=operator.getitem](args = (%split_3, 1), kwargs = {})
|
||||
%invoke_subgraph : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%_unnamed, , %x0, %x1, %l_y_), kwargs = {})
|
||||
%getitem_8 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph, 0), kwargs = {})
|
||||
%sum_1 : [num_users=1] = call_method[target=sum](args = (%getitem_8,), kwargs = {})
|
||||
%invoke_subgraph_1 : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%_unnamed, , %x0_1, %x1_1, %l_y_), kwargs = {})
|
||||
%getitem_9 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph_1, 0), kwargs = {})
|
||||
%sum_2 : [num_users=1] = call_method[target=sum](args = (%getitem_9,), kwargs = {})
|
||||
%add_8 : [num_users=1] = call_function[target=operator.add](args = (%sum_1, %sum_2), kwargs = {})
|
||||
%invoke_subgraph_2 : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%_unnamed, , %x0_2, %x1_2, %l_y_), kwargs = {})
|
||||
%getitem_10 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph_2, 0), kwargs = {})
|
||||
%sum_3 : [num_users=1] = call_method[target=sum](args = (%getitem_10,), kwargs = {})
|
||||
%add_9 : [num_users=1] = call_function[target=operator.add](args = (%add_8, %sum_3), kwargs = {})
|
||||
%invoke_subgraph_3 : [num_users=1] = call_function[target=torch.ops.higher_order.invoke_subgraph](args = (%_unnamed, , %x0_3, %x1_3, %l_y_), kwargs = {})
|
||||
%getitem_11 : [num_users=1] = call_function[target=operator.getitem](args = (%invoke_subgraph_3, 0), kwargs = {})
|
||||
%sum_4 : [num_users=1] = call_method[target=sum](args = (%getitem_11,), kwargs = {})
|
||||
%add_10 : [num_users=1] = call_function[target=operator.add](args = (%add_9, %sum_4), kwargs = {})
|
||||
return (add_10,)""",
|
||||
)
|
||||
|
||||
def test_param_transfer_to_submodule(self):
|
||||
def inner_fn(x, y):
|
||||
return x + y + y + x
|
||||
|
@ -80,6 +80,7 @@ when they are created in output_graph.
|
||||
(
|
||||
subgraph,
|
||||
external_node_usages,
|
||||
node_usage_to_tuple_elems,
|
||||
ind_to_tuple_spec,
|
||||
) = _create_subgraph(region, inds_with_external_users)
|
||||
|
||||
@ -101,6 +102,7 @@ when they are created in output_graph.
|
||||
region,
|
||||
get_subgraph_node,
|
||||
external_node_usages,
|
||||
node_usage_to_tuple_elems,
|
||||
ind_to_tuple_spec,
|
||||
inds_with_external_users,
|
||||
subgraph_name,
|
||||
@ -124,6 +126,7 @@ def _replace_region_with_subgraph(
|
||||
region: Region,
|
||||
get_subgraph_node: Node,
|
||||
external_node_usages: Iterable[OrderedSet[UsageIndex]],
|
||||
node_usage_to_tuple_elems: dict[UsageIndex, OrderedSet[int]],
|
||||
ind_to_tuple_spec: dict[int, dict[tuple[int, ...], int]],
|
||||
inds_with_external_users: list[int],
|
||||
subgraph_name: str,
|
||||
@ -131,6 +134,7 @@ def _replace_region_with_subgraph(
|
||||
node_to_mutated_arg_positions: dict[Node, OrderedSet[int]],
|
||||
) -> None:
|
||||
sub_args = []
|
||||
flattened_getitem_nodes: OrderedSet[Node] = OrderedSet()
|
||||
for usages in external_node_usages:
|
||||
usage = next(iter(usages))
|
||||
node_ind, usage_ind = usage
|
||||
@ -144,13 +148,19 @@ def _replace_region_with_subgraph(
|
||||
"NYI: Failed to substitute region %s due to mutation", region
|
||||
)
|
||||
return
|
||||
|
||||
sub_args.append(flattened_args_kwargs[usage_ind])
|
||||
if usage in node_usage_to_tuple_elems:
|
||||
tuple_elems = [region[i] for i in node_usage_to_tuple_elems[usage]]
|
||||
flattened_getitem_nodes.update(tuple_elems)
|
||||
sub_args.extend(tuple_elems)
|
||||
else:
|
||||
sub_args.append(flattened_args_kwargs[usage_ind])
|
||||
|
||||
# Input/Output aliasing not supported in HOPs today
|
||||
# Note: we should use the nodes in the original graph (the region here)
|
||||
# because we use the original traced example values for this check
|
||||
if _has_aliasing(region, sub_args, inds_with_external_users):
|
||||
if _has_aliasing(
|
||||
region, sub_args, inds_with_external_users, flattened_getitem_nodes
|
||||
):
|
||||
return
|
||||
|
||||
invoke_args = (get_subgraph_node, subgraph_name, *sub_args)
|
||||
@ -183,6 +193,10 @@ def _replace_region_with_subgraph(
|
||||
|
||||
# Erase in reverse topological order
|
||||
for node in reversed(region):
|
||||
if node in flattened_getitem_nodes:
|
||||
# Don't erase these, since they will still be used
|
||||
continue
|
||||
|
||||
if node not in flattened_output_nodes:
|
||||
graph.erase_node(node)
|
||||
|
||||
@ -244,17 +258,39 @@ def _create_subgraph(
|
||||
region: Region,
|
||||
inds_with_external_users: list[int],
|
||||
) -> tuple[
|
||||
torch.fx.Graph, list[OrderedSet[UsageIndex]], dict[int, dict[tuple[int, ...], int]]
|
||||
torch.fx.Graph,
|
||||
list[OrderedSet[UsageIndex]],
|
||||
dict[UsageIndex, OrderedSet[int]],
|
||||
dict[int, dict[tuple[int, ...], int]],
|
||||
]:
|
||||
subgraph: torch.fx.Graph = torch.fx.Graph()
|
||||
external_input_to_usages = _get_external_inputs(region)
|
||||
external_node_usages = list[OrderedSet[UsageIndex]]()
|
||||
region_to_subgraph_node = {}
|
||||
flattened_getitem_nodes: OrderedSet[Node] = OrderedSet()
|
||||
node_usage_to_tuple_elems: dict[UsageIndex, OrderedSet[int]] = {}
|
||||
|
||||
for node, usage_indices in external_input_to_usages.items():
|
||||
placeholder = subgraph.placeholder(f"subgraph_input_{node.name}")
|
||||
region_to_subgraph_node[node] = placeholder
|
||||
# We don't handle tuples as inputs today
|
||||
if _is_tuple_node(node):
|
||||
# If a node is a tuple we will possibly create multiple placeholders for them
|
||||
# and track which nodes we won't copy into the subgraph because they are flattened away
|
||||
# Later, when replacing each region with this subgraph, we will create a getitem node
|
||||
# externally which will perform the flattening on the outer nodes.
|
||||
flattened_node_indices = _get_flattened_node_indices(node, region)
|
||||
for ind in flattened_node_indices:
|
||||
placeholder = subgraph.placeholder(
|
||||
f"supgraph_input_{node.name}_flattened_{ind}"
|
||||
)
|
||||
region_to_subgraph_node[region[ind]] = placeholder
|
||||
flattened_getitem_nodes.add(region[ind])
|
||||
node_usage_to_tuple_elems[next(iter(usage_indices))] = (
|
||||
flattened_node_indices
|
||||
)
|
||||
else:
|
||||
placeholder = subgraph.placeholder(f"subgraph_input_{node.name}")
|
||||
region_to_subgraph_node[node] = placeholder
|
||||
|
||||
external_node_usages.append(usage_indices)
|
||||
|
||||
def map_arg(node: Node) -> Node:
|
||||
@ -285,7 +321,7 @@ def _create_subgraph(
|
||||
|
||||
subgraph.output(tuple(output_list))
|
||||
|
||||
return subgraph, external_node_usages, ind_to_tuple_spec
|
||||
return subgraph, external_node_usages, node_usage_to_tuple_elems, ind_to_tuple_spec
|
||||
|
||||
|
||||
def _stable_topological_sort(
|
||||
@ -413,10 +449,12 @@ def _has_aliasing(
|
||||
region: Region,
|
||||
inputs: list[Node],
|
||||
inds_with_external_users: list[int],
|
||||
flattened_getitem_nodes: OrderedSet[Node],
|
||||
) -> bool:
|
||||
input_storages: dict[StorageWeakRef, Node] = dict()
|
||||
|
||||
for node in inputs:
|
||||
if node in flattened_getitem_nodes:
|
||||
continue
|
||||
example_value = node.meta["example_value"]
|
||||
if isinstance(example_value, torch.Tensor):
|
||||
storage = StorageWeakRef(example_value._typed_storage())
|
||||
@ -430,11 +468,11 @@ def _has_aliasing(
|
||||
)
|
||||
return True
|
||||
input_storages[storage] = node
|
||||
|
||||
output_storages: dict[StorageWeakRef, Node] = dict()
|
||||
for i in inds_with_external_users:
|
||||
out_node = region[i]
|
||||
|
||||
if out_node in flattened_getitem_nodes:
|
||||
continue
|
||||
if out_node:
|
||||
example_value = out_node.meta["example_value"]
|
||||
assert not isinstance(example_value, list)
|
||||
@ -450,7 +488,6 @@ def _has_aliasing(
|
||||
)
|
||||
return True
|
||||
output_storages[storage] = out_node
|
||||
|
||||
intersected_storages = input_storages.keys() & output_storages.keys()
|
||||
if len(intersected_storages) > 0:
|
||||
# input-output aliasing
|
||||
@ -464,7 +501,6 @@ def _has_aliasing(
|
||||
aliased,
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@ -478,6 +514,20 @@ def _get_children_getitems(node: Node) -> Generator[Node, None, None]:
|
||||
yield user
|
||||
|
||||
|
||||
def _get_flattened_node_indices(node: Node, region: Region) -> OrderedSet[int]:
|
||||
"""Returns an ordered set of indices, each representing a node in the region which will be flattened"""
|
||||
flattened_node_to_ind = {n: i for i, n in enumerate(region)}
|
||||
node_indices: OrderedSet[int] = OrderedSet()
|
||||
queue = deque(_get_children_getitems(node))
|
||||
while queue:
|
||||
cur_node = queue.popleft()
|
||||
if any(user in region for user in cur_node.users):
|
||||
node_indices.add(flattened_node_to_ind[cur_node])
|
||||
for child in _get_children_getitems(cur_node):
|
||||
queue.append(child)
|
||||
return node_indices
|
||||
|
||||
|
||||
def _create_getitem_nodes(
|
||||
node: Node, subgraph_tuple_node: Node, subgraph: torch.fx.Graph
|
||||
) -> tuple[list[Node], dict[tuple[int, ...], int]]:
|
||||
|
Reference in New Issue
Block a user