Compare commits

...

4 Commits

Author SHA1 Message Date
525d34447e [Dynamo] Refactor to use install subgraph method in higher order ops
ghstack-source-id: 2b4cbdbd5f1986edac2abd62bee940b529678531
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141384
2024-12-10 11:55:58 -08:00
7da79a64aa [Dynamo] Initial deduplication pass impl
ghstack-source-id: 01499b26ef721938dcc9c469e50b446b6bf6eeaa
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141383

graph dedup tests

Remove region tracker

Rm option

Use config

Fix dedup

Fixes2

Fixes for dedup

update comment

Fixes
2024-12-10 11:55:57 -08:00
2fd48c7565 [Dynamo] add debug logging for graph region expansion
ghstack-source-id: 181e08854931debc5a2105d73b2f7790eecd1515
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141382
2024-12-10 11:55:57 -08:00
223d363f12 [Dynamo] Implement graph region tracking for deduplication
Fixes for bfs

Initial region tracking tests

hashing + test fixes

fix for hash

more tests

more fixes

region tracker updates

Update tests

Update tests2

ghstack-source-id: 433c97e4fb4e6c629460d7bec4cc129b32cba861
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141381

Fixes for region tracking

Reuse input hashing

fixes

fixes2

Fixes

Fixes

Fixes2

f

f

f

f2

f2

PR

Add config

Fixes for tracking

Fix test

fixes
2024-12-10 11:55:56 -08:00
11 changed files with 1559 additions and 50 deletions

View File

@ -0,0 +1,591 @@
# Owner(s): ["module: dynamo"]
import torch
import torch.fx
from torch._dynamo.test_case import TestCase
from torch._dynamo.testing import AotEagerAndRecordGraphs, normalize_gm
def extract_graph(fn, *args, **kwargs):
backend = AotEagerAndRecordGraphs()
result = torch.compile(backend=backend)(fn)(*args, **kwargs)
return result, backend.graphs, backend.fw_graphs
def graph_str(gm):
return normalize_gm(gm.print_readable(print_output=False))
class GraphDededuplicationTests(TestCase):
def run_and_return_graphs(self, fn, *args, **kwargs):
with torch._dynamo.config.patch("use_graph_deduplication", True):
return extract_graph(fn, *args, **kwargs)
def test_single_subgraph(self):
def inner_fn(x, y):
x0 = x + 1
y0 = y + 2
z = x0.sum() + y0.sum()
return z
def fn(x, y):
o0 = inner_fn(x, y)
o1 = torch.sin(y)
o2 = inner_fn(x, o1)
o3 = inner_fn(x, y)
o4 = o3 * o3
return o2 * o4
x = torch.rand(10, 10, requires_grad=True)
y = torch.rand(10, 20, requires_grad=True)
x_clone = x.clone().requires_grad_(True)
y_clone = y.clone().requires_grad_(True)
ref_result = fn(x, y)
result, graphs, fw_graphs = self.run_and_return_graphs(fn, x_clone, y_clone)
torch.allclose(ref_result, result)
ref_result.sum().backward()
result.sum().backward()
self.assertEqual(len(graphs), 1)
self.assertEqual(len(fw_graphs), 1)
self.assertExpectedInline(
graph_str(graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[10, 10]", L_y_: "f32[10, 20]"):
subgraph_0 = self.subgraph_0
l_x_ = L_x_
l_y_ = L_y_
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', \
(l_y_, l_x_)); invoke_subgraph = None
o1: "f32[10, 20]" = torch.sin(l_y_)
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', \
(o1, l_x_)); o1 = None
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', \
(l_y_, l_x_)); subgraph_0 = l_y_ = l_x_ = None
getitem_2: "f32[]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
o4: "f32[]" = getitem_2 * getitem_2; getitem_2 = None
mul_1: "f32[]" = getitem_1 * o4; getitem_1 = o4 = None
return (mul_1,)
class subgraph_0(torch.nn.Module):
def forward(self, subgraph_input_l_y_, subgraph_input_l_x_):
y0: "f32[10, 20]" = subgraph_input_l_y_ + 2; subgraph_input_l_y_ = None
x0: "f32[10, 10]" = subgraph_input_l_x_ + 1; subgraph_input_l_x_ = None
sum_2: "f32[]" = y0.sum(); y0 = None
sum_1: "f32[]" = x0.sum(); x0 = None
z: "f32[]" = sum_1 + sum_2; sum_1 = sum_2 = None
return (z,)
""",
)
self.assertExpectedInline(
graph_str(fw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[10, 10]", primals_2: "f32[10, 20]"):
sin: "f32[10, 20]" = torch.ops.aten.sin.default(primals_2)
repeated_subgraph0_1 = self.repeated_subgraph0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, \
'___forward_subgraph_0', (sin, primals_1)); repeated_subgraph0_1 = None
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
repeated_subgraph0_2 = self.repeated_subgraph0
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_2, \
'___forward_subgraph_0', (primals_2, primals_1)); repeated_subgraph0_2 = None
getitem_2: "f32[]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
mul: "f32[]" = torch.ops.aten.mul.Tensor(getitem_2, getitem_2)
mul_1: "f32[]" = torch.ops.aten.mul.Tensor(getitem_1, mul); mul = None
return (mul_1, primals_1, primals_2, sin, getitem_1, getitem_2)
class repeated_subgraph0(torch.nn.Module):
def forward(self, arg0_1: "f32[10, 20]", arg1_1: "f32[10, 10]"):
add: "f32[10, 20]" = torch.ops.aten.add.Tensor(arg0_1, 2); arg0_1 = None
add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(arg1_1, 1); arg1_1 = None
sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None
sum_2: "f32[]" = torch.ops.aten.sum.default(add_1); add_1 = None
add_2: "f32[]" = torch.ops.aten.add.Tensor(sum_2, sum_1); sum_2 = sum_1 = None
return (add_2,)
""",
)
def test_single_subgraph2(self):
def fn(x):
x0 = x + 2
o = inner_fn(x0)
o = torch.cos(o)
o = inner_fn(o)
return torch.sin(o)
def inner_fn(x):
o = x * 7
o += 1
o += 2
return o
x = torch.rand(10, 10, requires_grad=True)
x_clone = x.clone().requires_grad_(True)
ref_result = fn(x)
result, graphs, fw_graphs = self.run_and_return_graphs(fn, x_clone)
torch.allclose(ref_result, result)
ref_result.sum().backward()
result.sum().backward()
self.assertEqual(len(graphs), 1)
self.assertEqual(len(fw_graphs), 1)
self.assertExpectedInline(
graph_str(graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[10, 10]"):
subgraph_0 = self.subgraph_0
l_x_ = L_x_
x0: "f32[10, 10]" = l_x_ + 2; l_x_ = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (x0,)); x0 = None
getitem: "f32[10, 10]" = invoke_subgraph[0]; invoke_subgraph = None
o_3: "f32[10, 10]" = torch.cos(getitem); getitem = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (o_3,)); subgraph_0 = o_3 = None
getitem_1: "f32[10, 10]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
sin: "f32[10, 10]" = torch.sin(getitem_1); getitem_1 = None
return (sin,)
class subgraph_0(torch.nn.Module):
def forward(self, subgraph_input_x0):
o: "f32[10, 10]" = subgraph_input_x0 * 7; subgraph_input_x0 = None
o += 1; o_1: "f32[10, 10]" = o; o = None
o_1 += 2; o_2: "f32[10, 10]" = o_1; o_1 = None
return (o_2,)
""",
)
self.assertExpectedInline(
graph_str(fw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[10, 10]"):
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(primals_1, 2); primals_1 = None
repeated_subgraph0 = self.repeated_subgraph0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, \
'___forward_subgraph_0', (add,)); repeated_subgraph0 = None
getitem: "f32[10, 10]" = invoke_subgraph[0]; invoke_subgraph = None
cos: "f32[10, 10]" = torch.ops.aten.cos.default(getitem)
repeated_subgraph0_1 = self.repeated_subgraph0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, \
'___forward_subgraph_0', (cos,)); repeated_subgraph0_1 = None
getitem_1: "f32[10, 10]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
sin: "f32[10, 10]" = torch.ops.aten.sin.default(getitem_1)
cos_1: "f32[10, 10]" = torch.ops.aten.cos.default(getitem_1); getitem_1 = None
sin_1: "f32[10, 10]" = torch.ops.aten.sin.default(getitem); getitem = None
neg: "f32[10, 10]" = torch.ops.aten.neg.default(sin_1); sin_1 = None
return (sin, add, cos, cos_1, neg)
class repeated_subgraph0(torch.nn.Module):
def forward(self, arg0_1: "f32[10, 10]"):
mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(arg0_1, 7); arg0_1 = None
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(mul, 1); mul = None
add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, 2); add = None
return (add_1,)
""",
)
def test_multiple_subgraphs(self):
def inner_fn(x, y):
x1 = x + 1
y1 = y + 2
z = x1.sum() + y1.sum()
return z
def inner_fn2(a, b):
a0 = a + 2
b0 = b + 3
c = a0 * b0.cos().sum()
return c
def fn(x, y):
x0 = torch.cos(x)
y0 = torch.sin(y)
o1 = inner_fn2(x0, y0)
o0 = inner_fn(x, y)
o1 = torch.sin(o0)
o2 = inner_fn(x, y0)
o3 = inner_fn2(x0, y0)
o4 = inner_fn(x, y)
return o1 * o2 * o3 + o4
x = torch.rand(10, 10, requires_grad=True)
y = torch.rand(10, 20, requires_grad=True)
x_clone = x.clone().requires_grad_(True)
y_clone = y.clone().requires_grad_(True)
ref_result = fn(x, y)
result, graphs, fw_graphs = self.run_and_return_graphs(fn, x_clone, y_clone)
torch.allclose(ref_result, result)
ref_result.sum().backward()
result.sum().backward()
self.assertEqual(len(graphs), 1)
self.assertEqual(len(fw_graphs), 1)
self.assertExpectedInline(
graph_str(graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[10, 10]", L_y_: "f32[10, 20]"):
subgraph_1 = self.subgraph_1
subgraph_0 = self.subgraph_0
l_x_ = L_x_
l_y_ = L_y_
x0: "f32[10, 10]" = torch.cos(l_x_)
y0: "f32[10, 20]" = torch.sin(l_y_)
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(subgraph_1, \
'subgraph_1', (y0, x0)); invoke_subgraph_3 = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, \
'subgraph_0', (l_y_, l_x_))
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
o1: "f32[]" = torch.sin(getitem); getitem = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, \
'subgraph_0', (y0, l_x_))
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(subgraph_1, \
'subgraph_1', (y0, x0)); subgraph_1 = y0 = x0 = None
getitem_4: "f32[10, 10]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', \
(l_y_, l_x_)); subgraph_0 = l_y_ = l_x_ = None
getitem_2: "f32[]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
mul_2: "f32[]" = o1 * getitem_1; o1 = getitem_1 = None
mul_3: "f32[10, 10]" = mul_2 * getitem_4; mul_2 = getitem_4 = None
add_13: "f32[10, 10]" = mul_3 + getitem_2; mul_3 = getitem_2 = None
return (add_13,)
class subgraph_1(torch.nn.Module):
def forward(self, subgraph_input_y0, subgraph_input_x0):
b0: "f32[10, 20]" = subgraph_input_y0 + 3; subgraph_input_y0 = None
cos_1: "f32[10, 20]" = b0.cos(); b0 = None
sum_1: "f32[]" = cos_1.sum(); cos_1 = None
a0: "f32[10, 10]" = subgraph_input_x0 + 2; subgraph_input_x0 = None
c: "f32[10, 10]" = a0 * sum_1; a0 = sum_1 = None
return (c,)
class subgraph_0(torch.nn.Module):
def forward(self, subgraph_input_l_y_, subgraph_input_l_x_):
y1: "f32[10, 20]" = subgraph_input_l_y_ + 2; subgraph_input_l_y_ = None
x1: "f32[10, 10]" = subgraph_input_l_x_ + 1; subgraph_input_l_x_ = None
sum_3: "f32[]" = y1.sum(); y1 = None
sum_2: "f32[]" = x1.sum(); x1 = None
z: "f32[]" = sum_2 + sum_3; sum_2 = sum_3 = None
return (z,)
""",
)
self.assertExpectedInline(
graph_str(fw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[10, 10]", primals_2: "f32[10, 20]"):
cos: "f32[10, 10]" = torch.ops.aten.cos.default(primals_1)
sin: "f32[10, 20]" = torch.ops.aten.sin.default(primals_2)
repeated_subgraph1 = self.repeated_subgraph1
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph1, \
'___forward_subgraph_0', (primals_2, primals_1)); repeated_subgraph1 = None
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
sin_1: "f32[]" = torch.ops.aten.sin.default(getitem_1)
repeated_subgraph1_1 = self.repeated_subgraph1
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph1_1, \
'___forward_subgraph_0', (sin, primals_1)); repeated_subgraph1_1 = None
getitem_2: "f32[]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
repeated_subgraph0_1 = self.repeated_subgraph0
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, \
'___forward_subgraph_1', (sin, cos)); repeated_subgraph0_1 = None
getitem_3: "f32[10, 10]" = invoke_subgraph_3[0]; invoke_subgraph_3 = None
repeated_subgraph1_2 = self.repeated_subgraph1
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph1_2, \
'___forward_subgraph_0', (primals_2, primals_1)); repeated_subgraph1_2 = None
getitem_4: "f32[]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
mul: "f32[]" = torch.ops.aten.mul.Tensor(sin_1, getitem_2); sin_1 = None
mul_1: "f32[10, 10]" = torch.ops.aten.mul.Tensor(mul, getitem_3); mul = None
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(mul_1, getitem_4); mul_1 = getitem_4 = None
return (add, primals_1, primals_2, cos, sin, getitem_1, getitem_2, getitem_3)
class repeated_subgraph1(torch.nn.Module):
def forward(self, arg0_1: "f32[10, 20]", arg1_1: "f32[10, 10]"):
add: "f32[10, 20]" = torch.ops.aten.add.Tensor(arg0_1, 2); arg0_1 = None
add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(arg1_1, 1); arg1_1 = None
sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None
sum_2: "f32[]" = torch.ops.aten.sum.default(add_1); add_1 = None
add_2: "f32[]" = torch.ops.aten.add.Tensor(sum_2, sum_1); sum_2 = sum_1 = None
return (add_2,)
class repeated_subgraph0(torch.nn.Module):
def forward(self, arg0_1: "f32[10, 20]", arg1_1: "f32[10, 10]"):
add: "f32[10, 20]" = torch.ops.aten.add.Tensor(arg0_1, 3); arg0_1 = None
cos: "f32[10, 20]" = torch.ops.aten.cos.default(add); add = None
sum_1: "f32[]" = torch.ops.aten.sum.default(cos); cos = None
add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(arg1_1, 2); arg1_1 = None
mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(add_1, sum_1); add_1 = sum_1 = None
return (mul,)
""",
)
def test_dependent_subgraphs(self):
def inner_fn(x, y):
x0 = x + 1
y0 = y + 2
z = x0.sum() + y0.sum()
return z
def fn(x, y):
o0 = inner_fn(x, y)
o1 = inner_fn(x, o0)
return o1
x = torch.rand(10, 10, requires_grad=True)
y = torch.rand(10, 20, requires_grad=True)
x_clone = x.clone().requires_grad_(True)
y_clone = y.clone().requires_grad_(True)
ref_result = fn(x, y)
result, graphs, fw_graphs = self.run_and_return_graphs(fn, x_clone, y_clone)
torch.allclose(ref_result, result)
ref_result.sum().backward()
result.sum().backward()
self.assertEqual(len(graphs), 1)
self.assertEqual(len(fw_graphs), 1)
self.assertExpectedInline(
graph_str(fw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[10, 10]", primals_2: "f32[10, 20]"):
add: "f32[10, 20]" = torch.ops.aten.add.Tensor(primals_2, 2); primals_2 = None
sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None
repeated_subgraph0 = self.repeated_subgraph0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, \
'___forward_subgraph_0', (primals_1, sum_1)); repeated_subgraph0 = None
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
add_1: "f32[]" = torch.ops.aten.add.Tensor(getitem, 2); getitem = None
sum_2: "f32[]" = torch.ops.aten.sum.default(add_1); add_1 = None
repeated_subgraph0_1 = self.repeated_subgraph0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, \
'___forward_subgraph_0', (primals_1, sum_2)); repeated_subgraph0_1 = None
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
return (getitem_1, primals_1, sum_1, sum_2)
class repeated_subgraph0(torch.nn.Module):
def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[]"):
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None
sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None
add_1: "f32[]" = torch.ops.aten.add.Tensor(sum_1, arg1_1); sum_1 = arg1_1 = None
return (add_1,)
""",
)
def test_input_mutation(self):
def inner_fn(x, y):
x0 = x + 1
y0 = y + 2
z = x0.sum() + y0.sum()
return z
def inner_fn2(x, y):
x0 = x + 1
y0 = y + 1
x.add_(x0)
y.add_(y0)
return x.sum() + y.sum()
def fn(x, y):
x0 = torch.sin(x)
y0 = torch.cos(y)
# o0 = inner_fn(x0, y0)
# o1 = inner_fn(x0, o0)
o2 = inner_fn2(x0, y)
o3 = inner_fn2(x0.clone(), y.clone())
return o2 + o3
x = torch.rand(10, 10, requires_grad=False)
y = torch.rand(10, 20, requires_grad=False)
x_clone = x.clone()
y_clone = y.clone()
ref_result = fn(x, y)
result, graphs, fw_graphs = self.run_and_return_graphs(fn, x_clone, y_clone)
torch.allclose(ref_result, result)
self.assertEqual(len(graphs), 1)
self.assertEqual(len(fw_graphs), 1)
self.assertExpectedInline(
graph_str(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"):
sin: "f32[10, 10]" = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(sin, 1)
add_1: "f32[10, 20]" = torch.ops.aten.add.Tensor(arg1_1, 1)
add_2: "f32[10, 10]" = torch.ops.aten.add.Tensor(sin, add); sin = add = None
add_3: "f32[10, 20]" = torch.ops.aten.add.Tensor(arg1_1, add_1); add_1 = None
repeated_subgraph0 = self.repeated_subgraph0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, \
'subgraph_0', (add_3, add_2)); repeated_subgraph0 = None
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
clone: "f32[10, 10]" = torch.ops.aten.clone.default(add_2); add_2 = None
clone_1: "f32[10, 20]" = torch.ops.aten.clone.default(add_3)
add_4: "f32[10, 10]" = torch.ops.aten.add.Tensor(clone, 1)
add_5: "f32[10, 20]" = torch.ops.aten.add.Tensor(clone_1, 1)
add_6: "f32[10, 10]" = torch.ops.aten.add.Tensor(clone, add_4); clone = add_4 = None
add_7: "f32[10, 20]" = torch.ops.aten.add.Tensor(clone_1, add_5); clone_1 = add_5 = None
repeated_subgraph0_1 = self.repeated_subgraph0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, \
'subgraph_0', (add_7, add_6)); repeated_subgraph0_1 = add_7 = add_6 = None
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
add_8: "f32[]" = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
copy_: "f32[10, 20]" = torch.ops.aten.copy_.default(arg1_1, add_3); arg1_1 = add_3 = copy_ = None
return (add_8,)
class repeated_subgraph0(torch.nn.Module):
def forward(self, arg0_1: "f32[10, 20]", arg1_1: "f32[10, 10]"):
sum_1: "f32[]" = torch.ops.aten.sum.default(arg0_1); arg0_1 = None
sum_2: "f32[]" = torch.ops.aten.sum.default(arg1_1); arg1_1 = None
add: "f32[]" = torch.ops.aten.add.Tensor(sum_2, sum_1); sum_2 = sum_1 = None
return (add,)
""",
)
def test_input_aliasing(self):
def inner_fn(x, y):
x0 = x.view(x.size())
return x0.view(x.size())
def inner_fn2(x, y):
x = x * 2
y = y * 2
return x.sum() + y.sum()
def fn(x, y):
o0 = inner_fn(x, y)
o1 = inner_fn(x, y)
o2 = inner_fn2(x, y)
o3 = inner_fn2(x, y)
return o0 + o1 + o2.sum() + o3.sum()
x = torch.rand(10, 10, requires_grad=False)
y = torch.rand(10, 20, requires_grad=False)
x_clone = x.clone()
y_clone = y.clone()
ref_result = fn(x, y)
result, graphs, fw_graphs = self.run_and_return_graphs(fn, x_clone, y_clone)
torch.allclose(ref_result, result)
self.assertEqual(len(graphs), 1)
self.assertEqual(len(fw_graphs), 1)
self.assertExpectedInline(
graph_str(fw_graphs[0]),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"):
view: "f32[10, 10]" = torch.ops.aten.view.default(arg0_1, [10, 10])
view_1: "f32[10, 10]" = torch.ops.aten.view.default(view, [10, 10]); view = None
view_2: "f32[10, 10]" = torch.ops.aten.view.default(arg0_1, [10, 10])
view_3: "f32[10, 10]" = torch.ops.aten.view.default(view_2, [10, 10]); view_2 = None
repeated_subgraph0 = self.repeated_subgraph0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, \
'subgraph_0', (arg1_1, arg0_1)); repeated_subgraph0 = None
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
repeated_subgraph0_1 = self.repeated_subgraph0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, \
'subgraph_0', (arg1_1, arg0_1)); repeated_subgraph0_1 = arg1_1 = arg0_1 = None
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(view_1, view_3); view_1 = view_3 = None
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None
add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, sum_1); add = sum_1 = None
sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_1); getitem_1 = None
add_2: "f32[10, 10]" = torch.ops.aten.add.Tensor(add_1, sum_2); add_1 = sum_2 = None
return (add_2,)
class repeated_subgraph0(torch.nn.Module):
def forward(self, arg0_1: "f32[10, 20]", arg1_1: "f32[10, 10]"):
mul: "f32[10, 20]" = torch.ops.aten.mul.Tensor(arg0_1, 2); arg0_1 = None
mul_1: "f32[10, 10]" = torch.ops.aten.mul.Tensor(arg1_1, 2); arg1_1 = None
sum_1: "f32[]" = torch.ops.aten.sum.default(mul); mul = None
sum_2: "f32[]" = torch.ops.aten.sum.default(mul_1); mul_1 = None
add: "f32[]" = torch.ops.aten.add.Tensor(sum_2, sum_1); sum_2 = sum_1 = None
return (add,)
""",
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -0,0 +1,284 @@
# Owner(s): ["module: dynamo"]
import contextlib
import torch
import torch.fx
from torch._dynamo.test_case import TestCase
from torch._dynamo.testing import extract_graph_and_tracker
from torch.utils._pytree import tree_map
def get_nodes_by_name(graph, names):
nodes = []
for node in graph.nodes:
if node.name in names:
nodes.append(node)
return nodes
unique_ind = 0
def track_same_nodes(names, graph, region_tracker):
global unique_ind
unique_ind += 1
# find nodes in graph with names and track them
# as if they were at the same code location
nodes = get_nodes_by_name(graph, names)
for node in nodes:
region_tracker.track_node("x", unique_ind, node)
class GraphRegionTrackerTests(TestCase):
def setUp(self):
self.exit_stack = contextlib.ExitStack()
self.exit_stack.enter_context(
torch._dynamo.config.patch("track_nodes_for_deduplication", True)
)
super().setUp()
def tearDown(self):
self.exit_stack.close()
super().tearDown()
def get_result(self, fn, *args, **kwargs):
graph, region_tracker = extract_graph_and_tracker(fn, *args, **kwargs)
region_groups = region_tracker.get_identical_regions(graph)
region_groups = tree_map(lambda n: n.name, region_groups)
return str(region_groups)
def test_get_regions_single_region_group(self):
def inner_fn(x, y):
x0 = x + 1
y0 = y + 2
z = x0.sum() + y0.sum()
return z
def fn(x, y):
o0 = inner_fn(x, y)
o1 = torch.sin(y)
o2 = inner_fn(x, o1)
o3 = inner_fn(x, y)
o4 = o3 * o3
return o2 * o4
self.assertExpectedInline(
self.get_result(
fn,
torch.rand(10, 10),
torch.ones(10, 20),
),
"""[[['y0', 'x0', 'sum_2', 'sum_1', 'z'], \
['y0_1', 'x0_1', 'sum_4', 'sum_3', 'z_1'], ['y0_2', 'x0_2', 'sum_6', 'sum_5', 'z_2']]]""",
)
def test_get_regions_multiple_region_groups(self):
def inner_fn(x, y):
x1 = x + 1
y1 = y + 2
z = x1.sum() + y1.sum()
return z
def inner_fn2(a, b):
a += 2
b += 3
c = a * b.cos().sum()
return c
def fn(x, y):
x0 = torch.cos(x)
y0 = torch.sin(y)
o1 = inner_fn2(x0, y0)
o0 = inner_fn(x, y)
o1 = torch.sin(o0)
o2 = inner_fn(x, y0)
o2 = inner_fn2(x0, y0)
o3 = inner_fn(x, y)
return o1 * o2 + o3
self.assertExpectedInline(
self.get_result(
fn,
torch.rand(10, 10),
torch.ones(10, 20),
),
"""[[['y1', 'x1', 'sum_3', 'sum_2', 'z'], ['y1_1', 'x1_1', 'sum_5', 'sum_4', 'z_1'], \
['y1_2', 'x1_2', 'sum_8', 'sum_7', 'z_2']], [['b', 'cos_1', 'sum_1', 'a', 'c'], ['b_1', 'cos_2', 'sum_6', 'a_1', 'c_1']]]""",
)
def test_no_single_node_regions(self):
def inner_fn(x):
return x + 1
def fn(x):
o0 = inner_fn(x)
o1 = inner_fn(x)
o2 = inner_fn(x)
return o0 + o1 + o2
self.assertExpectedInline(self.get_result(fn, torch.ones(10, 10)), """[]""")
def test_mismatched_arg_shapes(self):
def inner_fn(x, y):
x1 = x + 1
y1 = y + 2
z = x1.sum() + y1.sum()
return z
def inner_fn2(a, b):
a += 2
b += 3
c = a * b.cos().sum()
return c
def fn(x, y):
x0 = torch.cos(x)
y0 = torch.sin(y)
o1 = inner_fn2(x0, y0)
o0 = inner_fn(x, o1)
o1 = torch.sin(o0)
o2 = inner_fn(x, y0)
o2 = inner_fn2(o2, y0)
o3 = inner_fn(x, y)
return o1 * o2 + o3
self.assertExpectedInline(
self.get_result(
fn,
torch.rand(10, 10),
torch.ones(10, 20),
),
"""[[['y1_1', 'sum_5'], ['y1_2', 'sum_8']], [['x1', 'sum_2', 'z'], ['x1_1', 'sum_4', 'z_1'], \
['x1_2', 'sum_7', 'z_2']], [['b', 'cos_1', 'sum_1'], ['b_1', 'cos_2', 'sum_6']]]""",
)
def test_mismatched_dtypes(self):
def inner_fn(x, y):
x1 = x * 1
y1 = y + 1
return x1 + y1.sum()
def fn(x, y):
x0 = torch.sin(x)
y0 = torch.cos(y)
o0 = inner_fn(x0, y0)
o2 = inner_fn(x0, y0)
o4 = inner_fn(x0, y0)
o5 = inner_fn(x0, y0)
o1 = inner_fn(x0.to(torch.bfloat16), y0.to(torch.bfloat16))
o3 = o1 + o2
return o3 * o0 + o4 + o5
self.assertExpectedInline(
self.get_result(
fn,
torch.rand(10, 10),
torch.ones(10, 20),
),
"""[[['y1', 'sum_1', 'x1', 'o0'], ['y1_1', 'sum_2', 'x1_1', 'o2'], \
['y1_2', 'sum_3', 'x1_2', 'o4'], ['y1_3', 'sum_4', 'x1_3', 'o5']]]""",
)
def test_nested_args(self):
def inner_fn(xs, ys):
out = torch._foreach_add(xs, ys)
return out[0] + out[1].sum()
def fn(x, y, z):
x0 = torch.sin(x)
y0 = torch.cos(y)
z0 = torch.sin(z)
o0 = inner_fn([x0, z0], [x0, y0])
o2 = inner_fn([x0, z0], [x0, y0])
o4 = inner_fn([x0, z0], [x0, y0])
o5 = inner_fn([x0, z0], [x0, y0])
o1 = inner_fn(
[x0.to(torch.bfloat16), z0.to(torch.bfloat16)],
[x0.to(torch.bfloat16), y0.to(torch.bfloat16)],
)
o3 = o1 + o2
return o3 * o0 + o4 + o5
self.assertExpectedInline(
self.get_result(
fn,
torch.rand(10, 10),
torch.rand(10, 20),
torch.ones(10, 20),
),
"""[[['getitem_1', '_foreach_add', 'sum_1', 'getitem', 'o0'], ['getitem_3', \
'_foreach_add_1', 'sum_2', 'getitem_2', 'o2'], ['getitem_5', '_foreach_add_2',\
'sum_3', 'getitem_4', 'o4'], ['getitem_7', '_foreach_add_3', 'sum_4', 'getitem_6', 'o5']]]""",
)
def test_mismatched_global_state(self):
def inner_fn(x, y):
x1 = x * 1
y1 = y + 1
return x1 + y1.sum()
def fn(x, y, c):
x0 = torch.sin(x)
y0 = torch.cos(y)
o4 = inner_fn(x0, y0)
o5 = inner_fn(x0, y0)
if isinstance(c, tuple):
c[0]()
o0 = inner_fn(x0, y0)
o2 = inner_fn(x0, y0)
c[1]()
else:
with c():
o0 = inner_fn(x0, y0)
o2 = inner_fn(x0, y0)
return o0 + o2 + o4 + o5
def create_toggle_fns(property):
old_value = getattr(torch.backends.cuda.matmul, property)
def toggle_property():
setattr(torch.backends.cuda.matmul, property, not old_value)
def reset_property():
setattr(torch.backends.cuda.matmul, property, old_value)
return toggle_property, reset_property
old_dtype = torch.get_default_dtype()
def set_default_dtype_bfloat16():
torch.set_default_dtype(torch.bfloat16)
def reset_default_dtype():
torch.set_default_dtype(old_dtype)
for ctx in [
lambda: torch.set_grad_enabled(False),
torch.autograd.grad_mode.inference_mode,
lambda: torch.autograd.graph.disable_saved_tensors_hooks(
"This is not supported"
),
# lambda: torch.set_num_threads(2), : Unsupported
(set_default_dtype_bfloat16, reset_default_dtype),
(
lambda: torch.use_deterministic_algorithms(True),
lambda: torch.use_deterministic_algorithms(False),
),
# (lambda: torch.use_deterministic_algorithms(True, warn_only=True),
# lambda: torch.use_deterministic_algorithms(False)), : Unsupported
create_toggle_fns("allow_bf16_reduced_precision_reduction"),
create_toggle_fns("allow_fp16_reduced_precision_reduction"),
create_toggle_fns("allow_tf32"),
]:
self.assertExpectedInline(
self.get_result(fn, torch.rand(10, 10), torch.ones(10, 20), ctx),
"""[[['y1_2', 'sum_3', 'x1_2', 'o0'], ['y1_3', 'sum_4', 'x1_3', 'o2']], \
[['y1', 'sum_1', 'x1', 'o4'], ['y1_1', 'sum_2', 'x1_1', 'o5']]]""",
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -10,7 +10,11 @@ import torch
import torch._dynamo.test_case
import torch._dynamo.testing
import torch.distributed as dist
from torch._dynamo.testing import empty_line_normalizer, skipIfNotPy311
from torch._dynamo.testing import (
empty_line_normalizer,
extract_graph_and_tracker,
skipIfNotPy311,
)
from torch._dynamo.trace_rules import _as_posix_path
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing._internal.common_utils import (
@ -731,6 +735,29 @@ TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_pre
self.assertGreater(len(records), 0)
self.assertLess(len(records), 3)
@make_logging_test(graph_region_expansion=True)
def test_graph_region_expansion(self, records):
with torch._dynamo.config.patch("track_nodes_for_deduplication", True):
def inner_fn(x, y):
x0 = x + 1
y0 = y + 2
z = x0.sum() + y0.sum()
return z
def fn(x, y):
o0 = inner_fn(x, y)
o1 = torch.sin(o0)
o2 = inner_fn(x, o1)
o3 = inner_fn(x, y)
return o2 * o3 * o3
graph, tracker = extract_graph_and_tracker(
fn, torch.randn(10, 10), torch.randn(10, 10)
)
tracker.get_identical_regions(graph)
self.assertGreater(len(records), 0)
@skipIfTorchDynamo("too slow")
@make_logging_test(**torch._logging.DEFAULT_LOGGING)
def test_default_logging(self, records):
@ -864,6 +891,7 @@ exclusions = {
"cudagraph_static_inputs",
"benchmarking",
"loop_ordering",
"graph_region_expansion",
}
for name in torch._logging._internal.log_registry.artifact_names:
if name not in exclusions:

View File

@ -382,6 +382,14 @@ enable_cpp_guard_manager = True
# Inline inbuilt nn modules
inline_inbuilt_nn_modules = not is_fbcode()
# Whether to automatically find and replace identical graph
# regions with a call to invoke_subgraph
use_graph_deduplication = False
# Whether to track nodes for deduplication (testing only)
# This flag is ignored if use_graph_deduplication is True
track_nodes_for_deduplication = False
# Issues a warning in Python 3.13.0 for possibly slower guard evaluation and
# instructs user to attempt using 3.13.1+, where the CPython bug is fixed.
# Should be disabled in dynamo-wrapped tests since some tests check that no warnings are issued.

View File

@ -0,0 +1,202 @@
import logging
import operator
from typing import Any, Dict, Iterable, List, Set, Tuple
import torch.fx
from torch._higher_order_ops.utils import has_potential_input_alias_or_mutation
from torch.utils._pytree import tree_flatten
from .graph_region_tracker import Node, Region
log = logging.getLogger(__name__)
def apply_graph_deduplication(output_graph) -> Dict[Node, Node]: # type: ignore[no-untyped-def]
"""
This is the main entry point for applying the graph deduplication pass. \
Deduplication occurs in two phases:
1. Subgraph creation:
Subgraph creation works by taking one representative region from each region \
group and creating a subgraph from it, which will then be used to replace all regions \
in the group. This is implemented by first copying all nodes of the region to the new \
subgraph and then finding all inputs which are not within the region and creating placeholders \
for them. For the outputs, all regions in a region group need to be scanned to ensure the \
largest set of outputs is found, and then an output node is created which returns \
a tuple of all outputs.
2. Graph replacement:
To replace each region with the extracted subgraph, the node index in the region \
and argument index within the node's flattened args and kwargs are recorded once during \
subgraph creation. This allows us to determine which (external to the region) nodes and \
in which order these nodes are passed as inputs. For the outputs, getitem nodes are created \
for each output, and all nodes in the region with external outputs are replaced by the proper \
getitem node. Finally, all original nodes are erased (there should be no uses of these \
left in the graph).
The deduplication mutates the output_graph argument in place.
Returns a mapping of nodes to their subgraph output replacement node to remap outputs
when they are created in output_graph.
"""
duplicated_region_groups = output_graph.region_tracker.get_identical_regions(
output_graph.graph
)
# Used to track which nodes were replaced with subgraph outputs
# today, we have to register the new subgraph submodules before the
# graph outputs have been created, so we pass the replacement mapping
# back to output graph to do the replacements at the site of output creation
output_replacements: Dict[Node, Node] = {}
for region_group in duplicated_region_groups:
inds_with_external_users = _get_all_output_indices(region_group)
region = region_group[0]
(
subgraph,
node_ind_arg_inds,
) = _create_subgraph(region, inds_with_external_users)
sub_gm = torch.fx.GraphModule(output_graph.nn_modules, subgraph)
subgraph_name = output_graph.install_subgraph("subgraph", sub_gm)
with output_graph.graph.inserting_before():
get_subgraph_node = output_graph.graph.create_node(
"get_attr", subgraph_name, (), {}
)
for region in region_group:
_replace_region_with_subgraph(
output_graph.graph,
region,
get_subgraph_node,
node_ind_arg_inds.keys(),
inds_with_external_users,
sub_gm,
subgraph_name,
output_replacements,
)
return output_replacements
def _replace_region_with_subgraph(
graph: torch.fx.Graph,
region: Region,
get_subgraph_node: Node,
node_ind_arg_ind: Iterable[Tuple[int, int]],
inds_with_external_users: List[int],
sub_gm: torch.fx.GraphModule,
subgraph_name: str,
output_replacements: Dict[Node, Node],
) -> None:
sub_args = []
for node_ind, arg_ind in node_ind_arg_ind:
node = region[node_ind]
flattened_args_kwargs, _ = tree_flatten((node.args, node.kwargs))
sub_args.append(flattened_args_kwargs[arg_ind])
invoke_args = (get_subgraph_node, subgraph_name, tuple(sub_args))
fake_inputs = [node.meta["example_value"] for node in sub_args]
if has_potential_input_alias_or_mutation(sub_gm, fake_inputs):
log.debug(
"NYI: Failed to substitute region %s due to input alias or mutation",
region,
)
return
latest_region_node = region[-1]
with graph.inserting_after(latest_region_node):
invoke_subgraph_node = graph.create_node(
"call_function", torch.ops.higher_order.invoke_subgraph, invoke_args, {}
)
with graph.inserting_after(invoke_subgraph_node):
for ind, external_user_ind in enumerate(inds_with_external_users):
node = region[external_user_ind]
subgraph_output = graph.create_node(
"call_function", operator.getitem, (invoke_subgraph_node, ind), {}
)
output_replacements[node] = subgraph_output
node.replace_all_uses_with(subgraph_output, propagate_meta=True)
# Erase in reverse topological order
for node in reversed(region):
graph.erase_node(node)
def _get_external_inputs(
region: Region,
) -> Dict[Node, Tuple[int, int]]:
external_node_to_indices = dict()
region_unique = set(region)
for node_ind, node in enumerate(region):
flattened_args_kwargs, _ = tree_flatten((node.args, node.kwargs))
for arg_ind, in_node in enumerate(flattened_args_kwargs):
if (
in_node not in region_unique
and in_node not in external_node_to_indices
and isinstance(in_node, Node)
):
external_node_to_indices[in_node] = (node_ind, arg_ind)
return external_node_to_indices
def _get_all_output_indices(regions: List[Region]) -> List[int]:
# Scan all regions to get the set of all possible output nodes indices in the region
# perhaps we can record this information during region creation for more efficiency?
inds_with_external_users: Set[int] = set()
for region in regions:
_get_inds_with_external_users(region, inds_with_external_users)
return sorted(inds_with_external_users)
def _get_inds_with_external_users(region: Region, inds_unique: Set[int]) -> None:
for ind, node in enumerate(region):
for user in node.users:
if user not in region:
if ind not in inds_unique:
inds_unique.add(ind)
def _copy_nodes_and_remap_inputs(
subgraph: torch.fx.Graph, region: Region
) -> Dict[Tuple[int, int], Any]:
external_inputs_to_indices = _get_external_inputs(region)
indices_to_placeholder_ind: Dict[Tuple[int, int], Any] = {}
region_to_subgraph_node = {}
for node in external_inputs_to_indices.keys():
placeholder = subgraph.placeholder(f"subgraph_input_{node.name}")
region_to_subgraph_node[node] = placeholder
arg_indices = external_inputs_to_indices[node]
# Note: insertion order matches the order in which placeholders were created
# for the calling convention of the subgraph
indices_to_placeholder_ind[arg_indices] = None
def map_arg(node: Node) -> Node:
if node in region_to_subgraph_node:
return region_to_subgraph_node[node]
else:
return node
for node in region:
subgraph_node = subgraph.node_copy(node, lambda old: map_arg(old))
region_to_subgraph_node[node] = subgraph_node
return indices_to_placeholder_ind
def _create_subgraph_outputs(
subgraph: torch.fx.Graph, inds_to_output: List[int]
) -> None:
node_list = [n for n in subgraph.nodes if n.op not in ("placeholder", "output")]
out_tup = tuple(node_list[ind] for ind in inds_to_output)
subgraph.output(out_tup)
def _create_subgraph(
region: Region,
inds_with_external_users: List[int],
) -> Tuple[torch.fx.Graph, Dict[Tuple[int, int], Any]]:
subgraph: torch.fx.Graph = torch.fx.Graph()
node_ind_input_inds = _copy_nodes_and_remap_inputs(subgraph, region)
_create_subgraph_outputs(subgraph, inds_with_external_users)
return subgraph, node_ind_input_inds

View File

@ -0,0 +1,353 @@
import copyreg
import io
import logging
import math
import pickle
from collections import defaultdict, deque
from dataclasses import fields
from typing import (
Any,
Callable,
Deque,
Dict,
List,
Optional,
Set,
Tuple,
TYPE_CHECKING,
TypeVar,
)
import torch._logging
import torch.fx
from torch._subclasses.fake_tensor import FakeTensor
from torch.utils._pytree import tree_flatten
T = TypeVar("T")
if TYPE_CHECKING:
from .symbolic_convert import InstructionTranslatorBase
Node = torch.fx.Node
Region = List[Node]
IdenticalNodes = List[Node]
GlobalStateKey = Tuple[bool, bool, int, bool, bool, torch.dtype, bool, bool, bool, bool]
log = logging.getLogger(__name__)
graph_expansion_log = torch._logging.getArtifactLogger(
__name__, "graph_region_expansion"
)
def debug_log(msg: str, *args) -> None: # type: ignore[no-untyped-def]
graph_expansion_log.debug(msg, *args)
def _extract_tensor_metadata_for_node_hash(
x: torch.Tensor,
) -> Tuple[Callable[[T], T], Tuple[Any, ...]]:
from torch._inductor.codecache import _ident, extract_tensor_metadata_for_cache_key
out = []
metadata = extract_tensor_metadata_for_cache_key(x)
for field in fields(metadata):
out.append(getattr(metadata, field.name))
return (_ident, tuple(out))
class NodeHashException(Exception):
pass
class InputPickler(pickle.Pickler):
def __init__(self) -> None:
from torch._inductor.codecache import _ident
stream = io.BytesIO()
self._stream = stream
super().__init__(stream)
self.dispatch_table = copyreg.dispatch_table.copy()
self.dispatch_table.update(
{
FakeTensor: _extract_tensor_metadata_for_node_hash,
torch.SymInt: lambda x: (_ident, (str(x),)),
torch.SymBool: lambda x: (_ident, (str(x),)),
torch.SymFloat: lambda x: (_ident, (str(x),)),
}
)
self.fast = True
def dumps(self, obj: Any) -> bytes:
"""
Pickle an object and return a byte string.
"""
try:
self.dump(obj)
return self._stream.getvalue()
except (TypeError, AttributeError) as e:
raise NodeHashException from e
finally:
self._stream.seek(0)
self._stream.truncate(0)
def _extract_tensor_arg(arg: Any) -> Any:
if isinstance(arg, Node):
return arg.meta.get("example_value")
else:
return None
def _normalize_args(
node: Node,
) -> Tuple[Tuple[str, ...], Tuple[Optional[Any], ...]]:
flat_args, _ = tree_flatten(node.args)
sorted_kwargs = sorted(node.kwargs.items(), key=lambda x: x[0])
sorted_keys = tuple(sorted(node.kwargs.keys()))
flat_kwargs, _ = tree_flatten(sorted_kwargs)
all_args = flat_args + flat_kwargs
return (sorted_keys, tuple(_extract_tensor_arg(arg) for arg in all_args))
def get_global_state_key() -> GlobalStateKey:
return (
torch.is_grad_enabled(),
torch.is_inference_mode_enabled(),
torch.get_num_threads(),
torch._C._get_cublas_allow_fp16_reduced_precision_reduction(),
torch._C._get_cublas_allow_bf16_reduced_precision_reduction(),
torch.get_default_dtype(),
torch.are_deterministic_algorithms_enabled(),
torch._C._get_cublas_allow_tf32(),
torch.is_deterministic_algorithms_warn_only_enabled(),
torch._C._autograd._saved_tensors_hooks_is_enabled(), # type: ignore[attr-defined]
)
# This is typical BFS with the caveat
# that a node's children need to be explicitly
# added with the add_children() method
# The flow is yield a node and check if it's valid for all regions
# if not valid, discard and continue onto the next node
# Note: this iterates backward through the graph by looking at args/kwargs
# of a node
class BackwardBfsArgIter:
def __init__(self, origin: Node) -> None:
self._cur: Optional[Node] = origin
self._queue: Deque[Optional[Node]] = deque()
@staticmethod
def create(origin: Node) -> "BackwardBfsArgIter":
it = BackwardBfsArgIter(origin)
it.add_children(origin)
return it
def next(self) -> Optional[Node]:
ret = self._cur
if not self._queue:
self._cur = None
else:
self._cur = self._queue.popleft()
return ret
def peek(self) -> Optional[Node]:
return self._cur
def add_children(self, node: Node) -> None:
arg: Any
flat_args, _ = tree_flatten(node.args)
for arg in flat_args:
if isinstance(arg, Node):
self._append(arg)
flat_kwargs, _ = tree_flatten(node.kwargs)
for kwarg in flat_kwargs:
if isinstance(kwarg, Node):
self._append(kwarg)
def _append(self, arg: Node) -> None:
if self._cur is None:
self._cur = arg
else:
self._queue.append(arg)
class GraphRegionTracker:
"""
GraphRegionTracker tracks each node added to the output graph and generates a key based on the source location,
instruction pointer, input shapes, and global state at the time the node is inserted into the graph. Nodes with
the same key are grouped together in a list of identical nodes (the value of node_to_duplicates).
hash_to_duplicates: Dict[str, IdenticalNodes] - A dictionary mapping the key to a list of identical nodes
node_to_duplicates: Dict[Node, IdenticalNodes] - A dictionary mapping a node to the list of identical nodes it belongs to
input_pickler: InputPickler - An instance of InputPickler used to generate a node hash
"""
def __init__(self) -> None:
self.hash_to_duplicates: Dict[str, IdenticalNodes] = defaultdict(list)
self.node_to_duplicates: Dict[Node, IdenticalNodes] = {}
self.input_pickler = InputPickler()
def _hash_node(
self, filename: str, lineno: int, instruction_pointer: Optional[int], node: Node
) -> str:
from torch._inductor.codecache import sha256_hash
key = (
get_global_state_key(),
filename,
lineno,
instruction_pointer,
_normalize_args(node),
)
return sha256_hash(self.input_pickler.dumps(key))
def _is_identical(self, n0: Node, n1: Node) -> bool:
return (
n0 in self.node_to_duplicates
and n1 in self.node_to_duplicates
and self.node_to_duplicates[n0] is self.node_to_duplicates[n1]
and n0 is not n1
)
def track_node(self, tx: "InstructionTranslatorBase", node: Node) -> None:
"""
The main entry point for tracking a node. This function will hash the node argument and group
nodes with the same hash together. It updates the hash_to_duplicates and node_to_duplicates dictionaries
to track the new node.
"""
try:
duplicates = self.hash_to_duplicates[
self._hash_node(
tx.f_code.co_filename, tx.lineno, tx.instruction_pointer, node
)
]
duplicates.append(node)
self.node_to_duplicates[node] = duplicates
except NodeHashException as e:
log.debug("Unable to hash node %s with exception %s", node, e)
def get_identical_regions(self, graph: torch.fx.Graph) -> List[List[Region]]:
"""
This function is responsible for extracting the largest regions of identical nodes from the given graph.
**Note**: This function assumes the nodes that have been tracked with track_node are in the provided graph argument.
The algorithm proceeds as follows:
The nodes tracked via track_node above are organized into region groups. The initial region groups look like this:
[[IdenticalNode1], [IdenticalNode2], [IdenticalNode3]] and each sublist is called a region. For each region group
(starting at the topologically latest region group), the inner regions are gradually expanded one node at time from
the flattened args and kwargs of the node in each region provided that for all regions in the group, the nodes being
added are also identical (ie have the same key computed by track_node). This is checked by verifying that the two
nodes have the same identical node list in node_to_duplicates.
"""
topological_ranking = {node: i for i, node in enumerate(graph.nodes)}
region_groups_with_rank = []
# Create region groups; a region group is a group
# of regions that are all identical. In this initial state
# each region in the group is a single node, and we discard
# groups that are only a single region.
# We track the topological ranking to start with groups later in the graph
# the reason for this is that we will necessarily create the largest groups first.
for group in self.hash_to_duplicates.values():
if len(group) > 1:
region_group = []
min_rank = math.inf
for node in group:
min_rank = min(min_rank, topological_ranking[node])
region_group.append([node])
region_groups_with_rank.append((region_group, min_rank))
region_groups_with_rank.sort(key=lambda rg: -rg[1])
region_groups = [rg for rg, _ in region_groups_with_rank]
# We start from regions later in the graph and expand them earlier
# as a result, we will create the largest regions first and they won't
# overlap.
seen_nodes: Set[Node] = set()
for region_group in region_groups:
fully_expand_region_group(region_group, seen_nodes, self._is_identical)
return [
region_group for region_group in region_groups if len(region_group[0]) > 1
]
def __str__(self) -> str:
return f"GraphRegionTracker(hash_to_duplicates={self.hash_to_duplicates}, node_to_duplicates={self.node_to_duplicates})"
def fully_expand_region_group(
regions: List[Region],
seen_nodes: Set[Node],
is_identical_fn: Callable[[Node, Node], bool],
) -> None:
debug_log("--------------------------------------------------")
debug_log("expanding new region group: %s", regions)
# All regions should start with 1 node
assert all(len(region) == 1 for region in regions)
region_iters = []
for region in regions:
(origin,) = region # Only works for 1 element sets
region_iters.append(BackwardBfsArgIter.create(origin))
nodes_to_add: List[Node] = []
# we already have the origin node in each region
for region_it in region_iters:
node = region_it.next()
assert node
region_it.add_children(node)
current_node = region_iters[0].next()
assert current_node is not None
# Loop incrementally adding new nodes to each region
# regions are only expanded if the node to add is valid
# for ALL regions
while current_node:
add_node = True
nodes_to_add.clear()
nodes_to_add.append(current_node)
nodes_to_add_set = set(nodes_to_add)
for region_it in region_iters[1:]:
node = region_it.next()
debug_log("--------------------")
debug_log("considering adding: %s, cur_node: %s", node, current_node)
debug_log("previously claimed nodes: %s", node in seen_nodes)
debug_log("%s", seen_nodes)
if node:
debug_log("is_identical: %s", is_identical_fn(node, current_node))
add_node &= (
node not in seen_nodes
and node not in nodes_to_add_set
and is_identical_fn(node, current_node)
)
nodes_to_add.append(node)
nodes_to_add_set.add(node)
else:
add_node = False
debug_log("--------------------")
if add_node:
for region, region_it, node in zip(regions, region_iters, nodes_to_add):
region.append(node)
debug_log("adding %s's children", node)
debug_log("%s %s", node.args, list(node.kwargs.items()))
region_it.add_children(node)
seen_nodes.add(node)
current_node = region_iters[0].next()
# Ensure regions are sorted in topological order
for region in regions:
region.reverse()
debug_log("end expand new region group: %s", regions)
debug_log("--------------------------------------------------")

View File

@ -72,6 +72,8 @@ from .exc import (
unimplemented,
unimplemented_with_warning,
)
from .graph_deduplication import apply_graph_deduplication
from .graph_region_tracker import GraphRegionTracker
from .guards import GuardBuilder, install_guard
from .mutation_guard import is_dynamic_nn_module
from .side_effects import AttributeMutationExisting, SideEffects
@ -297,6 +299,8 @@ class OutputGraph:
"co_firstlineno": f_code.co_firstlineno,
}
self.region_tracker = GraphRegionTracker()
# tracked_fakes says where any tensor that was wrapped to fake came
# from. It is similar to GraphArg, in that all GraphArgs will get
# will get added to TrackedFakes, but TrackedFakes also contains
@ -1015,6 +1019,8 @@ class OutputGraph:
for value in stack_values:
value.realize()
output_replacements = self.dedup_pass()
# Use nn.Module "proxies" in the constructed GraphModule so that
# the resulting GM does not hold additional strong references to the original modules.
# This prevents a strong ref cycle where Dynamo created code holds on to references
@ -1098,7 +1104,9 @@ class OutputGraph:
append_prefix_insts()
# optimization to generate better code in a common case
self.add_output_instructions(
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
self.compile_and_call_fx_graph(
tx, list(reversed(stack_values)), root, output_replacements
)
+ [create_instruction("UNPACK_SEQUENCE", arg=len(stack_values))]
)
# restore all the live local vars
@ -1131,7 +1139,9 @@ class OutputGraph:
output = []
if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:
output.extend(
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
self.compile_and_call_fx_graph(
tx, pass2.graph_output_vars(), root, output_replacements
)
)
if len(pass2.graph_outputs) != 0:
@ -1292,7 +1302,7 @@ class OutputGraph:
tx.speculation_log.clear()
raise exc.CompileCollectiveRestartAnalysis
def compile_and_call_fx_graph(self, tx, rv, root):
def compile_and_call_fx_graph(self, tx, rv, root, replaced_outputs):
"""
Generate code from self.graph and return the Instruction()s to
call that generated code.
@ -1308,12 +1318,17 @@ class OutputGraph:
assert isinstance(rv, list)
assert isinstance(root, FakeRootModule)
output_node = self.create_node(
"output",
"output",
(self.current_tracer.create_arg(tuple(x.as_proxy() for x in rv)),),
{},
)
for old_node, new_node in replaced_outputs.items():
old_node.replace_all_uses_with(new_node)
tx.output.current_tracer._maybe_preserve_original_meta(tx, output_node)
if not config.do_not_emit_runtime_asserts:
insert_deferred_runtime_asserts(
@ -1490,6 +1505,29 @@ class OutputGraph:
return compiled_fn
def dedup_pass(self):
if torch._dynamo.config.use_graph_deduplication:
return apply_graph_deduplication(self)
else:
return dict()
def install_subgraph(self, name, sub_gm):
next_name = None
i = 0
while not next_name:
candidate = f"{name}_{i}"
if candidate in self.nn_modules:
i += 1
else:
next_name = candidate
sub_gm.__name__ = next_name
sub_gm.torchdynamo_force_dynamic = False
# This graph module is not present in the user space, so it can't be
# accessed by a source. Set source=None.
self.register_attr_or_module(sub_gm, next_name, source=None)
return next_name
def example_inputs(self) -> List[torch.Tensor]:
result = [arg.example for arg in self.graphargs]
return result
@ -2104,6 +2142,13 @@ class SubgraphTracer(fx.Tracer):
msgs = traceback.StackSummary.from_list(frame_summaries).format()
rv.node.stack_trace = "".join(msgs)
if (
torch._dynamo.config.use_graph_deduplication
or torch._dynamo.config.track_nodes_for_deduplication
):
self.output_graph.region_tracker.track_node(
self.output_graph.current_tx, rv.node
)
return rv
def create_node(

View File

@ -62,6 +62,23 @@ def remove_optimized_module_prefix(name: str) -> str:
return re.sub(r"^_orig_mod[.]", "", name)
def extract_graph_and_tracker(fn, *args, **kwargs): # type: ignore[no-untyped-def]
from torch._dynamo.symbolic_convert import InstructionTranslator
gm = None
region_tracker = None
def extract_graph_backend(_gm, *args, **kwargs): # type: ignore[no-untyped-def]
nonlocal gm
nonlocal region_tracker
gm = _gm
region_tracker = InstructionTranslator.current_tx().output.region_tracker
return _gm
torch.compile(backend=extract_graph_backend, fullgraph=True)(fn)(*args, **kwargs)
return gm.graph, region_tracker # type: ignore[union-attr]
def collect_results(
model: torch.nn.Module, prediction: Any, loss: Any, example_inputs: Any
) -> List[Any]:

View File

@ -673,24 +673,6 @@ def make_attr(tx: "InstructionTranslator", name):
return node
def add_subgraph(tx: "InstructionTranslator", name, gm):
next_name = None
i = 0
while not next_name:
candidate = f"{name}_{i}"
if candidate in tx.output.nn_modules:
i += 1
else:
next_name = candidate
gm.__name__ = next_name
gm.torchdynamo_force_dynamic = False
# This graph module is not present in the user space, so it can't be
# accessed by a source. Set source=None.
tx.output.register_attr_or_module(gm, next_name, source=None)
return next_name
class TorchHigherOrderOperatorVariable(VariableTracker):
def __init__(
self, value: HigherOrderOperator, source: Optional[Source] = None, **kwargs
@ -928,13 +910,11 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
"false_branch",
)
true_name = add_subgraph(
tx,
true_name = tx.output.install_subgraph(
"cond_true",
torch.fx.GraphModule(true_nn_modules, true_graph),
)
false_name = add_subgraph(
tx,
false_name = tx.output.install_subgraph(
"cond_false",
torch.fx.GraphModule(false_nn_modules, false_graph),
)
@ -1141,13 +1121,11 @@ class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable):
body_nn_modules = dict(tx.output.nn_modules)
cond_name = add_subgraph(
tx,
cond_name = tx.output.install_subgraph(
"cond_fn",
torch.fx.GraphModule(cond_nn_modules, cond_graph),
)
body_name = add_subgraph(
tx,
body_name = tx.output.install_subgraph(
"body_fn",
torch.fx.GraphModule(body_nn_modules, body_graph),
)
@ -1257,7 +1235,9 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
)
combine_gm = torch.fx.GraphModule(dict(tx.output.nn_modules), combine_graph)
combine_fn_name = add_subgraph(tx, "associative_scan_combine_fn", combine_gm)
combine_fn_name = tx.output.install_subgraph(
"associative_scan_combine_fn", combine_gm
)
p_args = (
make_attr(tx, combine_fn_name),
@ -1410,7 +1390,7 @@ class ScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
)
combine_gm = torch.fx.GraphModule(dict(tx.output.nn_modules), combine_graph)
combine_fn_name = add_subgraph(tx, "scan_combine_fn", combine_gm)
combine_fn_name = tx.output.install_subgraph("scan_combine_fn", combine_gm)
p_args = (
make_attr(tx, combine_fn_name),
@ -1522,8 +1502,7 @@ class MapHigherOrderVariable(TorchHigherOrderOperatorVariable):
body_nn_modules = dict(tx.output.nn_modules)
body_name = add_subgraph(
tx,
body_name = tx.output.install_subgraph(
"map_body",
torch.fx.GraphModule(body_nn_modules, body_graph),
)
@ -1619,8 +1598,7 @@ class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable):
def install_subgraph_in_output_graph(
self, tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name="wrap_body"
):
return add_subgraph(
tx,
return tx.output.install_subgraph(
f"{attr_name}",
body_gmod,
)
@ -1756,8 +1734,7 @@ class WrapWithSetGradEnabledHigherOrderVariable(TorchHigherOrderOperatorVariable
)
body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph)
body_name = add_subgraph(
tx,
body_name = tx.output.install_subgraph(
"wrap_body",
body_gmod,
)
@ -1837,8 +1814,7 @@ class WrapWithAutocastHigherOrderVariable(TorchHigherOrderOperatorVariable):
)
body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph)
body_name = add_subgraph(
tx,
body_name = tx.output.install_subgraph(
"wrap_body",
body_gmod,
)
@ -1909,8 +1885,7 @@ class HintsWrapperHigherOrderVariable(TorchHigherOrderOperatorVariable):
)
body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph)
body_name = add_subgraph(
tx,
body_name = tx.output.install_subgraph(
"hints_wrapper_body",
body_gmod,
)
@ -2011,8 +1986,7 @@ class StrictModeHigherOrderVariable(TorchHigherOrderOperatorVariable):
strict_mode_nn_modules = dict(tx.output.nn_modules)
strict_mode_name = add_subgraph(
tx,
strict_mode_name = tx.output.install_subgraph(
"strict_mode_body",
torch.fx.GraphModule(strict_mode_nn_modules, ret_graph),
)
@ -2260,8 +2234,7 @@ class FlexAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
set_subgraph_inputs="flatten_manual",
)
body_name = add_subgraph(
tx,
body_name = tx.output.install_subgraph(
fn_name,
torch.fx.GraphModule(tx.output.nn_modules, body_graph),
)
@ -2544,8 +2517,7 @@ class AutogradFunctionApplyVariable(VariableTracker):
# Store fwd_body
fwd_nn_modules = tx.output.tracing_context.module_context.copy_graphstate()
fwd_name = add_subgraph(
tx,
fwd_name = tx.output.install_subgraph(
"fwd_body",
torch.fx.GraphModule(fwd_nn_modules.nn_modules, fwd_graph),
)
@ -2610,8 +2582,7 @@ class AutogradFunctionApplyVariable(VariableTracker):
# Store bwd_body
bwd_nn_modules = tx.output.tracing_context.module_context.copy_graphstate()
bwd_name = add_subgraph(
tx,
bwd_name = tx.output.install_subgraph(
"bwd_body",
torch.fx.GraphModule(bwd_nn_modules.nn_modules, bwd_graph),
)

View File

@ -240,6 +240,7 @@ def set_logs(
compiled_autograd_verbose: bool = False,
cudagraph_static_inputs: bool = False,
benchmarking: bool = False,
graph_region_expansion: bool = False,
):
"""
Sets the log level for individual components and toggles individual log
@ -416,6 +417,9 @@ def set_logs(
cudagraph_static_inputs (:class:`bool`):
Whether to emit debug info for cudagraph static input detection. Default: ``False``
graph_region_expansion (:class:`bool`):
Whether to emit the detailed steps of the duplicate graph region tracker expansion algorithm. Default: ``False``
Example::
@ -514,6 +518,7 @@ def set_logs(
compiled_autograd_verbose=compiled_autograd_verbose,
cudagraph_static_inputs=cudagraph_static_inputs,
benchmarking=benchmarking,
graph_region_expansion=graph_region_expansion,
)

View File

@ -191,5 +191,10 @@ register_artifact(
"Detailed Inductor benchmarking information.",
off_by_default=True,
)
register_artifact(
"graph_region_expansion",
"Logs detailed steps of the duplicate graph region tracker expansion algorithm",
off_by_default=True,
)
register_artifact("custom_format_test_artifact", "Testing only", log_format="")