mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-01 13:34:57 +08:00
Compare commits
7 Commits
ciflow/tru
...
mlazos/hc-
| Author | SHA1 | Date | |
|---|---|---|---|
| f3949c0bc1 | |||
| b50e411346 | |||
| 0cc273cf0c | |||
| b347aaa731 | |||
| 4161b0fd72 | |||
| ffe1faa370 | |||
| 8ba70748dc |
@ -1,8 +1,8 @@
|
||||
add_loop_eager,compile_time_instruction_count,3167000000,0.015
|
||||
add_loop_eager,compile_time_instruction_count,3035000000,0.015
|
||||
|
||||
|
||||
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,6066000000,0.025
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,5928000000,0.025
|
||||
|
||||
|
||||
|
||||
@ -14,11 +14,11 @@ add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44480000000,0.025
|
||||
|
||||
|
||||
|
||||
add_loop_inductor_gpu,compile_time_instruction_count,26050000000,0.015
|
||||
add_loop_inductor_gpu,compile_time_instruction_count,25900000000,0.015
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1018000000,0.015
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,1011000000,0.015
|
||||
|
||||
|
||||
|
||||
@ -34,7 +34,7 @@ basic_modules_ListOfLinears_inductor_gpu,compile_time_instruction_count,10370000
|
||||
|
||||
|
||||
|
||||
update_hint_regression,compile_time_instruction_count,1723000000,0.02
|
||||
update_hint_regression,compile_time_instruction_count,1715000000,0.02
|
||||
|
||||
|
||||
|
||||
@ -42,15 +42,15 @@ float_args,compile_time_instruction_count,439200000,0.015
|
||||
|
||||
|
||||
|
||||
sum_floordiv_regression,compile_time_instruction_count,1024000000,0.015
|
||||
sum_floordiv_regression,compile_time_instruction_count,1009000000,0.015
|
||||
|
||||
|
||||
|
||||
symint_sum,compile_time_instruction_count,3278000000,0.015
|
||||
symint_sum,compile_time_instruction_count,3252000000,0.015
|
||||
|
||||
|
||||
|
||||
symint_sum_loop,compile_time_instruction_count,4300000000,0.015
|
||||
symint_sum_loop,compile_time_instruction_count,4262000000,0.015
|
||||
|
||||
|
||||
|
||||
@ -58,11 +58,11 @@ aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2091000000
|
||||
|
||||
|
||||
|
||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5944000000,0.015
|
||||
aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5981000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8586000000,0.015
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,8630000000,0.015
|
||||
|
||||
|
||||
|
||||
@ -70,8 +70,8 @@ aotdispatcher_partitioner_cpu2,compile_time_instruction_count,1900000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3795000000,0.015
|
||||
aotdispatcher_training_nosubclass_cpu,compile_time_instruction_count,3818000000,0.015
|
||||
|
||||
|
||||
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10280000000,0.015
|
||||
aotdispatcher_training_subclass_cpu,compile_time_instruction_count,10350000000,0.015
|
||||
|
||||
|
@ -1,11 +1,17 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
# flake8: noqa: B950
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch._dynamo.graph_deduplication import _flatten_args_kwargs
|
||||
from torch._dynamo.graph_utils import _detect_cycles
|
||||
from torch._dynamo.test_case import TestCase
|
||||
from torch._dynamo.testing import AotEagerAndRecordGraphs, normalize_gm
|
||||
from torch._dynamo.testing import (
|
||||
AotEagerAndRecordGraphs,
|
||||
extract_graph_and_tracker,
|
||||
normalize_gm,
|
||||
)
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
def extract_graph(fn, *args, **kwargs):
|
||||
@ -19,9 +25,32 @@ def graph_str(gm):
|
||||
|
||||
|
||||
class GraphDededuplicationTests(TestCase):
|
||||
def setUp(self):
|
||||
self.exit_stack = contextlib.ExitStack()
|
||||
self.exit_stack.enter_context(
|
||||
torch._dynamo.config.patch("use_graph_deduplication", True)
|
||||
)
|
||||
super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
self.exit_stack.close()
|
||||
super().tearDown()
|
||||
|
||||
def run_and_return_graphs(self, fn, *args, **kwargs):
|
||||
with torch._dynamo.config.patch("use_graph_deduplication", True):
|
||||
return extract_graph(fn, *args, **kwargs)
|
||||
return extract_graph(fn, *args, **kwargs)
|
||||
|
||||
def run_and_get_simple_graph(self):
|
||||
def fn(x, y):
|
||||
x0 = x + 1
|
||||
y0 = y + 2
|
||||
z = x0.sum() + y0.sum()
|
||||
return z
|
||||
|
||||
x = torch.rand(10, 10, requires_grad=False)
|
||||
y = torch.rand(10, 20, requires_grad=False)
|
||||
|
||||
_, _, fw_graphs = self.run_and_return_graphs(fn, x, y)
|
||||
return fw_graphs[0]
|
||||
|
||||
def test_single_subgraph(self):
|
||||
def inner_fn(x, y):
|
||||
@ -432,12 +461,6 @@ class GraphModule(torch.nn.Module):
|
||||
)
|
||||
|
||||
def test_input_mutation(self):
|
||||
def inner_fn(x, y):
|
||||
x0 = x + 1
|
||||
y0 = y + 2
|
||||
z = x0.sum() + y0.sum()
|
||||
return z
|
||||
|
||||
def inner_fn2(x, y):
|
||||
x0 = x + 1
|
||||
y0 = y + 1
|
||||
@ -447,9 +470,6 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
def fn(x, y):
|
||||
x0 = torch.sin(x)
|
||||
_y0 = torch.cos(y)
|
||||
# o0 = inner_fn(x0, y0)
|
||||
# o1 = inner_fn(x0, o0)
|
||||
o2 = inner_fn2(x0, y)
|
||||
o3 = inner_fn2(x0.clone(), y.clone())
|
||||
return o2 + o3
|
||||
@ -583,28 +603,13 @@ class <lambda>(torch.nn.Module):
|
||||
""",
|
||||
)
|
||||
|
||||
def test_flatten_with_slices(self):
|
||||
tree = [{"x": 3}, ["x", slice(1, 2, 3), 1], [4, 5, 6, [slice(3, 4, 5)]]]
|
||||
out = _flatten_args_kwargs(tree)
|
||||
def test_cycle_detection_no_cycle(self):
|
||||
mod = self.run_and_get_simple_graph()
|
||||
self.assertExpectedInline(
|
||||
str(out), """[3, 'x', 1, 2, 3, 1, 4, 5, 6, 3, 4, 5]"""
|
||||
_detect_cycles(mod.graph, {}), """no cycle detected"""
|
||||
)
|
||||
|
||||
def test_cycle_detection_no_cycle(self):
|
||||
def fn(x, y):
|
||||
x0 = x + 1
|
||||
y0 = y + 2
|
||||
z = x0.sum() + y0.sum()
|
||||
return z
|
||||
|
||||
x = torch.rand(10, 10, requires_grad=False)
|
||||
y = torch.rand(10, 20, requires_grad=False)
|
||||
|
||||
_, _, fw_graphs = self.run_and_return_graphs(fn, x, y)
|
||||
mod = fw_graphs[0]
|
||||
self.assertExpectedInline(_detect_cycles(mod.graph), """no cycle detected""")
|
||||
|
||||
def test_cycle_detection_simple(self):
|
||||
def test_cycle_detection_single_node(self):
|
||||
def fn(x, y):
|
||||
x0 = x + 1
|
||||
y0 = y + 2
|
||||
@ -621,8 +626,64 @@ class <lambda>(torch.nn.Module):
|
||||
args = add_node.args
|
||||
add_node.args = (args[0], add_2)
|
||||
self.assertExpectedInline(
|
||||
_detect_cycles(mod.graph),
|
||||
"""cycle detected in path: deque([arg0_1, add, sum_1, add_2, add])""",
|
||||
_detect_cycles(mod.graph, {add_2: OrderedSet([add_2])}),
|
||||
"""cycle detected in path: deque([output, add_2, add_2])""",
|
||||
)
|
||||
|
||||
def test_cycle_detection_two_node(self):
|
||||
def fn(x, y):
|
||||
x0 = x + 1
|
||||
y0 = y + 2
|
||||
z = x0.sum() + y0.sum()
|
||||
return z
|
||||
|
||||
x = torch.rand(10, 10, requires_grad=False)
|
||||
y = torch.rand(10, 20, requires_grad=False)
|
||||
|
||||
_, _, fw_graphs = self.run_and_return_graphs(fn, x, y)
|
||||
mod = fw_graphs[0]
|
||||
add_node = next(n for n in mod.graph.nodes if n.name == "add")
|
||||
add_2 = next(n for n in mod.graph.nodes if n.name == "add_2")
|
||||
args = add_node.args
|
||||
add_node.args = (args[0], add_2)
|
||||
self.assertExpectedInline(
|
||||
_detect_cycles(
|
||||
mod.graph,
|
||||
{add_2: OrderedSet([add_node]), add_node: OrderedSet([add_2])},
|
||||
),
|
||||
"""cycle detected in path: deque([output, add_2, add, add_2])""",
|
||||
)
|
||||
|
||||
def test_cycle_detection_arg_and_additional_deps(self):
|
||||
def fn(x, y):
|
||||
x0 = x + 1
|
||||
y0 = y + 2
|
||||
z = x0.sum() + y0.sum()
|
||||
return z
|
||||
|
||||
x = torch.rand(10, 10, requires_grad=False)
|
||||
y = torch.rand(10, 20, requires_grad=False)
|
||||
|
||||
_, _, fw_graphs = self.run_and_return_graphs(fn, x, y)
|
||||
mod = fw_graphs[0]
|
||||
add_node = next(n for n in mod.graph.nodes if n.name == "add")
|
||||
add_2 = next(n for n in mod.graph.nodes if n.name == "add_2")
|
||||
args = add_node.args
|
||||
add_node.args = (args[0], add_2)
|
||||
self.assertExpectedInline(
|
||||
_detect_cycles(mod.graph, {add_2: OrderedSet([add_node])}),
|
||||
"""cycle detected in path: deque([output, add_2, add, add_2])""",
|
||||
)
|
||||
|
||||
def test_cycle_detection_simple(self):
|
||||
mod = self.run_and_get_simple_graph()
|
||||
add_node = next(n for n in mod.graph.nodes if n.name == "add")
|
||||
add_2 = next(n for n in mod.graph.nodes if n.name == "add_2")
|
||||
args = add_node.args
|
||||
add_node.args = (args[0], add_2)
|
||||
self.assertExpectedInline(
|
||||
_detect_cycles(mod.graph, {}),
|
||||
"""cycle detected in path: deque([output, add_2, sum_1, add, add_2])""",
|
||||
)
|
||||
|
||||
def test_cycle_detection_complex(self):
|
||||
@ -656,8 +717,8 @@ class <lambda>(torch.nn.Module):
|
||||
args = invoke_subgraph_node.args
|
||||
invoke_subgraph_node.args = (add_2, args[1])
|
||||
self.assertExpectedInline(
|
||||
_detect_cycles(mod.graph),
|
||||
"""cycle detected in path: deque([arg0_1, invoke_subgraph_1, getitem_1, sum_2, add_2, invoke_subgraph, getitem, sum_1, add_1, add_2])""",
|
||||
_detect_cycles(mod.graph, {}),
|
||||
"""cycle detected in path: deque([output, add_2, add_1, sum_1, getitem, invoke_subgraph, add_2])""",
|
||||
)
|
||||
|
||||
def test_autocast_ordering(self):
|
||||
@ -699,7 +760,7 @@ class <lambda>(torch.nn.Module):
|
||||
sum_2 = get_node("sum_2")
|
||||
exit_autocast = mod.graph.call_function(torch.amp._exit_autocast)
|
||||
sum_2.append(exit_autocast)
|
||||
additional_deps = _populate_additional_deps(mod.graph)
|
||||
additional_deps = _populate_additional_deps(mod.graph, {})
|
||||
invoke_subgraph = get_node("invoke_subgraph")
|
||||
invoke_subgraph.append(enter_autocast)
|
||||
getitem_1 = get_node("getitem_1")
|
||||
@ -914,6 +975,137 @@ class <lambda>(torch.nn.Module):
|
||||
""",
|
||||
)
|
||||
|
||||
def test_mutation_ordering(self):
|
||||
from torch._dynamo.graph_deduplication import _stable_topological_sort
|
||||
|
||||
def inner_fn(x, y):
|
||||
x0 = x.view(x.size())
|
||||
return x0.view(x.size())
|
||||
|
||||
def inner_fn2(x, y):
|
||||
x = x * 2
|
||||
y = y * 2
|
||||
return x.sum() + y.sum()
|
||||
|
||||
def fn(x, y):
|
||||
o0 = inner_fn(x, y)
|
||||
o1 = inner_fn(x, y)
|
||||
x.add_(x)
|
||||
o2 = inner_fn2(x, y)
|
||||
y.mul_(y)
|
||||
o3 = inner_fn2(x, y)
|
||||
return o0 + o1 + o2.sum() + o3.sum()
|
||||
|
||||
x = torch.rand(10, 10)
|
||||
y = torch.rand(10, 20)
|
||||
x_clone = x.clone()
|
||||
y_clone = y.clone()
|
||||
|
||||
graph, _ = extract_graph_and_tracker(fn, x_clone, y_clone)
|
||||
|
||||
def graph_code(graph):
|
||||
return graph.python_code("self").src
|
||||
|
||||
def get_node(name):
|
||||
return next(n for n in graph.nodes if n.name == name)
|
||||
|
||||
self.assertExpectedInline(
|
||||
graph_code(graph),
|
||||
"""\
|
||||
|
||||
|
||||
|
||||
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
|
||||
subgraph_0 = self.subgraph_0
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
x0 = l_x_.view((10, 10))
|
||||
o0 = x0.view((10, 10)); x0 = None
|
||||
x0_1 = l_x_.view((10, 10))
|
||||
o1 = x0_1.view((10, 10)); x0_1 = None
|
||||
add_ = l_x_.add_(l_x_); add_ = None
|
||||
add_2 = o0 + o1; o0 = o1 = None
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_)
|
||||
mul_ = l_y_.mul_(l_y_); mul_ = None
|
||||
getitem = invoke_subgraph[0]; invoke_subgraph = None
|
||||
sum_5 = getitem.sum(); getitem = None
|
||||
add_3 = add_2 + sum_5; add_2 = sum_5 = None
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None
|
||||
getitem_1 = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
sum_6 = getitem_1.sum(); getitem_1 = None
|
||||
add_4 = add_3 + sum_6; add_3 = sum_6 = None
|
||||
return (add_4,)
|
||||
""",
|
||||
)
|
||||
|
||||
# Shuffle nodes in the graph
|
||||
add_ = get_node("add_")
|
||||
mul_ = get_node("mul_")
|
||||
o1 = get_node("o1")
|
||||
o1.append(mul_)
|
||||
add_2 = get_node("add_2")
|
||||
add_2.append(add_)
|
||||
|
||||
self.assertExpectedInline(
|
||||
graph_code(graph),
|
||||
"""\
|
||||
|
||||
|
||||
|
||||
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
|
||||
subgraph_0 = self.subgraph_0
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
x0 = l_x_.view((10, 10))
|
||||
o0 = x0.view((10, 10)); x0 = None
|
||||
x0_1 = l_x_.view((10, 10))
|
||||
o1 = x0_1.view((10, 10)); x0_1 = None
|
||||
mul_ = l_y_.mul_(l_y_); mul_ = None
|
||||
add_2 = o0 + o1; o0 = o1 = None
|
||||
add_ = l_x_.add_(l_x_); add_ = None
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_)
|
||||
getitem = invoke_subgraph[0]; invoke_subgraph = None
|
||||
sum_5 = getitem.sum(); getitem = None
|
||||
add_3 = add_2 + sum_5; add_2 = sum_5 = None
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None
|
||||
getitem_1 = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
sum_6 = getitem_1.sum(); getitem_1 = None
|
||||
add_4 = add_3 + sum_6; add_3 = sum_6 = None
|
||||
return (add_4,)
|
||||
""",
|
||||
)
|
||||
_stable_topological_sort(
|
||||
graph, torch._dynamo.graph_deduplication.last_node_to_additional_deps
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
graph_code(graph),
|
||||
"""\
|
||||
|
||||
|
||||
|
||||
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
|
||||
subgraph_0 = self.subgraph_0
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
x0 = l_x_.view((10, 10))
|
||||
o0 = x0.view((10, 10)); x0 = None
|
||||
x0_1 = l_x_.view((10, 10))
|
||||
o1 = x0_1.view((10, 10)); x0_1 = None
|
||||
add_2 = o0 + o1; o0 = o1 = None
|
||||
add_ = l_x_.add_(l_x_); add_ = None
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_)
|
||||
mul_ = l_y_.mul_(l_y_); mul_ = None
|
||||
getitem = invoke_subgraph[0]; invoke_subgraph = None
|
||||
sum_5 = getitem.sum(); getitem = None
|
||||
add_3 = add_2 + sum_5; add_2 = sum_5 = None
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None
|
||||
getitem_1 = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
sum_6 = getitem_1.sum(); getitem_1 = None
|
||||
add_4 = add_3 + sum_6; add_3 = sum_6 = None
|
||||
return (add_4,)
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
@ -49,6 +49,10 @@ class GraphRegionTrackerTests(TestCase):
|
||||
region_groups = tree_map(lambda n: n.name, region_groups)
|
||||
return str(region_groups)
|
||||
|
||||
def get_mutation_tracking(self, fn, *args, **kwargs):
|
||||
_, region_tracker = extract_graph_and_tracker(fn, *args, **kwargs)
|
||||
return str(region_tracker.node_to_mutated_arg_positions)
|
||||
|
||||
def test_get_regions_single_region_group(self):
|
||||
def inner_fn(x, y):
|
||||
x0 = x + 1
|
||||
@ -57,12 +61,12 @@ class GraphRegionTrackerTests(TestCase):
|
||||
return z
|
||||
|
||||
def fn(x, y):
|
||||
_o0 = inner_fn(x, y)
|
||||
o0 = inner_fn(x, y)
|
||||
o1 = torch.sin(y)
|
||||
o2 = inner_fn(x, o1)
|
||||
o3 = inner_fn(x, y)
|
||||
o4 = o3 * o3
|
||||
return o2 * o4
|
||||
return o2 * o4 + o0
|
||||
|
||||
self.assertExpectedInline(
|
||||
self.get_result(
|
||||
@ -295,6 +299,45 @@ class GraphRegionTrackerTests(TestCase):
|
||||
[['x1', 'y1', 'sum_1', 'o4'], ['x1_1', 'y1_1', 'sum_2', 'o5']]]""",
|
||||
)
|
||||
|
||||
def test_mutation_tracking_simple(self):
|
||||
def fn(x, y, z):
|
||||
x0 = torch.sin(x)
|
||||
y0 = torch.cos(y)
|
||||
z.sin_()
|
||||
y0.add_(z)
|
||||
return x0.sum() + y0.sum()
|
||||
|
||||
self.assertExpectedInline(
|
||||
self.get_mutation_tracking(
|
||||
fn,
|
||||
torch.rand(10, 10),
|
||||
torch.rand(10, 20),
|
||||
torch.ones(10, 20),
|
||||
),
|
||||
"""{sin_: OrderedSet([0]), add_: OrderedSet([0])}""",
|
||||
)
|
||||
|
||||
def test_mutation_tracking_allow_in_graph(self):
|
||||
@torch._dynamo.allow_in_graph
|
||||
def fn_mut(x, y):
|
||||
x.add_(y)
|
||||
return x.sum() + y.sum()
|
||||
|
||||
def fn(x, y):
|
||||
z = x + y
|
||||
o0 = fn_mut(z, y)
|
||||
z.sin_()
|
||||
return x + o0
|
||||
|
||||
self.assertExpectedInline(
|
||||
self.get_mutation_tracking(
|
||||
fn,
|
||||
torch.rand(20, 10),
|
||||
torch.rand(20, 10),
|
||||
),
|
||||
"""{o0: OrderedSet([0]), sin_: OrderedSet([0])}""",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
@ -11,20 +11,28 @@ import logging
|
||||
import operator
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator, Iterable
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch._dynamo import config
|
||||
from torch._higher_order_ops.utils import has_potential_input_alias_or_mutation
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from .graph_region_tracker import Node, Region
|
||||
from .graph_utils import _detect_cycles, _flatten_args_kwargs
|
||||
from .graph_utils import _detect_cycles, _get_flat_args, _get_flat_args_unique
|
||||
|
||||
|
||||
# Represents an index into the region
|
||||
# to select a node and then
|
||||
# an index into that node's
|
||||
# flattened arguments
|
||||
UsageIndex = tuple[int, int]
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
last_node_to_additional_deps: Optional[dict[Node, OrderedSet[Node]]] = None
|
||||
|
||||
|
||||
def apply_graph_deduplication(output_graph) -> dict[str, torch.fx.GraphModule]: # type: ignore[no-untyped-def]
|
||||
"""
|
||||
@ -57,7 +65,12 @@ when they are created in output_graph.
|
||||
duplicated_region_groups = output_graph.region_tracker.get_identical_regions(
|
||||
output_graph.graph
|
||||
)
|
||||
node_to_additional_deps = _populate_additional_deps(output_graph.graph)
|
||||
node_to_mutated_arg_positions = (
|
||||
output_graph.region_tracker.node_to_mutated_arg_positions
|
||||
)
|
||||
node_to_additional_deps = _populate_additional_deps(
|
||||
output_graph.graph, output_graph.region_tracker.node_to_mutated_arg_positions
|
||||
)
|
||||
|
||||
sub_gms: dict[str, torch.fx.GraphModule] = {}
|
||||
|
||||
@ -66,11 +79,11 @@ when they are created in output_graph.
|
||||
region = region_group[0]
|
||||
(
|
||||
subgraph,
|
||||
node_ind_arg_inds,
|
||||
external_node_usages,
|
||||
) = _create_subgraph(region, inds_with_external_users)
|
||||
|
||||
# Ignore regions with no args for now, could they possibly be evaluated at compile time?
|
||||
if not list(node_ind_arg_inds):
|
||||
if not list(external_node_usages):
|
||||
continue
|
||||
|
||||
sub_gm = torch.fx.GraphModule(output_graph.nn_modules, subgraph)
|
||||
@ -80,19 +93,27 @@ when they are created in output_graph.
|
||||
get_subgraph_node = output_graph.graph.create_node(
|
||||
"get_attr", subgraph_name, (), {}
|
||||
)
|
||||
|
||||
for region in region_group:
|
||||
_replace_region_with_subgraph(
|
||||
output_graph.graph,
|
||||
region,
|
||||
get_subgraph_node,
|
||||
node_ind_arg_inds.keys(),
|
||||
external_node_usages,
|
||||
inds_with_external_users,
|
||||
sub_gm,
|
||||
subgraph_name,
|
||||
node_to_additional_deps,
|
||||
node_to_mutated_arg_positions,
|
||||
)
|
||||
|
||||
_stable_topological_sort(output_graph.graph, node_to_additional_deps)
|
||||
# This is to expose the updated node_to_additional_deps to tests
|
||||
global last_node_to_additional_deps
|
||||
last_node_to_additional_deps = node_to_additional_deps
|
||||
|
||||
_stable_topological_sort(
|
||||
output_graph.graph,
|
||||
node_to_additional_deps,
|
||||
)
|
||||
return sub_gms
|
||||
|
||||
|
||||
@ -100,29 +121,34 @@ def _replace_region_with_subgraph(
|
||||
graph: torch.fx.Graph,
|
||||
region: Region,
|
||||
get_subgraph_node: Node,
|
||||
node_ind_arg_ind: Iterable[tuple[int, int]],
|
||||
external_node_usages: Iterable[OrderedSet[UsageIndex]],
|
||||
inds_with_external_users: list[int],
|
||||
sub_gm: torch.fx.GraphModule,
|
||||
subgraph_name: str,
|
||||
node_to_additional_deps: dict[torch.fx.Node, list[torch.fx.Node]],
|
||||
node_to_additional_deps: dict[Node, OrderedSet[Node]],
|
||||
node_to_mutated_arg_positions: dict[Node, OrderedSet[int]],
|
||||
) -> None:
|
||||
sub_args = []
|
||||
for node_ind, arg_ind in node_ind_arg_ind:
|
||||
for usages in external_node_usages:
|
||||
node_ind, usage_ind = next(iter(usages))
|
||||
node = region[node_ind]
|
||||
flattened_args_kwargs = _flatten_args_kwargs((node.args, node.kwargs))
|
||||
sub_args.append(flattened_args_kwargs[arg_ind])
|
||||
flattened_args_kwargs = _get_flat_args(node, {})
|
||||
for user_ind, node_usage_ind in usages:
|
||||
user = region[user_ind]
|
||||
if user in node_to_mutated_arg_positions:
|
||||
if node_usage_ind in node_to_mutated_arg_positions[user]:
|
||||
log.debug(
|
||||
"NYI: Failed to substitute region %s due to mutation", region
|
||||
)
|
||||
return
|
||||
sub_args.append(flattened_args_kwargs[usage_ind])
|
||||
|
||||
invoke_args = (get_subgraph_node, subgraph_name, *sub_args)
|
||||
fake_inputs = [node.meta["example_value"] for node in sub_args]
|
||||
|
||||
if has_potential_input_alias_or_mutation(sub_gm, fake_inputs):
|
||||
log.debug(
|
||||
"NYI: Failed to substitute region %s due to input alias or mutation",
|
||||
region,
|
||||
)
|
||||
# Input/Output aliasing not supported in HOPs today
|
||||
# Note: we should use the nodes in the original graph (the region here)
|
||||
# because we use the original traced example values for this check
|
||||
if _has_aliasing(region, sub_args, inds_with_external_users):
|
||||
return
|
||||
|
||||
from torch._inductor.pattern_matcher import stable_topological_sort
|
||||
invoke_args = (get_subgraph_node, subgraph_name, *sub_args)
|
||||
|
||||
invoke_subgraph_node = graph.create_node(
|
||||
"call_function",
|
||||
@ -140,38 +166,40 @@ def _replace_region_with_subgraph(
|
||||
# Erase in reverse topological order
|
||||
for node in reversed(region):
|
||||
graph.erase_node(node)
|
||||
node_to_additional_deps.pop(node)
|
||||
for dep_list in node_to_additional_deps.values():
|
||||
# Remove any nodes with additional deps
|
||||
# This is safe; we've guaranteed that there is
|
||||
# no input mutation, so all additional deps
|
||||
# will be internal to the subgraph
|
||||
node_to_additional_deps.pop(node, None)
|
||||
for deps in node_to_additional_deps.values():
|
||||
try:
|
||||
dep_list.remove(node)
|
||||
except ValueError:
|
||||
deps.remove(node)
|
||||
deps.add(invoke_subgraph_node)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if config.graph_deduplication_lint:
|
||||
_detect_cycles(graph)
|
||||
stable_topological_sort(graph)
|
||||
graph.lint()
|
||||
|
||||
if config.graph_deduplication_lint:
|
||||
print(_detect_cycles(graph, node_to_additional_deps))
|
||||
_stable_topological_sort(graph, node_to_additional_deps)
|
||||
graph.lint()
|
||||
|
||||
|
||||
def _get_external_inputs(
|
||||
region: Region,
|
||||
) -> dict[Node, tuple[int, int]]:
|
||||
external_node_to_indices = dict()
|
||||
) -> dict[Node, OrderedSet[UsageIndex]]:
|
||||
external_node_to_usages = defaultdict[Node, OrderedSet[UsageIndex]](OrderedSet)
|
||||
region_unique = set(region)
|
||||
for node_ind, node in enumerate(region):
|
||||
flattened_args_kwargs = _flatten_args_kwargs((node.args, node.kwargs))
|
||||
flattened_args_kwargs = _get_flat_args(node, {})
|
||||
for arg_ind, in_node in enumerate(flattened_args_kwargs):
|
||||
if (
|
||||
isinstance(in_node, Node)
|
||||
and in_node not in region_unique
|
||||
and in_node not in external_node_to_indices
|
||||
):
|
||||
external_node_to_indices[in_node] = (node_ind, arg_ind)
|
||||
if isinstance(in_node, Node) and in_node not in region_unique:
|
||||
# in_node may occur in multiple nodes' flat_args
|
||||
# track this so we can check if the arg is mutated
|
||||
# Previously, we only needed to track one occurrence
|
||||
# to be able to map that node to a placeholder
|
||||
external_node_to_usages[in_node].add((node_ind, arg_ind))
|
||||
|
||||
return external_node_to_indices
|
||||
return external_node_to_usages
|
||||
|
||||
|
||||
def _get_all_output_indices(regions: list[Region]) -> list[int]:
|
||||
@ -194,17 +222,14 @@ def _get_inds_with_external_users(region: Region, inds_unique: set[int]) -> None
|
||||
|
||||
def _copy_nodes_and_remap_inputs(
|
||||
subgraph: torch.fx.Graph, region: Region
|
||||
) -> dict[tuple[int, int], Any]:
|
||||
external_inputs_to_indices = _get_external_inputs(region)
|
||||
indices_to_placeholder_ind: dict[tuple[int, int], Any] = {}
|
||||
) -> list[OrderedSet[UsageIndex]]:
|
||||
external_input_to_usages = _get_external_inputs(region)
|
||||
external_node_usages = list[OrderedSet[UsageIndex]]()
|
||||
region_to_subgraph_node = {}
|
||||
for node in external_inputs_to_indices.keys():
|
||||
for node, usage_indices in external_input_to_usages.items():
|
||||
placeholder = subgraph.placeholder(f"subgraph_input_{node.name}")
|
||||
region_to_subgraph_node[node] = placeholder
|
||||
arg_indices = external_inputs_to_indices[node]
|
||||
# Note: insertion order matches the order in which placeholders were created
|
||||
# for the calling convention of the subgraph
|
||||
indices_to_placeholder_ind[arg_indices] = None
|
||||
external_node_usages.append(usage_indices)
|
||||
|
||||
def map_arg(node: Node) -> Node:
|
||||
if node in region_to_subgraph_node:
|
||||
@ -216,7 +241,7 @@ def _copy_nodes_and_remap_inputs(
|
||||
subgraph_node = subgraph.node_copy(node, lambda old: map_arg(old))
|
||||
region_to_subgraph_node[node] = subgraph_node
|
||||
|
||||
return indices_to_placeholder_ind
|
||||
return external_node_usages
|
||||
|
||||
|
||||
def _create_subgraph_outputs(
|
||||
@ -230,30 +255,16 @@ def _create_subgraph_outputs(
|
||||
def _create_subgraph(
|
||||
region: Region,
|
||||
inds_with_external_users: list[int],
|
||||
) -> tuple[torch.fx.Graph, dict[tuple[int, int], Any]]:
|
||||
) -> tuple[torch.fx.Graph, list[OrderedSet[UsageIndex]]]:
|
||||
subgraph: torch.fx.Graph = torch.fx.Graph()
|
||||
node_ind_input_inds = _copy_nodes_and_remap_inputs(subgraph, region)
|
||||
external_node_usages = _copy_nodes_and_remap_inputs(subgraph, region)
|
||||
_create_subgraph_outputs(subgraph, inds_with_external_users)
|
||||
return subgraph, node_ind_input_inds
|
||||
|
||||
|
||||
def _args(
|
||||
n: torch.fx.Node,
|
||||
node_to_additional_deps: Optional[dict[torch.fx.Node, list[torch.fx.Node]]] = None,
|
||||
) -> list[torch.fx.node.Argument]:
|
||||
if node_to_additional_deps is None:
|
||||
node_to_additional_deps = {}
|
||||
|
||||
args: list[torch.fx.node.Argument] = []
|
||||
torch.fx.map_arg((n.args, n.kwargs), args.append)
|
||||
if n in node_to_additional_deps:
|
||||
args.extend(node_to_additional_deps[n])
|
||||
return args
|
||||
return subgraph, external_node_usages
|
||||
|
||||
|
||||
def _stable_topological_sort(
|
||||
graph: torch.fx.Graph,
|
||||
node_to_additional_deps: dict[torch.fx.Node, list[torch.fx.Node]],
|
||||
node_to_additional_deps: dict[Node, OrderedSet[Node]],
|
||||
) -> None:
|
||||
# Nodes are in exactly one of these four collections:
|
||||
|
||||
@ -262,14 +273,14 @@ def _stable_topological_sort(
|
||||
|
||||
# - Nodes in `ready` have been processed and are already in the correct
|
||||
# order.
|
||||
ready = OrderedSet[torch.fx.Node]()
|
||||
ready = OrderedSet[Node]()
|
||||
|
||||
# - `waiting` is a mapping from a dependency to nodes which depend on that
|
||||
# dependency.
|
||||
waiting = defaultdict(list)
|
||||
|
||||
# - `outputs` are always at the end of the graph
|
||||
outputs = OrderedSet[torch.fx.Node]()
|
||||
outputs = OrderedSet[Node]()
|
||||
|
||||
# The cursor indicates the last processed node so we can add new nodes
|
||||
# after it.
|
||||
@ -283,7 +294,9 @@ def _stable_topological_sort(
|
||||
continue
|
||||
|
||||
waiting_for = [
|
||||
x for x in _args(node, node_to_additional_deps) if x not in ready
|
||||
x
|
||||
for x in _get_flat_args_unique(node, node_to_additional_deps)
|
||||
if x not in ready
|
||||
]
|
||||
if waiting_for:
|
||||
# We have unprocessed input nodes. Might as well wait for the last
|
||||
@ -303,23 +316,29 @@ def _stable_topological_sort(
|
||||
|
||||
|
||||
def _populate_additional_deps(
|
||||
graph: torch.fx.Graph,
|
||||
) -> dict[torch.fx.Node, list[torch.fx.Node]]:
|
||||
graph: torch.fx.Graph, node_to_mutated_arg_positions: dict[Node, OrderedSet[int]]
|
||||
) -> dict[Node, OrderedSet[Node]]:
|
||||
node_to_additional_deps: dict[Node, OrderedSet[Node]] = defaultdict(OrderedSet)
|
||||
_add_mutation_dependencies(node_to_mutated_arg_positions, node_to_additional_deps)
|
||||
_add_global_state_dependencies(graph, node_to_additional_deps)
|
||||
return node_to_additional_deps
|
||||
|
||||
|
||||
def _add_global_state_dependencies(
|
||||
graph: torch.fx.Graph, node_to_additional_deps: dict[Node, OrderedSet[Node]]
|
||||
) -> None:
|
||||
import torch.amp
|
||||
|
||||
node_to_additional_deps: dict[torch.fx.Node, list[torch.fx.Node]] = defaultdict(
|
||||
list
|
||||
)
|
||||
all_nodes = list(graph.nodes)
|
||||
|
||||
# These are targets of the nodes which need to stay in the same relative place in the graph
|
||||
global_state_targets = {torch.amp._enter_autocast, torch.amp._exit_autocast}
|
||||
all_nodes_dep_on: list[torch.fx.Node] = []
|
||||
all_nodes_dep_on: list[Node] = []
|
||||
|
||||
def prev_cur_nodes(
|
||||
all_nodes: list[torch.fx.Node],
|
||||
) -> Generator[tuple[list[torch.fx.Node], torch.fx.Node]]:
|
||||
prev_nodes: list[torch.fx.Node] = []
|
||||
all_nodes: list[Node],
|
||||
) -> Generator[tuple[list[Node], Node], None, None]:
|
||||
prev_nodes: list[Node] = []
|
||||
next_nodes = list(reversed(all_nodes))
|
||||
|
||||
while next_nodes:
|
||||
@ -328,11 +347,93 @@ def _populate_additional_deps(
|
||||
prev_nodes.append(cur_node)
|
||||
|
||||
for prev_nodes, cur_node in prev_cur_nodes(all_nodes):
|
||||
args_unique = _args(cur_node)
|
||||
additional_deps = node_to_additional_deps[cur_node]
|
||||
additional_deps.extend(n for n in all_nodes_dep_on if n not in args_unique)
|
||||
args_unique = _get_flat_args_unique(cur_node, {})
|
||||
new_deps = [n for n in all_nodes_dep_on if n not in args_unique]
|
||||
|
||||
if new_deps:
|
||||
additional_deps = node_to_additional_deps[cur_node]
|
||||
additional_deps.update(new_deps)
|
||||
|
||||
if cur_node.target in global_state_targets:
|
||||
additional_deps.extend(n for n in prev_nodes if n not in args_unique)
|
||||
additional_deps = node_to_additional_deps[cur_node]
|
||||
additional_deps.update(n for n in prev_nodes if n not in args_unique)
|
||||
all_nodes_dep_on.append(cur_node)
|
||||
|
||||
return node_to_additional_deps
|
||||
|
||||
def _add_mutation_dependencies(
|
||||
node_to_mutated_arg_positions: dict[Node, OrderedSet[int]],
|
||||
node_to_additional_deps: dict[Node, OrderedSet[Node]],
|
||||
) -> None:
|
||||
for node, indices in node_to_mutated_arg_positions.items():
|
||||
flat_args_kwargs = _get_flat_args(node, {})
|
||||
|
||||
# for all mutated args,
|
||||
# add dependency on usages which occur after node to ensure
|
||||
# node will always be ordered before them
|
||||
# also add node as a dependency on usages which
|
||||
# occur before node to ensure node is ordered after them
|
||||
for index in indices:
|
||||
mutated_arg = flat_args_kwargs[index]
|
||||
for user in mutated_arg.users:
|
||||
if user is node:
|
||||
continue
|
||||
elif user < node:
|
||||
node_to_additional_deps[node].add(user)
|
||||
elif user > node:
|
||||
node_to_additional_deps[user].add(node)
|
||||
|
||||
|
||||
def _has_aliasing(
|
||||
region: Region, inputs: list[Node], inds_with_external_users: list[int]
|
||||
) -> bool:
|
||||
input_storages: dict[StorageWeakRef, Node] = dict()
|
||||
|
||||
for node in inputs:
|
||||
example_value = node.meta["example_value"]
|
||||
if isinstance(example_value, torch.Tensor):
|
||||
storage = StorageWeakRef(example_value._typed_storage())
|
||||
if storage in input_storages:
|
||||
# input-input aliasing
|
||||
log.debug(
|
||||
"NYI: Failed to substitute region %s due to input-output aliasing detected at nodes %s, %s",
|
||||
region,
|
||||
input_storages[storage],
|
||||
node,
|
||||
)
|
||||
return True
|
||||
input_storages[storage] = node
|
||||
|
||||
output_storages: dict[StorageWeakRef, Node] = dict()
|
||||
for i in inds_with_external_users:
|
||||
out_node = region[i]
|
||||
if out_node:
|
||||
example_value = out_node.meta["example_value"]
|
||||
assert not isinstance(example_value, list)
|
||||
if isinstance(example_value, torch.Tensor):
|
||||
storage = StorageWeakRef(example_value._typed_storage())
|
||||
if storage in output_storages:
|
||||
# output-output aliasing
|
||||
log.debug(
|
||||
"NYI: Failed to substitute region %s due to output-output aliasing detected at nodes %s, %s",
|
||||
region,
|
||||
output_storages[storage],
|
||||
out_node,
|
||||
)
|
||||
return True
|
||||
output_storages[storage] = out_node
|
||||
|
||||
intersected_storages = input_storages.keys() & output_storages.keys()
|
||||
if len(intersected_storages) > 0:
|
||||
# input-output aliasing
|
||||
aliased = [
|
||||
(input_storages[s], output_storages[s]) for s in intersected_storages
|
||||
]
|
||||
aliased = ", ".join([f"{i} and {o}" for i, o in aliased])
|
||||
log.debug(
|
||||
"NYI: Failed to substitute region %s due to input-output aliasing detected at nodes %s",
|
||||
region,
|
||||
aliased,
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@ -25,9 +25,10 @@ from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar
|
||||
import torch._logging
|
||||
import torch.fx
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._pytree import tree_flatten
|
||||
|
||||
from .graph_utils import _flatten_args_kwargs
|
||||
from .graph_utils import _get_flat_args_unique
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
@ -150,6 +151,9 @@ class BackwardBfsArgIter:
|
||||
def create(origin: Node) -> "BackwardBfsArgIter":
|
||||
it = BackwardBfsArgIter(origin)
|
||||
it.add_children(origin)
|
||||
# pop the origin node, since it is the origin of
|
||||
# the region and does not need to be considered for addition
|
||||
assert it.next()
|
||||
return it
|
||||
|
||||
def next(self) -> Optional[Node]:
|
||||
@ -164,17 +168,11 @@ class BackwardBfsArgIter:
|
||||
return self._cur
|
||||
|
||||
def add_children(self, node: Node) -> None:
|
||||
arg: Any
|
||||
flat_args, _ = tree_flatten(node.args)
|
||||
flat_args = _get_flat_args_unique(node, {})
|
||||
for arg in flat_args:
|
||||
if isinstance(arg, Node):
|
||||
self._append(arg)
|
||||
|
||||
flat_kwargs, _ = tree_flatten(node.kwargs)
|
||||
for kwarg in flat_kwargs:
|
||||
if isinstance(kwarg, Node):
|
||||
self._append(kwarg)
|
||||
|
||||
def _append(self, arg: Node) -> None:
|
||||
if self._cur is None:
|
||||
self._cur = arg
|
||||
@ -199,6 +197,8 @@ class GraphRegionTracker:
|
||||
def __init__(self) -> None:
|
||||
self.hash_to_duplicates: dict[str, IdenticalNodes] = defaultdict(list)
|
||||
self.node_to_duplicates: dict[Node, IdenticalNodes] = {}
|
||||
# Note: position is in flattened args/kwargs list
|
||||
self.node_to_mutated_arg_positions: dict[Node, OrderedSet[int]] = {}
|
||||
self.input_pickler = InputPickler()
|
||||
|
||||
def _hash_node(
|
||||
@ -240,6 +240,28 @@ class GraphRegionTracker:
|
||||
except NodeHashException as e:
|
||||
log.debug("Unable to hash node %s with exception %s", node, e)
|
||||
|
||||
def track_node_mutations(
|
||||
self,
|
||||
node: Node,
|
||||
flat_args_kwargs: list[Any],
|
||||
id_to_initial_version: dict[int, int],
|
||||
) -> None:
|
||||
"""
|
||||
This function tracks which argument positions are mutated by the given node. Subgraph HOP does not support
|
||||
input mutations today so we will skip regions which have inputs that are mutated.
|
||||
"""
|
||||
mutated_arg_positions = OrderedSet[int]()
|
||||
for i, arg in enumerate(flat_args_kwargs):
|
||||
val_id = id(arg)
|
||||
if (
|
||||
val_id in id_to_initial_version
|
||||
and id_to_initial_version[val_id] != arg._version
|
||||
):
|
||||
mutated_arg_positions.add(i)
|
||||
|
||||
if mutated_arg_positions:
|
||||
self.node_to_mutated_arg_positions[node] = mutated_arg_positions
|
||||
|
||||
def get_identical_regions(self, graph: torch.fx.Graph) -> list[list[Region]]:
|
||||
"""
|
||||
This function is responsible for extracting the largest regions of identical nodes from the given graph.
|
||||
@ -303,6 +325,38 @@ class GraphRegionTracker:
|
||||
return f"GraphRegionTracker(hash_to_duplicates={self.hash_to_duplicates}, node_to_duplicates={self.node_to_duplicates})"
|
||||
|
||||
|
||||
class RegionWrapper:
|
||||
"""Holds state for regions e.g. ancestors and new candidate nodes for consideration"""
|
||||
|
||||
def __init__(
|
||||
self, region: Region, node_to_recursive_ancestors: dict[Node, set[Node]]
|
||||
) -> None:
|
||||
assert len(region) == 1, "all regions should start with one node"
|
||||
node = region[0]
|
||||
self.node_to_recursive_ancestors = node_to_recursive_ancestors
|
||||
self.iter = BackwardBfsArgIter.create(node)
|
||||
self.nodes_unique = OrderedSet([node])
|
||||
self.ancestors = set(node_to_recursive_ancestors[node])
|
||||
self.region = region
|
||||
|
||||
def next_candidate(self) -> Optional[Node]:
|
||||
return self.iter.next()
|
||||
|
||||
def will_inclusion_create_cycle(self, node: Node) -> bool:
|
||||
external_users = [user for user in node.users if user not in self.nodes_unique]
|
||||
for user in external_users:
|
||||
if user in self.ancestors:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def add(self, node: Node) -> None:
|
||||
self.nodes_unique.add(node)
|
||||
self.region.append(node)
|
||||
self.iter.add_children(node)
|
||||
self.ancestors.update(self.node_to_recursive_ancestors[node])
|
||||
|
||||
|
||||
def fully_expand_region_group(
|
||||
regions: list[Region],
|
||||
seen_nodes: set[Node],
|
||||
@ -314,20 +368,12 @@ def fully_expand_region_group(
|
||||
|
||||
# All regions should start with 1 node
|
||||
assert all(len(region) == 1 for region in regions)
|
||||
region_iters = []
|
||||
for region in regions:
|
||||
(origin,) = region # Only works for 1 element sets
|
||||
region_iters.append(BackwardBfsArgIter.create(origin))
|
||||
region_wrappers = [
|
||||
RegionWrapper(region, node_to_recursive_ancestors) for region in regions
|
||||
]
|
||||
|
||||
nodes_to_add: list[Node] = []
|
||||
|
||||
# we already have the origin node in each region
|
||||
for region_it in region_iters:
|
||||
node = region_it.next()
|
||||
assert node
|
||||
region_it.add_children(node)
|
||||
|
||||
current_node = region_iters[0].next()
|
||||
nodes_to_add = OrderedSet[Node]()
|
||||
current_node = region_wrappers[0].next_candidate()
|
||||
|
||||
# No children
|
||||
if current_node is None:
|
||||
@ -337,46 +383,51 @@ def fully_expand_region_group(
|
||||
# regions are only expanded if the node to add is valid
|
||||
# for ALL regions
|
||||
while current_node:
|
||||
add_node = not _will_create_cycle(
|
||||
current_node, regions[0], node_to_recursive_ancestors
|
||||
add_to_all_regions = not region_wrappers[0].will_inclusion_create_cycle(
|
||||
current_node
|
||||
)
|
||||
nodes_to_add.clear()
|
||||
nodes_to_add.append(current_node)
|
||||
nodes_to_add_set = set(nodes_to_add)
|
||||
for ind, region_it in enumerate(region_iters[1:]):
|
||||
ind += 1 # compensate for the 0th region
|
||||
node = region_it.next()
|
||||
nodes_to_add.add(current_node)
|
||||
for region_wrapper in region_wrappers[1:]:
|
||||
candidate = region_wrapper.next_candidate()
|
||||
|
||||
debug_log("--------------------")
|
||||
debug_log("considering adding: %s, cur_node: %s", node, current_node)
|
||||
debug_log("previously claimed nodes: %s", node in seen_nodes)
|
||||
if node:
|
||||
debug_log("is_identical: %s", is_identical_fn(node, current_node))
|
||||
add_node &= (
|
||||
node not in seen_nodes
|
||||
and node not in nodes_to_add_set
|
||||
and node.op != "placeholder"
|
||||
and is_identical_fn(node, current_node)
|
||||
and not _will_create_cycle(
|
||||
node, regions[ind], node_to_recursive_ancestors
|
||||
)
|
||||
)
|
||||
nodes_to_add.append(node)
|
||||
nodes_to_add_set.add(node)
|
||||
else:
|
||||
add_node = False
|
||||
debug_log(
|
||||
"considering candidate: %s, cur_node: %s", candidate, current_node
|
||||
)
|
||||
|
||||
if not candidate or not add_to_all_regions:
|
||||
add_to_all_regions = False
|
||||
continue
|
||||
|
||||
debug_log(
|
||||
"candidate in previously claimed nodes?: %s", candidate in seen_nodes
|
||||
)
|
||||
debug_log("is_identical: %s", is_identical_fn(candidate, current_node))
|
||||
|
||||
add_to_all_regions &= (
|
||||
candidate not in seen_nodes
|
||||
and candidate not in nodes_to_add
|
||||
and candidate.op != "placeholder"
|
||||
and is_identical_fn(candidate, current_node)
|
||||
and not region_wrapper.will_inclusion_create_cycle(candidate)
|
||||
)
|
||||
nodes_to_add.add(candidate)
|
||||
|
||||
debug_log(f"add_to_all_regions: {add_to_all_regions}")
|
||||
debug_log("--------------------")
|
||||
|
||||
if add_node:
|
||||
for region, region_it, node in zip(regions, region_iters, nodes_to_add):
|
||||
region.append(node)
|
||||
if add_to_all_regions:
|
||||
assert len(region_wrappers) == len(nodes_to_add), (
|
||||
"Numer of nodes to add must equal the number of regions"
|
||||
)
|
||||
for region_wrapper, node in zip(region_wrappers, nodes_to_add):
|
||||
region_wrapper.add(node)
|
||||
debug_log("adding %s's children", node)
|
||||
debug_log("%s %s", node.args, list(node.kwargs.items()))
|
||||
region_it.add_children(node)
|
||||
seen_nodes.add(node)
|
||||
|
||||
current_node = region_iters[0].next()
|
||||
current_node = region_wrappers[0].next_candidate()
|
||||
|
||||
# Ensure regions are sorted in topological order
|
||||
for region in regions:
|
||||
@ -391,7 +442,7 @@ def _populate_recursive_ancestor_map(graph: torch.fx.Graph) -> dict[Node, set[No
|
||||
for node in graph.nodes:
|
||||
node_to_recursive_ancestors[node] = set()
|
||||
for node in graph.nodes:
|
||||
all_args = _flatten_args_kwargs((node.args, node.kwargs))
|
||||
all_args = _get_flat_args_unique(node, {})
|
||||
for arg in all_args:
|
||||
if isinstance(arg, Node):
|
||||
node_to_recursive_ancestors[node].update(
|
||||
@ -399,20 +450,3 @@ def _populate_recursive_ancestor_map(graph: torch.fx.Graph) -> dict[Node, set[No
|
||||
)
|
||||
node_to_recursive_ancestors[node].add(arg)
|
||||
return node_to_recursive_ancestors
|
||||
|
||||
|
||||
def _will_create_cycle(
|
||||
node_to_add: Node,
|
||||
region: Region,
|
||||
node_to_recursive_ancestors: dict[Node, set[Node]],
|
||||
) -> bool:
|
||||
region_set: set[Node] = set(region)
|
||||
region_ancestors: set[Node] = set(
|
||||
tree_flatten([list(node_to_recursive_ancestors[node]) for node in region])[0]
|
||||
)
|
||||
external_users = [user for user in node_to_add.users if user not in region_set]
|
||||
for user in external_users:
|
||||
if user in region_ancestors:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
from collections import deque
|
||||
from typing import Any
|
||||
|
||||
from torch.fx import Graph, Node
|
||||
from torch.utils._pytree import tree_flatten
|
||||
from torch.fx import Graph, map_arg, Node
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
# flattens with support for slices
|
||||
@ -10,26 +10,29 @@ from torch.utils._pytree import tree_flatten
|
||||
# be register/unregister slices as pytree nodes
|
||||
# but there is no unregister API in the pytorch
|
||||
# pytree impl
|
||||
def _flatten_args_kwargs(args: Any) -> list[Node]:
|
||||
fully_flattened = []
|
||||
|
||||
def flatten(args: Any) -> None:
|
||||
flattened, _ = tree_flatten(args)
|
||||
for arg in flattened:
|
||||
if isinstance(arg, slice):
|
||||
start = arg.start
|
||||
stop = arg.stop
|
||||
step = arg.step
|
||||
flatten((start, stop, step))
|
||||
else:
|
||||
fully_flattened.append(arg)
|
||||
|
||||
flatten(args)
|
||||
|
||||
return fully_flattened
|
||||
def _get_flat_args(
|
||||
node: Node, node_to_additional_deps: dict[Node, OrderedSet[Node]]
|
||||
) -> list[Node]:
|
||||
args = list[Any]()
|
||||
map_arg((node.args, node.kwargs), args.append)
|
||||
if node in node_to_additional_deps:
|
||||
args.extend(node_to_additional_deps[node])
|
||||
return args
|
||||
|
||||
|
||||
def _detect_cycles(graph: Graph) -> str:
|
||||
def _get_flat_args_unique(
|
||||
node: Node, node_to_additional_deps: dict[Node, OrderedSet[Node]]
|
||||
) -> OrderedSet[Node]:
|
||||
args = OrderedSet[Node]()
|
||||
map_arg((node.args, node.kwargs), args.add)
|
||||
if node in node_to_additional_deps:
|
||||
args.update(node_to_additional_deps[node])
|
||||
return args
|
||||
|
||||
|
||||
def _detect_cycles(
|
||||
graph: Graph, node_to_additional_deps: dict[Node, OrderedSet[Node]]
|
||||
) -> str:
|
||||
current_path: deque[Node] = deque()
|
||||
current_path_set: set[Node] = set()
|
||||
pending: deque[tuple[Node, Node]] = deque()
|
||||
@ -45,25 +48,30 @@ def _detect_cycles(graph: Graph) -> str:
|
||||
def current_path_head() -> Node:
|
||||
return current_path[-1]
|
||||
|
||||
for origin in graph.find_nodes(op="placeholder"):
|
||||
for origin in graph.find_nodes(op="output"):
|
||||
current_path.clear()
|
||||
current_path_set.clear()
|
||||
add_to_current_path(origin)
|
||||
for child in origin.users:
|
||||
for child in _get_flat_args_unique(origin, node_to_additional_deps):
|
||||
pending.append((child, origin))
|
||||
|
||||
while pending:
|
||||
cur_node, parent = pending.pop()
|
||||
|
||||
while current_path_head() != parent:
|
||||
# handle backtracking
|
||||
while current_path and current_path_head() != parent:
|
||||
pop_current_path()
|
||||
|
||||
if not isinstance(cur_node, Node):
|
||||
continue
|
||||
|
||||
if cur_node in current_path_set:
|
||||
current_path.append(cur_node)
|
||||
return f"cycle detected in path: {current_path}"
|
||||
|
||||
add_to_current_path(cur_node)
|
||||
for child in cur_node.users:
|
||||
|
||||
for child in _get_flat_args_unique(cur_node, node_to_additional_deps):
|
||||
pending.append((child, cur_node))
|
||||
|
||||
return "no cycle detected"
|
||||
|
||||
@ -92,6 +92,8 @@ from torch.nn.modules.lazy import LazyModuleMixin
|
||||
from torch.utils._triton import has_triton, has_triton_package
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
from .graph_utils import _get_flat_args
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from collections.abc import (
|
||||
@ -3151,6 +3153,20 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
|
||||
tx, (node.args, node.kwargs), allow_non_graph_fake
|
||||
)
|
||||
|
||||
if (
|
||||
torch._dynamo.config.use_graph_deduplication
|
||||
or torch._dynamo.config.track_nodes_for_deduplication
|
||||
):
|
||||
flat_args_kwargs = get_fake_values_from_nodes(
|
||||
tx, _get_flat_args(node, {}), allow_non_graph_fake
|
||||
)
|
||||
id_to_initial_version = {
|
||||
id(arg): arg._version for arg in flat_args_kwargs if is_fake(arg)
|
||||
}
|
||||
else:
|
||||
flat_args_kwargs = []
|
||||
id_to_initial_version = {}
|
||||
|
||||
nnmodule = None
|
||||
if op == "call_method" and len(args) > 0 and isinstance(args[0], torch.nn.Module):
|
||||
# If the first argument is nn.Module, should copy to fake mode.
|
||||
@ -3290,6 +3306,17 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
|
||||
_ = pytree.tree_map_only(
|
||||
torch.Tensor, functools.partial(ensure_graph_fake, tx=tx), ret_val
|
||||
)
|
||||
|
||||
if (
|
||||
torch._dynamo.config.use_graph_deduplication
|
||||
or torch._dynamo.config.track_nodes_for_deduplication
|
||||
):
|
||||
tx.output.region_tracker.track_node_mutations(
|
||||
node,
|
||||
flat_args_kwargs,
|
||||
id_to_initial_version,
|
||||
)
|
||||
|
||||
return ret_val
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user