mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-13 21:59:56 +08:00
Compare commits
4 Commits
ciflow/tru
...
mlazos/h-c
| Author | SHA1 | Date | |
|---|---|---|---|
| 525d34447e | |||
| 7da79a64aa | |||
| 2fd48c7565 | |||
| 223d363f12 |
591
test/dynamo/test_graph_deduplication.py
Normal file
591
test/dynamo/test_graph_deduplication.py
Normal file
@ -0,0 +1,591 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch._dynamo.test_case import TestCase
|
||||
from torch._dynamo.testing import AotEagerAndRecordGraphs, normalize_gm
|
||||
|
||||
|
||||
def extract_graph(fn, *args, **kwargs):
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
result = torch.compile(backend=backend)(fn)(*args, **kwargs)
|
||||
return result, backend.graphs, backend.fw_graphs
|
||||
|
||||
|
||||
def graph_str(gm):
|
||||
return normalize_gm(gm.print_readable(print_output=False))
|
||||
|
||||
|
||||
class GraphDededuplicationTests(TestCase):
|
||||
def run_and_return_graphs(self, fn, *args, **kwargs):
|
||||
with torch._dynamo.config.patch("use_graph_deduplication", True):
|
||||
return extract_graph(fn, *args, **kwargs)
|
||||
|
||||
def test_single_subgraph(self):
|
||||
def inner_fn(x, y):
|
||||
x0 = x + 1
|
||||
y0 = y + 2
|
||||
z = x0.sum() + y0.sum()
|
||||
return z
|
||||
|
||||
def fn(x, y):
|
||||
o0 = inner_fn(x, y)
|
||||
o1 = torch.sin(y)
|
||||
o2 = inner_fn(x, o1)
|
||||
o3 = inner_fn(x, y)
|
||||
o4 = o3 * o3
|
||||
return o2 * o4
|
||||
|
||||
x = torch.rand(10, 10, requires_grad=True)
|
||||
y = torch.rand(10, 20, requires_grad=True)
|
||||
x_clone = x.clone().requires_grad_(True)
|
||||
y_clone = y.clone().requires_grad_(True)
|
||||
|
||||
ref_result = fn(x, y)
|
||||
result, graphs, fw_graphs = self.run_and_return_graphs(fn, x_clone, y_clone)
|
||||
|
||||
torch.allclose(ref_result, result)
|
||||
ref_result.sum().backward()
|
||||
result.sum().backward()
|
||||
|
||||
self.assertEqual(len(graphs), 1)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertExpectedInline(
|
||||
graph_str(graphs[0]),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_: "f32[10, 10]", L_y_: "f32[10, 20]"):
|
||||
subgraph_0 = self.subgraph_0
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', \
|
||||
(l_y_, l_x_)); invoke_subgraph = None
|
||||
|
||||
o1: "f32[10, 20]" = torch.sin(l_y_)
|
||||
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', \
|
||||
(o1, l_x_)); o1 = None
|
||||
|
||||
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', \
|
||||
(l_y_, l_x_)); subgraph_0 = l_y_ = l_x_ = None
|
||||
|
||||
getitem_2: "f32[]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
|
||||
|
||||
o4: "f32[]" = getitem_2 * getitem_2; getitem_2 = None
|
||||
|
||||
mul_1: "f32[]" = getitem_1 * o4; getitem_1 = o4 = None
|
||||
return (mul_1,)
|
||||
|
||||
class subgraph_0(torch.nn.Module):
|
||||
def forward(self, subgraph_input_l_y_, subgraph_input_l_x_):
|
||||
y0: "f32[10, 20]" = subgraph_input_l_y_ + 2; subgraph_input_l_y_ = None
|
||||
|
||||
x0: "f32[10, 10]" = subgraph_input_l_x_ + 1; subgraph_input_l_x_ = None
|
||||
|
||||
sum_2: "f32[]" = y0.sum(); y0 = None
|
||||
sum_1: "f32[]" = x0.sum(); x0 = None
|
||||
z: "f32[]" = sum_1 + sum_2; sum_1 = sum_2 = None
|
||||
return (z,)
|
||||
""",
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
graph_str(fw_graphs[0]),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[10, 10]", primals_2: "f32[10, 20]"):
|
||||
sin: "f32[10, 20]" = torch.ops.aten.sin.default(primals_2)
|
||||
|
||||
repeated_subgraph0_1 = self.repeated_subgraph0
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, \
|
||||
'___forward_subgraph_0', (sin, primals_1)); repeated_subgraph0_1 = None
|
||||
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
repeated_subgraph0_2 = self.repeated_subgraph0
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_2, \
|
||||
'___forward_subgraph_0', (primals_2, primals_1)); repeated_subgraph0_2 = None
|
||||
getitem_2: "f32[]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
|
||||
|
||||
mul: "f32[]" = torch.ops.aten.mul.Tensor(getitem_2, getitem_2)
|
||||
|
||||
mul_1: "f32[]" = torch.ops.aten.mul.Tensor(getitem_1, mul); mul = None
|
||||
return (mul_1, primals_1, primals_2, sin, getitem_1, getitem_2)
|
||||
|
||||
class repeated_subgraph0(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[10, 20]", arg1_1: "f32[10, 10]"):
|
||||
add: "f32[10, 20]" = torch.ops.aten.add.Tensor(arg0_1, 2); arg0_1 = None
|
||||
add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(arg1_1, 1); arg1_1 = None
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None
|
||||
sum_2: "f32[]" = torch.ops.aten.sum.default(add_1); add_1 = None
|
||||
add_2: "f32[]" = torch.ops.aten.add.Tensor(sum_2, sum_1); sum_2 = sum_1 = None
|
||||
return (add_2,)
|
||||
""",
|
||||
)
|
||||
|
||||
def test_single_subgraph2(self):
|
||||
def fn(x):
|
||||
x0 = x + 2
|
||||
o = inner_fn(x0)
|
||||
o = torch.cos(o)
|
||||
o = inner_fn(o)
|
||||
return torch.sin(o)
|
||||
|
||||
def inner_fn(x):
|
||||
o = x * 7
|
||||
o += 1
|
||||
o += 2
|
||||
return o
|
||||
|
||||
x = torch.rand(10, 10, requires_grad=True)
|
||||
x_clone = x.clone().requires_grad_(True)
|
||||
|
||||
ref_result = fn(x)
|
||||
result, graphs, fw_graphs = self.run_and_return_graphs(fn, x_clone)
|
||||
|
||||
torch.allclose(ref_result, result)
|
||||
ref_result.sum().backward()
|
||||
result.sum().backward()
|
||||
self.assertEqual(len(graphs), 1)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertExpectedInline(
|
||||
graph_str(graphs[0]),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_: "f32[10, 10]"):
|
||||
subgraph_0 = self.subgraph_0
|
||||
l_x_ = L_x_
|
||||
|
||||
x0: "f32[10, 10]" = l_x_ + 2; l_x_ = None
|
||||
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (x0,)); x0 = None
|
||||
|
||||
getitem: "f32[10, 10]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
|
||||
o_3: "f32[10, 10]" = torch.cos(getitem); getitem = None
|
||||
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (o_3,)); subgraph_0 = o_3 = None
|
||||
|
||||
getitem_1: "f32[10, 10]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
|
||||
sin: "f32[10, 10]" = torch.sin(getitem_1); getitem_1 = None
|
||||
return (sin,)
|
||||
|
||||
class subgraph_0(torch.nn.Module):
|
||||
def forward(self, subgraph_input_x0):
|
||||
o: "f32[10, 10]" = subgraph_input_x0 * 7; subgraph_input_x0 = None
|
||||
|
||||
o += 1; o_1: "f32[10, 10]" = o; o = None
|
||||
|
||||
o_1 += 2; o_2: "f32[10, 10]" = o_1; o_1 = None
|
||||
return (o_2,)
|
||||
""",
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
graph_str(fw_graphs[0]),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[10, 10]"):
|
||||
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(primals_1, 2); primals_1 = None
|
||||
|
||||
repeated_subgraph0 = self.repeated_subgraph0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, \
|
||||
'___forward_subgraph_0', (add,)); repeated_subgraph0 = None
|
||||
getitem: "f32[10, 10]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
|
||||
cos: "f32[10, 10]" = torch.ops.aten.cos.default(getitem)
|
||||
|
||||
repeated_subgraph0_1 = self.repeated_subgraph0
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, \
|
||||
'___forward_subgraph_0', (cos,)); repeated_subgraph0_1 = None
|
||||
getitem_1: "f32[10, 10]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
|
||||
sin: "f32[10, 10]" = torch.ops.aten.sin.default(getitem_1)
|
||||
cos_1: "f32[10, 10]" = torch.ops.aten.cos.default(getitem_1); getitem_1 = None
|
||||
|
||||
sin_1: "f32[10, 10]" = torch.ops.aten.sin.default(getitem); getitem = None
|
||||
neg: "f32[10, 10]" = torch.ops.aten.neg.default(sin_1); sin_1 = None
|
||||
return (sin, add, cos, cos_1, neg)
|
||||
|
||||
class repeated_subgraph0(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[10, 10]"):
|
||||
mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(arg0_1, 7); arg0_1 = None
|
||||
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(mul, 1); mul = None
|
||||
add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, 2); add = None
|
||||
return (add_1,)
|
||||
""",
|
||||
)
|
||||
|
||||
def test_multiple_subgraphs(self):
|
||||
def inner_fn(x, y):
|
||||
x1 = x + 1
|
||||
y1 = y + 2
|
||||
z = x1.sum() + y1.sum()
|
||||
return z
|
||||
|
||||
def inner_fn2(a, b):
|
||||
a0 = a + 2
|
||||
b0 = b + 3
|
||||
c = a0 * b0.cos().sum()
|
||||
return c
|
||||
|
||||
def fn(x, y):
|
||||
x0 = torch.cos(x)
|
||||
y0 = torch.sin(y)
|
||||
o1 = inner_fn2(x0, y0)
|
||||
o0 = inner_fn(x, y)
|
||||
o1 = torch.sin(o0)
|
||||
o2 = inner_fn(x, y0)
|
||||
o3 = inner_fn2(x0, y0)
|
||||
o4 = inner_fn(x, y)
|
||||
return o1 * o2 * o3 + o4
|
||||
|
||||
x = torch.rand(10, 10, requires_grad=True)
|
||||
y = torch.rand(10, 20, requires_grad=True)
|
||||
x_clone = x.clone().requires_grad_(True)
|
||||
y_clone = y.clone().requires_grad_(True)
|
||||
|
||||
ref_result = fn(x, y)
|
||||
result, graphs, fw_graphs = self.run_and_return_graphs(fn, x_clone, y_clone)
|
||||
|
||||
torch.allclose(ref_result, result)
|
||||
ref_result.sum().backward()
|
||||
result.sum().backward()
|
||||
self.assertEqual(len(graphs), 1)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
|
||||
self.assertExpectedInline(
|
||||
graph_str(graphs[0]),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_x_: "f32[10, 10]", L_y_: "f32[10, 20]"):
|
||||
subgraph_1 = self.subgraph_1
|
||||
subgraph_0 = self.subgraph_0
|
||||
l_x_ = L_x_
|
||||
l_y_ = L_y_
|
||||
|
||||
x0: "f32[10, 10]" = torch.cos(l_x_)
|
||||
|
||||
y0: "f32[10, 20]" = torch.sin(l_y_)
|
||||
|
||||
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(subgraph_1, \
|
||||
'subgraph_1', (y0, x0)); invoke_subgraph_3 = None
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, \
|
||||
'subgraph_0', (l_y_, l_x_))
|
||||
|
||||
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
|
||||
o1: "f32[]" = torch.sin(getitem); getitem = None
|
||||
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, \
|
||||
'subgraph_0', (y0, l_x_))
|
||||
|
||||
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
|
||||
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(subgraph_1, \
|
||||
'subgraph_1', (y0, x0)); subgraph_1 = y0 = x0 = None
|
||||
|
||||
getitem_4: "f32[10, 10]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
|
||||
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', \
|
||||
(l_y_, l_x_)); subgraph_0 = l_y_ = l_x_ = None
|
||||
|
||||
getitem_2: "f32[]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
|
||||
|
||||
mul_2: "f32[]" = o1 * getitem_1; o1 = getitem_1 = None
|
||||
mul_3: "f32[10, 10]" = mul_2 * getitem_4; mul_2 = getitem_4 = None
|
||||
add_13: "f32[10, 10]" = mul_3 + getitem_2; mul_3 = getitem_2 = None
|
||||
return (add_13,)
|
||||
|
||||
class subgraph_1(torch.nn.Module):
|
||||
def forward(self, subgraph_input_y0, subgraph_input_x0):
|
||||
b0: "f32[10, 20]" = subgraph_input_y0 + 3; subgraph_input_y0 = None
|
||||
|
||||
cos_1: "f32[10, 20]" = b0.cos(); b0 = None
|
||||
sum_1: "f32[]" = cos_1.sum(); cos_1 = None
|
||||
|
||||
a0: "f32[10, 10]" = subgraph_input_x0 + 2; subgraph_input_x0 = None
|
||||
|
||||
c: "f32[10, 10]" = a0 * sum_1; a0 = sum_1 = None
|
||||
return (c,)
|
||||
|
||||
class subgraph_0(torch.nn.Module):
|
||||
def forward(self, subgraph_input_l_y_, subgraph_input_l_x_):
|
||||
y1: "f32[10, 20]" = subgraph_input_l_y_ + 2; subgraph_input_l_y_ = None
|
||||
|
||||
x1: "f32[10, 10]" = subgraph_input_l_x_ + 1; subgraph_input_l_x_ = None
|
||||
|
||||
sum_3: "f32[]" = y1.sum(); y1 = None
|
||||
sum_2: "f32[]" = x1.sum(); x1 = None
|
||||
z: "f32[]" = sum_2 + sum_3; sum_2 = sum_3 = None
|
||||
return (z,)
|
||||
""",
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
graph_str(fw_graphs[0]),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[10, 10]", primals_2: "f32[10, 20]"):
|
||||
cos: "f32[10, 10]" = torch.ops.aten.cos.default(primals_1)
|
||||
|
||||
sin: "f32[10, 20]" = torch.ops.aten.sin.default(primals_2)
|
||||
|
||||
repeated_subgraph1 = self.repeated_subgraph1
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph1, \
|
||||
'___forward_subgraph_0', (primals_2, primals_1)); repeated_subgraph1 = None
|
||||
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
|
||||
sin_1: "f32[]" = torch.ops.aten.sin.default(getitem_1)
|
||||
|
||||
repeated_subgraph1_1 = self.repeated_subgraph1
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph1_1, \
|
||||
'___forward_subgraph_0', (sin, primals_1)); repeated_subgraph1_1 = None
|
||||
getitem_2: "f32[]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
|
||||
repeated_subgraph0_1 = self.repeated_subgraph0
|
||||
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, \
|
||||
'___forward_subgraph_1', (sin, cos)); repeated_subgraph0_1 = None
|
||||
getitem_3: "f32[10, 10]" = invoke_subgraph_3[0]; invoke_subgraph_3 = None
|
||||
repeated_subgraph1_2 = self.repeated_subgraph1
|
||||
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph1_2, \
|
||||
'___forward_subgraph_0', (primals_2, primals_1)); repeated_subgraph1_2 = None
|
||||
getitem_4: "f32[]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
|
||||
|
||||
mul: "f32[]" = torch.ops.aten.mul.Tensor(sin_1, getitem_2); sin_1 = None
|
||||
mul_1: "f32[10, 10]" = torch.ops.aten.mul.Tensor(mul, getitem_3); mul = None
|
||||
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(mul_1, getitem_4); mul_1 = getitem_4 = None
|
||||
return (add, primals_1, primals_2, cos, sin, getitem_1, getitem_2, getitem_3)
|
||||
|
||||
class repeated_subgraph1(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[10, 20]", arg1_1: "f32[10, 10]"):
|
||||
add: "f32[10, 20]" = torch.ops.aten.add.Tensor(arg0_1, 2); arg0_1 = None
|
||||
add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(arg1_1, 1); arg1_1 = None
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None
|
||||
sum_2: "f32[]" = torch.ops.aten.sum.default(add_1); add_1 = None
|
||||
add_2: "f32[]" = torch.ops.aten.add.Tensor(sum_2, sum_1); sum_2 = sum_1 = None
|
||||
return (add_2,)
|
||||
|
||||
class repeated_subgraph0(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[10, 20]", arg1_1: "f32[10, 10]"):
|
||||
add: "f32[10, 20]" = torch.ops.aten.add.Tensor(arg0_1, 3); arg0_1 = None
|
||||
cos: "f32[10, 20]" = torch.ops.aten.cos.default(add); add = None
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(cos); cos = None
|
||||
add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(arg1_1, 2); arg1_1 = None
|
||||
mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(add_1, sum_1); add_1 = sum_1 = None
|
||||
return (mul,)
|
||||
""",
|
||||
)
|
||||
|
||||
def test_dependent_subgraphs(self):
|
||||
def inner_fn(x, y):
|
||||
x0 = x + 1
|
||||
y0 = y + 2
|
||||
z = x0.sum() + y0.sum()
|
||||
return z
|
||||
|
||||
def fn(x, y):
|
||||
o0 = inner_fn(x, y)
|
||||
o1 = inner_fn(x, o0)
|
||||
return o1
|
||||
|
||||
x = torch.rand(10, 10, requires_grad=True)
|
||||
y = torch.rand(10, 20, requires_grad=True)
|
||||
x_clone = x.clone().requires_grad_(True)
|
||||
y_clone = y.clone().requires_grad_(True)
|
||||
|
||||
ref_result = fn(x, y)
|
||||
result, graphs, fw_graphs = self.run_and_return_graphs(fn, x_clone, y_clone)
|
||||
|
||||
torch.allclose(ref_result, result)
|
||||
ref_result.sum().backward()
|
||||
result.sum().backward()
|
||||
self.assertEqual(len(graphs), 1)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertExpectedInline(
|
||||
graph_str(fw_graphs[0]),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[10, 10]", primals_2: "f32[10, 20]"):
|
||||
add: "f32[10, 20]" = torch.ops.aten.add.Tensor(primals_2, 2); primals_2 = None
|
||||
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None
|
||||
|
||||
repeated_subgraph0 = self.repeated_subgraph0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, \
|
||||
'___forward_subgraph_0', (primals_1, sum_1)); repeated_subgraph0 = None
|
||||
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
|
||||
add_1: "f32[]" = torch.ops.aten.add.Tensor(getitem, 2); getitem = None
|
||||
|
||||
sum_2: "f32[]" = torch.ops.aten.sum.default(add_1); add_1 = None
|
||||
|
||||
repeated_subgraph0_1 = self.repeated_subgraph0
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, \
|
||||
'___forward_subgraph_0', (primals_1, sum_2)); repeated_subgraph0_1 = None
|
||||
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
return (getitem_1, primals_1, sum_1, sum_2)
|
||||
|
||||
class repeated_subgraph0(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[]"):
|
||||
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None
|
||||
add_1: "f32[]" = torch.ops.aten.add.Tensor(sum_1, arg1_1); sum_1 = arg1_1 = None
|
||||
return (add_1,)
|
||||
""",
|
||||
)
|
||||
|
||||
def test_input_mutation(self):
|
||||
def inner_fn(x, y):
|
||||
x0 = x + 1
|
||||
y0 = y + 2
|
||||
z = x0.sum() + y0.sum()
|
||||
return z
|
||||
|
||||
def inner_fn2(x, y):
|
||||
x0 = x + 1
|
||||
y0 = y + 1
|
||||
x.add_(x0)
|
||||
y.add_(y0)
|
||||
return x.sum() + y.sum()
|
||||
|
||||
def fn(x, y):
|
||||
x0 = torch.sin(x)
|
||||
y0 = torch.cos(y)
|
||||
# o0 = inner_fn(x0, y0)
|
||||
# o1 = inner_fn(x0, o0)
|
||||
o2 = inner_fn2(x0, y)
|
||||
o3 = inner_fn2(x0.clone(), y.clone())
|
||||
return o2 + o3
|
||||
|
||||
x = torch.rand(10, 10, requires_grad=False)
|
||||
y = torch.rand(10, 20, requires_grad=False)
|
||||
x_clone = x.clone()
|
||||
y_clone = y.clone()
|
||||
|
||||
ref_result = fn(x, y)
|
||||
result, graphs, fw_graphs = self.run_and_return_graphs(fn, x_clone, y_clone)
|
||||
|
||||
torch.allclose(ref_result, result)
|
||||
self.assertEqual(len(graphs), 1)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertExpectedInline(
|
||||
graph_str(fw_graphs[0]),
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"):
|
||||
sin: "f32[10, 10]" = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
|
||||
|
||||
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(sin, 1)
|
||||
|
||||
add_1: "f32[10, 20]" = torch.ops.aten.add.Tensor(arg1_1, 1)
|
||||
|
||||
add_2: "f32[10, 10]" = torch.ops.aten.add.Tensor(sin, add); sin = add = None
|
||||
|
||||
add_3: "f32[10, 20]" = torch.ops.aten.add.Tensor(arg1_1, add_1); add_1 = None
|
||||
|
||||
repeated_subgraph0 = self.repeated_subgraph0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, \
|
||||
'subgraph_0', (add_3, add_2)); repeated_subgraph0 = None
|
||||
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
|
||||
clone: "f32[10, 10]" = torch.ops.aten.clone.default(add_2); add_2 = None
|
||||
clone_1: "f32[10, 20]" = torch.ops.aten.clone.default(add_3)
|
||||
|
||||
add_4: "f32[10, 10]" = torch.ops.aten.add.Tensor(clone, 1)
|
||||
|
||||
add_5: "f32[10, 20]" = torch.ops.aten.add.Tensor(clone_1, 1)
|
||||
|
||||
add_6: "f32[10, 10]" = torch.ops.aten.add.Tensor(clone, add_4); clone = add_4 = None
|
||||
|
||||
add_7: "f32[10, 20]" = torch.ops.aten.add.Tensor(clone_1, add_5); clone_1 = add_5 = None
|
||||
|
||||
repeated_subgraph0_1 = self.repeated_subgraph0
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, \
|
||||
'subgraph_0', (add_7, add_6)); repeated_subgraph0_1 = add_7 = add_6 = None
|
||||
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
|
||||
add_8: "f32[]" = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
|
||||
|
||||
copy_: "f32[10, 20]" = torch.ops.aten.copy_.default(arg1_1, add_3); arg1_1 = add_3 = copy_ = None
|
||||
return (add_8,)
|
||||
|
||||
class repeated_subgraph0(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[10, 20]", arg1_1: "f32[10, 10]"):
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(arg0_1); arg0_1 = None
|
||||
sum_2: "f32[]" = torch.ops.aten.sum.default(arg1_1); arg1_1 = None
|
||||
add: "f32[]" = torch.ops.aten.add.Tensor(sum_2, sum_1); sum_2 = sum_1 = None
|
||||
return (add,)
|
||||
""",
|
||||
)
|
||||
|
||||
def test_input_aliasing(self):
|
||||
def inner_fn(x, y):
|
||||
x0 = x.view(x.size())
|
||||
return x0.view(x.size())
|
||||
|
||||
def inner_fn2(x, y):
|
||||
x = x * 2
|
||||
y = y * 2
|
||||
return x.sum() + y.sum()
|
||||
|
||||
def fn(x, y):
|
||||
o0 = inner_fn(x, y)
|
||||
o1 = inner_fn(x, y)
|
||||
o2 = inner_fn2(x, y)
|
||||
o3 = inner_fn2(x, y)
|
||||
return o0 + o1 + o2.sum() + o3.sum()
|
||||
|
||||
x = torch.rand(10, 10, requires_grad=False)
|
||||
y = torch.rand(10, 20, requires_grad=False)
|
||||
x_clone = x.clone()
|
||||
y_clone = y.clone()
|
||||
|
||||
ref_result = fn(x, y)
|
||||
result, graphs, fw_graphs = self.run_and_return_graphs(fn, x_clone, y_clone)
|
||||
|
||||
torch.allclose(ref_result, result)
|
||||
self.assertEqual(len(graphs), 1)
|
||||
self.assertEqual(len(fw_graphs), 1)
|
||||
self.assertExpectedInline(
|
||||
graph_str(fw_graphs[0]),
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"):
|
||||
view: "f32[10, 10]" = torch.ops.aten.view.default(arg0_1, [10, 10])
|
||||
|
||||
view_1: "f32[10, 10]" = torch.ops.aten.view.default(view, [10, 10]); view = None
|
||||
|
||||
view_2: "f32[10, 10]" = torch.ops.aten.view.default(arg0_1, [10, 10])
|
||||
|
||||
view_3: "f32[10, 10]" = torch.ops.aten.view.default(view_2, [10, 10]); view_2 = None
|
||||
|
||||
repeated_subgraph0 = self.repeated_subgraph0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, \
|
||||
'subgraph_0', (arg1_1, arg0_1)); repeated_subgraph0 = None
|
||||
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
repeated_subgraph0_1 = self.repeated_subgraph0
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, \
|
||||
'subgraph_0', (arg1_1, arg0_1)); repeated_subgraph0_1 = arg1_1 = arg0_1 = None
|
||||
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
|
||||
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(view_1, view_3); view_1 = view_3 = None
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None
|
||||
add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, sum_1); add = sum_1 = None
|
||||
sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_1); getitem_1 = None
|
||||
add_2: "f32[10, 10]" = torch.ops.aten.add.Tensor(add_1, sum_2); add_1 = sum_2 = None
|
||||
return (add_2,)
|
||||
|
||||
class repeated_subgraph0(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[10, 20]", arg1_1: "f32[10, 10]"):
|
||||
mul: "f32[10, 20]" = torch.ops.aten.mul.Tensor(arg0_1, 2); arg0_1 = None
|
||||
mul_1: "f32[10, 10]" = torch.ops.aten.mul.Tensor(arg1_1, 2); arg1_1 = None
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(mul); mul = None
|
||||
sum_2: "f32[]" = torch.ops.aten.sum.default(mul_1); mul_1 = None
|
||||
add: "f32[]" = torch.ops.aten.add.Tensor(sum_2, sum_1); sum_2 = sum_1 = None
|
||||
return (add,)
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
284
test/dynamo/test_graph_region_tracker.py
Normal file
284
test/dynamo/test_graph_region_tracker.py
Normal file
@ -0,0 +1,284 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch._dynamo.test_case import TestCase
|
||||
from torch._dynamo.testing import extract_graph_and_tracker
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
||||
def get_nodes_by_name(graph, names):
|
||||
nodes = []
|
||||
for node in graph.nodes:
|
||||
if node.name in names:
|
||||
nodes.append(node)
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
unique_ind = 0
|
||||
|
||||
|
||||
def track_same_nodes(names, graph, region_tracker):
|
||||
global unique_ind
|
||||
unique_ind += 1
|
||||
# find nodes in graph with names and track them
|
||||
# as if they were at the same code location
|
||||
nodes = get_nodes_by_name(graph, names)
|
||||
for node in nodes:
|
||||
region_tracker.track_node("x", unique_ind, node)
|
||||
|
||||
|
||||
class GraphRegionTrackerTests(TestCase):
|
||||
def setUp(self):
|
||||
self.exit_stack = contextlib.ExitStack()
|
||||
self.exit_stack.enter_context(
|
||||
torch._dynamo.config.patch("track_nodes_for_deduplication", True)
|
||||
)
|
||||
super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
self.exit_stack.close()
|
||||
super().tearDown()
|
||||
|
||||
def get_result(self, fn, *args, **kwargs):
|
||||
graph, region_tracker = extract_graph_and_tracker(fn, *args, **kwargs)
|
||||
region_groups = region_tracker.get_identical_regions(graph)
|
||||
region_groups = tree_map(lambda n: n.name, region_groups)
|
||||
return str(region_groups)
|
||||
|
||||
def test_get_regions_single_region_group(self):
|
||||
def inner_fn(x, y):
|
||||
x0 = x + 1
|
||||
y0 = y + 2
|
||||
z = x0.sum() + y0.sum()
|
||||
return z
|
||||
|
||||
def fn(x, y):
|
||||
o0 = inner_fn(x, y)
|
||||
o1 = torch.sin(y)
|
||||
o2 = inner_fn(x, o1)
|
||||
o3 = inner_fn(x, y)
|
||||
o4 = o3 * o3
|
||||
return o2 * o4
|
||||
|
||||
self.assertExpectedInline(
|
||||
self.get_result(
|
||||
fn,
|
||||
torch.rand(10, 10),
|
||||
torch.ones(10, 20),
|
||||
),
|
||||
"""[[['y0', 'x0', 'sum_2', 'sum_1', 'z'], \
|
||||
['y0_1', 'x0_1', 'sum_4', 'sum_3', 'z_1'], ['y0_2', 'x0_2', 'sum_6', 'sum_5', 'z_2']]]""",
|
||||
)
|
||||
|
||||
def test_get_regions_multiple_region_groups(self):
|
||||
def inner_fn(x, y):
|
||||
x1 = x + 1
|
||||
y1 = y + 2
|
||||
z = x1.sum() + y1.sum()
|
||||
return z
|
||||
|
||||
def inner_fn2(a, b):
|
||||
a += 2
|
||||
b += 3
|
||||
c = a * b.cos().sum()
|
||||
return c
|
||||
|
||||
def fn(x, y):
|
||||
x0 = torch.cos(x)
|
||||
y0 = torch.sin(y)
|
||||
o1 = inner_fn2(x0, y0)
|
||||
o0 = inner_fn(x, y)
|
||||
o1 = torch.sin(o0)
|
||||
o2 = inner_fn(x, y0)
|
||||
o2 = inner_fn2(x0, y0)
|
||||
o3 = inner_fn(x, y)
|
||||
return o1 * o2 + o3
|
||||
|
||||
self.assertExpectedInline(
|
||||
self.get_result(
|
||||
fn,
|
||||
torch.rand(10, 10),
|
||||
torch.ones(10, 20),
|
||||
),
|
||||
"""[[['y1', 'x1', 'sum_3', 'sum_2', 'z'], ['y1_1', 'x1_1', 'sum_5', 'sum_4', 'z_1'], \
|
||||
['y1_2', 'x1_2', 'sum_8', 'sum_7', 'z_2']], [['b', 'cos_1', 'sum_1', 'a', 'c'], ['b_1', 'cos_2', 'sum_6', 'a_1', 'c_1']]]""",
|
||||
)
|
||||
|
||||
def test_no_single_node_regions(self):
|
||||
def inner_fn(x):
|
||||
return x + 1
|
||||
|
||||
def fn(x):
|
||||
o0 = inner_fn(x)
|
||||
o1 = inner_fn(x)
|
||||
o2 = inner_fn(x)
|
||||
return o0 + o1 + o2
|
||||
|
||||
self.assertExpectedInline(self.get_result(fn, torch.ones(10, 10)), """[]""")
|
||||
|
||||
def test_mismatched_arg_shapes(self):
|
||||
def inner_fn(x, y):
|
||||
x1 = x + 1
|
||||
y1 = y + 2
|
||||
z = x1.sum() + y1.sum()
|
||||
return z
|
||||
|
||||
def inner_fn2(a, b):
|
||||
a += 2
|
||||
b += 3
|
||||
c = a * b.cos().sum()
|
||||
return c
|
||||
|
||||
def fn(x, y):
|
||||
x0 = torch.cos(x)
|
||||
y0 = torch.sin(y)
|
||||
o1 = inner_fn2(x0, y0)
|
||||
o0 = inner_fn(x, o1)
|
||||
o1 = torch.sin(o0)
|
||||
o2 = inner_fn(x, y0)
|
||||
o2 = inner_fn2(o2, y0)
|
||||
o3 = inner_fn(x, y)
|
||||
return o1 * o2 + o3
|
||||
|
||||
self.assertExpectedInline(
|
||||
self.get_result(
|
||||
fn,
|
||||
torch.rand(10, 10),
|
||||
torch.ones(10, 20),
|
||||
),
|
||||
"""[[['y1_1', 'sum_5'], ['y1_2', 'sum_8']], [['x1', 'sum_2', 'z'], ['x1_1', 'sum_4', 'z_1'], \
|
||||
['x1_2', 'sum_7', 'z_2']], [['b', 'cos_1', 'sum_1'], ['b_1', 'cos_2', 'sum_6']]]""",
|
||||
)
|
||||
|
||||
def test_mismatched_dtypes(self):
|
||||
def inner_fn(x, y):
|
||||
x1 = x * 1
|
||||
y1 = y + 1
|
||||
return x1 + y1.sum()
|
||||
|
||||
def fn(x, y):
|
||||
x0 = torch.sin(x)
|
||||
y0 = torch.cos(y)
|
||||
o0 = inner_fn(x0, y0)
|
||||
o2 = inner_fn(x0, y0)
|
||||
o4 = inner_fn(x0, y0)
|
||||
o5 = inner_fn(x0, y0)
|
||||
o1 = inner_fn(x0.to(torch.bfloat16), y0.to(torch.bfloat16))
|
||||
o3 = o1 + o2
|
||||
return o3 * o0 + o4 + o5
|
||||
|
||||
self.assertExpectedInline(
|
||||
self.get_result(
|
||||
fn,
|
||||
torch.rand(10, 10),
|
||||
torch.ones(10, 20),
|
||||
),
|
||||
"""[[['y1', 'sum_1', 'x1', 'o0'], ['y1_1', 'sum_2', 'x1_1', 'o2'], \
|
||||
['y1_2', 'sum_3', 'x1_2', 'o4'], ['y1_3', 'sum_4', 'x1_3', 'o5']]]""",
|
||||
)
|
||||
|
||||
def test_nested_args(self):
|
||||
def inner_fn(xs, ys):
|
||||
out = torch._foreach_add(xs, ys)
|
||||
return out[0] + out[1].sum()
|
||||
|
||||
def fn(x, y, z):
|
||||
x0 = torch.sin(x)
|
||||
y0 = torch.cos(y)
|
||||
z0 = torch.sin(z)
|
||||
o0 = inner_fn([x0, z0], [x0, y0])
|
||||
o2 = inner_fn([x0, z0], [x0, y0])
|
||||
o4 = inner_fn([x0, z0], [x0, y0])
|
||||
o5 = inner_fn([x0, z0], [x0, y0])
|
||||
o1 = inner_fn(
|
||||
[x0.to(torch.bfloat16), z0.to(torch.bfloat16)],
|
||||
[x0.to(torch.bfloat16), y0.to(torch.bfloat16)],
|
||||
)
|
||||
o3 = o1 + o2
|
||||
return o3 * o0 + o4 + o5
|
||||
|
||||
self.assertExpectedInline(
|
||||
self.get_result(
|
||||
fn,
|
||||
torch.rand(10, 10),
|
||||
torch.rand(10, 20),
|
||||
torch.ones(10, 20),
|
||||
),
|
||||
"""[[['getitem_1', '_foreach_add', 'sum_1', 'getitem', 'o0'], ['getitem_3', \
|
||||
'_foreach_add_1', 'sum_2', 'getitem_2', 'o2'], ['getitem_5', '_foreach_add_2',\
|
||||
'sum_3', 'getitem_4', 'o4'], ['getitem_7', '_foreach_add_3', 'sum_4', 'getitem_6', 'o5']]]""",
|
||||
)
|
||||
|
||||
def test_mismatched_global_state(self):
|
||||
def inner_fn(x, y):
|
||||
x1 = x * 1
|
||||
y1 = y + 1
|
||||
return x1 + y1.sum()
|
||||
|
||||
def fn(x, y, c):
|
||||
x0 = torch.sin(x)
|
||||
y0 = torch.cos(y)
|
||||
o4 = inner_fn(x0, y0)
|
||||
o5 = inner_fn(x0, y0)
|
||||
if isinstance(c, tuple):
|
||||
c[0]()
|
||||
o0 = inner_fn(x0, y0)
|
||||
o2 = inner_fn(x0, y0)
|
||||
c[1]()
|
||||
else:
|
||||
with c():
|
||||
o0 = inner_fn(x0, y0)
|
||||
o2 = inner_fn(x0, y0)
|
||||
return o0 + o2 + o4 + o5
|
||||
|
||||
def create_toggle_fns(property):
|
||||
old_value = getattr(torch.backends.cuda.matmul, property)
|
||||
|
||||
def toggle_property():
|
||||
setattr(torch.backends.cuda.matmul, property, not old_value)
|
||||
|
||||
def reset_property():
|
||||
setattr(torch.backends.cuda.matmul, property, old_value)
|
||||
|
||||
return toggle_property, reset_property
|
||||
|
||||
old_dtype = torch.get_default_dtype()
|
||||
|
||||
def set_default_dtype_bfloat16():
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
|
||||
def reset_default_dtype():
|
||||
torch.set_default_dtype(old_dtype)
|
||||
|
||||
for ctx in [
|
||||
lambda: torch.set_grad_enabled(False),
|
||||
torch.autograd.grad_mode.inference_mode,
|
||||
lambda: torch.autograd.graph.disable_saved_tensors_hooks(
|
||||
"This is not supported"
|
||||
),
|
||||
# lambda: torch.set_num_threads(2), : Unsupported
|
||||
(set_default_dtype_bfloat16, reset_default_dtype),
|
||||
(
|
||||
lambda: torch.use_deterministic_algorithms(True),
|
||||
lambda: torch.use_deterministic_algorithms(False),
|
||||
),
|
||||
# (lambda: torch.use_deterministic_algorithms(True, warn_only=True),
|
||||
# lambda: torch.use_deterministic_algorithms(False)), : Unsupported
|
||||
create_toggle_fns("allow_bf16_reduced_precision_reduction"),
|
||||
create_toggle_fns("allow_fp16_reduced_precision_reduction"),
|
||||
create_toggle_fns("allow_tf32"),
|
||||
]:
|
||||
self.assertExpectedInline(
|
||||
self.get_result(fn, torch.rand(10, 10), torch.ones(10, 20), ctx),
|
||||
"""[[['y1_2', 'sum_3', 'x1_2', 'o0'], ['y1_3', 'sum_4', 'x1_3', 'o2']], \
|
||||
[['y1', 'sum_1', 'x1', 'o4'], ['y1_1', 'sum_2', 'x1_1', 'o5']]]""",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
@ -10,7 +10,11 @@ import torch
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
import torch.distributed as dist
|
||||
from torch._dynamo.testing import empty_line_normalizer, skipIfNotPy311
|
||||
from torch._dynamo.testing import (
|
||||
empty_line_normalizer,
|
||||
extract_graph_and_tracker,
|
||||
skipIfNotPy311,
|
||||
)
|
||||
from torch._dynamo.trace_rules import _as_posix_path
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing._internal.common_utils import (
|
||||
@ -731,6 +735,29 @@ TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_pre
|
||||
self.assertGreater(len(records), 0)
|
||||
self.assertLess(len(records), 3)
|
||||
|
||||
@make_logging_test(graph_region_expansion=True)
|
||||
def test_graph_region_expansion(self, records):
|
||||
with torch._dynamo.config.patch("track_nodes_for_deduplication", True):
|
||||
|
||||
def inner_fn(x, y):
|
||||
x0 = x + 1
|
||||
y0 = y + 2
|
||||
z = x0.sum() + y0.sum()
|
||||
return z
|
||||
|
||||
def fn(x, y):
|
||||
o0 = inner_fn(x, y)
|
||||
o1 = torch.sin(o0)
|
||||
o2 = inner_fn(x, o1)
|
||||
o3 = inner_fn(x, y)
|
||||
return o2 * o3 * o3
|
||||
|
||||
graph, tracker = extract_graph_and_tracker(
|
||||
fn, torch.randn(10, 10), torch.randn(10, 10)
|
||||
)
|
||||
tracker.get_identical_regions(graph)
|
||||
self.assertGreater(len(records), 0)
|
||||
|
||||
@skipIfTorchDynamo("too slow")
|
||||
@make_logging_test(**torch._logging.DEFAULT_LOGGING)
|
||||
def test_default_logging(self, records):
|
||||
@ -864,6 +891,7 @@ exclusions = {
|
||||
"cudagraph_static_inputs",
|
||||
"benchmarking",
|
||||
"loop_ordering",
|
||||
"graph_region_expansion",
|
||||
}
|
||||
for name in torch._logging._internal.log_registry.artifact_names:
|
||||
if name not in exclusions:
|
||||
|
||||
@ -382,6 +382,14 @@ enable_cpp_guard_manager = True
|
||||
# Inline inbuilt nn modules
|
||||
inline_inbuilt_nn_modules = not is_fbcode()
|
||||
|
||||
# Whether to automatically find and replace identical graph
|
||||
# regions with a call to invoke_subgraph
|
||||
use_graph_deduplication = False
|
||||
|
||||
# Whether to track nodes for deduplication (testing only)
|
||||
# This flag is ignored if use_graph_deduplication is True
|
||||
track_nodes_for_deduplication = False
|
||||
|
||||
# Issues a warning in Python 3.13.0 for possibly slower guard evaluation and
|
||||
# instructs user to attempt using 3.13.1+, where the CPython bug is fixed.
|
||||
# Should be disabled in dynamo-wrapped tests since some tests check that no warnings are issued.
|
||||
|
||||
202
torch/_dynamo/graph_deduplication.py
Normal file
202
torch/_dynamo/graph_deduplication.py
Normal file
@ -0,0 +1,202 @@
|
||||
import logging
|
||||
import operator
|
||||
from typing import Any, Dict, Iterable, List, Set, Tuple
|
||||
|
||||
import torch.fx
|
||||
from torch._higher_order_ops.utils import has_potential_input_alias_or_mutation
|
||||
from torch.utils._pytree import tree_flatten
|
||||
|
||||
from .graph_region_tracker import Node, Region
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def apply_graph_deduplication(output_graph) -> Dict[Node, Node]: # type: ignore[no-untyped-def]
|
||||
"""
|
||||
This is the main entry point for applying the graph deduplication pass. \
|
||||
Deduplication occurs in two phases:
|
||||
1. Subgraph creation:
|
||||
Subgraph creation works by taking one representative region from each region \
|
||||
group and creating a subgraph from it, which will then be used to replace all regions \
|
||||
in the group. This is implemented by first copying all nodes of the region to the new \
|
||||
subgraph and then finding all inputs which are not within the region and creating placeholders \
|
||||
for them. For the outputs, all regions in a region group need to be scanned to ensure the \
|
||||
largest set of outputs is found, and then an output node is created which returns \
|
||||
a tuple of all outputs.
|
||||
|
||||
2. Graph replacement:
|
||||
To replace each region with the extracted subgraph, the node index in the region \
|
||||
and argument index within the node's flattened args and kwargs are recorded once during \
|
||||
subgraph creation. This allows us to determine which (external to the region) nodes and \
|
||||
in which order these nodes are passed as inputs. For the outputs, getitem nodes are created \
|
||||
for each output, and all nodes in the region with external outputs are replaced by the proper \
|
||||
getitem node. Finally, all original nodes are erased (there should be no uses of these \
|
||||
left in the graph).
|
||||
|
||||
The deduplication mutates the output_graph argument in place.
|
||||
|
||||
Returns a mapping of nodes to their subgraph output replacement node to remap outputs
|
||||
when they are created in output_graph.
|
||||
"""
|
||||
duplicated_region_groups = output_graph.region_tracker.get_identical_regions(
|
||||
output_graph.graph
|
||||
)
|
||||
|
||||
# Used to track which nodes were replaced with subgraph outputs
|
||||
# today, we have to register the new subgraph submodules before the
|
||||
# graph outputs have been created, so we pass the replacement mapping
|
||||
# back to output graph to do the replacements at the site of output creation
|
||||
output_replacements: Dict[Node, Node] = {}
|
||||
for region_group in duplicated_region_groups:
|
||||
inds_with_external_users = _get_all_output_indices(region_group)
|
||||
region = region_group[0]
|
||||
(
|
||||
subgraph,
|
||||
node_ind_arg_inds,
|
||||
) = _create_subgraph(region, inds_with_external_users)
|
||||
sub_gm = torch.fx.GraphModule(output_graph.nn_modules, subgraph)
|
||||
subgraph_name = output_graph.install_subgraph("subgraph", sub_gm)
|
||||
with output_graph.graph.inserting_before():
|
||||
get_subgraph_node = output_graph.graph.create_node(
|
||||
"get_attr", subgraph_name, (), {}
|
||||
)
|
||||
for region in region_group:
|
||||
_replace_region_with_subgraph(
|
||||
output_graph.graph,
|
||||
region,
|
||||
get_subgraph_node,
|
||||
node_ind_arg_inds.keys(),
|
||||
inds_with_external_users,
|
||||
sub_gm,
|
||||
subgraph_name,
|
||||
output_replacements,
|
||||
)
|
||||
|
||||
return output_replacements
|
||||
|
||||
|
||||
def _replace_region_with_subgraph(
|
||||
graph: torch.fx.Graph,
|
||||
region: Region,
|
||||
get_subgraph_node: Node,
|
||||
node_ind_arg_ind: Iterable[Tuple[int, int]],
|
||||
inds_with_external_users: List[int],
|
||||
sub_gm: torch.fx.GraphModule,
|
||||
subgraph_name: str,
|
||||
output_replacements: Dict[Node, Node],
|
||||
) -> None:
|
||||
sub_args = []
|
||||
for node_ind, arg_ind in node_ind_arg_ind:
|
||||
node = region[node_ind]
|
||||
flattened_args_kwargs, _ = tree_flatten((node.args, node.kwargs))
|
||||
sub_args.append(flattened_args_kwargs[arg_ind])
|
||||
|
||||
invoke_args = (get_subgraph_node, subgraph_name, tuple(sub_args))
|
||||
fake_inputs = [node.meta["example_value"] for node in sub_args]
|
||||
|
||||
if has_potential_input_alias_or_mutation(sub_gm, fake_inputs):
|
||||
log.debug(
|
||||
"NYI: Failed to substitute region %s due to input alias or mutation",
|
||||
region,
|
||||
)
|
||||
return
|
||||
|
||||
latest_region_node = region[-1]
|
||||
with graph.inserting_after(latest_region_node):
|
||||
invoke_subgraph_node = graph.create_node(
|
||||
"call_function", torch.ops.higher_order.invoke_subgraph, invoke_args, {}
|
||||
)
|
||||
with graph.inserting_after(invoke_subgraph_node):
|
||||
for ind, external_user_ind in enumerate(inds_with_external_users):
|
||||
node = region[external_user_ind]
|
||||
subgraph_output = graph.create_node(
|
||||
"call_function", operator.getitem, (invoke_subgraph_node, ind), {}
|
||||
)
|
||||
output_replacements[node] = subgraph_output
|
||||
node.replace_all_uses_with(subgraph_output, propagate_meta=True)
|
||||
|
||||
# Erase in reverse topological order
|
||||
for node in reversed(region):
|
||||
graph.erase_node(node)
|
||||
|
||||
|
||||
def _get_external_inputs(
|
||||
region: Region,
|
||||
) -> Dict[Node, Tuple[int, int]]:
|
||||
external_node_to_indices = dict()
|
||||
region_unique = set(region)
|
||||
for node_ind, node in enumerate(region):
|
||||
flattened_args_kwargs, _ = tree_flatten((node.args, node.kwargs))
|
||||
for arg_ind, in_node in enumerate(flattened_args_kwargs):
|
||||
if (
|
||||
in_node not in region_unique
|
||||
and in_node not in external_node_to_indices
|
||||
and isinstance(in_node, Node)
|
||||
):
|
||||
external_node_to_indices[in_node] = (node_ind, arg_ind)
|
||||
|
||||
return external_node_to_indices
|
||||
|
||||
|
||||
def _get_all_output_indices(regions: List[Region]) -> List[int]:
|
||||
# Scan all regions to get the set of all possible output nodes indices in the region
|
||||
# perhaps we can record this information during region creation for more efficiency?
|
||||
inds_with_external_users: Set[int] = set()
|
||||
for region in regions:
|
||||
_get_inds_with_external_users(region, inds_with_external_users)
|
||||
|
||||
return sorted(inds_with_external_users)
|
||||
|
||||
|
||||
def _get_inds_with_external_users(region: Region, inds_unique: Set[int]) -> None:
|
||||
for ind, node in enumerate(region):
|
||||
for user in node.users:
|
||||
if user not in region:
|
||||
if ind not in inds_unique:
|
||||
inds_unique.add(ind)
|
||||
|
||||
|
||||
def _copy_nodes_and_remap_inputs(
|
||||
subgraph: torch.fx.Graph, region: Region
|
||||
) -> Dict[Tuple[int, int], Any]:
|
||||
external_inputs_to_indices = _get_external_inputs(region)
|
||||
indices_to_placeholder_ind: Dict[Tuple[int, int], Any] = {}
|
||||
region_to_subgraph_node = {}
|
||||
for node in external_inputs_to_indices.keys():
|
||||
placeholder = subgraph.placeholder(f"subgraph_input_{node.name}")
|
||||
region_to_subgraph_node[node] = placeholder
|
||||
arg_indices = external_inputs_to_indices[node]
|
||||
# Note: insertion order matches the order in which placeholders were created
|
||||
# for the calling convention of the subgraph
|
||||
indices_to_placeholder_ind[arg_indices] = None
|
||||
|
||||
def map_arg(node: Node) -> Node:
|
||||
if node in region_to_subgraph_node:
|
||||
return region_to_subgraph_node[node]
|
||||
else:
|
||||
return node
|
||||
|
||||
for node in region:
|
||||
subgraph_node = subgraph.node_copy(node, lambda old: map_arg(old))
|
||||
region_to_subgraph_node[node] = subgraph_node
|
||||
|
||||
return indices_to_placeholder_ind
|
||||
|
||||
|
||||
def _create_subgraph_outputs(
|
||||
subgraph: torch.fx.Graph, inds_to_output: List[int]
|
||||
) -> None:
|
||||
node_list = [n for n in subgraph.nodes if n.op not in ("placeholder", "output")]
|
||||
out_tup = tuple(node_list[ind] for ind in inds_to_output)
|
||||
subgraph.output(out_tup)
|
||||
|
||||
|
||||
def _create_subgraph(
|
||||
region: Region,
|
||||
inds_with_external_users: List[int],
|
||||
) -> Tuple[torch.fx.Graph, Dict[Tuple[int, int], Any]]:
|
||||
subgraph: torch.fx.Graph = torch.fx.Graph()
|
||||
node_ind_input_inds = _copy_nodes_and_remap_inputs(subgraph, region)
|
||||
_create_subgraph_outputs(subgraph, inds_with_external_users)
|
||||
return subgraph, node_ind_input_inds
|
||||
353
torch/_dynamo/graph_region_tracker.py
Normal file
353
torch/_dynamo/graph_region_tracker.py
Normal file
@ -0,0 +1,353 @@
|
||||
import copyreg
|
||||
import io
|
||||
import logging
|
||||
import math
|
||||
import pickle
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import fields
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Deque,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import torch._logging
|
||||
import torch.fx
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch.utils._pytree import tree_flatten
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .symbolic_convert import InstructionTranslatorBase
|
||||
|
||||
|
||||
Node = torch.fx.Node
|
||||
Region = List[Node]
|
||||
IdenticalNodes = List[Node]
|
||||
GlobalStateKey = Tuple[bool, bool, int, bool, bool, torch.dtype, bool, bool, bool, bool]
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
graph_expansion_log = torch._logging.getArtifactLogger(
|
||||
__name__, "graph_region_expansion"
|
||||
)
|
||||
|
||||
|
||||
def debug_log(msg: str, *args) -> None: # type: ignore[no-untyped-def]
|
||||
graph_expansion_log.debug(msg, *args)
|
||||
|
||||
|
||||
def _extract_tensor_metadata_for_node_hash(
|
||||
x: torch.Tensor,
|
||||
) -> Tuple[Callable[[T], T], Tuple[Any, ...]]:
|
||||
from torch._inductor.codecache import _ident, extract_tensor_metadata_for_cache_key
|
||||
|
||||
out = []
|
||||
metadata = extract_tensor_metadata_for_cache_key(x)
|
||||
for field in fields(metadata):
|
||||
out.append(getattr(metadata, field.name))
|
||||
|
||||
return (_ident, tuple(out))
|
||||
|
||||
|
||||
class NodeHashException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InputPickler(pickle.Pickler):
|
||||
def __init__(self) -> None:
|
||||
from torch._inductor.codecache import _ident
|
||||
|
||||
stream = io.BytesIO()
|
||||
self._stream = stream
|
||||
super().__init__(stream)
|
||||
self.dispatch_table = copyreg.dispatch_table.copy()
|
||||
self.dispatch_table.update(
|
||||
{
|
||||
FakeTensor: _extract_tensor_metadata_for_node_hash,
|
||||
torch.SymInt: lambda x: (_ident, (str(x),)),
|
||||
torch.SymBool: lambda x: (_ident, (str(x),)),
|
||||
torch.SymFloat: lambda x: (_ident, (str(x),)),
|
||||
}
|
||||
)
|
||||
self.fast = True
|
||||
|
||||
def dumps(self, obj: Any) -> bytes:
|
||||
"""
|
||||
Pickle an object and return a byte string.
|
||||
"""
|
||||
try:
|
||||
self.dump(obj)
|
||||
return self._stream.getvalue()
|
||||
except (TypeError, AttributeError) as e:
|
||||
raise NodeHashException from e
|
||||
finally:
|
||||
self._stream.seek(0)
|
||||
self._stream.truncate(0)
|
||||
|
||||
|
||||
def _extract_tensor_arg(arg: Any) -> Any:
|
||||
if isinstance(arg, Node):
|
||||
return arg.meta.get("example_value")
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_args(
|
||||
node: Node,
|
||||
) -> Tuple[Tuple[str, ...], Tuple[Optional[Any], ...]]:
|
||||
flat_args, _ = tree_flatten(node.args)
|
||||
sorted_kwargs = sorted(node.kwargs.items(), key=lambda x: x[0])
|
||||
sorted_keys = tuple(sorted(node.kwargs.keys()))
|
||||
flat_kwargs, _ = tree_flatten(sorted_kwargs)
|
||||
all_args = flat_args + flat_kwargs
|
||||
return (sorted_keys, tuple(_extract_tensor_arg(arg) for arg in all_args))
|
||||
|
||||
|
||||
def get_global_state_key() -> GlobalStateKey:
|
||||
return (
|
||||
torch.is_grad_enabled(),
|
||||
torch.is_inference_mode_enabled(),
|
||||
torch.get_num_threads(),
|
||||
torch._C._get_cublas_allow_fp16_reduced_precision_reduction(),
|
||||
torch._C._get_cublas_allow_bf16_reduced_precision_reduction(),
|
||||
torch.get_default_dtype(),
|
||||
torch.are_deterministic_algorithms_enabled(),
|
||||
torch._C._get_cublas_allow_tf32(),
|
||||
torch.is_deterministic_algorithms_warn_only_enabled(),
|
||||
torch._C._autograd._saved_tensors_hooks_is_enabled(), # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
|
||||
# This is typical BFS with the caveat
|
||||
# that a node's children need to be explicitly
|
||||
# added with the add_children() method
|
||||
# The flow is yield a node and check if it's valid for all regions
|
||||
# if not valid, discard and continue onto the next node
|
||||
# Note: this iterates backward through the graph by looking at args/kwargs
|
||||
# of a node
|
||||
class BackwardBfsArgIter:
|
||||
def __init__(self, origin: Node) -> None:
|
||||
self._cur: Optional[Node] = origin
|
||||
self._queue: Deque[Optional[Node]] = deque()
|
||||
|
||||
@staticmethod
|
||||
def create(origin: Node) -> "BackwardBfsArgIter":
|
||||
it = BackwardBfsArgIter(origin)
|
||||
it.add_children(origin)
|
||||
return it
|
||||
|
||||
def next(self) -> Optional[Node]:
|
||||
ret = self._cur
|
||||
if not self._queue:
|
||||
self._cur = None
|
||||
else:
|
||||
self._cur = self._queue.popleft()
|
||||
return ret
|
||||
|
||||
def peek(self) -> Optional[Node]:
|
||||
return self._cur
|
||||
|
||||
def add_children(self, node: Node) -> None:
|
||||
arg: Any
|
||||
flat_args, _ = tree_flatten(node.args)
|
||||
for arg in flat_args:
|
||||
if isinstance(arg, Node):
|
||||
self._append(arg)
|
||||
|
||||
flat_kwargs, _ = tree_flatten(node.kwargs)
|
||||
for kwarg in flat_kwargs:
|
||||
if isinstance(kwarg, Node):
|
||||
self._append(kwarg)
|
||||
|
||||
def _append(self, arg: Node) -> None:
|
||||
if self._cur is None:
|
||||
self._cur = arg
|
||||
else:
|
||||
self._queue.append(arg)
|
||||
|
||||
|
||||
class GraphRegionTracker:
|
||||
"""
|
||||
GraphRegionTracker tracks each node added to the output graph and generates a key based on the source location,
|
||||
instruction pointer, input shapes, and global state at the time the node is inserted into the graph. Nodes with
|
||||
the same key are grouped together in a list of identical nodes (the value of node_to_duplicates).
|
||||
|
||||
hash_to_duplicates: Dict[str, IdenticalNodes] - A dictionary mapping the key to a list of identical nodes
|
||||
node_to_duplicates: Dict[Node, IdenticalNodes] - A dictionary mapping a node to the list of identical nodes it belongs to
|
||||
input_pickler: InputPickler - An instance of InputPickler used to generate a node hash
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.hash_to_duplicates: Dict[str, IdenticalNodes] = defaultdict(list)
|
||||
self.node_to_duplicates: Dict[Node, IdenticalNodes] = {}
|
||||
self.input_pickler = InputPickler()
|
||||
|
||||
def _hash_node(
|
||||
self, filename: str, lineno: int, instruction_pointer: Optional[int], node: Node
|
||||
) -> str:
|
||||
from torch._inductor.codecache import sha256_hash
|
||||
|
||||
key = (
|
||||
get_global_state_key(),
|
||||
filename,
|
||||
lineno,
|
||||
instruction_pointer,
|
||||
_normalize_args(node),
|
||||
)
|
||||
return sha256_hash(self.input_pickler.dumps(key))
|
||||
|
||||
def _is_identical(self, n0: Node, n1: Node) -> bool:
|
||||
return (
|
||||
n0 in self.node_to_duplicates
|
||||
and n1 in self.node_to_duplicates
|
||||
and self.node_to_duplicates[n0] is self.node_to_duplicates[n1]
|
||||
and n0 is not n1
|
||||
)
|
||||
|
||||
def track_node(self, tx: "InstructionTranslatorBase", node: Node) -> None:
|
||||
"""
|
||||
The main entry point for tracking a node. This function will hash the node argument and group
|
||||
nodes with the same hash together. It updates the hash_to_duplicates and node_to_duplicates dictionaries
|
||||
to track the new node.
|
||||
"""
|
||||
try:
|
||||
duplicates = self.hash_to_duplicates[
|
||||
self._hash_node(
|
||||
tx.f_code.co_filename, tx.lineno, tx.instruction_pointer, node
|
||||
)
|
||||
]
|
||||
duplicates.append(node)
|
||||
self.node_to_duplicates[node] = duplicates
|
||||
except NodeHashException as e:
|
||||
log.debug("Unable to hash node %s with exception %s", node, e)
|
||||
|
||||
def get_identical_regions(self, graph: torch.fx.Graph) -> List[List[Region]]:
|
||||
"""
|
||||
This function is responsible for extracting the largest regions of identical nodes from the given graph.
|
||||
**Note**: This function assumes the nodes that have been tracked with track_node are in the provided graph argument.
|
||||
|
||||
The algorithm proceeds as follows:
|
||||
The nodes tracked via track_node above are organized into region groups. The initial region groups look like this:
|
||||
[[IdenticalNode1], [IdenticalNode2], [IdenticalNode3]] and each sublist is called a region. For each region group
|
||||
(starting at the topologically latest region group), the inner regions are gradually expanded one node at time from
|
||||
the flattened args and kwargs of the node in each region provided that for all regions in the group, the nodes being
|
||||
added are also identical (ie have the same key computed by track_node). This is checked by verifying that the two
|
||||
nodes have the same identical node list in node_to_duplicates.
|
||||
"""
|
||||
topological_ranking = {node: i for i, node in enumerate(graph.nodes)}
|
||||
region_groups_with_rank = []
|
||||
|
||||
# Create region groups; a region group is a group
|
||||
# of regions that are all identical. In this initial state
|
||||
# each region in the group is a single node, and we discard
|
||||
# groups that are only a single region.
|
||||
# We track the topological ranking to start with groups later in the graph
|
||||
# the reason for this is that we will necessarily create the largest groups first.
|
||||
for group in self.hash_to_duplicates.values():
|
||||
if len(group) > 1:
|
||||
region_group = []
|
||||
min_rank = math.inf
|
||||
for node in group:
|
||||
min_rank = min(min_rank, topological_ranking[node])
|
||||
region_group.append([node])
|
||||
|
||||
region_groups_with_rank.append((region_group, min_rank))
|
||||
|
||||
region_groups_with_rank.sort(key=lambda rg: -rg[1])
|
||||
region_groups = [rg for rg, _ in region_groups_with_rank]
|
||||
|
||||
# We start from regions later in the graph and expand them earlier
|
||||
# as a result, we will create the largest regions first and they won't
|
||||
# overlap.
|
||||
seen_nodes: Set[Node] = set()
|
||||
for region_group in region_groups:
|
||||
fully_expand_region_group(region_group, seen_nodes, self._is_identical)
|
||||
|
||||
return [
|
||||
region_group for region_group in region_groups if len(region_group[0]) > 1
|
||||
]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"GraphRegionTracker(hash_to_duplicates={self.hash_to_duplicates}, node_to_duplicates={self.node_to_duplicates})"
|
||||
|
||||
|
||||
def fully_expand_region_group(
|
||||
regions: List[Region],
|
||||
seen_nodes: Set[Node],
|
||||
is_identical_fn: Callable[[Node, Node], bool],
|
||||
) -> None:
|
||||
debug_log("--------------------------------------------------")
|
||||
debug_log("expanding new region group: %s", regions)
|
||||
|
||||
# All regions should start with 1 node
|
||||
assert all(len(region) == 1 for region in regions)
|
||||
region_iters = []
|
||||
for region in regions:
|
||||
(origin,) = region # Only works for 1 element sets
|
||||
region_iters.append(BackwardBfsArgIter.create(origin))
|
||||
|
||||
nodes_to_add: List[Node] = []
|
||||
|
||||
# we already have the origin node in each region
|
||||
for region_it in region_iters:
|
||||
node = region_it.next()
|
||||
assert node
|
||||
region_it.add_children(node)
|
||||
|
||||
current_node = region_iters[0].next()
|
||||
assert current_node is not None
|
||||
# Loop incrementally adding new nodes to each region
|
||||
# regions are only expanded if the node to add is valid
|
||||
# for ALL regions
|
||||
while current_node:
|
||||
add_node = True
|
||||
nodes_to_add.clear()
|
||||
nodes_to_add.append(current_node)
|
||||
nodes_to_add_set = set(nodes_to_add)
|
||||
for region_it in region_iters[1:]:
|
||||
node = region_it.next()
|
||||
|
||||
debug_log("--------------------")
|
||||
debug_log("considering adding: %s, cur_node: %s", node, current_node)
|
||||
debug_log("previously claimed nodes: %s", node in seen_nodes)
|
||||
debug_log("%s", seen_nodes)
|
||||
if node:
|
||||
debug_log("is_identical: %s", is_identical_fn(node, current_node))
|
||||
add_node &= (
|
||||
node not in seen_nodes
|
||||
and node not in nodes_to_add_set
|
||||
and is_identical_fn(node, current_node)
|
||||
)
|
||||
nodes_to_add.append(node)
|
||||
nodes_to_add_set.add(node)
|
||||
else:
|
||||
add_node = False
|
||||
|
||||
debug_log("--------------------")
|
||||
|
||||
if add_node:
|
||||
for region, region_it, node in zip(regions, region_iters, nodes_to_add):
|
||||
region.append(node)
|
||||
debug_log("adding %s's children", node)
|
||||
debug_log("%s %s", node.args, list(node.kwargs.items()))
|
||||
region_it.add_children(node)
|
||||
seen_nodes.add(node)
|
||||
|
||||
current_node = region_iters[0].next()
|
||||
|
||||
# Ensure regions are sorted in topological order
|
||||
for region in regions:
|
||||
region.reverse()
|
||||
|
||||
debug_log("end expand new region group: %s", regions)
|
||||
debug_log("--------------------------------------------------")
|
||||
@ -72,6 +72,8 @@ from .exc import (
|
||||
unimplemented,
|
||||
unimplemented_with_warning,
|
||||
)
|
||||
from .graph_deduplication import apply_graph_deduplication
|
||||
from .graph_region_tracker import GraphRegionTracker
|
||||
from .guards import GuardBuilder, install_guard
|
||||
from .mutation_guard import is_dynamic_nn_module
|
||||
from .side_effects import AttributeMutationExisting, SideEffects
|
||||
@ -297,6 +299,8 @@ class OutputGraph:
|
||||
"co_firstlineno": f_code.co_firstlineno,
|
||||
}
|
||||
|
||||
self.region_tracker = GraphRegionTracker()
|
||||
|
||||
# tracked_fakes says where any tensor that was wrapped to fake came
|
||||
# from. It is similar to GraphArg, in that all GraphArgs will get
|
||||
# will get added to TrackedFakes, but TrackedFakes also contains
|
||||
@ -1015,6 +1019,8 @@ class OutputGraph:
|
||||
for value in stack_values:
|
||||
value.realize()
|
||||
|
||||
output_replacements = self.dedup_pass()
|
||||
|
||||
# Use nn.Module "proxies" in the constructed GraphModule so that
|
||||
# the resulting GM does not hold additional strong references to the original modules.
|
||||
# This prevents a strong ref cycle where Dynamo created code holds on to references
|
||||
@ -1098,7 +1104,9 @@ class OutputGraph:
|
||||
append_prefix_insts()
|
||||
# optimization to generate better code in a common case
|
||||
self.add_output_instructions(
|
||||
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
|
||||
self.compile_and_call_fx_graph(
|
||||
tx, list(reversed(stack_values)), root, output_replacements
|
||||
)
|
||||
+ [create_instruction("UNPACK_SEQUENCE", arg=len(stack_values))]
|
||||
)
|
||||
# restore all the live local vars
|
||||
@ -1131,7 +1139,9 @@ class OutputGraph:
|
||||
output = []
|
||||
if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:
|
||||
output.extend(
|
||||
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
|
||||
self.compile_and_call_fx_graph(
|
||||
tx, pass2.graph_output_vars(), root, output_replacements
|
||||
)
|
||||
)
|
||||
|
||||
if len(pass2.graph_outputs) != 0:
|
||||
@ -1292,7 +1302,7 @@ class OutputGraph:
|
||||
tx.speculation_log.clear()
|
||||
raise exc.CompileCollectiveRestartAnalysis
|
||||
|
||||
def compile_and_call_fx_graph(self, tx, rv, root):
|
||||
def compile_and_call_fx_graph(self, tx, rv, root, replaced_outputs):
|
||||
"""
|
||||
Generate code from self.graph and return the Instruction()s to
|
||||
call that generated code.
|
||||
@ -1308,12 +1318,17 @@ class OutputGraph:
|
||||
|
||||
assert isinstance(rv, list)
|
||||
assert isinstance(root, FakeRootModule)
|
||||
|
||||
output_node = self.create_node(
|
||||
"output",
|
||||
"output",
|
||||
(self.current_tracer.create_arg(tuple(x.as_proxy() for x in rv)),),
|
||||
{},
|
||||
)
|
||||
|
||||
for old_node, new_node in replaced_outputs.items():
|
||||
old_node.replace_all_uses_with(new_node)
|
||||
|
||||
tx.output.current_tracer._maybe_preserve_original_meta(tx, output_node)
|
||||
if not config.do_not_emit_runtime_asserts:
|
||||
insert_deferred_runtime_asserts(
|
||||
@ -1490,6 +1505,29 @@ class OutputGraph:
|
||||
|
||||
return compiled_fn
|
||||
|
||||
def dedup_pass(self):
|
||||
if torch._dynamo.config.use_graph_deduplication:
|
||||
return apply_graph_deduplication(self)
|
||||
else:
|
||||
return dict()
|
||||
|
||||
def install_subgraph(self, name, sub_gm):
|
||||
next_name = None
|
||||
i = 0
|
||||
while not next_name:
|
||||
candidate = f"{name}_{i}"
|
||||
if candidate in self.nn_modules:
|
||||
i += 1
|
||||
else:
|
||||
next_name = candidate
|
||||
|
||||
sub_gm.__name__ = next_name
|
||||
sub_gm.torchdynamo_force_dynamic = False
|
||||
# This graph module is not present in the user space, so it can't be
|
||||
# accessed by a source. Set source=None.
|
||||
self.register_attr_or_module(sub_gm, next_name, source=None)
|
||||
return next_name
|
||||
|
||||
def example_inputs(self) -> List[torch.Tensor]:
|
||||
result = [arg.example for arg in self.graphargs]
|
||||
return result
|
||||
@ -2104,6 +2142,13 @@ class SubgraphTracer(fx.Tracer):
|
||||
msgs = traceback.StackSummary.from_list(frame_summaries).format()
|
||||
rv.node.stack_trace = "".join(msgs)
|
||||
|
||||
if (
|
||||
torch._dynamo.config.use_graph_deduplication
|
||||
or torch._dynamo.config.track_nodes_for_deduplication
|
||||
):
|
||||
self.output_graph.region_tracker.track_node(
|
||||
self.output_graph.current_tx, rv.node
|
||||
)
|
||||
return rv
|
||||
|
||||
def create_node(
|
||||
|
||||
@ -62,6 +62,23 @@ def remove_optimized_module_prefix(name: str) -> str:
|
||||
return re.sub(r"^_orig_mod[.]", "", name)
|
||||
|
||||
|
||||
def extract_graph_and_tracker(fn, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
gm = None
|
||||
region_tracker = None
|
||||
|
||||
def extract_graph_backend(_gm, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
nonlocal gm
|
||||
nonlocal region_tracker
|
||||
gm = _gm
|
||||
region_tracker = InstructionTranslator.current_tx().output.region_tracker
|
||||
return _gm
|
||||
|
||||
torch.compile(backend=extract_graph_backend, fullgraph=True)(fn)(*args, **kwargs)
|
||||
return gm.graph, region_tracker # type: ignore[union-attr]
|
||||
|
||||
|
||||
def collect_results(
|
||||
model: torch.nn.Module, prediction: Any, loss: Any, example_inputs: Any
|
||||
) -> List[Any]:
|
||||
|
||||
@ -673,24 +673,6 @@ def make_attr(tx: "InstructionTranslator", name):
|
||||
return node
|
||||
|
||||
|
||||
def add_subgraph(tx: "InstructionTranslator", name, gm):
|
||||
next_name = None
|
||||
i = 0
|
||||
while not next_name:
|
||||
candidate = f"{name}_{i}"
|
||||
if candidate in tx.output.nn_modules:
|
||||
i += 1
|
||||
else:
|
||||
next_name = candidate
|
||||
|
||||
gm.__name__ = next_name
|
||||
gm.torchdynamo_force_dynamic = False
|
||||
# This graph module is not present in the user space, so it can't be
|
||||
# accessed by a source. Set source=None.
|
||||
tx.output.register_attr_or_module(gm, next_name, source=None)
|
||||
return next_name
|
||||
|
||||
|
||||
class TorchHigherOrderOperatorVariable(VariableTracker):
|
||||
def __init__(
|
||||
self, value: HigherOrderOperator, source: Optional[Source] = None, **kwargs
|
||||
@ -928,13 +910,11 @@ class CondHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
"false_branch",
|
||||
)
|
||||
|
||||
true_name = add_subgraph(
|
||||
tx,
|
||||
true_name = tx.output.install_subgraph(
|
||||
"cond_true",
|
||||
torch.fx.GraphModule(true_nn_modules, true_graph),
|
||||
)
|
||||
false_name = add_subgraph(
|
||||
tx,
|
||||
false_name = tx.output.install_subgraph(
|
||||
"cond_false",
|
||||
torch.fx.GraphModule(false_nn_modules, false_graph),
|
||||
)
|
||||
@ -1141,13 +1121,11 @@ class WhileLoopHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
|
||||
body_nn_modules = dict(tx.output.nn_modules)
|
||||
|
||||
cond_name = add_subgraph(
|
||||
tx,
|
||||
cond_name = tx.output.install_subgraph(
|
||||
"cond_fn",
|
||||
torch.fx.GraphModule(cond_nn_modules, cond_graph),
|
||||
)
|
||||
body_name = add_subgraph(
|
||||
tx,
|
||||
body_name = tx.output.install_subgraph(
|
||||
"body_fn",
|
||||
torch.fx.GraphModule(body_nn_modules, body_graph),
|
||||
)
|
||||
@ -1257,7 +1235,9 @@ class AssociativeScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
)
|
||||
|
||||
combine_gm = torch.fx.GraphModule(dict(tx.output.nn_modules), combine_graph)
|
||||
combine_fn_name = add_subgraph(tx, "associative_scan_combine_fn", combine_gm)
|
||||
combine_fn_name = tx.output.install_subgraph(
|
||||
"associative_scan_combine_fn", combine_gm
|
||||
)
|
||||
|
||||
p_args = (
|
||||
make_attr(tx, combine_fn_name),
|
||||
@ -1410,7 +1390,7 @@ class ScanHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
)
|
||||
|
||||
combine_gm = torch.fx.GraphModule(dict(tx.output.nn_modules), combine_graph)
|
||||
combine_fn_name = add_subgraph(tx, "scan_combine_fn", combine_gm)
|
||||
combine_fn_name = tx.output.install_subgraph("scan_combine_fn", combine_gm)
|
||||
|
||||
p_args = (
|
||||
make_attr(tx, combine_fn_name),
|
||||
@ -1522,8 +1502,7 @@ class MapHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
|
||||
body_nn_modules = dict(tx.output.nn_modules)
|
||||
|
||||
body_name = add_subgraph(
|
||||
tx,
|
||||
body_name = tx.output.install_subgraph(
|
||||
"map_body",
|
||||
torch.fx.GraphModule(body_nn_modules, body_graph),
|
||||
)
|
||||
@ -1619,8 +1598,7 @@ class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
def install_subgraph_in_output_graph(
|
||||
self, tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name="wrap_body"
|
||||
):
|
||||
return add_subgraph(
|
||||
tx,
|
||||
return tx.output.install_subgraph(
|
||||
f"{attr_name}",
|
||||
body_gmod,
|
||||
)
|
||||
@ -1756,8 +1734,7 @@ class WrapWithSetGradEnabledHigherOrderVariable(TorchHigherOrderOperatorVariable
|
||||
)
|
||||
|
||||
body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph)
|
||||
body_name = add_subgraph(
|
||||
tx,
|
||||
body_name = tx.output.install_subgraph(
|
||||
"wrap_body",
|
||||
body_gmod,
|
||||
)
|
||||
@ -1837,8 +1814,7 @@ class WrapWithAutocastHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
)
|
||||
|
||||
body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph)
|
||||
body_name = add_subgraph(
|
||||
tx,
|
||||
body_name = tx.output.install_subgraph(
|
||||
"wrap_body",
|
||||
body_gmod,
|
||||
)
|
||||
@ -1909,8 +1885,7 @@ class HintsWrapperHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
)
|
||||
|
||||
body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph)
|
||||
body_name = add_subgraph(
|
||||
tx,
|
||||
body_name = tx.output.install_subgraph(
|
||||
"hints_wrapper_body",
|
||||
body_gmod,
|
||||
)
|
||||
@ -2011,8 +1986,7 @@ class StrictModeHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
|
||||
strict_mode_nn_modules = dict(tx.output.nn_modules)
|
||||
|
||||
strict_mode_name = add_subgraph(
|
||||
tx,
|
||||
strict_mode_name = tx.output.install_subgraph(
|
||||
"strict_mode_body",
|
||||
torch.fx.GraphModule(strict_mode_nn_modules, ret_graph),
|
||||
)
|
||||
@ -2260,8 +2234,7 @@ class FlexAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
set_subgraph_inputs="flatten_manual",
|
||||
)
|
||||
|
||||
body_name = add_subgraph(
|
||||
tx,
|
||||
body_name = tx.output.install_subgraph(
|
||||
fn_name,
|
||||
torch.fx.GraphModule(tx.output.nn_modules, body_graph),
|
||||
)
|
||||
@ -2544,8 +2517,7 @@ class AutogradFunctionApplyVariable(VariableTracker):
|
||||
|
||||
# Store fwd_body
|
||||
fwd_nn_modules = tx.output.tracing_context.module_context.copy_graphstate()
|
||||
fwd_name = add_subgraph(
|
||||
tx,
|
||||
fwd_name = tx.output.install_subgraph(
|
||||
"fwd_body",
|
||||
torch.fx.GraphModule(fwd_nn_modules.nn_modules, fwd_graph),
|
||||
)
|
||||
@ -2610,8 +2582,7 @@ class AutogradFunctionApplyVariable(VariableTracker):
|
||||
|
||||
# Store bwd_body
|
||||
bwd_nn_modules = tx.output.tracing_context.module_context.copy_graphstate()
|
||||
bwd_name = add_subgraph(
|
||||
tx,
|
||||
bwd_name = tx.output.install_subgraph(
|
||||
"bwd_body",
|
||||
torch.fx.GraphModule(bwd_nn_modules.nn_modules, bwd_graph),
|
||||
)
|
||||
|
||||
@ -240,6 +240,7 @@ def set_logs(
|
||||
compiled_autograd_verbose: bool = False,
|
||||
cudagraph_static_inputs: bool = False,
|
||||
benchmarking: bool = False,
|
||||
graph_region_expansion: bool = False,
|
||||
):
|
||||
"""
|
||||
Sets the log level for individual components and toggles individual log
|
||||
@ -416,6 +417,9 @@ def set_logs(
|
||||
cudagraph_static_inputs (:class:`bool`):
|
||||
Whether to emit debug info for cudagraph static input detection. Default: ``False``
|
||||
|
||||
graph_region_expansion (:class:`bool`):
|
||||
Whether to emit the detailed steps of the duplicate graph region tracker expansion algorithm. Default: ``False``
|
||||
|
||||
|
||||
Example::
|
||||
|
||||
@ -514,6 +518,7 @@ def set_logs(
|
||||
compiled_autograd_verbose=compiled_autograd_verbose,
|
||||
cudagraph_static_inputs=cudagraph_static_inputs,
|
||||
benchmarking=benchmarking,
|
||||
graph_region_expansion=graph_region_expansion,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -191,5 +191,10 @@ register_artifact(
|
||||
"Detailed Inductor benchmarking information.",
|
||||
off_by_default=True,
|
||||
)
|
||||
register_artifact(
|
||||
"graph_region_expansion",
|
||||
"Logs detailed steps of the duplicate graph region tracker expansion algorithm",
|
||||
off_by_default=True,
|
||||
)
|
||||
|
||||
register_artifact("custom_format_test_artifact", "Testing only", log_format="")
|
||||
|
||||
Reference in New Issue
Block a user