Compare commits

...

4 Commits

Author SHA1 Message Date
da300559c6 [Hierarchical compile] Ensure output nodes are sorted last
ghstack-source-id: 0259a8bf2855f4d8983d4e937dfec94ec8cbef70
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151295
2025-04-14 22:51:45 -07:00
2a854bfd9c [Hierarchical Compile] Handle autocast ctx manager 2025-04-14 13:54:19 -07:00
2ae6a16975 [Hierarchical Compile] Fix small bug 2025-04-14 13:54:19 -07:00
2d92df58c9 Disable optimizer and enable graph deduplication 2025-04-14 13:54:19 -07:00
5 changed files with 371 additions and 6 deletions

View File

@ -1729,6 +1729,7 @@ class BenchmarkRunner:
self.optimizer = torch.optim.Adam(
params, lr=0.01, capturable=True, foreach=True
)
self.optimizer.step = torch._dynamo.disable(self.optimizer.step)
else:
self.optimizer = None

View File

@ -660,6 +660,259 @@ class <lambda>(torch.nn.Module):
"""cycle detected in path: deque([arg0_1, invoke_subgraph_1, getitem_1, sum_2, add_2, invoke_subgraph, getitem, sum_1, add_1, add_2])""",
)
def test_autocast_ordering(self):
from torch._dynamo.graph_deduplication import (
_populate_additional_deps,
_stable_topological_sort,
)
def inner_fn(x, y):
x0 = x.view(x.size())
return x0.view(x.size())
def inner_fn2(x, y):
x = x * 2
y = y * 2
return x.sum() + y.sum()
def fn(x, y):
o0 = inner_fn(x, y)
o1 = inner_fn(x, y)
o2 = inner_fn2(x, y)
o3 = inner_fn2(x, y)
return o0 + o1 + o2.sum() + o3.sum()
x = torch.rand(10, 10, requires_grad=False)
y = torch.rand(10, 20, requires_grad=False)
x_clone = x.clone()
y_clone = y.clone()
_, _, fw_graphs = self.run_and_return_graphs(fn, x_clone, y_clone)
mod = fw_graphs[0]
def get_node(name):
return next(n for n in mod.graph.nodes if n.name == name)
sum_1 = get_node("sum_1")
enter_autocast = mod.graph.call_function(torch.amp._enter_autocast)
sum_1.append(enter_autocast)
sum_2 = get_node("sum_2")
exit_autocast = mod.graph.call_function(torch.amp._exit_autocast)
sum_2.append(exit_autocast)
additional_deps = _populate_additional_deps(mod.graph)
invoke_subgraph = get_node("invoke_subgraph")
invoke_subgraph.append(enter_autocast)
getitem_1 = get_node("getitem_1")
getitem_1.append(exit_autocast)
self.assertExpectedInline(
graph_str(mod),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"):
view: "f32[10, 10]" = torch.ops.aten.view.default(arg0_1, [10, 10])
view_1: "f32[10, 10]" = torch.ops.aten.view.default(view, [10, 10]); view = None
view_2: "f32[10, 10]" = torch.ops.aten.view.default(arg0_1, [10, 10])
view_3: "f32[10, 10]" = torch.ops.aten.view.default(view_2, [10, 10]); view_2 = None
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(view_1, view_3); view_1 = view_3 = None
repeated_subgraph0 = self.repeated_subgraph0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0 = None
_enter_autocast = torch.amp.autocast_mode._enter_autocast(); _enter_autocast = None
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None
add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, sum_1); add = sum_1 = None
repeated_subgraph0_1 = self.repeated_subgraph0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0_1 = arg0_1 = arg1_1 = None
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
_exit_autocast = torch.amp.autocast_mode._exit_autocast(); _exit_autocast = None
sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_1); getitem_1 = None
add_2: "f32[10, 10]" = torch.ops.aten.add.Tensor(add_1, sum_2); add_1 = sum_2 = None
return (add_2,)
class repeated_subgraph0(torch.nn.Module):
def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"):
mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(arg0_1, 2); arg0_1 = None
mul_1: "f32[10, 20]" = torch.ops.aten.mul.Tensor(arg1_1, 2); arg1_1 = None
sum_1: "f32[]" = torch.ops.aten.sum.default(mul); mul = None
sum_2: "f32[]" = torch.ops.aten.sum.default(mul_1); mul_1 = None
add: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None
return (add,)
""",
)
_stable_topological_sort(mod.graph, additional_deps)
self.assertExpectedInline(
graph_str(mod),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"):
view: "f32[10, 10]" = torch.ops.aten.view.default(arg0_1, [10, 10])
view_1: "f32[10, 10]" = torch.ops.aten.view.default(view, [10, 10]); view = None
view_2: "f32[10, 10]" = torch.ops.aten.view.default(arg0_1, [10, 10])
view_3: "f32[10, 10]" = torch.ops.aten.view.default(view_2, [10, 10]); view_2 = None
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(view_1, view_3); view_1 = view_3 = None
repeated_subgraph0 = self.repeated_subgraph0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0 = None
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None
_enter_autocast = torch.amp.autocast_mode._enter_autocast(); _enter_autocast = None
add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, sum_1); add = sum_1 = None
repeated_subgraph0_1 = self.repeated_subgraph0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0_1 = arg0_1 = arg1_1 = None
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_1); getitem_1 = None
_exit_autocast = torch.amp.autocast_mode._exit_autocast(); _exit_autocast = None
add_2: "f32[10, 10]" = torch.ops.aten.add.Tensor(add_1, sum_2); add_1 = sum_2 = None
return (add_2,)
class repeated_subgraph0(torch.nn.Module):
def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"):
mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(arg0_1, 2); arg0_1 = None
mul_1: "f32[10, 20]" = torch.ops.aten.mul.Tensor(arg1_1, 2); arg1_1 = None
sum_1: "f32[]" = torch.ops.aten.sum.default(mul); mul = None
sum_2: "f32[]" = torch.ops.aten.sum.default(mul_1); mul_1 = None
add: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None
return (add,)
""",
)
def test_output_nodes_last(self):
from torch._dynamo.graph_deduplication import _stable_topological_sort
def inner_fn(x, y):
x0 = x.view(x.size())
return x0.view(x.size())
def inner_fn2(x, y):
x = x * 2
y = y * 2
return x.sum() + y.sum()
def fn(x, y):
o0 = inner_fn(x, y)
o1 = inner_fn(x, y)
o2 = inner_fn2(x, y)
o3 = inner_fn2(x, y)
return o0 + o1 + o2.sum() + o3.sum()
x = torch.rand(10, 10, requires_grad=False)
y = torch.rand(10, 20, requires_grad=False)
x_clone = x.clone()
y_clone = y.clone()
_, _, fw_graphs = self.run_and_return_graphs(fn, x_clone, y_clone)
mod = fw_graphs[0]
output = next(n for n in mod.graph.nodes if n.op == "output")
add_2 = next(n for n in mod.graph.nodes if n.name == "sum_2")
add_2.append(output)
self.assertExpectedInline(
graph_str(mod),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"):
view: "f32[10, 10]" = torch.ops.aten.view.default(arg0_1, [10, 10])
view_1: "f32[10, 10]" = torch.ops.aten.view.default(view, [10, 10]); view = None
view_2: "f32[10, 10]" = torch.ops.aten.view.default(arg0_1, [10, 10])
view_3: "f32[10, 10]" = torch.ops.aten.view.default(view_2, [10, 10]); view_2 = None
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(view_1, view_3); view_1 = view_3 = None
repeated_subgraph0 = self.repeated_subgraph0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0 = None
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None
add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, sum_1); add = sum_1 = None
repeated_subgraph0_1 = self.repeated_subgraph0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0_1 = arg0_1 = arg1_1 = None
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_1); getitem_1 = None
return (add_2,)
add_2: "f32[10, 10]" = torch.ops.aten.add.Tensor(add_1, sum_2); add_1 = sum_2 = None
class repeated_subgraph0(torch.nn.Module):
def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"):
mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(arg0_1, 2); arg0_1 = None
mul_1: "f32[10, 20]" = torch.ops.aten.mul.Tensor(arg1_1, 2); arg1_1 = None
sum_1: "f32[]" = torch.ops.aten.sum.default(mul); mul = None
sum_2: "f32[]" = torch.ops.aten.sum.default(mul_1); mul_1 = None
add: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None
return (add,)
""",
)
_stable_topological_sort(mod.graph, {})
self.assertExpectedInline(
graph_str(mod),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"):
view: "f32[10, 10]" = torch.ops.aten.view.default(arg0_1, [10, 10])
view_1: "f32[10, 10]" = torch.ops.aten.view.default(view, [10, 10]); view = None
view_2: "f32[10, 10]" = torch.ops.aten.view.default(arg0_1, [10, 10])
view_3: "f32[10, 10]" = torch.ops.aten.view.default(view_2, [10, 10]); view_2 = None
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(view_1, view_3); view_1 = view_3 = None
repeated_subgraph0 = self.repeated_subgraph0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0 = None
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None
add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, sum_1); add = sum_1 = None
repeated_subgraph0_1 = self.repeated_subgraph0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0_1 = arg0_1 = arg1_1 = None
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_1); getitem_1 = None
add_2: "f32[10, 10]" = torch.ops.aten.add.Tensor(add_1, sum_2); add_1 = sum_2 = None
return (add_2,)
class repeated_subgraph0(torch.nn.Module):
def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"):
mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(arg0_1, 2); arg0_1 = None
mul_1: "f32[10, 20]" = torch.ops.aten.mul.Tensor(arg1_1, 2); arg1_1 = None
sum_1: "f32[]" = torch.ops.aten.sum.default(mul); mul = None
sum_2: "f32[]" = torch.ops.aten.sum.default(mul_1); mul_1 = None
add: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None
return (add,)
""",
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -419,7 +419,7 @@ enable_cpp_framelocals_guard_eval = True
# Whether to automatically find and replace identical graph
# regions with a call to invoke_subgraph
use_graph_deduplication = False
use_graph_deduplication = True
# Whether to track nodes for deduplication (testing only)
# This flag is ignored if use_graph_deduplication is True

View File

@ -9,13 +9,15 @@ structures across different parts of the network.
import logging
import operator
from collections.abc import Iterable
from typing import Any
from collections import defaultdict
from collections.abc import Generator, Iterable
from typing import Any, Optional
import torch
import torch.fx
from torch._dynamo import config
from torch._higher_order_ops.utils import has_potential_input_alias_or_mutation
from torch.utils._ordered_set import OrderedSet
from .graph_region_tracker import Node, Region
from .graph_utils import _detect_cycles, _flatten_args_kwargs
@ -51,11 +53,11 @@ The deduplication mutates the output_graph argument in place.
Returns a mapping of nodes to their subgraph output replacement node to remap outputs
when they are created in output_graph.
"""
from torch._inductor.pattern_matcher import stable_topological_sort
duplicated_region_groups = output_graph.region_tracker.get_identical_regions(
output_graph.graph
)
node_to_additional_deps = _populate_additional_deps(output_graph.graph)
sub_gms: dict[str, torch.fx.GraphModule] = {}
@ -87,9 +89,10 @@ when they are created in output_graph.
inds_with_external_users,
sub_gm,
subgraph_name,
node_to_additional_deps,
)
stable_topological_sort(output_graph.graph)
_stable_topological_sort(output_graph.graph, node_to_additional_deps)
return sub_gms
@ -101,6 +104,7 @@ def _replace_region_with_subgraph(
inds_with_external_users: list[int],
sub_gm: torch.fx.GraphModule,
subgraph_name: str,
node_to_additional_deps: dict[torch.fx.Node, list[torch.fx.Node]],
) -> None:
sub_args = []
for node_ind, arg_ind in node_ind_arg_ind:
@ -133,6 +137,12 @@ def _replace_region_with_subgraph(
# Erase in reverse topological order
for node in reversed(region):
graph.erase_node(node)
node_to_additional_deps.pop(node)
for dep_list in node_to_additional_deps.values():
try:
dep_list.remove(node)
except ValueError:
pass
if config.graph_deduplication_lint:
_detect_cycles(graph)
@ -222,3 +232,104 @@ def _create_subgraph(
node_ind_input_inds = _copy_nodes_and_remap_inputs(subgraph, region)
_create_subgraph_outputs(subgraph, inds_with_external_users)
return subgraph, node_ind_input_inds
def _args(
n: torch.fx.Node,
node_to_additional_deps: Optional[dict[torch.fx.Node, list[torch.fx.Node]]] = None,
) -> list[torch.fx.node.Argument]:
if node_to_additional_deps is None:
node_to_additional_deps = {}
args: list[torch.fx.node.Argument] = []
torch.fx.map_arg((n.args, n.kwargs), args.append)
if n in node_to_additional_deps:
args.extend(node_to_additional_deps[n])
return args
def _stable_topological_sort(
graph: torch.fx.Graph,
node_to_additional_deps: dict[torch.fx.Node, list[torch.fx.Node]],
) -> None:
# Nodes are in exactly one of these four collections:
# - Nodes in `pending` are waiting to be processed (in reverse order):
pending = list(reversed(graph.nodes))
# - Nodes in `ready` have been processed and are already in the correct
# order.
ready = OrderedSet[torch.fx.Node]()
# - `waiting` is a mapping from a dependency to nodes which depend on that
# dependency.
waiting = defaultdict(list)
# - `outputs` are always at the end of the graph
outputs = OrderedSet[torch.fx.Node]()
# The cursor indicates the last processed node so we can add new nodes
# after it.
cursor = None
while pending:
node = pending.pop()
if node.target == "output":
outputs.add(node)
assert not node.users, "output nodes should have no users"
continue
waiting_for = [
x for x in _args(node, node_to_additional_deps) if x not in ready
]
if waiting_for:
# We have unprocessed input nodes. Might as well wait for the last
# arg so an already sorted list will only recheck this node once.
waiting[waiting_for[-1]].append(node)
else:
ready.add(node)
if cursor and cursor.next is not node:
cursor.append(node)
cursor = node
# Mark the nodes that have been waiting for this node to finish as
# ready to check again.
pending.extend(reversed(waiting.pop(node, ())))
ready.update(outputs)
assert not waiting and len(ready) == len(graph.nodes)
def _populate_additional_deps(
graph: torch.fx.Graph,
) -> dict[torch.fx.Node, list[torch.fx.Node]]:
import torch.amp
node_to_additional_deps: dict[torch.fx.Node, list[torch.fx.Node]] = defaultdict(
list
)
all_nodes = list(graph.nodes)
# These are targets of the nodes which need to stay in the same relative place in the graph
global_state_targets = {torch.amp._enter_autocast, torch.amp._exit_autocast}
all_nodes_dep_on: list[torch.fx.Node] = []
def prev_cur_nodes(
all_nodes: list[torch.fx.Node],
) -> Generator[tuple[list[torch.fx.Node], torch.fx.Node]]:
prev_nodes: list[torch.fx.Node] = []
next_nodes = list(reversed(all_nodes))
while next_nodes:
cur_node = next_nodes.pop()
yield prev_nodes, cur_node
prev_nodes.append(cur_node)
for prev_nodes, cur_node in prev_cur_nodes(all_nodes):
args_unique = _args(cur_node)
additional_deps = node_to_additional_deps[cur_node]
additional_deps.extend(n for n in all_nodes_dep_on if n not in args_unique)
if cur_node.target in global_state_targets:
additional_deps.extend(n for n in prev_nodes if n not in args_unique)
all_nodes_dep_on.append(cur_node)
return node_to_additional_deps

View File

@ -397,7 +397,7 @@ def _populate_recursive_ancestor_map(graph: torch.fx.Graph) -> dict[Node, set[No
node_to_recursive_ancestors[node].update(
node_to_recursive_ancestors[arg]
)
node_to_recursive_ancestors[node].add(node)
node_to_recursive_ancestors[node].add(arg)
return node_to_recursive_ancestors