Compare commits

...

7 Commits

Author SHA1 Message Date
f3949c0bc1 [Dynamo] Optimize dedupe region ancestor tracking
ghstack-source-id: 887bbae5244a1e4df25aad4121427b2b4c190196
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152589
2025-05-13 01:31:29 -07:00
b50e411346 [Dynamo] Fix typing in graph_deduplication.py
ghstack-source-id: 4021125c82b729284722948fe2c610b1dbc8bede
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152572
2025-05-12 22:20:00 -07:00
0cc273cf0c [Hierarchical Compile] Replace tracing alias and mutation check with dynamo impl
ghstack-source-id: 5c15aad0148bd1c660044b31e57547c46c177ff9
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152570
2025-05-12 22:19:59 -07:00
b347aaa731 [Hierarchical Compile] Take into account mutation deps in cycle detection
ghstack-source-id: f1e27f13c060784a05c4f4cbeb76c7c3c8c48a11
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152506
2025-05-12 22:19:59 -07:00
4161b0fd72 [Hierarchical Compile] Add mutation dependencies to topological sorting
ghstack-source-id: f849ea10343a8760d58f879a311c78ff802ed89a
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152410

Fix mutation ordering bug
2025-05-12 22:19:58 -07:00
ffe1faa370 [Hierarchical Compilation] Use universal flatten APIs
ghstack-source-id: f3ceea9e34fe1a587c49faeb06fd4644ad03e1c8
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152505
2025-05-12 22:19:58 -07:00
8ba70748dc [Hierarchical Compilation] Track node mutations (#152389)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152389
Approved by: https://github.com/anijain2305
ghstack-source-id: 43322da49d1171561c507636e24c7ecb910a00ce
2025-05-12 22:19:58 -07:00
7 changed files with 635 additions and 230 deletions

View File

@ -1,8 +1,8 @@
add_loop_eager,compile_time_instruction_count,3167000000,0.015
add_loop_eager,compile_time_instruction_count,3035000000,0.015
add_loop_eager_dynamic,compile_time_instruction_count,6066000000,0.025
add_loop_eager_dynamic,compile_time_instruction_count,5928000000,0.025
@ -14,11 +14,11 @@ add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44480000000,0.025
add_loop_inductor_gpu,compile_time_instruction_count,26050000000,0.015
add_loop_inductor_gpu,compile_time_instruction_count,25900000000,0.015
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1018000000,0.015
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1011000000,0.015
@ -34,7 +34,7 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10370000
update_hint_regression,compile_time_instruction_count,1723000000,0.02
update_hint_regression,compile_time_instruction_count,1715000000,0.02
@ -42,15 +42,15 @@ float_args,compile_time_instruction_count,439200000,0.015
sum_floordiv_regression,compile_time_instruction_count,1024000000,0.015
sum_floordiv_regression,compile_time_instruction_count,1009000000,0.015
symint_sum,compile_time_instruction_count,3278000000,0.015
symint_sum,compile_time_instruction_count,3252000000,0.015
symint_sum_loop,compile_time_instruction_count,4300000000,0.015
symint_sum_loop,compile_time_instruction_count,4262000000,0.015
@ -58,11 +58,11 @@ aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2091000000
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5944000000,0.015
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5981000000,0.015
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8586000000,0.015
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8630000000,0.015
@ -70,8 +70,8 @@ aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1900000000,0.015
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3795000000,0.015
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3818000000,0.015
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10280000000,0.015
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10350000000,0.015

1 add_loop_eager compile_time_instruction_count 3167000000 3035000000 0.015
2 add_loop_eager_dynamic compile_time_instruction_count 6066000000 5928000000 0.025
3 add_loop_inductor compile_time_instruction_count 29570000000 29570000000 0.015
4 add_loop_inductor_dynamic_gpu compile_time_instruction_count 44480000000 44480000000 0.025
5 add_loop_inductor_gpu compile_time_instruction_count 26050000000 25900000000 0.015
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 1018000000 1011000000 0.015
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 18240000000 18240000000 0.015
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 16340000000 16340000000 0.015
14 symint_sum_loop compile_time_instruction_count 4300000000 4262000000 0.015
15 aotdispatcher_inference_nosubclass_cpu compile_time_instruction_count 2091000000 2091000000 0.015
16 aotdispatcher_inference_subclass_cpu compile_time_instruction_count 5944000000 5981000000 0.015
17 aotdispatcher_partitioner_cpu compile_time_instruction_count 8586000000 8630000000 0.015
18 aotdispatcher_partitioner_cpu2 compile_time_instruction_count 1900000000 1900000000 0.015
19 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3795000000 3818000000 0.015
20 aotdispatcher_training_subclass_cpu compile_time_instruction_count 10280000000 10350000000 0.015
21
22
23
24
34
35
36
37
38
39
40
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
58
59
60
61
62
63
64
65
66
67
68
70
71
72
73
74
75
76
77

View File

@ -1,11 +1,17 @@
# Owner(s): ["module: dynamo"]
# flake8: noqa: B950
import contextlib
import torch
import torch.fx
from torch._dynamo.graph_deduplication import _flatten_args_kwargs
from torch._dynamo.graph_utils import _detect_cycles
from torch._dynamo.test_case import TestCase
from torch._dynamo.testing import AotEagerAndRecordGraphs, normalize_gm
from torch._dynamo.testing import (
AotEagerAndRecordGraphs,
extract_graph_and_tracker,
normalize_gm,
)
from torch.utils._ordered_set import OrderedSet
def extract_graph(fn, *args, **kwargs):
@ -19,9 +25,32 @@ def graph_str(gm):
class GraphDededuplicationTests(TestCase):
def setUp(self):
self.exit_stack = contextlib.ExitStack()
self.exit_stack.enter_context(
torch._dynamo.config.patch("use_graph_deduplication", True)
)
super().setUp()
def tearDown(self):
self.exit_stack.close()
super().tearDown()
def run_and_return_graphs(self, fn, *args, **kwargs):
with torch._dynamo.config.patch("use_graph_deduplication", True):
return extract_graph(fn, *args, **kwargs)
return extract_graph(fn, *args, **kwargs)
def run_and_get_simple_graph(self):
def fn(x, y):
x0 = x + 1
y0 = y + 2
z = x0.sum() + y0.sum()
return z
x = torch.rand(10, 10, requires_grad=False)
y = torch.rand(10, 20, requires_grad=False)
_, _, fw_graphs = self.run_and_return_graphs(fn, x, y)
return fw_graphs[0]
def test_single_subgraph(self):
def inner_fn(x, y):
@ -432,12 +461,6 @@ class GraphModule(torch.nn.Module):
)
def test_input_mutation(self):
def inner_fn(x, y):
x0 = x + 1
y0 = y + 2
z = x0.sum() + y0.sum()
return z
def inner_fn2(x, y):
x0 = x + 1
y0 = y + 1
@ -447,9 +470,6 @@ class GraphModule(torch.nn.Module):
def fn(x, y):
x0 = torch.sin(x)
_y0 = torch.cos(y)
# o0 = inner_fn(x0, y0)
# o1 = inner_fn(x0, o0)
o2 = inner_fn2(x0, y)
o3 = inner_fn2(x0.clone(), y.clone())
return o2 + o3
@ -583,28 +603,13 @@ class <lambda>(torch.nn.Module):
""",
)
def test_flatten_with_slices(self):
tree = [{"x": 3}, ["x", slice(1, 2, 3), 1], [4, 5, 6, [slice(3, 4, 5)]]]
out = _flatten_args_kwargs(tree)
def test_cycle_detection_no_cycle(self):
mod = self.run_and_get_simple_graph()
self.assertExpectedInline(
str(out), """[3, 'x', 1, 2, 3, 1, 4, 5, 6, 3, 4, 5]"""
_detect_cycles(mod.graph, {}), """no cycle detected"""
)
def test_cycle_detection_no_cycle(self):
def fn(x, y):
x0 = x + 1
y0 = y + 2
z = x0.sum() + y0.sum()
return z
x = torch.rand(10, 10, requires_grad=False)
y = torch.rand(10, 20, requires_grad=False)
_, _, fw_graphs = self.run_and_return_graphs(fn, x, y)
mod = fw_graphs[0]
self.assertExpectedInline(_detect_cycles(mod.graph), """no cycle detected""")
def test_cycle_detection_simple(self):
def test_cycle_detection_single_node(self):
def fn(x, y):
x0 = x + 1
y0 = y + 2
@ -621,8 +626,64 @@ class <lambda>(torch.nn.Module):
args = add_node.args
add_node.args = (args[0], add_2)
self.assertExpectedInline(
_detect_cycles(mod.graph),
"""cycle detected in path: deque([arg0_1, add, sum_1, add_2, add])""",
_detect_cycles(mod.graph, {add_2: OrderedSet([add_2])}),
"""cycle detected in path: deque([output, add_2, add_2])""",
)
def test_cycle_detection_two_node(self):
def fn(x, y):
x0 = x + 1
y0 = y + 2
z = x0.sum() + y0.sum()
return z
x = torch.rand(10, 10, requires_grad=False)
y = torch.rand(10, 20, requires_grad=False)
_, _, fw_graphs = self.run_and_return_graphs(fn, x, y)
mod = fw_graphs[0]
add_node = next(n for n in mod.graph.nodes if n.name == "add")
add_2 = next(n for n in mod.graph.nodes if n.name == "add_2")
args = add_node.args
add_node.args = (args[0], add_2)
self.assertExpectedInline(
_detect_cycles(
mod.graph,
{add_2: OrderedSet([add_node]), add_node: OrderedSet([add_2])},
),
"""cycle detected in path: deque([output, add_2, add, add_2])""",
)
def test_cycle_detection_arg_and_additional_deps(self):
def fn(x, y):
x0 = x + 1
y0 = y + 2
z = x0.sum() + y0.sum()
return z
x = torch.rand(10, 10, requires_grad=False)
y = torch.rand(10, 20, requires_grad=False)
_, _, fw_graphs = self.run_and_return_graphs(fn, x, y)
mod = fw_graphs[0]
add_node = next(n for n in mod.graph.nodes if n.name == "add")
add_2 = next(n for n in mod.graph.nodes if n.name == "add_2")
args = add_node.args
add_node.args = (args[0], add_2)
self.assertExpectedInline(
_detect_cycles(mod.graph, {add_2: OrderedSet([add_node])}),
"""cycle detected in path: deque([output, add_2, add, add_2])""",
)
def test_cycle_detection_simple(self):
mod = self.run_and_get_simple_graph()
add_node = next(n for n in mod.graph.nodes if n.name == "add")
add_2 = next(n for n in mod.graph.nodes if n.name == "add_2")
args = add_node.args
add_node.args = (args[0], add_2)
self.assertExpectedInline(
_detect_cycles(mod.graph, {}),
"""cycle detected in path: deque([output, add_2, sum_1, add, add_2])""",
)
def test_cycle_detection_complex(self):
@ -656,8 +717,8 @@ class <lambda>(torch.nn.Module):
args = invoke_subgraph_node.args
invoke_subgraph_node.args = (add_2, args[1])
self.assertExpectedInline(
_detect_cycles(mod.graph),
"""cycle detected in path: deque([arg0_1, invoke_subgraph_1, getitem_1, sum_2, add_2, invoke_subgraph, getitem, sum_1, add_1, add_2])""",
_detect_cycles(mod.graph, {}),
"""cycle detected in path: deque([output, add_2, add_1, sum_1, getitem, invoke_subgraph, add_2])""",
)
def test_autocast_ordering(self):
@ -699,7 +760,7 @@ class <lambda>(torch.nn.Module):
sum_2 = get_node("sum_2")
exit_autocast = mod.graph.call_function(torch.amp._exit_autocast)
sum_2.append(exit_autocast)
additional_deps = _populate_additional_deps(mod.graph)
additional_deps = _populate_additional_deps(mod.graph, {})
invoke_subgraph = get_node("invoke_subgraph")
invoke_subgraph.append(enter_autocast)
getitem_1 = get_node("getitem_1")
@ -914,6 +975,137 @@ class <lambda>(torch.nn.Module):
""",
)
def test_mutation_ordering(self):
from torch._dynamo.graph_deduplication import _stable_topological_sort
def inner_fn(x, y):
x0 = x.view(x.size())
return x0.view(x.size())
def inner_fn2(x, y):
x = x * 2
y = y * 2
return x.sum() + y.sum()
def fn(x, y):
o0 = inner_fn(x, y)
o1 = inner_fn(x, y)
x.add_(x)
o2 = inner_fn2(x, y)
y.mul_(y)
o3 = inner_fn2(x, y)
return o0 + o1 + o2.sum() + o3.sum()
x = torch.rand(10, 10)
y = torch.rand(10, 20)
x_clone = x.clone()
y_clone = y.clone()
graph, _ = extract_graph_and_tracker(fn, x_clone, y_clone)
def graph_code(graph):
return graph.python_code("self").src
def get_node(name):
return next(n for n in graph.nodes if n.name == name)
self.assertExpectedInline(
graph_code(graph),
"""\
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
subgraph_0 = self.subgraph_0
l_x_ = L_x_
l_y_ = L_y_
x0 = l_x_.view((10, 10))
o0 = x0.view((10, 10)); x0 = None
x0_1 = l_x_.view((10, 10))
o1 = x0_1.view((10, 10)); x0_1 = None
add_ = l_x_.add_(l_x_); add_ = None
add_2 = o0 + o1; o0 = o1 = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_)
mul_ = l_y_.mul_(l_y_); mul_ = None
getitem = invoke_subgraph[0]; invoke_subgraph = None
sum_5 = getitem.sum(); getitem = None
add_3 = add_2 + sum_5; add_2 = sum_5 = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None
getitem_1 = invoke_subgraph_1[0]; invoke_subgraph_1 = None
sum_6 = getitem_1.sum(); getitem_1 = None
add_4 = add_3 + sum_6; add_3 = sum_6 = None
return (add_4,)
""",
)
# Shuffle nodes in the graph
add_ = get_node("add_")
mul_ = get_node("mul_")
o1 = get_node("o1")
o1.append(mul_)
add_2 = get_node("add_2")
add_2.append(add_)
self.assertExpectedInline(
graph_code(graph),
"""\
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
subgraph_0 = self.subgraph_0
l_x_ = L_x_
l_y_ = L_y_
x0 = l_x_.view((10, 10))
o0 = x0.view((10, 10)); x0 = None
x0_1 = l_x_.view((10, 10))
o1 = x0_1.view((10, 10)); x0_1 = None
mul_ = l_y_.mul_(l_y_); mul_ = None
add_2 = o0 + o1; o0 = o1 = None
add_ = l_x_.add_(l_x_); add_ = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_)
getitem = invoke_subgraph[0]; invoke_subgraph = None
sum_5 = getitem.sum(); getitem = None
add_3 = add_2 + sum_5; add_2 = sum_5 = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None
getitem_1 = invoke_subgraph_1[0]; invoke_subgraph_1 = None
sum_6 = getitem_1.sum(); getitem_1 = None
add_4 = add_3 + sum_6; add_3 = sum_6 = None
return (add_4,)
""",
)
_stable_topological_sort(
graph, torch._dynamo.graph_deduplication.last_node_to_additional_deps
)
self.assertExpectedInline(
graph_code(graph),
"""\
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
subgraph_0 = self.subgraph_0
l_x_ = L_x_
l_y_ = L_y_
x0 = l_x_.view((10, 10))
o0 = x0.view((10, 10)); x0 = None
x0_1 = l_x_.view((10, 10))
o1 = x0_1.view((10, 10)); x0_1 = None
add_2 = o0 + o1; o0 = o1 = None
add_ = l_x_.add_(l_x_); add_ = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_)
mul_ = l_y_.mul_(l_y_); mul_ = None
getitem = invoke_subgraph[0]; invoke_subgraph = None
sum_5 = getitem.sum(); getitem = None
add_3 = add_2 + sum_5; add_2 = sum_5 = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None
getitem_1 = invoke_subgraph_1[0]; invoke_subgraph_1 = None
sum_6 = getitem_1.sum(); getitem_1 = None
add_4 = add_3 + sum_6; add_3 = sum_6 = None
return (add_4,)
""",
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -49,6 +49,10 @@ class GraphRegionTrackerTests(TestCase):
region_groups = tree_map(lambda n: n.name, region_groups)
return str(region_groups)
def get_mutation_tracking(self, fn, *args, **kwargs):
_, region_tracker = extract_graph_and_tracker(fn, *args, **kwargs)
return str(region_tracker.node_to_mutated_arg_positions)
def test_get_regions_single_region_group(self):
def inner_fn(x, y):
x0 = x + 1
@ -57,12 +61,12 @@ class GraphRegionTrackerTests(TestCase):
return z
def fn(x, y):
_o0 = inner_fn(x, y)
o0 = inner_fn(x, y)
o1 = torch.sin(y)
o2 = inner_fn(x, o1)
o3 = inner_fn(x, y)
o4 = o3 * o3
return o2 * o4
return o2 * o4 + o0
self.assertExpectedInline(
self.get_result(
@ -295,6 +299,45 @@ class GraphRegionTrackerTests(TestCase):
[['x1', 'y1', 'sum_1', 'o4'], ['x1_1', 'y1_1', 'sum_2', 'o5']]]""",
)
def test_mutation_tracking_simple(self):
def fn(x, y, z):
x0 = torch.sin(x)
y0 = torch.cos(y)
z.sin_()
y0.add_(z)
return x0.sum() + y0.sum()
self.assertExpectedInline(
self.get_mutation_tracking(
fn,
torch.rand(10, 10),
torch.rand(10, 20),
torch.ones(10, 20),
),
"""{sin_: OrderedSet([0]), add_: OrderedSet([0])}""",
)
def test_mutation_tracking_allow_in_graph(self):
@torch._dynamo.allow_in_graph
def fn_mut(x, y):
x.add_(y)
return x.sum() + y.sum()
def fn(x, y):
z = x + y
o0 = fn_mut(z, y)
z.sin_()
return x + o0
self.assertExpectedInline(
self.get_mutation_tracking(
fn,
torch.rand(20, 10),
torch.rand(20, 10),
),
"""{o0: OrderedSet([0]), sin_: OrderedSet([0])}""",
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -11,20 +11,28 @@ import logging
import operator
from collections import defaultdict
from collections.abc import Generator, Iterable
from typing import Any, Optional
from typing import Optional
import torch
import torch.fx
from torch._dynamo import config
from torch._higher_order_ops.utils import has_potential_input_alias_or_mutation
from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils._ordered_set import OrderedSet
from .graph_region_tracker import Node, Region
from .graph_utils import _detect_cycles, _flatten_args_kwargs
from .graph_utils import _detect_cycles, _get_flat_args, _get_flat_args_unique
# Represents an index into the region
# to select a node and then
# an index into that node's
# flattened arguments
UsageIndex = tuple[int, int]
log = logging.getLogger(__name__)
last_node_to_additional_deps: Optional[dict[Node, OrderedSet[Node]]] = None
def apply_graph_deduplication(output_graph) -> dict[str, torch.fx.GraphModule]: # type: ignore[no-untyped-def]
"""
@ -57,7 +65,12 @@ when they are created in output_graph.
duplicated_region_groups = output_graph.region_tracker.get_identical_regions(
output_graph.graph
)
node_to_additional_deps = _populate_additional_deps(output_graph.graph)
node_to_mutated_arg_positions = (
output_graph.region_tracker.node_to_mutated_arg_positions
)
node_to_additional_deps = _populate_additional_deps(
output_graph.graph, output_graph.region_tracker.node_to_mutated_arg_positions
)
sub_gms: dict[str, torch.fx.GraphModule] = {}
@ -66,11 +79,11 @@ when they are created in output_graph.
region = region_group[0]
(
subgraph,
node_ind_arg_inds,
external_node_usages,
) = _create_subgraph(region, inds_with_external_users)
# Ignore regions with no args for now, could they possibly be evaluated at compile time?
if not list(node_ind_arg_inds):
if not list(external_node_usages):
continue
sub_gm = torch.fx.GraphModule(output_graph.nn_modules, subgraph)
@ -80,19 +93,27 @@ when they are created in output_graph.
get_subgraph_node = output_graph.graph.create_node(
"get_attr", subgraph_name, (), {}
)
for region in region_group:
_replace_region_with_subgraph(
output_graph.graph,
region,
get_subgraph_node,
node_ind_arg_inds.keys(),
external_node_usages,
inds_with_external_users,
sub_gm,
subgraph_name,
node_to_additional_deps,
node_to_mutated_arg_positions,
)
_stable_topological_sort(output_graph.graph, node_to_additional_deps)
# This is to expose the updated node_to_additional_deps to tests
global last_node_to_additional_deps
last_node_to_additional_deps = node_to_additional_deps
_stable_topological_sort(
output_graph.graph,
node_to_additional_deps,
)
return sub_gms
@ -100,29 +121,34 @@ def _replace_region_with_subgraph(
graph: torch.fx.Graph,
region: Region,
get_subgraph_node: Node,
node_ind_arg_ind: Iterable[tuple[int, int]],
external_node_usages: Iterable[OrderedSet[UsageIndex]],
inds_with_external_users: list[int],
sub_gm: torch.fx.GraphModule,
subgraph_name: str,
node_to_additional_deps: dict[torch.fx.Node, list[torch.fx.Node]],
node_to_additional_deps: dict[Node, OrderedSet[Node]],
node_to_mutated_arg_positions: dict[Node, OrderedSet[int]],
) -> None:
sub_args = []
for node_ind, arg_ind in node_ind_arg_ind:
for usages in external_node_usages:
node_ind, usage_ind = next(iter(usages))
node = region[node_ind]
flattened_args_kwargs = _flatten_args_kwargs((node.args, node.kwargs))
sub_args.append(flattened_args_kwargs[arg_ind])
flattened_args_kwargs = _get_flat_args(node, {})
for user_ind, node_usage_ind in usages:
user = region[user_ind]
if user in node_to_mutated_arg_positions:
if node_usage_ind in node_to_mutated_arg_positions[user]:
log.debug(
"NYI: Failed to substitute region %s due to mutation", region
)
return
sub_args.append(flattened_args_kwargs[usage_ind])
invoke_args = (get_subgraph_node, subgraph_name, *sub_args)
fake_inputs = [node.meta["example_value"] for node in sub_args]
if has_potential_input_alias_or_mutation(sub_gm, fake_inputs):
log.debug(
"NYI: Failed to substitute region %s due to input alias or mutation",
region,
)
# Input/Output aliasing not supported in HOPs today
# Note: we should use the nodes in the original graph (the region here)
# because we use the original traced example values for this check
if _has_aliasing(region, sub_args, inds_with_external_users):
return
from torch._inductor.pattern_matcher import stable_topological_sort
invoke_args = (get_subgraph_node, subgraph_name, *sub_args)
invoke_subgraph_node = graph.create_node(
"call_function",
@ -140,38 +166,40 @@ def _replace_region_with_subgraph(
# Erase in reverse topological order
for node in reversed(region):
graph.erase_node(node)
node_to_additional_deps.pop(node)
for dep_list in node_to_additional_deps.values():
# Remove any nodes with additional deps
# This is safe; we've guaranteed that there is
# no input mutation, so all additional deps
# will be internal to the subgraph
node_to_additional_deps.pop(node, None)
for deps in node_to_additional_deps.values():
try:
dep_list.remove(node)
except ValueError:
deps.remove(node)
deps.add(invoke_subgraph_node)
except KeyError:
pass
if config.graph_deduplication_lint:
_detect_cycles(graph)
stable_topological_sort(graph)
graph.lint()
if config.graph_deduplication_lint:
print(_detect_cycles(graph, node_to_additional_deps))
_stable_topological_sort(graph, node_to_additional_deps)
graph.lint()
def _get_external_inputs(
region: Region,
) -> dict[Node, tuple[int, int]]:
external_node_to_indices = dict()
) -> dict[Node, OrderedSet[UsageIndex]]:
external_node_to_usages = defaultdict[Node, OrderedSet[UsageIndex]](OrderedSet)
region_unique = set(region)
for node_ind, node in enumerate(region):
flattened_args_kwargs = _flatten_args_kwargs((node.args, node.kwargs))
flattened_args_kwargs = _get_flat_args(node, {})
for arg_ind, in_node in enumerate(flattened_args_kwargs):
if (
isinstance(in_node, Node)
and in_node not in region_unique
and in_node not in external_node_to_indices
):
external_node_to_indices[in_node] = (node_ind, arg_ind)
if isinstance(in_node, Node) and in_node not in region_unique:
# in_node may occur in multiple nodes' flat_args
# track this so we can check if the arg is mutated
# Previously, we only needed to track one occurrence
# to be able to map that node to a placeholder
external_node_to_usages[in_node].add((node_ind, arg_ind))
return external_node_to_indices
return external_node_to_usages
def _get_all_output_indices(regions: list[Region]) -> list[int]:
@ -194,17 +222,14 @@ def _get_inds_with_external_users(region: Region, inds_unique: set[int]) -> None
def _copy_nodes_and_remap_inputs(
subgraph: torch.fx.Graph, region: Region
) -> dict[tuple[int, int], Any]:
external_inputs_to_indices = _get_external_inputs(region)
indices_to_placeholder_ind: dict[tuple[int, int], Any] = {}
) -> list[OrderedSet[UsageIndex]]:
external_input_to_usages = _get_external_inputs(region)
external_node_usages = list[OrderedSet[UsageIndex]]()
region_to_subgraph_node = {}
for node in external_inputs_to_indices.keys():
for node, usage_indices in external_input_to_usages.items():
placeholder = subgraph.placeholder(f"subgraph_input_{node.name}")
region_to_subgraph_node[node] = placeholder
arg_indices = external_inputs_to_indices[node]
# Note: insertion order matches the order in which placeholders were created
# for the calling convention of the subgraph
indices_to_placeholder_ind[arg_indices] = None
external_node_usages.append(usage_indices)
def map_arg(node: Node) -> Node:
if node in region_to_subgraph_node:
@ -216,7 +241,7 @@ def _copy_nodes_and_remap_inputs(
subgraph_node = subgraph.node_copy(node, lambda old: map_arg(old))
region_to_subgraph_node[node] = subgraph_node
return indices_to_placeholder_ind
return external_node_usages
def _create_subgraph_outputs(
@ -230,30 +255,16 @@ def _create_subgraph_outputs(
def _create_subgraph(
region: Region,
inds_with_external_users: list[int],
) -> tuple[torch.fx.Graph, dict[tuple[int, int], Any]]:
) -> tuple[torch.fx.Graph, list[OrderedSet[UsageIndex]]]:
subgraph: torch.fx.Graph = torch.fx.Graph()
node_ind_input_inds = _copy_nodes_and_remap_inputs(subgraph, region)
external_node_usages = _copy_nodes_and_remap_inputs(subgraph, region)
_create_subgraph_outputs(subgraph, inds_with_external_users)
return subgraph, node_ind_input_inds
def _args(
n: torch.fx.Node,
node_to_additional_deps: Optional[dict[torch.fx.Node, list[torch.fx.Node]]] = None,
) -> list[torch.fx.node.Argument]:
if node_to_additional_deps is None:
node_to_additional_deps = {}
args: list[torch.fx.node.Argument] = []
torch.fx.map_arg((n.args, n.kwargs), args.append)
if n in node_to_additional_deps:
args.extend(node_to_additional_deps[n])
return args
return subgraph, external_node_usages
def _stable_topological_sort(
graph: torch.fx.Graph,
node_to_additional_deps: dict[torch.fx.Node, list[torch.fx.Node]],
node_to_additional_deps: dict[Node, OrderedSet[Node]],
) -> None:
# Nodes are in exactly one of these four collections:
@ -262,14 +273,14 @@ def _stable_topological_sort(
# - Nodes in `ready` have been processed and are already in the correct
# order.
ready = OrderedSet[torch.fx.Node]()
ready = OrderedSet[Node]()
# - `waiting` is a mapping from a dependency to nodes which depend on that
# dependency.
waiting = defaultdict(list)
# - `outputs` are always at the end of the graph
outputs = OrderedSet[torch.fx.Node]()
outputs = OrderedSet[Node]()
# The cursor indicates the last processed node so we can add new nodes
# after it.
@ -283,7 +294,9 @@ def _stable_topological_sort(
continue
waiting_for = [
x for x in _args(node, node_to_additional_deps) if x not in ready
x
for x in _get_flat_args_unique(node, node_to_additional_deps)
if x not in ready
]
if waiting_for:
# We have unprocessed input nodes. Might as well wait for the last
@ -303,23 +316,29 @@ def _stable_topological_sort(
def _populate_additional_deps(
graph: torch.fx.Graph,
) -> dict[torch.fx.Node, list[torch.fx.Node]]:
graph: torch.fx.Graph, node_to_mutated_arg_positions: dict[Node, OrderedSet[int]]
) -> dict[Node, OrderedSet[Node]]:
node_to_additional_deps: dict[Node, OrderedSet[Node]] = defaultdict(OrderedSet)
_add_mutation_dependencies(node_to_mutated_arg_positions, node_to_additional_deps)
_add_global_state_dependencies(graph, node_to_additional_deps)
return node_to_additional_deps
def _add_global_state_dependencies(
graph: torch.fx.Graph, node_to_additional_deps: dict[Node, OrderedSet[Node]]
) -> None:
import torch.amp
node_to_additional_deps: dict[torch.fx.Node, list[torch.fx.Node]] = defaultdict(
list
)
all_nodes = list(graph.nodes)
# These are targets of the nodes which need to stay in the same relative place in the graph
global_state_targets = {torch.amp._enter_autocast, torch.amp._exit_autocast}
all_nodes_dep_on: list[torch.fx.Node] = []
all_nodes_dep_on: list[Node] = []
def prev_cur_nodes(
all_nodes: list[torch.fx.Node],
) -> Generator[tuple[list[torch.fx.Node], torch.fx.Node]]:
prev_nodes: list[torch.fx.Node] = []
all_nodes: list[Node],
) -> Generator[tuple[list[Node], Node], None, None]:
prev_nodes: list[Node] = []
next_nodes = list(reversed(all_nodes))
while next_nodes:
@ -328,11 +347,93 @@ def _populate_additional_deps(
prev_nodes.append(cur_node)
for prev_nodes, cur_node in prev_cur_nodes(all_nodes):
args_unique = _args(cur_node)
additional_deps = node_to_additional_deps[cur_node]
additional_deps.extend(n for n in all_nodes_dep_on if n not in args_unique)
args_unique = _get_flat_args_unique(cur_node, {})
new_deps = [n for n in all_nodes_dep_on if n not in args_unique]
if new_deps:
additional_deps = node_to_additional_deps[cur_node]
additional_deps.update(new_deps)
if cur_node.target in global_state_targets:
additional_deps.extend(n for n in prev_nodes if n not in args_unique)
additional_deps = node_to_additional_deps[cur_node]
additional_deps.update(n for n in prev_nodes if n not in args_unique)
all_nodes_dep_on.append(cur_node)
return node_to_additional_deps
def _add_mutation_dependencies(
node_to_mutated_arg_positions: dict[Node, OrderedSet[int]],
node_to_additional_deps: dict[Node, OrderedSet[Node]],
) -> None:
for node, indices in node_to_mutated_arg_positions.items():
flat_args_kwargs = _get_flat_args(node, {})
# for all mutated args,
# add dependency on usages which occur after node to ensure
# node will always be ordered before them
# also add node as a dependency on usages which
# occur before node to ensure node is ordered after them
for index in indices:
mutated_arg = flat_args_kwargs[index]
for user in mutated_arg.users:
if user is node:
continue
elif user < node:
node_to_additional_deps[node].add(user)
elif user > node:
node_to_additional_deps[user].add(node)
def _has_aliasing(
region: Region, inputs: list[Node], inds_with_external_users: list[int]
) -> bool:
input_storages: dict[StorageWeakRef, Node] = dict()
for node in inputs:
example_value = node.meta["example_value"]
if isinstance(example_value, torch.Tensor):
storage = StorageWeakRef(example_value._typed_storage())
if storage in input_storages:
# input-input aliasing
log.debug(
"NYI: Failed to substitute region %s due to input-output aliasing detected at nodes %s, %s",
region,
input_storages[storage],
node,
)
return True
input_storages[storage] = node
output_storages: dict[StorageWeakRef, Node] = dict()
for i in inds_with_external_users:
out_node = region[i]
if out_node:
example_value = out_node.meta["example_value"]
assert not isinstance(example_value, list)
if isinstance(example_value, torch.Tensor):
storage = StorageWeakRef(example_value._typed_storage())
if storage in output_storages:
# output-output aliasing
log.debug(
"NYI: Failed to substitute region %s due to output-output aliasing detected at nodes %s, %s",
region,
output_storages[storage],
out_node,
)
return True
output_storages[storage] = out_node
intersected_storages = input_storages.keys() & output_storages.keys()
if len(intersected_storages) > 0:
# input-output aliasing
aliased = [
(input_storages[s], output_storages[s]) for s in intersected_storages
]
aliased = ", ".join([f"{i} and {o}" for i, o in aliased])
log.debug(
"NYI: Failed to substitute region %s due to input-output aliasing detected at nodes %s",
region,
aliased,
)
return True
return False

View File

@ -25,9 +25,10 @@ from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar
import torch._logging
import torch.fx
from torch._subclasses.fake_tensor import FakeTensor
from torch.utils._ordered_set import OrderedSet
from torch.utils._pytree import tree_flatten
from .graph_utils import _flatten_args_kwargs
from .graph_utils import _get_flat_args_unique
T = TypeVar("T")
@ -150,6 +151,9 @@ class BackwardBfsArgIter:
def create(origin: Node) -> "BackwardBfsArgIter":
it = BackwardBfsArgIter(origin)
it.add_children(origin)
# pop the origin node, since it is the origin of
# the region and does not need to be considered for addition
assert it.next()
return it
def next(self) -> Optional[Node]:
@ -164,17 +168,11 @@ class BackwardBfsArgIter:
return self._cur
def add_children(self, node: Node) -> None:
arg: Any
flat_args, _ = tree_flatten(node.args)
flat_args = _get_flat_args_unique(node, {})
for arg in flat_args:
if isinstance(arg, Node):
self._append(arg)
flat_kwargs, _ = tree_flatten(node.kwargs)
for kwarg in flat_kwargs:
if isinstance(kwarg, Node):
self._append(kwarg)
def _append(self, arg: Node) -> None:
if self._cur is None:
self._cur = arg
@ -199,6 +197,8 @@ class GraphRegionTracker:
def __init__(self) -> None:
self.hash_to_duplicates: dict[str, IdenticalNodes] = defaultdict(list)
self.node_to_duplicates: dict[Node, IdenticalNodes] = {}
# Note: position is in flattened args/kwargs list
self.node_to_mutated_arg_positions: dict[Node, OrderedSet[int]] = {}
self.input_pickler = InputPickler()
def _hash_node(
@ -240,6 +240,28 @@ class GraphRegionTracker:
except NodeHashException as e:
log.debug("Unable to hash node %s with exception %s", node, e)
def track_node_mutations(
self,
node: Node,
flat_args_kwargs: list[Any],
id_to_initial_version: dict[int, int],
) -> None:
"""
This function tracks which argument positions are mutated by the given node. Subgraph HOP does not support
input mutations today so we will skip regions which have inputs that are mutated.
"""
mutated_arg_positions = OrderedSet[int]()
for i, arg in enumerate(flat_args_kwargs):
val_id = id(arg)
if (
val_id in id_to_initial_version
and id_to_initial_version[val_id] != arg._version
):
mutated_arg_positions.add(i)
if mutated_arg_positions:
self.node_to_mutated_arg_positions[node] = mutated_arg_positions
def get_identical_regions(self, graph: torch.fx.Graph) -> list[list[Region]]:
"""
This function is responsible for extracting the largest regions of identical nodes from the given graph.
@ -303,6 +325,38 @@ class GraphRegionTracker:
return f"GraphRegionTracker(hash_to_duplicates={self.hash_to_duplicates}, node_to_duplicates={self.node_to_duplicates})"
class RegionWrapper:
"""Holds state for regions e.g. ancestors and new candidate nodes for consideration"""
def __init__(
self, region: Region, node_to_recursive_ancestors: dict[Node, set[Node]]
) -> None:
assert len(region) == 1, "all regions should start with one node"
node = region[0]
self.node_to_recursive_ancestors = node_to_recursive_ancestors
self.iter = BackwardBfsArgIter.create(node)
self.nodes_unique = OrderedSet([node])
self.ancestors = set(node_to_recursive_ancestors[node])
self.region = region
def next_candidate(self) -> Optional[Node]:
return self.iter.next()
def will_inclusion_create_cycle(self, node: Node) -> bool:
external_users = [user for user in node.users if user not in self.nodes_unique]
for user in external_users:
if user in self.ancestors:
return True
return False
def add(self, node: Node) -> None:
self.nodes_unique.add(node)
self.region.append(node)
self.iter.add_children(node)
self.ancestors.update(self.node_to_recursive_ancestors[node])
def fully_expand_region_group(
regions: list[Region],
seen_nodes: set[Node],
@ -314,20 +368,12 @@ def fully_expand_region_group(
# All regions should start with 1 node
assert all(len(region) == 1 for region in regions)
region_iters = []
for region in regions:
(origin,) = region # Only works for 1 element sets
region_iters.append(BackwardBfsArgIter.create(origin))
region_wrappers = [
RegionWrapper(region, node_to_recursive_ancestors) for region in regions
]
nodes_to_add: list[Node] = []
# we already have the origin node in each region
for region_it in region_iters:
node = region_it.next()
assert node
region_it.add_children(node)
current_node = region_iters[0].next()
nodes_to_add = OrderedSet[Node]()
current_node = region_wrappers[0].next_candidate()
# No children
if current_node is None:
@ -337,46 +383,51 @@ def fully_expand_region_group(
# regions are only expanded if the node to add is valid
# for ALL regions
while current_node:
add_node = not _will_create_cycle(
current_node, regions[0], node_to_recursive_ancestors
add_to_all_regions = not region_wrappers[0].will_inclusion_create_cycle(
current_node
)
nodes_to_add.clear()
nodes_to_add.append(current_node)
nodes_to_add_set = set(nodes_to_add)
for ind, region_it in enumerate(region_iters[1:]):
ind += 1 # compensate for the 0th region
node = region_it.next()
nodes_to_add.add(current_node)
for region_wrapper in region_wrappers[1:]:
candidate = region_wrapper.next_candidate()
debug_log("--------------------")
debug_log("considering adding: %s, cur_node: %s", node, current_node)
debug_log("previously claimed nodes: %s", node in seen_nodes)
if node:
debug_log("is_identical: %s", is_identical_fn(node, current_node))
add_node &= (
node not in seen_nodes
and node not in nodes_to_add_set
and node.op != "placeholder"
and is_identical_fn(node, current_node)
and not _will_create_cycle(
node, regions[ind], node_to_recursive_ancestors
)
)
nodes_to_add.append(node)
nodes_to_add_set.add(node)
else:
add_node = False
debug_log(
"considering candidate: %s, cur_node: %s", candidate, current_node
)
if not candidate or not add_to_all_regions:
add_to_all_regions = False
continue
debug_log(
"candidate in previously claimed nodes?: %s", candidate in seen_nodes
)
debug_log("is_identical: %s", is_identical_fn(candidate, current_node))
add_to_all_regions &= (
candidate not in seen_nodes
and candidate not in nodes_to_add
and candidate.op != "placeholder"
and is_identical_fn(candidate, current_node)
and not region_wrapper.will_inclusion_create_cycle(candidate)
)
nodes_to_add.add(candidate)
debug_log(f"add_to_all_regions: {add_to_all_regions}")
debug_log("--------------------")
if add_node:
for region, region_it, node in zip(regions, region_iters, nodes_to_add):
region.append(node)
if add_to_all_regions:
assert len(region_wrappers) == len(nodes_to_add), (
"Numer of nodes to add must equal the number of regions"
)
for region_wrapper, node in zip(region_wrappers, nodes_to_add):
region_wrapper.add(node)
debug_log("adding %s's children", node)
debug_log("%s %s", node.args, list(node.kwargs.items()))
region_it.add_children(node)
seen_nodes.add(node)
current_node = region_iters[0].next()
current_node = region_wrappers[0].next_candidate()
# Ensure regions are sorted in topological order
for region in regions:
@ -391,7 +442,7 @@ def _populate_recursive_ancestor_map(graph: torch.fx.Graph) -> dict[Node, set[No
for node in graph.nodes:
node_to_recursive_ancestors[node] = set()
for node in graph.nodes:
all_args = _flatten_args_kwargs((node.args, node.kwargs))
all_args = _get_flat_args_unique(node, {})
for arg in all_args:
if isinstance(arg, Node):
node_to_recursive_ancestors[node].update(
@ -399,20 +450,3 @@ def _populate_recursive_ancestor_map(graph: torch.fx.Graph) -> dict[Node, set[No
)
node_to_recursive_ancestors[node].add(arg)
return node_to_recursive_ancestors
def _will_create_cycle(
node_to_add: Node,
region: Region,
node_to_recursive_ancestors: dict[Node, set[Node]],
) -> bool:
region_set: set[Node] = set(region)
region_ancestors: set[Node] = set(
tree_flatten([list(node_to_recursive_ancestors[node]) for node in region])[0]
)
external_users = [user for user in node_to_add.users if user not in region_set]
for user in external_users:
if user in region_ancestors:
return True
return False

View File

@ -1,8 +1,8 @@
from collections import deque
from typing import Any
from torch.fx import Graph, Node
from torch.utils._pytree import tree_flatten
from torch.fx import Graph, map_arg, Node
from torch.utils._ordered_set import OrderedSet
# flattens with support for slices
@ -10,26 +10,29 @@ from torch.utils._pytree import tree_flatten
# be register/unregister slices as pytree nodes
# but there is no unregister API in the pytorch
# pytree impl
def _flatten_args_kwargs(args: Any) -> list[Node]:
fully_flattened = []
def flatten(args: Any) -> None:
flattened, _ = tree_flatten(args)
for arg in flattened:
if isinstance(arg, slice):
start = arg.start
stop = arg.stop
step = arg.step
flatten((start, stop, step))
else:
fully_flattened.append(arg)
flatten(args)
return fully_flattened
def _get_flat_args(
node: Node, node_to_additional_deps: dict[Node, OrderedSet[Node]]
) -> list[Node]:
args = list[Any]()
map_arg((node.args, node.kwargs), args.append)
if node in node_to_additional_deps:
args.extend(node_to_additional_deps[node])
return args
def _detect_cycles(graph: Graph) -> str:
def _get_flat_args_unique(
node: Node, node_to_additional_deps: dict[Node, OrderedSet[Node]]
) -> OrderedSet[Node]:
args = OrderedSet[Node]()
map_arg((node.args, node.kwargs), args.add)
if node in node_to_additional_deps:
args.update(node_to_additional_deps[node])
return args
def _detect_cycles(
graph: Graph, node_to_additional_deps: dict[Node, OrderedSet[Node]]
) -> str:
current_path: deque[Node] = deque()
current_path_set: set[Node] = set()
pending: deque[tuple[Node, Node]] = deque()
@ -45,25 +48,30 @@ def _detect_cycles(graph: Graph) -> str:
def current_path_head() -> Node:
return current_path[-1]
for origin in graph.find_nodes(op="placeholder"):
for origin in graph.find_nodes(op="output"):
current_path.clear()
current_path_set.clear()
add_to_current_path(origin)
for child in origin.users:
for child in _get_flat_args_unique(origin, node_to_additional_deps):
pending.append((child, origin))
while pending:
cur_node, parent = pending.pop()
while current_path_head() != parent:
# handle backtracking
while current_path and current_path_head() != parent:
pop_current_path()
if not isinstance(cur_node, Node):
continue
if cur_node in current_path_set:
current_path.append(cur_node)
return f"cycle detected in path: {current_path}"
add_to_current_path(cur_node)
for child in cur_node.users:
for child in _get_flat_args_unique(cur_node, node_to_additional_deps):
pending.append((child, cur_node))
return "no cycle detected"

View File

@ -92,6 +92,8 @@ from torch.nn.modules.lazy import LazyModuleMixin
from torch.utils._triton import has_triton, has_triton_package
from torch.utils.hooks import RemovableHandle
from .graph_utils import _get_flat_args
if typing.TYPE_CHECKING:
from collections.abc import (
@ -3151,6 +3153,20 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
tx, (node.args, node.kwargs), allow_non_graph_fake
)
if (
torch._dynamo.config.use_graph_deduplication
or torch._dynamo.config.track_nodes_for_deduplication
):
flat_args_kwargs = get_fake_values_from_nodes(
tx, _get_flat_args(node, {}), allow_non_graph_fake
)
id_to_initial_version = {
id(arg): arg._version for arg in flat_args_kwargs if is_fake(arg)
}
else:
flat_args_kwargs = []
id_to_initial_version = {}
nnmodule = None
if op == "call_method" and len(args) > 0 and isinstance(args[0], torch.nn.Module):
# If the first argument is nn.Module, should copy to fake mode.
@ -3290,6 +3306,17 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
_ = pytree.tree_map_only(
torch.Tensor, functools.partial(ensure_graph_fake, tx=tx), ret_val
)
if (
torch._dynamo.config.use_graph_deduplication
or torch._dynamo.config.track_nodes_for_deduplication
):
tx.output.region_tracker.track_node_mutations(
node,
flat_args_kwargs,
id_to_initial_version,
)
return ret_val