Compare commits

...

4 Commits

Author SHA1 Message Date
ac2ffd17e8 [Hierarchical Compile] Apply deduplication after output node creation
ghstack-source-id: 43241c7f09ad1e13584fbea47e36c083abe02750
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150306
2025-03-31 14:23:02 -07:00
b9381401a7 [Hierarchical Compile] Add cycle detection to graph region expansion
ghstack-source-id: 78aeb6aadb96068dd2cfa8d5569641cd87e9ad11
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150305
2025-03-31 14:23:02 -07:00
2d9b7170fe [Hierarchical Compile] Add cycle detection function for debug
Remove print

ghstack-source-id: 37d977625ff866a720386ea1862c6762b69f1e95
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150304
2025-03-31 14:23:02 -07:00
fbeaa3931a [Hierarchical Compile] Remove spammy debug log
ghstack-source-id: 1d49e23e6630890bfc112c105730e8300b8fb3d1
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150303
2025-03-31 14:23:01 -07:00
5 changed files with 280 additions and 119 deletions

View File

@ -3,6 +3,7 @@
import torch
import torch.fx
from torch._dynamo.graph_deduplication import _flatten_args_kwargs
from torch._dynamo.graph_utils import _detect_cycles
from torch._dynamo.test_case import TestCase
from torch._dynamo.testing import AotEagerAndRecordGraphs, normalize_gm
@ -59,18 +60,15 @@ class GraphModule(torch.nn.Module):
subgraph_0 = self.subgraph_0
l_x_ = L_x_
l_y_ = L_y_
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, \
'subgraph_0', (l_x_, l_y_)); invoke_subgraph = None
o1: "f32[10, 20]" = torch.sin(l_y_)
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, \
'subgraph_0', (l_x_, o1)); o1 = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_)); invoke_subgraph = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, o1)); o1 = None
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, \
'subgraph_0', (l_x_, l_y_)); subgraph_0 = l_x_ = l_y_ = None
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_)); subgraph_0 = l_x_ = l_y_ = None
getitem_2: "f32[]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
@ -265,31 +263,27 @@ class GraphModule(torch.nn.Module):
y0: "f32[10, 20]" = torch.sin(l_y_)
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', \
(x0, y0)); invoke_subgraph_3 = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', \
(l_x_, l_y_))
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_))
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
o1: "f32[]" = torch.sin(getitem); getitem = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', \
(l_x_, y0))
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, y0))
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', \
(x0, y0)); subgraph_1 = x0 = y0 = None
mul_2: "f32[]" = o1 * getitem_1; o1 = getitem_1 = None
getitem_4: "f32[10, 10]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', \
(l_x_, l_y_)); subgraph_0 = l_x_ = l_y_ = None
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_)); subgraph_0 = l_x_ = l_y_ = None
getitem_2: "f32[]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
mul_2: "f32[]" = o1 * getitem_1; o1 = getitem_1 = None
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', (x0, y0)); invoke_subgraph_3 = None
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', (x0, y0)); subgraph_1 = x0 = y0 = None
getitem_4: "f32[10, 10]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
mul_3: "f32[10, 10]" = mul_2 * getitem_4; mul_2 = getitem_4 = None
add_13: "f32[10, 10]" = mul_3 + getitem_2; mul_3 = getitem_2 = None
return (add_13,)
@ -328,27 +322,29 @@ class GraphModule(torch.nn.Module):
___forward_subgraph_0_post_graph = self.___forward_subgraph_0_post_graph
invoke_subgraph_9 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_post_graph, '___forward_subgraph_0_post_graph', (primals_1, primals_2)); ___forward_subgraph_0_post_graph = None
getitem_1: "f32[]" = invoke_subgraph_9[0]; invoke_subgraph_9 = None
getitem: "f32[]" = invoke_subgraph_9[0]; invoke_subgraph_9 = None
sin_1: "f32[]" = torch.ops.aten.sin.default(getitem_1)
sin_1: "f32[]" = torch.ops.aten.sin.default(getitem)
___forward_subgraph_0_post_graph_1 = self.___forward_subgraph_0_post_graph
invoke_subgraph_10 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_post_graph_1, '___forward_subgraph_0_post_graph', (primals_1, sin)); ___forward_subgraph_0_post_graph_1 = None
getitem_2: "f32[]" = invoke_subgraph_10[0]; invoke_subgraph_10 = None
___forward_subgraph_1_post_graph = self.___forward_subgraph_1_post_graph
invoke_subgraph_11 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_1_post_graph, '___forward_subgraph_1_post_graph', (cos, sin)); ___forward_subgraph_1_post_graph = cos = sin = None
getitem_19: "f32[]" = invoke_subgraph_11[3]
getitem_18: "f32[10, 20]" = invoke_subgraph_11[2]
getitem_17: "f32[10, 10]" = invoke_subgraph_11[1]
getitem_3: "f32[10, 10]" = invoke_subgraph_11[0]; invoke_subgraph_11 = None
___forward_subgraph_0_post_graph_2 = self.___forward_subgraph_0_post_graph
invoke_subgraph_12 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_post_graph_2, '___forward_subgraph_0_post_graph', (primals_1, primals_2)); ___forward_subgraph_0_post_graph_2 = None
getitem_4: "f32[]" = invoke_subgraph_12[0]; invoke_subgraph_12 = None
getitem_1: "f32[]" = invoke_subgraph_10[0]; invoke_subgraph_10 = None
mul: "f32[]" = torch.ops.aten.mul.Tensor(sin_1, getitem_2); sin_1 = None
mul_1: "f32[10, 10]" = torch.ops.aten.mul.Tensor(mul, getitem_3); mul = None
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(mul_1, getitem_4); mul_1 = getitem_4 = None
return (add, primals_1, primals_2, getitem_1, getitem_2, getitem_19, getitem_18, getitem_17, getitem_3)
mul: "f32[]" = torch.ops.aten.mul.Tensor(sin_1, getitem_1); sin_1 = None
___forward_subgraph_0_post_graph_2 = self.___forward_subgraph_0_post_graph
invoke_subgraph_11 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_post_graph_2, '___forward_subgraph_0_post_graph', (primals_1, primals_2)); ___forward_subgraph_0_post_graph_2 = None
getitem_2: "f32[]" = invoke_subgraph_11[0]; invoke_subgraph_11 = None
___forward_subgraph_1_post_graph = self.___forward_subgraph_1_post_graph
invoke_subgraph_12 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_1_post_graph, '___forward_subgraph_1_post_graph', (cos, sin)); ___forward_subgraph_1_post_graph = cos = sin = None
getitem_19: "f32[]" = invoke_subgraph_12[3]
getitem_18: "f32[10, 20]" = invoke_subgraph_12[2]
getitem_17: "f32[10, 10]" = invoke_subgraph_12[1]
getitem_4: "f32[10, 10]" = invoke_subgraph_12[0]; invoke_subgraph_12 = None
mul_1: "f32[10, 10]" = torch.ops.aten.mul.Tensor(mul, getitem_4); mul = None
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(mul_1, getitem_2); mul_1 = getitem_2 = None
return (add, primals_1, primals_2, getitem, getitem_1, getitem_19, getitem_18, getitem_17, getitem_4)
class ___forward_subgraph_0_post_graph(torch.nn.Module):
def forward(self, primals_0: "f32[10, 10]", primals_1: "f32[10, 20]"):
@ -475,12 +471,7 @@ class <lambda>(torch.nn.Module):
add_3: "f32[10, 20]" = torch.ops.aten.add.Tensor(arg1_1, add_1); add_1 = None
repeated_subgraph0 = self.repeated_subgraph0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, \
'subgraph_0', (add_2, add_3)); repeated_subgraph0 = None
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
clone: "f32[10, 10]" = torch.ops.aten.clone.default(add_2); add_2 = None
clone: "f32[10, 10]" = torch.ops.aten.clone.default(add_2)
clone_1: "f32[10, 20]" = torch.ops.aten.clone.default(add_3)
add_4: "f32[10, 10]" = torch.ops.aten.add.Tensor(clone, 1)
@ -491,9 +482,11 @@ class <lambda>(torch.nn.Module):
add_7: "f32[10, 20]" = torch.ops.aten.add.Tensor(clone_1, add_5); clone_1 = add_5 = None
repeated_subgraph0 = self.repeated_subgraph0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', (add_2, add_3)); repeated_subgraph0 = add_2 = None
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
repeated_subgraph0_1 = self.repeated_subgraph0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, \
'subgraph_0', (add_6, add_7)); repeated_subgraph0_1 = add_6 = add_7 = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', (add_6, add_7)); repeated_subgraph0_1 = add_6 = add_7 = None
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
add_8: "f32[]" = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
@ -551,18 +544,19 @@ class <lambda>(torch.nn.Module):
view_3: "f32[10, 10]" = torch.ops.aten.view.default(view_2, [10, 10]); view_2 = None
repeated_subgraph0 = self.repeated_subgraph0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, \
'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0 = None
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
repeated_subgraph0_1 = self.repeated_subgraph0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, \
'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0_1 = arg0_1 = arg1_1 = None
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(view_1, view_3); view_1 = view_3 = None
repeated_subgraph0 = self.repeated_subgraph0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0 = None
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None
add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, sum_1); add = sum_1 = None
repeated_subgraph0_1 = self.repeated_subgraph0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0_1 = arg0_1 = arg1_1 = None
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_1); getitem_1 = None
add_2: "f32[10, 10]" = torch.ops.aten.add.Tensor(add_1, sum_2); add_1 = sum_2 = None
return (add_2,)
@ -585,6 +579,76 @@ class <lambda>(torch.nn.Module):
str(out), """[3, 'x', 1, 2, 3, 1, 4, 5, 6, 3, 4, 5]"""
)
def test_cycle_detection_no_cycle(self):
def fn(x, y):
x0 = x + 1
y0 = y + 2
z = x0.sum() + y0.sum()
return z
x = torch.rand(10, 10, requires_grad=False)
y = torch.rand(10, 20, requires_grad=False)
_, _, fw_graphs = self.run_and_return_graphs(fn, x, y)
mod = fw_graphs[0]
self.assertExpectedInline(_detect_cycles(mod.graph), """no cycle detected""")
def test_cycle_detection_simple(self):
def fn(x, y):
x0 = x + 1
y0 = y + 2
z = x0.sum() + y0.sum()
return z
x = torch.rand(10, 10, requires_grad=False)
y = torch.rand(10, 20, requires_grad=False)
_, _, fw_graphs = self.run_and_return_graphs(fn, x, y)
mod = fw_graphs[0]
add_node = next(n for n in mod.graph.nodes if n.name == "add")
add_2 = next(n for n in mod.graph.nodes if n.name == "add_2")
args = add_node.args
add_node.args = (args[0], add_2)
self.assertExpectedInline(
_detect_cycles(mod.graph),
"""cycle detected in path: deque([arg0_1, add, sum_1, add_2, add])""",
)
def test_cycle_detection_complex(self):
def inner_fn(x, y):
x0 = x.view(x.size())
return x0.view(x.size())
def inner_fn2(x, y):
x = x * 2
y = y * 2
return x.sum() + y.sum()
def fn(x, y):
o0 = inner_fn(x, y)
o1 = inner_fn(x, y)
o2 = inner_fn2(x, y)
o3 = inner_fn2(x, y)
return o0 + o1 + o2.sum() + o3.sum()
x = torch.rand(10, 10, requires_grad=False)
y = torch.rand(10, 20, requires_grad=False)
x_clone = x.clone()
y_clone = y.clone()
_, _, fw_graphs = self.run_and_return_graphs(fn, x_clone, y_clone)
mod = fw_graphs[0]
invoke_subgraph_node = next(
n for n in mod.graph.nodes if n.name == "invoke_subgraph"
)
add_2 = next(n for n in mod.graph.nodes if n.name == "add_2")
args = invoke_subgraph_node.args
invoke_subgraph_node.args = (add_2, args[1])
self.assertExpectedInline(
_detect_cycles(mod.graph),
"""cycle detected in path: deque([arg0_1, invoke_subgraph_1, getitem_1, sum_2, add_2, invoke_subgraph, getitem, sum_1, add_1, add_2])""",
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -12,18 +12,19 @@ import operator
from collections.abc import Iterable
from typing import Any
import torch
import torch.fx
from torch._dynamo import config
from torch._higher_order_ops.utils import has_potential_input_alias_or_mutation
from torch.utils._pytree import tree_flatten
from .graph_region_tracker import Node, Region
from .graph_utils import _detect_cycles, _flatten_args_kwargs
log = logging.getLogger(__name__)
def apply_graph_deduplication(output_graph) -> dict[Node, Node]: # type: ignore[no-untyped-def]
def apply_graph_deduplication(output_graph) -> dict[str, torch.fx.GraphModule]: # type: ignore[no-untyped-def]
"""
This is the main entry point for applying the graph deduplication pass. \
Deduplication occurs in two phases:
@ -50,15 +51,14 @@ The deduplication mutates the output_graph argument in place.
Returns a mapping of nodes to their subgraph output replacement node to remap outputs
when they are created in output_graph.
"""
from torch._inductor.pattern_matcher import stable_topological_sort
duplicated_region_groups = output_graph.region_tracker.get_identical_regions(
output_graph.graph
)
# Used to track which nodes were replaced with subgraph outputs
# today, we have to register the new subgraph submodules before the
# graph outputs have been created, so we pass the replacement mapping
# back to output graph to do the replacements at the site of output creation
output_replacements: dict[Node, Node] = {}
sub_gms: dict[str, torch.fx.GraphModule] = {}
for region_group in duplicated_region_groups:
inds_with_external_users = _get_all_output_indices(region_group)
region = region_group[0]
@ -66,8 +66,14 @@ when they are created in output_graph.
subgraph,
node_ind_arg_inds,
) = _create_subgraph(region, inds_with_external_users)
# Ignore regions with no args for now, could they possibly be evaluated at compile time?
if not list(node_ind_arg_inds):
continue
sub_gm = torch.fx.GraphModule(output_graph.nn_modules, subgraph)
subgraph_name = output_graph.install_subgraph("subgraph", sub_gm)
sub_gms[subgraph_name] = sub_gm
with output_graph.graph.inserting_before():
get_subgraph_node = output_graph.graph.create_node(
"get_attr", subgraph_name, (), {}
@ -81,34 +87,10 @@ when they are created in output_graph.
inds_with_external_users,
sub_gm,
subgraph_name,
output_replacements,
)
return output_replacements
# flattens with support for slices
# Note: a better way to do this would
# be register/unregister slices as pytree nodes
# but there is no unregister API in the pytorch
# pytree impl
def _flatten_args_kwargs(args: Any) -> list[Node]:
fully_flattened = []
def flatten(args: Any) -> None:
flattened, _ = tree_flatten(args)
for arg in flattened:
if isinstance(arg, slice):
start = arg.start
stop = arg.stop
step = arg.step
flatten((start, stop, step))
else:
fully_flattened.append(arg)
flatten(args)
return fully_flattened
stable_topological_sort(output_graph.graph)
return sub_gms
def _replace_region_with_subgraph(
@ -119,7 +101,6 @@ def _replace_region_with_subgraph(
inds_with_external_users: list[int],
sub_gm: torch.fx.GraphModule,
subgraph_name: str,
output_replacements: dict[Node, Node],
) -> None:
sub_args = []
for node_ind, arg_ind in node_ind_arg_ind:
@ -137,23 +118,26 @@ def _replace_region_with_subgraph(
)
return
latest_region_node = region[-1]
with graph.inserting_after(latest_region_node):
invoke_subgraph_node = graph.create_node(
"call_function", torch.ops.higher_order.invoke_subgraph, invoke_args, {}
)
with graph.inserting_after(invoke_subgraph_node):
for ind, external_user_ind in enumerate(inds_with_external_users):
node = region[external_user_ind]
subgraph_output = graph.create_node(
"call_function", operator.getitem, (invoke_subgraph_node, ind), {}
)
output_replacements[node] = subgraph_output
node.replace_all_uses_with(subgraph_output, propagate_meta=True)
from torch._inductor.pattern_matcher import stable_topological_sort
# Erase in reverse topological order
for node in reversed(region):
graph.erase_node(node)
invoke_subgraph_node = graph.create_node(
"call_function", torch.ops.higher_order.invoke_subgraph, invoke_args, {}
)
for ind, external_user_ind in enumerate(inds_with_external_users):
node = region[external_user_ind]
subgraph_output = graph.create_node(
"call_function", operator.getitem, (invoke_subgraph_node, ind), {}
)
node.replace_all_uses_with(subgraph_output, propagate_meta=True)
# Erase in reverse topological order
for node in reversed(region):
graph.erase_node(node)
if config.graph_deduplication_lint:
_detect_cycles(graph)
stable_topological_sort(graph)
graph.lint()
if config.graph_deduplication_lint:
graph.lint()

View File

@ -27,6 +27,8 @@ import torch.fx
from torch._subclasses.fake_tensor import FakeTensor
from torch.utils._pytree import tree_flatten
from .graph_utils import _flatten_args_kwargs
T = TypeVar("T")
@ -253,6 +255,8 @@ class GraphRegionTracker:
"""
topological_ranking = {node: i for i, node in enumerate(graph.nodes)}
region_groups_with_rank = []
# needed to detect if replacing a region will create cycles
node_to_recursive_ancestors = _populate_recursive_ancestor_map(graph)
# Create region groups; a region group is a group
# of regions that are all identical. In this initial state
@ -281,7 +285,12 @@ class GraphRegionTracker:
# overlap.
seen_nodes: set[Node] = set()
for region_group in region_groups:
fully_expand_region_group(region_group, seen_nodes, self._is_identical)
fully_expand_region_group(
region_group,
seen_nodes,
node_to_recursive_ancestors,
self._is_identical,
)
# sort topologically
for region in region_group:
region.sort(key=lambda n: topological_ranking[n])
@ -297,6 +306,7 @@ class GraphRegionTracker:
def fully_expand_region_group(
regions: list[Region],
seen_nodes: set[Node],
node_to_recursive_ancestors: dict[Node, set[Node]],
is_identical_fn: Callable[[Node, Node], bool],
) -> None:
debug_log("--------------------------------------------------")
@ -327,17 +337,19 @@ def fully_expand_region_group(
# regions are only expanded if the node to add is valid
# for ALL regions
while current_node:
add_node = True
add_node = not _will_create_cycle(
current_node, regions[0], node_to_recursive_ancestors
)
nodes_to_add.clear()
nodes_to_add.append(current_node)
nodes_to_add_set = set(nodes_to_add)
for region_it in region_iters[1:]:
for ind, region_it in enumerate(region_iters[1:]):
ind += 1 # compensate for the 0th region
node = region_it.next()
debug_log("--------------------")
debug_log("considering adding: %s, cur_node: %s", node, current_node)
debug_log("previously claimed nodes: %s", node in seen_nodes)
debug_log("%s", seen_nodes)
if node:
debug_log("is_identical: %s", is_identical_fn(node, current_node))
add_node &= (
@ -345,6 +357,9 @@ def fully_expand_region_group(
and node not in nodes_to_add_set
and node.op != "placeholder"
and is_identical_fn(node, current_node)
and not _will_create_cycle(
node, regions[ind], node_to_recursive_ancestors
)
)
nodes_to_add.append(node)
nodes_to_add_set.add(node)
@ -369,3 +384,35 @@ def fully_expand_region_group(
debug_log("end expand new region group: %s", regions)
debug_log("--------------------------------------------------")
def _populate_recursive_ancestor_map(graph: torch.fx.Graph) -> dict[Node, set[Node]]:
node_to_recursive_ancestors: dict[Node, set[Node]] = {}
for node in graph.nodes:
node_to_recursive_ancestors[node] = set()
for node in graph.nodes:
all_args = _flatten_args_kwargs((node.args, node.kwargs))
for arg in all_args:
if isinstance(arg, Node):
node_to_recursive_ancestors[node].update(
node_to_recursive_ancestors[arg]
)
node_to_recursive_ancestors[node].add(node)
return node_to_recursive_ancestors
def _will_create_cycle(
node_to_add: Node,
region: Region,
node_to_recursive_ancestors: dict[Node, set[Node]],
) -> bool:
region_set: set[Node] = set(region)
region_ancestors: set[Node] = set(
tree_flatten([list(node_to_recursive_ancestors[node]) for node in region])[0]
)
external_users = [user for user in node_to_add.users if user not in region_set]
for user in external_users:
if user in region_ancestors:
return True
return False

View File

@ -0,0 +1,69 @@
from collections import deque
from typing import Any
from torch.fx import Graph, Node
from torch.utils._pytree import tree_flatten
# flattens with support for slices
# Note: a better way to do this would
# be register/unregister slices as pytree nodes
# but there is no unregister API in the pytorch
# pytree impl
def _flatten_args_kwargs(args: Any) -> list[Node]:
fully_flattened = []
def flatten(args: Any) -> None:
flattened, _ = tree_flatten(args)
for arg in flattened:
if isinstance(arg, slice):
start = arg.start
stop = arg.stop
step = arg.step
flatten((start, stop, step))
else:
fully_flattened.append(arg)
flatten(args)
return fully_flattened
def _detect_cycles(graph: Graph) -> str:
current_path: deque[Node] = deque()
current_path_set: set[Node] = set()
pending: deque[tuple[Node, Node]] = deque()
def add_to_current_path(node: Node) -> None:
current_path.append(node)
current_path_set.add(node)
def pop_current_path() -> None:
node = current_path.pop()
current_path_set.remove(node)
def current_path_head() -> Node:
return current_path[-1]
for origin in graph.find_nodes(op="placeholder"):
current_path.clear()
current_path_set.clear()
add_to_current_path(origin)
for child in origin.users:
pending.append((child, origin))
while pending:
cur_node, parent = pending.pop()
while current_path_head() != parent:
pop_current_path()
if cur_node in current_path_set:
current_path.append(cur_node)
return f"cycle detected in path: {current_path}"
add_to_current_path(cur_node)
for child in cur_node.users:
pending.append((child, cur_node))
return "no cycle detected"

View File

@ -240,6 +240,10 @@ class FakeRootModule(torch.nn.Module):
def __repr__(self) -> str:
return "FakeRootModule(...)"
def add_nn_modules(self, nn_modules: dict[str, torch.nn.Module]):
for k, v in nn_modules.items():
setattr(self, k, v)
class WrapperBackend:
def __init__(self, backend: CompilerFn):
@ -1070,8 +1074,6 @@ class OutputGraph:
for value in stack_values:
value.realize()
output_replacements = self.dedup_pass()
# Use nn.Module "proxies" in the constructed GraphModule so that
# the resulting GM does not hold additional strong references to the original modules.
# This prevents a strong ref cycle where Dynamo created code holds on to references
@ -1155,9 +1157,7 @@ class OutputGraph:
append_prefix_insts()
# optimization to generate better code in a common case
self.add_output_instructions(
self.compile_and_call_fx_graph(
tx, list(reversed(stack_values)), root, output_replacements
)
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
+ [create_instruction("UNPACK_SEQUENCE", arg=len(stack_values))]
)
# restore all the live local vars
@ -1190,9 +1190,7 @@ class OutputGraph:
output = []
if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:
output.extend(
self.compile_and_call_fx_graph(
tx, pass2.graph_output_vars(), root, output_replacements
)
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
)
if len(pass2.graph_outputs) != 0:
@ -1356,7 +1354,7 @@ class OutputGraph:
tx.speculation_log.clear()
raise exc.CompileCollectiveRestartAnalysis
def compile_and_call_fx_graph(self, tx, rv, root, replaced_outputs):
def compile_and_call_fx_graph(self, tx, rv, root):
"""
Generate code from self.graph and return the Instruction()s to
call that generated code.
@ -1379,9 +1377,8 @@ class OutputGraph:
(self.current_tracer.create_arg(tuple(x.as_proxy() for x in rv)),),
{},
)
for old_node, new_node in replaced_outputs.items():
old_node.replace_all_uses_with(new_node)
sub_gms = self.dedup_pass()
root.add_nn_modules(sub_gms)
tx.output.current_tracer._maybe_preserve_original_meta(tx, output_node)
if not config.do_not_emit_runtime_asserts:
@ -1576,7 +1573,7 @@ class OutputGraph:
if torch._dynamo.config.use_graph_deduplication:
return apply_graph_deduplication(self)
else:
return dict()
return {}
def install_subgraph(self, name, sub_gm):
next_name = get_unique_name_wrt(name, self.nn_modules, requires_suffix=True)