mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[invoke_subgraph] Unpacked operands (#152547)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152547 Approved by: https://github.com/ydwu4, https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
e6989ceea9
commit
4649fd17b0
@ -63,12 +63,12 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
o1: "f32[10, 20]" = torch.sin(l_y_)
|
||||
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_)); invoke_subgraph = None
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, o1)); o1 = None
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); invoke_subgraph = None
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, o1); o1 = None
|
||||
|
||||
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_)); subgraph_0 = l_x_ = l_y_ = None
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None
|
||||
|
||||
getitem_2: "f32[]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
|
||||
|
||||
@ -98,10 +98,10 @@ class GraphModule(torch.nn.Module):
|
||||
sin: "f32[10, 20]" = torch.ops.aten.sin.default(primals_2)
|
||||
|
||||
___forward_subgraph_0_0_post_graph = self.___forward_subgraph_0_0_post_graph
|
||||
invoke_subgraph_5 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph, '___forward_subgraph_0_0_post_graph', (primals_1, sin)); ___forward_subgraph_0_0_post_graph = sin = None
|
||||
invoke_subgraph_5 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph, '___forward_subgraph_0_0_post_graph', primals_1, sin); ___forward_subgraph_0_0_post_graph = sin = None
|
||||
getitem_1: "f32[]" = invoke_subgraph_5[0]; invoke_subgraph_5 = None
|
||||
___forward_subgraph_0_0_post_graph_1 = self.___forward_subgraph_0_0_post_graph
|
||||
invoke_subgraph_7 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph_1, '___forward_subgraph_0_0_post_graph', (primals_1, primals_2)); ___forward_subgraph_0_0_post_graph_1 = primals_1 = None
|
||||
invoke_subgraph_7 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph_1, '___forward_subgraph_0_0_post_graph', primals_1, primals_2); ___forward_subgraph_0_0_post_graph_1 = primals_1 = None
|
||||
getitem_2: "f32[]" = invoke_subgraph_7[0]; invoke_subgraph_7 = None
|
||||
|
||||
mul: "f32[]" = torch.ops.aten.mul.Tensor(getitem_2, getitem_2)
|
||||
@ -157,13 +157,13 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
x0: "f32[10, 10]" = l_x_ + 2; l_x_ = None
|
||||
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (x0,)); x0 = None
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', x0); x0 = None
|
||||
|
||||
getitem: "f32[10, 10]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
|
||||
o_3: "f32[10, 10]" = torch.cos(getitem); getitem = None
|
||||
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (o_3,)); subgraph_0 = o_3 = None
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', o_3); subgraph_0 = o_3 = None
|
||||
|
||||
getitem_1: "f32[10, 10]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
|
||||
@ -188,13 +188,13 @@ class GraphModule(torch.nn.Module):
|
||||
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(primals_1, 2); primals_1 = None
|
||||
|
||||
___forward_subgraph_0_0_post_graph = self.___forward_subgraph_0_0_post_graph
|
||||
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph, '___forward_subgraph_0_0_post_graph', (add,)); ___forward_subgraph_0_0_post_graph = add = None
|
||||
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph, '___forward_subgraph_0_0_post_graph', add); ___forward_subgraph_0_0_post_graph = add = None
|
||||
getitem: "f32[10, 10]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
|
||||
|
||||
cos: "f32[10, 10]" = torch.ops.aten.cos.default(getitem)
|
||||
|
||||
___forward_subgraph_0_0_post_graph_1 = self.___forward_subgraph_0_0_post_graph
|
||||
invoke_subgraph_6 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph_1, '___forward_subgraph_0_0_post_graph', (cos,)); ___forward_subgraph_0_0_post_graph_1 = cos = None
|
||||
invoke_subgraph_6 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph_1, '___forward_subgraph_0_0_post_graph', cos); ___forward_subgraph_0_0_post_graph_1 = cos = None
|
||||
getitem_1: "f32[10, 10]" = invoke_subgraph_6[0]; invoke_subgraph_6 = None
|
||||
|
||||
sin: "f32[10, 10]" = torch.ops.aten.sin.default(getitem_1)
|
||||
@ -267,24 +267,24 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
y0: "f32[10, 20]" = torch.sin(l_y_)
|
||||
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_))
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_)
|
||||
|
||||
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
|
||||
o1: "f32[]" = torch.sin(getitem); getitem = None
|
||||
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, y0))
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, y0)
|
||||
|
||||
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
|
||||
mul_2: "f32[]" = o1 * getitem_1; o1 = getitem_1 = None
|
||||
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_)); subgraph_0 = l_x_ = l_y_ = None
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None
|
||||
|
||||
getitem_2: "f32[]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
|
||||
|
||||
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', (x0, y0)); invoke_subgraph_3 = None
|
||||
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', (x0, y0)); subgraph_1 = x0 = y0 = None
|
||||
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', x0, y0); invoke_subgraph_3 = None
|
||||
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', x0, y0); subgraph_1 = x0 = y0 = None
|
||||
|
||||
getitem_4: "f32[10, 10]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
|
||||
|
||||
@ -325,22 +325,22 @@ class GraphModule(torch.nn.Module):
|
||||
sin: "f32[10, 20]" = torch.ops.aten.sin.default(primals_2)
|
||||
|
||||
___forward_subgraph_0_0_post_graph = self.___forward_subgraph_0_0_post_graph
|
||||
invoke_subgraph_9 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph, '___forward_subgraph_0_0_post_graph', (primals_1, primals_2)); ___forward_subgraph_0_0_post_graph = None
|
||||
invoke_subgraph_9 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph, '___forward_subgraph_0_0_post_graph', primals_1, primals_2); ___forward_subgraph_0_0_post_graph = None
|
||||
getitem: "f32[]" = invoke_subgraph_9[0]; invoke_subgraph_9 = None
|
||||
|
||||
sin_1: "f32[]" = torch.ops.aten.sin.default(getitem)
|
||||
|
||||
___forward_subgraph_0_0_post_graph_1 = self.___forward_subgraph_0_0_post_graph
|
||||
invoke_subgraph_11 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph_1, '___forward_subgraph_0_0_post_graph', (primals_1, sin)); ___forward_subgraph_0_0_post_graph_1 = None
|
||||
invoke_subgraph_11 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph_1, '___forward_subgraph_0_0_post_graph', primals_1, sin); ___forward_subgraph_0_0_post_graph_1 = None
|
||||
getitem_1: "f32[]" = invoke_subgraph_11[0]; invoke_subgraph_11 = None
|
||||
|
||||
mul: "f32[]" = torch.ops.aten.mul.Tensor(sin_1, getitem_1); sin_1 = None
|
||||
|
||||
___forward_subgraph_0_0_post_graph_2 = self.___forward_subgraph_0_0_post_graph
|
||||
invoke_subgraph_13 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph_2, '___forward_subgraph_0_0_post_graph', (primals_1, primals_2)); ___forward_subgraph_0_0_post_graph_2 = None
|
||||
invoke_subgraph_13 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph_2, '___forward_subgraph_0_0_post_graph', primals_1, primals_2); ___forward_subgraph_0_0_post_graph_2 = None
|
||||
getitem_2: "f32[]" = invoke_subgraph_13[0]; invoke_subgraph_13 = None
|
||||
___forward_subgraph_1_0_post_graph = self.___forward_subgraph_1_0_post_graph
|
||||
invoke_subgraph_15 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_1_0_post_graph, '___forward_subgraph_1_0_post_graph', (cos, sin)); ___forward_subgraph_1_0_post_graph = cos = sin = None
|
||||
invoke_subgraph_15 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_1_0_post_graph, '___forward_subgraph_1_0_post_graph', cos, sin); ___forward_subgraph_1_0_post_graph = cos = sin = None
|
||||
getitem_19: "f32[]" = invoke_subgraph_15[3]
|
||||
getitem_18: "f32[10, 20]" = invoke_subgraph_15[2]
|
||||
getitem_17: "f32[10, 10]" = invoke_subgraph_15[1]
|
||||
@ -409,7 +409,7 @@ class GraphModule(torch.nn.Module):
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None
|
||||
|
||||
___forward_subgraph_0_0_post_graph = self.___forward_subgraph_0_0_post_graph
|
||||
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph, '___forward_subgraph_0_0_post_graph', (primals_1, sum_1)); ___forward_subgraph_0_0_post_graph = sum_1 = None
|
||||
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph, '___forward_subgraph_0_0_post_graph', primals_1, sum_1); ___forward_subgraph_0_0_post_graph = sum_1 = None
|
||||
getitem: "f32[]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
|
||||
|
||||
add_1: "f32[]" = torch.ops.aten.add.Tensor(getitem, 2); getitem = None
|
||||
@ -417,7 +417,7 @@ class GraphModule(torch.nn.Module):
|
||||
sum_2: "f32[]" = torch.ops.aten.sum.default(add_1); add_1 = None
|
||||
|
||||
___forward_subgraph_0_0_post_graph_1 = self.___forward_subgraph_0_0_post_graph
|
||||
invoke_subgraph_6 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph_1, '___forward_subgraph_0_0_post_graph', (primals_1, sum_2)); ___forward_subgraph_0_0_post_graph_1 = primals_1 = sum_2 = None
|
||||
invoke_subgraph_6 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph_1, '___forward_subgraph_0_0_post_graph', primals_1, sum_2); ___forward_subgraph_0_0_post_graph_1 = primals_1 = sum_2 = None
|
||||
getitem_1: "f32[]" = invoke_subgraph_6[0]; invoke_subgraph_6 = None
|
||||
return (getitem_1,)
|
||||
|
||||
@ -492,10 +492,10 @@ class <lambda>(torch.nn.Module):
|
||||
add_7: "f32[10, 20]" = torch.ops.aten.add.Tensor(clone_1, add_5); clone_1 = add_5 = None
|
||||
|
||||
repeated_subgraph0 = self.repeated_subgraph0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', (add_2, add_3)); repeated_subgraph0 = add_2 = None
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', add_2, add_3); repeated_subgraph0 = add_2 = None
|
||||
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
repeated_subgraph0_1 = self.repeated_subgraph0
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', (add_6, add_7)); repeated_subgraph0_1 = add_6 = add_7 = None
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', add_6, add_7); repeated_subgraph0_1 = add_6 = add_7 = None
|
||||
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
|
||||
add_8: "f32[]" = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
|
||||
@ -556,14 +556,14 @@ class <lambda>(torch.nn.Module):
|
||||
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(view_1, view_3); view_1 = view_3 = None
|
||||
|
||||
repeated_subgraph0 = self.repeated_subgraph0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0 = None
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0 = None
|
||||
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None
|
||||
add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, sum_1); add = sum_1 = None
|
||||
|
||||
repeated_subgraph0_1 = self.repeated_subgraph0
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0_1 = arg0_1 = arg1_1 = None
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0_1 = arg0_1 = arg1_1 = None
|
||||
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
|
||||
sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_1); getitem_1 = None
|
||||
@ -720,7 +720,7 @@ class <lambda>(torch.nn.Module):
|
||||
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(view_1, view_3); view_1 = view_3 = None
|
||||
|
||||
repeated_subgraph0 = self.repeated_subgraph0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0 = None
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0 = None
|
||||
_enter_autocast = torch.amp.autocast_mode._enter_autocast(); _enter_autocast = None
|
||||
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
|
||||
@ -728,7 +728,7 @@ class <lambda>(torch.nn.Module):
|
||||
add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, sum_1); add = sum_1 = None
|
||||
|
||||
repeated_subgraph0_1 = self.repeated_subgraph0
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0_1 = arg0_1 = arg1_1 = None
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0_1 = arg0_1 = arg1_1 = None
|
||||
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
_exit_autocast = torch.amp.autocast_mode._exit_autocast(); _exit_autocast = None
|
||||
|
||||
@ -765,7 +765,7 @@ class <lambda>(torch.nn.Module):
|
||||
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(view_1, view_3); view_1 = view_3 = None
|
||||
|
||||
repeated_subgraph0 = self.repeated_subgraph0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0 = None
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0 = None
|
||||
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None
|
||||
@ -775,7 +775,7 @@ class <lambda>(torch.nn.Module):
|
||||
add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, sum_1); add = sum_1 = None
|
||||
|
||||
repeated_subgraph0_1 = self.repeated_subgraph0
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0_1 = arg0_1 = arg1_1 = None
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0_1 = arg0_1 = arg1_1 = None
|
||||
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
|
||||
sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_1); getitem_1 = None
|
||||
@ -844,14 +844,14 @@ class <lambda>(torch.nn.Module):
|
||||
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(view_1, view_3); view_1 = view_3 = None
|
||||
|
||||
repeated_subgraph0 = self.repeated_subgraph0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0 = None
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0 = None
|
||||
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None
|
||||
add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, sum_1); add = sum_1 = None
|
||||
|
||||
repeated_subgraph0_1 = self.repeated_subgraph0
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0_1 = arg0_1 = arg1_1 = None
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0_1 = arg0_1 = arg1_1 = None
|
||||
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
|
||||
sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_1); getitem_1 = None
|
||||
@ -887,14 +887,14 @@ class <lambda>(torch.nn.Module):
|
||||
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(view_1, view_3); view_1 = view_3 = None
|
||||
|
||||
repeated_subgraph0 = self.repeated_subgraph0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0 = None
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0 = None
|
||||
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None
|
||||
add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, sum_1); add = sum_1 = None
|
||||
|
||||
repeated_subgraph0_1 = self.repeated_subgraph0
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0_1 = arg0_1 = arg1_1 = None
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0_1 = arg0_1 = arg1_1 = None
|
||||
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
|
||||
sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_1); getitem_1 = None
|
||||
|
@ -381,11 +381,11 @@ class GraphModule(torch.nn.Module):
|
||||
l_y_ = L_y_
|
||||
|
||||
invoke_subgraph_0 = self.invoke_subgraph_0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (l_x_, l_y_)); invoke_subgraph_0 = l_x_ = None
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', l_x_, l_y_); invoke_subgraph_0 = l_x_ = None
|
||||
a: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
|
||||
invoke_subgraph_1 = self.invoke_subgraph_0
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_0', (a, l_y_)); invoke_subgraph_1 = a = l_y_ = None
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_0', a, l_y_); invoke_subgraph_1 = a = l_y_ = None
|
||||
getitem_1: "f32[8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
|
||||
return (getitem_1,)
|
||||
|
||||
@ -403,14 +403,14 @@ class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[8]", primals_2: "f32[8]"):
|
||||
___forward_invoke_subgraph_0_0_post_graph = self.___forward_invoke_subgraph_0_0_post_graph
|
||||
|
||||
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_0_post_graph, '___forward_invoke_subgraph_0_0_post_graph', (primals_1, primals_2)); ___forward_invoke_subgraph_0_0_post_graph = primals_1 = None
|
||||
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_0_post_graph, '___forward_invoke_subgraph_0_0_post_graph', primals_1, primals_2); ___forward_invoke_subgraph_0_0_post_graph = primals_1 = None
|
||||
getitem_9: "f32[8]" = invoke_subgraph_4[2]
|
||||
getitem_8: "f32[8]" = invoke_subgraph_4[1]
|
||||
getitem: "f32[8]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
|
||||
|
||||
___forward_invoke_subgraph_0_0_post_graph_1 = self.___forward_invoke_subgraph_0_0_post_graph
|
||||
|
||||
invoke_subgraph_6 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_0_post_graph_1, '___forward_invoke_subgraph_0_0_post_graph', (getitem, primals_2)); ___forward_invoke_subgraph_0_0_post_graph_1 = getitem = primals_2 = None
|
||||
invoke_subgraph_6 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_0_post_graph_1, '___forward_invoke_subgraph_0_0_post_graph', getitem, primals_2); ___forward_invoke_subgraph_0_0_post_graph_1 = getitem = primals_2 = None
|
||||
getitem_11: "f32[8]" = invoke_subgraph_6[2]
|
||||
getitem_10: "f32[8]" = invoke_subgraph_6[1]
|
||||
getitem_1: "f32[8]" = invoke_subgraph_6[0]; invoke_subgraph_6 = None
|
||||
@ -468,11 +468,11 @@ class GraphModule(torch.nn.Module):
|
||||
l_y_ = L_y_
|
||||
|
||||
invoke_subgraph_0 = self.invoke_subgraph_0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (l_x_, l_y_)); invoke_subgraph_0 = l_x_ = None
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', l_x_, l_y_); invoke_subgraph_0 = l_x_ = None
|
||||
a: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
|
||||
invoke_subgraph_1 = self.invoke_subgraph_1
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_1', (a, l_y_)); invoke_subgraph_1 = a = l_y_ = None
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_1', a, l_y_); invoke_subgraph_1 = a = l_y_ = None
|
||||
getitem_1: "f32[8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
|
||||
return (getitem_1,)
|
||||
|
||||
@ -521,19 +521,19 @@ class GraphModule(torch.nn.Module):
|
||||
l_y_ = L_y_
|
||||
|
||||
invoke_subgraph_0 = self.invoke_subgraph_0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (l_x_, l_y_)); invoke_subgraph_0 = l_x_ = None
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', l_x_, l_y_); invoke_subgraph_0 = l_x_ = None
|
||||
x: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
invoke_subgraph_1 = self.invoke_subgraph_0
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_0', (x, l_y_)); invoke_subgraph_1 = x = None
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_0', x, l_y_); invoke_subgraph_1 = x = None
|
||||
x_1: "f32[8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
|
||||
invoke_subgraph_3 = self.invoke_subgraph_0
|
||||
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_3, 'invoke_subgraph_0', (x_1, l_y_)); invoke_subgraph_3 = x_1 = None
|
||||
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_3, 'invoke_subgraph_0', x_1, l_y_); invoke_subgraph_3 = x_1 = None
|
||||
x_2: "f32[8]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
|
||||
invoke_subgraph_5 = self.invoke_subgraph_0
|
||||
invoke_subgraph_6 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_5, 'invoke_subgraph_0', (x_2, l_y_)); invoke_subgraph_5 = x_2 = None
|
||||
invoke_subgraph_6 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_5, 'invoke_subgraph_0', x_2, l_y_); invoke_subgraph_5 = x_2 = None
|
||||
x_3: "f32[8]" = invoke_subgraph_6[0]; invoke_subgraph_6 = None
|
||||
invoke_subgraph_7 = self.invoke_subgraph_0
|
||||
invoke_subgraph_8 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_7, 'invoke_subgraph_0', (x_3, l_y_)); invoke_subgraph_7 = x_3 = l_y_ = None
|
||||
invoke_subgraph_8 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_7, 'invoke_subgraph_0', x_3, l_y_); invoke_subgraph_7 = x_3 = l_y_ = None
|
||||
x_4: "f32[8]" = invoke_subgraph_8[0]; invoke_subgraph_8 = None
|
||||
return (x_4,)
|
||||
|
||||
@ -765,10 +765,10 @@ class GraphModule(torch.nn.Module):
|
||||
l_self_modules_linear_parameters_bias_ = L_self_modules_linear_parameters_bias_
|
||||
|
||||
invoke_subgraph_0 = self.invoke_subgraph_0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (l_x_, l_self_modules_linear_parameters_weight_, l_self_modules_linear_parameters_bias_)); invoke_subgraph_0 = None
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', l_x_, l_self_modules_linear_parameters_weight_, l_self_modules_linear_parameters_bias_); invoke_subgraph_0 = None
|
||||
getitem: "f32[8, 8]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
invoke_subgraph_1 = self.invoke_subgraph_0
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_0', (l_x_, l_self_modules_linear_parameters_weight_, l_self_modules_linear_parameters_bias_)); invoke_subgraph_1 = l_self_modules_linear_parameters_weight_ = l_self_modules_linear_parameters_bias_ = None
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_0', l_x_, l_self_modules_linear_parameters_weight_, l_self_modules_linear_parameters_bias_); invoke_subgraph_1 = l_self_modules_linear_parameters_weight_ = l_self_modules_linear_parameters_bias_ = None
|
||||
getitem_1: "f32[8, 8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
|
||||
|
||||
mul: "f32[8, 8]" = getitem * getitem_1; getitem = getitem_1 = None
|
||||
@ -818,10 +818,10 @@ class GraphModule(torch.nn.Module):
|
||||
l_x_ = L_x_
|
||||
|
||||
invoke_subgraph_0 = self.invoke_subgraph_0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (l_x_,)); invoke_subgraph_0 = None
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', l_x_); invoke_subgraph_0 = None
|
||||
getitem: "f32[8, 8]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
invoke_subgraph_1 = self.invoke_subgraph_0
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_0', (l_x_,)); invoke_subgraph_1 = None
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_0', l_x_); invoke_subgraph_1 = None
|
||||
getitem_1: "f32[8, 8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
|
||||
|
||||
mul: "f32[8, 8]" = getitem * getitem_1; getitem = getitem_1 = None
|
||||
@ -903,7 +903,7 @@ class GraphModule(torch.nn.Module):
|
||||
l_x_ = L_x_
|
||||
|
||||
invoke_subgraph_0 = self.invoke_subgraph_0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (l_x_,)); invoke_subgraph_0 = l_x_ = None
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', l_x_); invoke_subgraph_0 = l_x_ = None
|
||||
getitem: "f32[8, 8]" = invoke_subgraph[0]
|
||||
getitem_1: "f32[8, 8]" = invoke_subgraph[2]; invoke_subgraph = None
|
||||
|
||||
@ -925,7 +925,7 @@ class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[8, 8]"):
|
||||
___forward_invoke_subgraph_0_0_post_graph = self.___forward_invoke_subgraph_0_0_post_graph
|
||||
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_0_post_graph, '___forward_invoke_subgraph_0_0_post_graph', (primals_1,)); ___forward_invoke_subgraph_0_0_post_graph = primals_1 = None
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_0_post_graph, '___forward_invoke_subgraph_0_0_post_graph', primals_1); ___forward_invoke_subgraph_0_0_post_graph = primals_1 = None
|
||||
getitem: "f32[8, 8]" = invoke_subgraph_2[0]
|
||||
getitem_2: "f32[8, 8]" = invoke_subgraph_2[2]; invoke_subgraph_2 = None
|
||||
|
||||
@ -947,7 +947,7 @@ class GraphModule(torch.nn.Module):
|
||||
def forward(self, tangents_1: "f32[8, 8]"):
|
||||
___backward_invoke_subgraph_0_0_post_graph = self.___backward_invoke_subgraph_0_0_post_graph
|
||||
|
||||
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(___backward_invoke_subgraph_0_0_post_graph, '___backward_invoke_subgraph_0_0_post_graph', (tangents_1, tangents_1)); ___backward_invoke_subgraph_0_0_post_graph = tangents_1 = None
|
||||
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(___backward_invoke_subgraph_0_0_post_graph, '___backward_invoke_subgraph_0_0_post_graph', tangents_1, tangents_1); ___backward_invoke_subgraph_0_0_post_graph = tangents_1 = None
|
||||
getitem_3: "f32[8, 8]" = invoke_subgraph_3[0]; invoke_subgraph_3 = None
|
||||
return (getitem_3,)
|
||||
|
||||
@ -1052,7 +1052,7 @@ class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[8, 8]", primals_2: "f32[8, 8]"):
|
||||
___forward_invoke_subgraph_0_0_post_graph = self.___forward_invoke_subgraph_0_0_post_graph
|
||||
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_0_post_graph, '___forward_invoke_subgraph_0_0_post_graph', (primals_1, primals_2)); ___forward_invoke_subgraph_0_0_post_graph = primals_1 = primals_2 = None
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_0_post_graph, '___forward_invoke_subgraph_0_0_post_graph', primals_1, primals_2); ___forward_invoke_subgraph_0_0_post_graph = primals_1 = primals_2 = None
|
||||
getitem_6: "f32[8, 8]" = invoke_subgraph_2[3]
|
||||
getitem_5: "f32[8, 8]" = invoke_subgraph_2[2]
|
||||
getitem_4: "f32[8, 8]" = invoke_subgraph_2[1]
|
||||
@ -1083,7 +1083,7 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
___backward_invoke_subgraph_0_0_post_graph = self.___backward_invoke_subgraph_0_0_post_graph
|
||||
|
||||
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(___backward_invoke_subgraph_0_0_post_graph, '___backward_invoke_subgraph_0_0_post_graph', (getitem_4, getitem_5, getitem_6, mul)); ___backward_invoke_subgraph_0_0_post_graph = getitem_4 = getitem_5 = getitem_6 = mul = None
|
||||
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(___backward_invoke_subgraph_0_0_post_graph, '___backward_invoke_subgraph_0_0_post_graph', getitem_4, getitem_5, getitem_6, mul); ___backward_invoke_subgraph_0_0_post_graph = getitem_4 = getitem_5 = getitem_6 = mul = None
|
||||
getitem_1: "f32[8, 8]" = invoke_subgraph_3[0]
|
||||
getitem_2: "f32[8, 8]" = invoke_subgraph_3[1]; invoke_subgraph_3 = None
|
||||
return (getitem_1, getitem_2)
|
||||
@ -1195,10 +1195,10 @@ class GraphModule(torch.nn.Module):
|
||||
l_y_ = L_y_
|
||||
|
||||
invoke_subgraph_0 = self.invoke_subgraph_0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (l_x_,)); invoke_subgraph_0 = l_x_ = None
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', l_x_); invoke_subgraph_0 = l_x_ = None
|
||||
getitem: "f32[8, 8]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
invoke_subgraph_1 = self.invoke_subgraph_1
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_1', (l_y_,)); invoke_subgraph_1 = l_y_ = None
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_1', l_y_); invoke_subgraph_1 = l_y_ = None
|
||||
getitem_1: "f32[16, 16]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
|
||||
return (getitem, getitem_1)
|
||||
|
||||
@ -1259,14 +1259,14 @@ class GraphModule(torch.nn.Module):
|
||||
l_x_ = L_x_
|
||||
|
||||
invoke_subgraph_0 = self.invoke_subgraph_0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (s77, l_x_)); invoke_subgraph_0 = l_x_ = None
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', s77, l_x_); invoke_subgraph_0 = l_x_ = None
|
||||
a: "f32[s77, 8]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
|
||||
floordiv: "Sym((s77//2))" = s77 // 2
|
||||
b: "f32[(s77//2), 8]" = torch.narrow(a, 0, 0, floordiv); a = floordiv = None
|
||||
|
||||
invoke_subgraph_1 = self.invoke_subgraph_1
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_1', (s77, b)); invoke_subgraph_1 = s77 = b = None
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_1', s77, b); invoke_subgraph_1 = s77 = b = None
|
||||
getitem_3: "f32[(s77//2), 8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
|
||||
return (getitem_3,)
|
||||
|
||||
@ -1326,10 +1326,10 @@ class GraphModule(torch.nn.Module):
|
||||
l_x_ = L_x_
|
||||
|
||||
invoke_subgraph_0 = self.invoke_subgraph_0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (l_x_,)); invoke_subgraph_0 = None
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', l_x_); invoke_subgraph_0 = None
|
||||
getitem: "f32[8, 8]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
invoke_subgraph_1 = self.invoke_subgraph_0
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_0', (l_x_,)); invoke_subgraph_1 = l_x_ = None
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_0', l_x_); invoke_subgraph_1 = l_x_ = None
|
||||
getitem_1: "f32[8, 8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
|
||||
|
||||
add: "f32[8, 8]" = getitem + getitem_1; getitem = getitem_1 = None
|
||||
@ -1435,11 +1435,11 @@ class GraphModule(torch.nn.Module):
|
||||
y: "f32[5]" = l_y_.sin(); l_y_ = None
|
||||
|
||||
invoke_subgraph_0 = self.invoke_subgraph_0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (x, y)); invoke_subgraph_0 = x = None
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', x, y); invoke_subgraph_0 = x = None
|
||||
z: "f32[5]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
|
||||
invoke_subgraph_1 = self.invoke_subgraph_1
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_1', (z, y)); invoke_subgraph_1 = z = y = None
|
||||
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_1', z, y); invoke_subgraph_1 = z = y = None
|
||||
getitem_1: "f32[5]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
|
||||
return (getitem_1,)
|
||||
|
||||
@ -1534,14 +1534,14 @@ class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "Sym(s77)", primals_2: "f32[s77, 16]"):
|
||||
___forward_invoke_subgraph_0_1_post_graph = self.___forward_invoke_subgraph_0_1_post_graph
|
||||
|
||||
invoke_subgraph_8 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_1_post_graph, '___forward_invoke_subgraph_0_1_post_graph', (primals_1, primals_2)); ___forward_invoke_subgraph_0_1_post_graph = primals_2 = None
|
||||
invoke_subgraph_8 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_1_post_graph, '___forward_invoke_subgraph_0_1_post_graph', primals_1, primals_2); ___forward_invoke_subgraph_0_1_post_graph = primals_2 = None
|
||||
getitem_17: "Sym(s77)" = invoke_subgraph_8[2]
|
||||
getitem_16: "f32[s77, 16]" = invoke_subgraph_8[1]
|
||||
getitem: "f32[s77, 16]" = invoke_subgraph_8[0]; invoke_subgraph_8 = None
|
||||
|
||||
___forward_invoke_subgraph_0_1_post_graph_1 = self.___forward_invoke_subgraph_0_1_post_graph
|
||||
|
||||
invoke_subgraph_10 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_1_post_graph_1, '___forward_invoke_subgraph_0_1_post_graph', (primals_1, getitem)); ___forward_invoke_subgraph_0_1_post_graph_1 = getitem = None
|
||||
invoke_subgraph_10 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_1_post_graph_1, '___forward_invoke_subgraph_0_1_post_graph', primals_1, getitem); ___forward_invoke_subgraph_0_1_post_graph_1 = getitem = None
|
||||
getitem_19: "Sym(s77)" = invoke_subgraph_10[2]
|
||||
getitem_18: "f32[s77, 16]" = invoke_subgraph_10[1]
|
||||
getitem_1: "f32[s77, 16]" = invoke_subgraph_10[0]; invoke_subgraph_10 = None
|
||||
@ -1550,14 +1550,14 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
___forward_invoke_subgraph_0_1_post_graph_2 = self.___forward_invoke_subgraph_0_1_post_graph
|
||||
|
||||
invoke_subgraph_12 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_1_post_graph_2, '___forward_invoke_subgraph_0_1_post_graph', (primals_1, sin)); ___forward_invoke_subgraph_0_1_post_graph_2 = sin = None
|
||||
invoke_subgraph_12 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_1_post_graph_2, '___forward_invoke_subgraph_0_1_post_graph', primals_1, sin); ___forward_invoke_subgraph_0_1_post_graph_2 = sin = None
|
||||
getitem_21: "Sym(s77)" = invoke_subgraph_12[2]
|
||||
getitem_20: "f32[s77, 16]" = invoke_subgraph_12[1]
|
||||
getitem_2: "f32[s77, 16]" = invoke_subgraph_12[0]; invoke_subgraph_12 = None
|
||||
|
||||
___forward_invoke_subgraph_0_0_post_graph = self.___forward_invoke_subgraph_0_0_post_graph
|
||||
|
||||
invoke_subgraph_14 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_0_post_graph, '___forward_invoke_subgraph_0_0_post_graph', (primals_1, getitem_2)); ___forward_invoke_subgraph_0_0_post_graph = None
|
||||
invoke_subgraph_14 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_0_post_graph, '___forward_invoke_subgraph_0_0_post_graph', primals_1, getitem_2); ___forward_invoke_subgraph_0_0_post_graph = None
|
||||
getitem_23: "Sym(s77)" = invoke_subgraph_14[2]
|
||||
getitem_22: "f32[s77, 16]" = invoke_subgraph_14[1]
|
||||
getitem_3: "f32[s77, 16]" = invoke_subgraph_14[0]; invoke_subgraph_14 = None
|
||||
@ -1589,26 +1589,26 @@ class GraphModule(torch.nn.Module):
|
||||
|
||||
___backward_invoke_subgraph_0_0_post_graph = self.___backward_invoke_subgraph_0_0_post_graph
|
||||
|
||||
invoke_subgraph_15 = torch.ops.higher_order.invoke_subgraph(___backward_invoke_subgraph_0_0_post_graph, '___backward_invoke_subgraph_0_0_post_graph', (getitem_23, getitem_22, expand)); ___backward_invoke_subgraph_0_0_post_graph = getitem_23 = getitem_22 = None
|
||||
invoke_subgraph_15 = torch.ops.higher_order.invoke_subgraph(___backward_invoke_subgraph_0_0_post_graph, '___backward_invoke_subgraph_0_0_post_graph', getitem_23, getitem_22, expand); ___backward_invoke_subgraph_0_0_post_graph = getitem_23 = getitem_22 = None
|
||||
getitem_5: "f32[s77, 16]" = invoke_subgraph_15[1]; invoke_subgraph_15 = None
|
||||
|
||||
add_16: "f32[s77, 16]" = torch.ops.aten.add.Tensor(expand, getitem_5); expand = getitem_5 = None
|
||||
|
||||
___backward_invoke_subgraph_0_1_post_graph_2 = self.___backward_invoke_subgraph_0_1_post_graph
|
||||
|
||||
invoke_subgraph_13 = torch.ops.higher_order.invoke_subgraph(___backward_invoke_subgraph_0_1_post_graph_2, '___backward_invoke_subgraph_0_1_post_graph', (getitem_21, getitem_20, add_16)); ___backward_invoke_subgraph_0_1_post_graph_2 = getitem_21 = getitem_20 = add_16 = None
|
||||
invoke_subgraph_13 = torch.ops.higher_order.invoke_subgraph(___backward_invoke_subgraph_0_1_post_graph_2, '___backward_invoke_subgraph_0_1_post_graph', getitem_21, getitem_20, add_16); ___backward_invoke_subgraph_0_1_post_graph_2 = getitem_21 = getitem_20 = add_16 = None
|
||||
getitem_8: "f32[s77, 16]" = invoke_subgraph_13[1]; invoke_subgraph_13 = None
|
||||
|
||||
mul_10: "f32[s77, 16]" = torch.ops.aten.mul.Tensor(getitem_8, cos); getitem_8 = cos = None
|
||||
|
||||
___backward_invoke_subgraph_0_1_post_graph_1 = self.___backward_invoke_subgraph_0_1_post_graph
|
||||
|
||||
invoke_subgraph_11 = torch.ops.higher_order.invoke_subgraph(___backward_invoke_subgraph_0_1_post_graph_1, '___backward_invoke_subgraph_0_1_post_graph', (getitem_19, getitem_18, mul_10)); ___backward_invoke_subgraph_0_1_post_graph_1 = getitem_19 = getitem_18 = mul_10 = None
|
||||
invoke_subgraph_11 = torch.ops.higher_order.invoke_subgraph(___backward_invoke_subgraph_0_1_post_graph_1, '___backward_invoke_subgraph_0_1_post_graph', getitem_19, getitem_18, mul_10); ___backward_invoke_subgraph_0_1_post_graph_1 = getitem_19 = getitem_18 = mul_10 = None
|
||||
getitem_11: "f32[s77, 16]" = invoke_subgraph_11[1]; invoke_subgraph_11 = None
|
||||
|
||||
___backward_invoke_subgraph_0_1_post_graph = self.___backward_invoke_subgraph_0_1_post_graph
|
||||
|
||||
invoke_subgraph_9 = torch.ops.higher_order.invoke_subgraph(___backward_invoke_subgraph_0_1_post_graph, '___backward_invoke_subgraph_0_1_post_graph', (getitem_17, getitem_16, getitem_11)); ___backward_invoke_subgraph_0_1_post_graph = getitem_17 = getitem_16 = getitem_11 = None
|
||||
invoke_subgraph_9 = torch.ops.higher_order.invoke_subgraph(___backward_invoke_subgraph_0_1_post_graph, '___backward_invoke_subgraph_0_1_post_graph', getitem_17, getitem_16, getitem_11); ___backward_invoke_subgraph_0_1_post_graph = getitem_17 = getitem_16 = getitem_11 = None
|
||||
getitem_14: "f32[s77, 16]" = invoke_subgraph_9[1]; invoke_subgraph_9 = None
|
||||
return (None, getitem_14)
|
||||
|
||||
@ -1679,11 +1679,11 @@ class TestInvokeSubgraphExport(TestCase):
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, x: "f32[8]", y: "f32[8]"):
|
||||
repeated_subgraph0 = self.repeated_subgraph0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'invoke_subgraph_0', (x, y)); repeated_subgraph0 = x = None
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'invoke_subgraph_0', x, y); repeated_subgraph0 = x = None
|
||||
getitem: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
|
||||
repeated_subgraph0_1 = self.repeated_subgraph0
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'invoke_subgraph_0', (getitem, y)); repeated_subgraph0_1 = getitem = y = None
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'invoke_subgraph_0', getitem, y); repeated_subgraph0_1 = getitem = y = None
|
||||
getitem_1: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
return (getitem_1,)
|
||||
|
||||
|
@ -2216,16 +2216,16 @@ class FakeTensorDispatchCache(TestCase):
|
||||
FakeTensorMode.cache_clear()
|
||||
self.assertHitsMisses(0, 0)
|
||||
|
||||
ref = invoke_subgraph(fn, "subgraph", (x, y))
|
||||
ref = invoke_subgraph(fn, "subgraph", x, y)
|
||||
self.assertHitsMisses(0, 2)
|
||||
self.assertBypasses("function argument", 1)
|
||||
|
||||
res = invoke_subgraph(fn, "subgraph", (x, y))
|
||||
res = invoke_subgraph(fn, "subgraph", x, y)
|
||||
# The hits are from the ops inside fn
|
||||
self.assertHitsMisses(2, 2)
|
||||
self.assertBypasses("function argument", 2)
|
||||
|
||||
res = invoke_subgraph(fn, "subgraph", (x, y))
|
||||
res = invoke_subgraph(fn, "subgraph", x, y)
|
||||
# The hits are from the ops inside fn
|
||||
self.assertHitsMisses(4, 2)
|
||||
self.assertBypasses("function argument", 3)
|
||||
@ -2246,14 +2246,14 @@ class FakeTensorDispatchCache(TestCase):
|
||||
FakeTensorMode.cache_clear()
|
||||
self.assertHitsMisses(0, 0)
|
||||
|
||||
ref = invoke_subgraph(mod, "subgraph", (x, y))
|
||||
ref = invoke_subgraph(mod, "subgraph", x, y)
|
||||
self.assertHitsMisses(0, 3)
|
||||
|
||||
res = invoke_subgraph(mod, "subgraph", (x, y))
|
||||
res = invoke_subgraph(mod, "subgraph", x, y)
|
||||
# The hits are from re-running the subgraph
|
||||
self.assertHitsMisses(1, 3)
|
||||
|
||||
res = invoke_subgraph(mod, "subgraph", (x, y))
|
||||
res = invoke_subgraph(mod, "subgraph", x, y)
|
||||
# The hits are from re-running the subgraph
|
||||
self.assertHitsMisses(2, 3)
|
||||
|
||||
@ -2312,14 +2312,14 @@ class FakeTensorDispatchCache(TestCase):
|
||||
FakeTensorMode.cache_clear()
|
||||
self.assertHitsMisses(0, 0)
|
||||
|
||||
ref = invoke_subgraph(mod, "subgraph", (x, y))
|
||||
ref = invoke_subgraph(mod, "subgraph", x, y)
|
||||
self.assertHitsMisses(0, 3)
|
||||
|
||||
res = invoke_subgraph(mod, "subgraph", (x, y))
|
||||
res = invoke_subgraph(mod, "subgraph", x, y)
|
||||
# The hits are from the ops inside fn and not the subgraph
|
||||
self.assertHitsMisses(1, 3)
|
||||
|
||||
res = invoke_subgraph(mod, "subgraph", (x, y))
|
||||
res = invoke_subgraph(mod, "subgraph", x, y)
|
||||
# The hits are from the ops inside fn and not the subgraph
|
||||
self.assertHitsMisses(2, 3)
|
||||
|
||||
|
@ -112,7 +112,7 @@ def _replace_region_with_subgraph(
|
||||
flattened_args_kwargs = _flatten_args_kwargs((node.args, node.kwargs))
|
||||
sub_args.append(flattened_args_kwargs[arg_ind])
|
||||
|
||||
invoke_args = (get_subgraph_node, subgraph_name, tuple(sub_args))
|
||||
invoke_args = (get_subgraph_node, subgraph_name, *sub_args)
|
||||
fake_inputs = [node.meta["example_value"] for node in sub_args]
|
||||
|
||||
if has_potential_input_alias_or_mutation(sub_gm, fake_inputs):
|
||||
@ -125,7 +125,10 @@ def _replace_region_with_subgraph(
|
||||
from torch._inductor.pattern_matcher import stable_topological_sort
|
||||
|
||||
invoke_subgraph_node = graph.create_node(
|
||||
"call_function", torch.ops.higher_order.invoke_subgraph, invoke_args, {}
|
||||
"call_function",
|
||||
torch.ops.higher_order.invoke_subgraph,
|
||||
invoke_args, # type: ignore[arg-type]
|
||||
{},
|
||||
)
|
||||
for ind, external_user_ind in enumerate(inds_with_external_users):
|
||||
node = region[external_user_ind]
|
||||
|
@ -3221,7 +3221,7 @@ class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
|
||||
p_args = (
|
||||
p_args[0],
|
||||
body_name,
|
||||
p_args[1:],
|
||||
*p_args[1:],
|
||||
)
|
||||
return _call_function_and_unflatten_output(
|
||||
tx,
|
||||
|
@ -707,7 +707,7 @@ def run_joint_graph_passes_on_hops(
|
||||
args=(
|
||||
new_fw_mod_attr,
|
||||
new_fw_mod_attr_name,
|
||||
fw_node.args[2],
|
||||
*fw_node.args[2:],
|
||||
),
|
||||
)
|
||||
propagate_meta_info(new_fw_hop_gm, new_fw_node, fw_node)
|
||||
@ -744,7 +744,7 @@ def run_joint_graph_passes_on_hops(
|
||||
|
||||
num_primals = new_hop_graphs[identifier].old_num_fw_inputs
|
||||
assert num_primals is not None
|
||||
tangents = list(bw_node.args[2][num_primals:])
|
||||
tangents = list(bw_node.args[2 + num_primals :])
|
||||
operands = sym_nodes + saved_tensor_nodes + tangents
|
||||
|
||||
# Insert the new_bw_hop_gm into the joint_gm
|
||||
@ -758,7 +758,7 @@ def run_joint_graph_passes_on_hops(
|
||||
args=(
|
||||
new_bw_mod_attr,
|
||||
new_bw_mod_attr_name,
|
||||
tuple(operands),
|
||||
*operands,
|
||||
),
|
||||
)
|
||||
propagate_meta_info(new_bw_hop_gm, new_bw_node, bw_node)
|
||||
|
@ -65,23 +65,17 @@ class InvokeSubgraphHOP(HigherOrderOperator):
|
||||
self,
|
||||
subgraph: Union[GraphModule, FunctionalizeCtxWrapper],
|
||||
identifier: Optional[str],
|
||||
operands: Union[
|
||||
list[Union[torch.Tensor, int, torch.SymInt]],
|
||||
tuple[Union[torch.Tensor, int, torch.SymInt]],
|
||||
],
|
||||
*operands,
|
||||
):
|
||||
assert identifier is None or isinstance(
|
||||
identifier, str
|
||||
), "identifier must be a None or a string"
|
||||
|
||||
assert isinstance(
|
||||
operands, (list, tuple)
|
||||
), f"invoke_subgraph operands must be a list or tuple of tensors/ints/SymInts {operands}"
|
||||
assert all(
|
||||
isinstance(o, (torch.Tensor, int, torch.SymInt)) for o in operands
|
||||
), f"invoke_subgraph operands must be a list of tensors/ints/SymInts {operands}"
|
||||
|
||||
return super().__call__(subgraph, identifier, operands)
|
||||
return super().__call__(subgraph, identifier, *operands)
|
||||
|
||||
|
||||
invoke_subgraph = InvokeSubgraphHOP()
|
||||
@ -261,7 +255,7 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None):
|
||||
return fw_graph, bw_graph, output_metadata
|
||||
|
||||
|
||||
def get_output_metadata(subgraph, operands):
|
||||
def get_output_metadata(subgraph, *operands):
|
||||
with suspend_functionalization(), disable_functional_mode():
|
||||
with disable_proxy_modes_tracing():
|
||||
# args are functional tensors, generate some example tensors
|
||||
@ -384,7 +378,7 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function):
|
||||
out = invoke_subgraph(
|
||||
subgraph,
|
||||
f"___forward_{identifier}",
|
||||
operands,
|
||||
*operands,
|
||||
)
|
||||
|
||||
# Check that None is at expected indexes.
|
||||
@ -472,13 +466,13 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function):
|
||||
)
|
||||
|
||||
grads = invoke_subgraph(
|
||||
bw_graph, f"___backward_{identifier}_{suffix}", primals_and_tangents
|
||||
bw_graph, f"___backward_{identifier}_{suffix}", *primals_and_tangents
|
||||
)[: -output_metadata.num_fw_outs]
|
||||
return None, None, None, *grads
|
||||
|
||||
|
||||
@invoke_subgraph.py_autograd_impl
|
||||
def _(subgraph, identifier, operands):
|
||||
def _(subgraph, identifier, *operands):
|
||||
# Check if we have already traced the subgraph.
|
||||
invoke_subgraph_cache = get_invoke_subgraph_cache()
|
||||
if invoke_subgraph_cache:
|
||||
@ -487,7 +481,7 @@ def _(subgraph, identifier, operands):
|
||||
):
|
||||
return saved_autograd_fn(*operands)
|
||||
|
||||
output_metadata = get_output_metadata(subgraph, operands)
|
||||
output_metadata = get_output_metadata(subgraph, *operands)
|
||||
|
||||
def autograd_fn_callable(*args):
|
||||
return InvokeSubgraphAutogradOp.apply(
|
||||
@ -502,7 +496,7 @@ def _(subgraph, identifier, operands):
|
||||
|
||||
|
||||
@invoke_subgraph.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
def _(subgraph, identifier, operands):
|
||||
def _(subgraph, identifier, *operands):
|
||||
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
||||
|
||||
mode = _get_current_dispatch_mode()
|
||||
@ -511,20 +505,20 @@ def _(subgraph, identifier, operands):
|
||||
|
||||
|
||||
@invoke_subgraph.py_functionalize_impl
|
||||
def _(ctx, subgraph, identifier, operands):
|
||||
def _(ctx, subgraph, identifier, *operands):
|
||||
unwrapped_operands = ctx.unwrap_tensors(operands)
|
||||
with ctx.redispatch_to_next():
|
||||
# NB: There is an assumption that subgraph does not mutate inputs and
|
||||
# there is no aliasing. Its Dynamo responsibility to prevent formation
|
||||
# of invoke_subgraph ops if input aliasing/mutation is detected.
|
||||
functionalized_subgraph = FunctionalizeCtxWrapper(ctx, subgraph)
|
||||
out = invoke_subgraph(functionalized_subgraph, identifier, unwrapped_operands)
|
||||
out = invoke_subgraph(functionalized_subgraph, identifier, *unwrapped_operands)
|
||||
return ctx.wrap_tensors(out)
|
||||
|
||||
|
||||
# Register the hop fake fn. This will be called in the fake_tensor _dispatch_impl.
|
||||
@register_fake(invoke_subgraph)
|
||||
def _(subgraph, identifier, operands):
|
||||
def _(subgraph, identifier, *operands):
|
||||
from torch._dynamo.utils import dynamo_timed
|
||||
|
||||
with dynamo_timed("invoke_subgraph_fake_tensor", log_pt2_compile_event=True):
|
||||
@ -532,7 +526,7 @@ def _(subgraph, identifier, operands):
|
||||
|
||||
|
||||
@invoke_subgraph.py_impl(ProxyTorchDispatchMode)
|
||||
def _(proxy_mode: ProxyTorchDispatchMode, subgraph, identifier, operands):
|
||||
def _(proxy_mode: ProxyTorchDispatchMode, subgraph, identifier, *operands):
|
||||
# Check if we have already traced the subgraph.
|
||||
graph = None
|
||||
invoke_subgraph_cache = get_invoke_subgraph_cache()
|
||||
@ -562,13 +556,13 @@ def _(proxy_mode: ProxyTorchDispatchMode, subgraph, identifier, operands):
|
||||
if invoke_subgraph_cache:
|
||||
invoke_subgraph_cache.add_proxy_dispatch_entry(identifier, graph)
|
||||
|
||||
node_args = (graph, identifier, operands)
|
||||
node_args = (graph, identifier, *operands)
|
||||
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) # type: ignore[union-attr]
|
||||
out_proxy = proxy_mode.tracer.create_proxy(
|
||||
"call_function", invoke_subgraph, proxy_args, {}
|
||||
)
|
||||
|
||||
example_out = invoke_subgraph(graph, identifier, operands)
|
||||
example_out = invoke_subgraph(graph, identifier, *operands)
|
||||
return track_tensor_tree(
|
||||
example_out, out_proxy, constant=None, tracer=proxy_mode.tracer
|
||||
)
|
||||
|
@ -7445,9 +7445,9 @@ class InvokeSubgraph(ExternKernel):
|
||||
V.graph.register_operation(self)
|
||||
|
||||
@classmethod
|
||||
def create(cls, subgraph: Subgraph, operands): # type: ignore[no-untyped-def]
|
||||
def create(cls, subgraph: Subgraph, *operands): # type: ignore[no-untyped-def]
|
||||
# TODO(anijain2305) - Support sym expr as operands in future.
|
||||
fx_operands = V.graph.current_node.args[-1]
|
||||
fx_operands = V.graph.current_node.args[2:]
|
||||
fake_operands = [x.meta["val"] for x in fx_operands] # type: ignore[union-attr]
|
||||
|
||||
# Realize the inputs. Also intermediates can have different strides than
|
||||
|
@ -6904,8 +6904,8 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs):
|
||||
|
||||
|
||||
@register_lowering(torch.ops.higher_order.invoke_subgraph, type_promotion_kind=None)
|
||||
def invoke_subgraph(subgraph_fn: ir.Subgraph, identifier: str, operands):
|
||||
result = ir.InvokeSubgraph.create(subgraph_fn, operands)
|
||||
def invoke_subgraph(subgraph_fn: ir.Subgraph, identifier: str, *operands):
|
||||
result = ir.InvokeSubgraph.create(subgraph_fn, *operands)
|
||||
return list(map(TensorBox.create, result))
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user