[invoke_subgraph] Unpacked operands (#152547)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152547
Approved by: https://github.com/ydwu4, https://github.com/zou3519
This commit is contained in:
Animesh Jain
2025-05-01 18:01:13 -07:00
committed by PyTorch MergeBot
parent e6989ceea9
commit 4649fd17b0
9 changed files with 106 additions and 109 deletions

View File

@ -63,12 +63,12 @@ class GraphModule(torch.nn.Module):
o1: "f32[10, 20]" = torch.sin(l_y_)
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_)); invoke_subgraph = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, o1)); o1 = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); invoke_subgraph = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, o1); o1 = None
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_)); subgraph_0 = l_x_ = l_y_ = None
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None
getitem_2: "f32[]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
@ -98,10 +98,10 @@ class GraphModule(torch.nn.Module):
sin: "f32[10, 20]" = torch.ops.aten.sin.default(primals_2)
___forward_subgraph_0_0_post_graph = self.___forward_subgraph_0_0_post_graph
invoke_subgraph_5 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph, '___forward_subgraph_0_0_post_graph', (primals_1, sin)); ___forward_subgraph_0_0_post_graph = sin = None
invoke_subgraph_5 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph, '___forward_subgraph_0_0_post_graph', primals_1, sin); ___forward_subgraph_0_0_post_graph = sin = None
getitem_1: "f32[]" = invoke_subgraph_5[0]; invoke_subgraph_5 = None
___forward_subgraph_0_0_post_graph_1 = self.___forward_subgraph_0_0_post_graph
invoke_subgraph_7 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph_1, '___forward_subgraph_0_0_post_graph', (primals_1, primals_2)); ___forward_subgraph_0_0_post_graph_1 = primals_1 = None
invoke_subgraph_7 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph_1, '___forward_subgraph_0_0_post_graph', primals_1, primals_2); ___forward_subgraph_0_0_post_graph_1 = primals_1 = None
getitem_2: "f32[]" = invoke_subgraph_7[0]; invoke_subgraph_7 = None
mul: "f32[]" = torch.ops.aten.mul.Tensor(getitem_2, getitem_2)
@ -157,13 +157,13 @@ class GraphModule(torch.nn.Module):
x0: "f32[10, 10]" = l_x_ + 2; l_x_ = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (x0,)); x0 = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', x0); x0 = None
getitem: "f32[10, 10]" = invoke_subgraph[0]; invoke_subgraph = None
o_3: "f32[10, 10]" = torch.cos(getitem); getitem = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (o_3,)); subgraph_0 = o_3 = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', o_3); subgraph_0 = o_3 = None
getitem_1: "f32[10, 10]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
@ -188,13 +188,13 @@ class GraphModule(torch.nn.Module):
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(primals_1, 2); primals_1 = None
___forward_subgraph_0_0_post_graph = self.___forward_subgraph_0_0_post_graph
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph, '___forward_subgraph_0_0_post_graph', (add,)); ___forward_subgraph_0_0_post_graph = add = None
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph, '___forward_subgraph_0_0_post_graph', add); ___forward_subgraph_0_0_post_graph = add = None
getitem: "f32[10, 10]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
cos: "f32[10, 10]" = torch.ops.aten.cos.default(getitem)
___forward_subgraph_0_0_post_graph_1 = self.___forward_subgraph_0_0_post_graph
invoke_subgraph_6 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph_1, '___forward_subgraph_0_0_post_graph', (cos,)); ___forward_subgraph_0_0_post_graph_1 = cos = None
invoke_subgraph_6 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph_1, '___forward_subgraph_0_0_post_graph', cos); ___forward_subgraph_0_0_post_graph_1 = cos = None
getitem_1: "f32[10, 10]" = invoke_subgraph_6[0]; invoke_subgraph_6 = None
sin: "f32[10, 10]" = torch.ops.aten.sin.default(getitem_1)
@ -267,24 +267,24 @@ class GraphModule(torch.nn.Module):
y0: "f32[10, 20]" = torch.sin(l_y_)
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_))
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_)
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
o1: "f32[]" = torch.sin(getitem); getitem = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, y0))
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, y0)
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
mul_2: "f32[]" = o1 * getitem_1; o1 = getitem_1 = None
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_)); subgraph_0 = l_x_ = l_y_ = None
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None
getitem_2: "f32[]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', (x0, y0)); invoke_subgraph_3 = None
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', (x0, y0)); subgraph_1 = x0 = y0 = None
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', x0, y0); invoke_subgraph_3 = None
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', x0, y0); subgraph_1 = x0 = y0 = None
getitem_4: "f32[10, 10]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
@ -325,22 +325,22 @@ class GraphModule(torch.nn.Module):
sin: "f32[10, 20]" = torch.ops.aten.sin.default(primals_2)
___forward_subgraph_0_0_post_graph = self.___forward_subgraph_0_0_post_graph
invoke_subgraph_9 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph, '___forward_subgraph_0_0_post_graph', (primals_1, primals_2)); ___forward_subgraph_0_0_post_graph = None
invoke_subgraph_9 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph, '___forward_subgraph_0_0_post_graph', primals_1, primals_2); ___forward_subgraph_0_0_post_graph = None
getitem: "f32[]" = invoke_subgraph_9[0]; invoke_subgraph_9 = None
sin_1: "f32[]" = torch.ops.aten.sin.default(getitem)
___forward_subgraph_0_0_post_graph_1 = self.___forward_subgraph_0_0_post_graph
invoke_subgraph_11 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph_1, '___forward_subgraph_0_0_post_graph', (primals_1, sin)); ___forward_subgraph_0_0_post_graph_1 = None
invoke_subgraph_11 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph_1, '___forward_subgraph_0_0_post_graph', primals_1, sin); ___forward_subgraph_0_0_post_graph_1 = None
getitem_1: "f32[]" = invoke_subgraph_11[0]; invoke_subgraph_11 = None
mul: "f32[]" = torch.ops.aten.mul.Tensor(sin_1, getitem_1); sin_1 = None
___forward_subgraph_0_0_post_graph_2 = self.___forward_subgraph_0_0_post_graph
invoke_subgraph_13 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph_2, '___forward_subgraph_0_0_post_graph', (primals_1, primals_2)); ___forward_subgraph_0_0_post_graph_2 = None
invoke_subgraph_13 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph_2, '___forward_subgraph_0_0_post_graph', primals_1, primals_2); ___forward_subgraph_0_0_post_graph_2 = None
getitem_2: "f32[]" = invoke_subgraph_13[0]; invoke_subgraph_13 = None
___forward_subgraph_1_0_post_graph = self.___forward_subgraph_1_0_post_graph
invoke_subgraph_15 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_1_0_post_graph, '___forward_subgraph_1_0_post_graph', (cos, sin)); ___forward_subgraph_1_0_post_graph = cos = sin = None
invoke_subgraph_15 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_1_0_post_graph, '___forward_subgraph_1_0_post_graph', cos, sin); ___forward_subgraph_1_0_post_graph = cos = sin = None
getitem_19: "f32[]" = invoke_subgraph_15[3]
getitem_18: "f32[10, 20]" = invoke_subgraph_15[2]
getitem_17: "f32[10, 10]" = invoke_subgraph_15[1]
@ -409,7 +409,7 @@ class GraphModule(torch.nn.Module):
sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None
___forward_subgraph_0_0_post_graph = self.___forward_subgraph_0_0_post_graph
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph, '___forward_subgraph_0_0_post_graph', (primals_1, sum_1)); ___forward_subgraph_0_0_post_graph = sum_1 = None
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph, '___forward_subgraph_0_0_post_graph', primals_1, sum_1); ___forward_subgraph_0_0_post_graph = sum_1 = None
getitem: "f32[]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
add_1: "f32[]" = torch.ops.aten.add.Tensor(getitem, 2); getitem = None
@ -417,7 +417,7 @@ class GraphModule(torch.nn.Module):
sum_2: "f32[]" = torch.ops.aten.sum.default(add_1); add_1 = None
___forward_subgraph_0_0_post_graph_1 = self.___forward_subgraph_0_0_post_graph
invoke_subgraph_6 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph_1, '___forward_subgraph_0_0_post_graph', (primals_1, sum_2)); ___forward_subgraph_0_0_post_graph_1 = primals_1 = sum_2 = None
invoke_subgraph_6 = torch.ops.higher_order.invoke_subgraph(___forward_subgraph_0_0_post_graph_1, '___forward_subgraph_0_0_post_graph', primals_1, sum_2); ___forward_subgraph_0_0_post_graph_1 = primals_1 = sum_2 = None
getitem_1: "f32[]" = invoke_subgraph_6[0]; invoke_subgraph_6 = None
return (getitem_1,)
@ -492,10 +492,10 @@ class <lambda>(torch.nn.Module):
add_7: "f32[10, 20]" = torch.ops.aten.add.Tensor(clone_1, add_5); clone_1 = add_5 = None
repeated_subgraph0 = self.repeated_subgraph0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', (add_2, add_3)); repeated_subgraph0 = add_2 = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', add_2, add_3); repeated_subgraph0 = add_2 = None
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
repeated_subgraph0_1 = self.repeated_subgraph0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', (add_6, add_7)); repeated_subgraph0_1 = add_6 = add_7 = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', add_6, add_7); repeated_subgraph0_1 = add_6 = add_7 = None
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
add_8: "f32[]" = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
@ -556,14 +556,14 @@ class <lambda>(torch.nn.Module):
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(view_1, view_3); view_1 = view_3 = None
repeated_subgraph0 = self.repeated_subgraph0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0 = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0 = None
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None
add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, sum_1); add = sum_1 = None
repeated_subgraph0_1 = self.repeated_subgraph0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0_1 = arg0_1 = arg1_1 = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0_1 = arg0_1 = arg1_1 = None
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_1); getitem_1 = None
@ -720,7 +720,7 @@ class <lambda>(torch.nn.Module):
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(view_1, view_3); view_1 = view_3 = None
repeated_subgraph0 = self.repeated_subgraph0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0 = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0 = None
_enter_autocast = torch.amp.autocast_mode._enter_autocast(); _enter_autocast = None
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
@ -728,7 +728,7 @@ class <lambda>(torch.nn.Module):
add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, sum_1); add = sum_1 = None
repeated_subgraph0_1 = self.repeated_subgraph0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0_1 = arg0_1 = arg1_1 = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0_1 = arg0_1 = arg1_1 = None
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
_exit_autocast = torch.amp.autocast_mode._exit_autocast(); _exit_autocast = None
@ -765,7 +765,7 @@ class <lambda>(torch.nn.Module):
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(view_1, view_3); view_1 = view_3 = None
repeated_subgraph0 = self.repeated_subgraph0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0 = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0 = None
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None
@ -775,7 +775,7 @@ class <lambda>(torch.nn.Module):
add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, sum_1); add = sum_1 = None
repeated_subgraph0_1 = self.repeated_subgraph0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0_1 = arg0_1 = arg1_1 = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0_1 = arg0_1 = arg1_1 = None
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_1); getitem_1 = None
@ -844,14 +844,14 @@ class <lambda>(torch.nn.Module):
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(view_1, view_3); view_1 = view_3 = None
repeated_subgraph0 = self.repeated_subgraph0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0 = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0 = None
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None
add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, sum_1); add = sum_1 = None
repeated_subgraph0_1 = self.repeated_subgraph0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0_1 = arg0_1 = arg1_1 = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0_1 = arg0_1 = arg1_1 = None
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_1); getitem_1 = None
@ -887,14 +887,14 @@ class <lambda>(torch.nn.Module):
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(view_1, view_3); view_1 = view_3 = None
repeated_subgraph0 = self.repeated_subgraph0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0 = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0 = None
getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None
add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, sum_1); add = sum_1 = None
repeated_subgraph0_1 = self.repeated_subgraph0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0_1 = arg0_1 = arg1_1 = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0_1 = arg0_1 = arg1_1 = None
getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_1); getitem_1 = None

View File

@ -381,11 +381,11 @@ class GraphModule(torch.nn.Module):
l_y_ = L_y_
invoke_subgraph_0 = self.invoke_subgraph_0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (l_x_, l_y_)); invoke_subgraph_0 = l_x_ = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', l_x_, l_y_); invoke_subgraph_0 = l_x_ = None
a: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
invoke_subgraph_1 = self.invoke_subgraph_0
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_0', (a, l_y_)); invoke_subgraph_1 = a = l_y_ = None
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_0', a, l_y_); invoke_subgraph_1 = a = l_y_ = None
getitem_1: "f32[8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
return (getitem_1,)
@ -403,14 +403,14 @@ class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[8]", primals_2: "f32[8]"):
___forward_invoke_subgraph_0_0_post_graph = self.___forward_invoke_subgraph_0_0_post_graph
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_0_post_graph, '___forward_invoke_subgraph_0_0_post_graph', (primals_1, primals_2)); ___forward_invoke_subgraph_0_0_post_graph = primals_1 = None
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_0_post_graph, '___forward_invoke_subgraph_0_0_post_graph', primals_1, primals_2); ___forward_invoke_subgraph_0_0_post_graph = primals_1 = None
getitem_9: "f32[8]" = invoke_subgraph_4[2]
getitem_8: "f32[8]" = invoke_subgraph_4[1]
getitem: "f32[8]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
___forward_invoke_subgraph_0_0_post_graph_1 = self.___forward_invoke_subgraph_0_0_post_graph
invoke_subgraph_6 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_0_post_graph_1, '___forward_invoke_subgraph_0_0_post_graph', (getitem, primals_2)); ___forward_invoke_subgraph_0_0_post_graph_1 = getitem = primals_2 = None
invoke_subgraph_6 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_0_post_graph_1, '___forward_invoke_subgraph_0_0_post_graph', getitem, primals_2); ___forward_invoke_subgraph_0_0_post_graph_1 = getitem = primals_2 = None
getitem_11: "f32[8]" = invoke_subgraph_6[2]
getitem_10: "f32[8]" = invoke_subgraph_6[1]
getitem_1: "f32[8]" = invoke_subgraph_6[0]; invoke_subgraph_6 = None
@ -468,11 +468,11 @@ class GraphModule(torch.nn.Module):
l_y_ = L_y_
invoke_subgraph_0 = self.invoke_subgraph_0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (l_x_, l_y_)); invoke_subgraph_0 = l_x_ = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', l_x_, l_y_); invoke_subgraph_0 = l_x_ = None
a: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
invoke_subgraph_1 = self.invoke_subgraph_1
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_1', (a, l_y_)); invoke_subgraph_1 = a = l_y_ = None
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_1', a, l_y_); invoke_subgraph_1 = a = l_y_ = None
getitem_1: "f32[8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
return (getitem_1,)
@ -521,19 +521,19 @@ class GraphModule(torch.nn.Module):
l_y_ = L_y_
invoke_subgraph_0 = self.invoke_subgraph_0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (l_x_, l_y_)); invoke_subgraph_0 = l_x_ = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', l_x_, l_y_); invoke_subgraph_0 = l_x_ = None
x: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
invoke_subgraph_1 = self.invoke_subgraph_0
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_0', (x, l_y_)); invoke_subgraph_1 = x = None
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_0', x, l_y_); invoke_subgraph_1 = x = None
x_1: "f32[8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
invoke_subgraph_3 = self.invoke_subgraph_0
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_3, 'invoke_subgraph_0', (x_1, l_y_)); invoke_subgraph_3 = x_1 = None
invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_3, 'invoke_subgraph_0', x_1, l_y_); invoke_subgraph_3 = x_1 = None
x_2: "f32[8]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None
invoke_subgraph_5 = self.invoke_subgraph_0
invoke_subgraph_6 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_5, 'invoke_subgraph_0', (x_2, l_y_)); invoke_subgraph_5 = x_2 = None
invoke_subgraph_6 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_5, 'invoke_subgraph_0', x_2, l_y_); invoke_subgraph_5 = x_2 = None
x_3: "f32[8]" = invoke_subgraph_6[0]; invoke_subgraph_6 = None
invoke_subgraph_7 = self.invoke_subgraph_0
invoke_subgraph_8 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_7, 'invoke_subgraph_0', (x_3, l_y_)); invoke_subgraph_7 = x_3 = l_y_ = None
invoke_subgraph_8 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_7, 'invoke_subgraph_0', x_3, l_y_); invoke_subgraph_7 = x_3 = l_y_ = None
x_4: "f32[8]" = invoke_subgraph_8[0]; invoke_subgraph_8 = None
return (x_4,)
@ -765,10 +765,10 @@ class GraphModule(torch.nn.Module):
l_self_modules_linear_parameters_bias_ = L_self_modules_linear_parameters_bias_
invoke_subgraph_0 = self.invoke_subgraph_0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (l_x_, l_self_modules_linear_parameters_weight_, l_self_modules_linear_parameters_bias_)); invoke_subgraph_0 = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', l_x_, l_self_modules_linear_parameters_weight_, l_self_modules_linear_parameters_bias_); invoke_subgraph_0 = None
getitem: "f32[8, 8]" = invoke_subgraph[0]; invoke_subgraph = None
invoke_subgraph_1 = self.invoke_subgraph_0
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_0', (l_x_, l_self_modules_linear_parameters_weight_, l_self_modules_linear_parameters_bias_)); invoke_subgraph_1 = l_self_modules_linear_parameters_weight_ = l_self_modules_linear_parameters_bias_ = None
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_0', l_x_, l_self_modules_linear_parameters_weight_, l_self_modules_linear_parameters_bias_); invoke_subgraph_1 = l_self_modules_linear_parameters_weight_ = l_self_modules_linear_parameters_bias_ = None
getitem_1: "f32[8, 8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
mul: "f32[8, 8]" = getitem * getitem_1; getitem = getitem_1 = None
@ -818,10 +818,10 @@ class GraphModule(torch.nn.Module):
l_x_ = L_x_
invoke_subgraph_0 = self.invoke_subgraph_0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (l_x_,)); invoke_subgraph_0 = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', l_x_); invoke_subgraph_0 = None
getitem: "f32[8, 8]" = invoke_subgraph[0]; invoke_subgraph = None
invoke_subgraph_1 = self.invoke_subgraph_0
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_0', (l_x_,)); invoke_subgraph_1 = None
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_0', l_x_); invoke_subgraph_1 = None
getitem_1: "f32[8, 8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
mul: "f32[8, 8]" = getitem * getitem_1; getitem = getitem_1 = None
@ -903,7 +903,7 @@ class GraphModule(torch.nn.Module):
l_x_ = L_x_
invoke_subgraph_0 = self.invoke_subgraph_0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (l_x_,)); invoke_subgraph_0 = l_x_ = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', l_x_); invoke_subgraph_0 = l_x_ = None
getitem: "f32[8, 8]" = invoke_subgraph[0]
getitem_1: "f32[8, 8]" = invoke_subgraph[2]; invoke_subgraph = None
@ -925,7 +925,7 @@ class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[8, 8]"):
___forward_invoke_subgraph_0_0_post_graph = self.___forward_invoke_subgraph_0_0_post_graph
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_0_post_graph, '___forward_invoke_subgraph_0_0_post_graph', (primals_1,)); ___forward_invoke_subgraph_0_0_post_graph = primals_1 = None
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_0_post_graph, '___forward_invoke_subgraph_0_0_post_graph', primals_1); ___forward_invoke_subgraph_0_0_post_graph = primals_1 = None
getitem: "f32[8, 8]" = invoke_subgraph_2[0]
getitem_2: "f32[8, 8]" = invoke_subgraph_2[2]; invoke_subgraph_2 = None
@ -947,7 +947,7 @@ class GraphModule(torch.nn.Module):
def forward(self, tangents_1: "f32[8, 8]"):
___backward_invoke_subgraph_0_0_post_graph = self.___backward_invoke_subgraph_0_0_post_graph
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(___backward_invoke_subgraph_0_0_post_graph, '___backward_invoke_subgraph_0_0_post_graph', (tangents_1, tangents_1)); ___backward_invoke_subgraph_0_0_post_graph = tangents_1 = None
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(___backward_invoke_subgraph_0_0_post_graph, '___backward_invoke_subgraph_0_0_post_graph', tangents_1, tangents_1); ___backward_invoke_subgraph_0_0_post_graph = tangents_1 = None
getitem_3: "f32[8, 8]" = invoke_subgraph_3[0]; invoke_subgraph_3 = None
return (getitem_3,)
@ -1052,7 +1052,7 @@ class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[8, 8]", primals_2: "f32[8, 8]"):
___forward_invoke_subgraph_0_0_post_graph = self.___forward_invoke_subgraph_0_0_post_graph
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_0_post_graph, '___forward_invoke_subgraph_0_0_post_graph', (primals_1, primals_2)); ___forward_invoke_subgraph_0_0_post_graph = primals_1 = primals_2 = None
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_0_post_graph, '___forward_invoke_subgraph_0_0_post_graph', primals_1, primals_2); ___forward_invoke_subgraph_0_0_post_graph = primals_1 = primals_2 = None
getitem_6: "f32[8, 8]" = invoke_subgraph_2[3]
getitem_5: "f32[8, 8]" = invoke_subgraph_2[2]
getitem_4: "f32[8, 8]" = invoke_subgraph_2[1]
@ -1083,7 +1083,7 @@ class GraphModule(torch.nn.Module):
___backward_invoke_subgraph_0_0_post_graph = self.___backward_invoke_subgraph_0_0_post_graph
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(___backward_invoke_subgraph_0_0_post_graph, '___backward_invoke_subgraph_0_0_post_graph', (getitem_4, getitem_5, getitem_6, mul)); ___backward_invoke_subgraph_0_0_post_graph = getitem_4 = getitem_5 = getitem_6 = mul = None
invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(___backward_invoke_subgraph_0_0_post_graph, '___backward_invoke_subgraph_0_0_post_graph', getitem_4, getitem_5, getitem_6, mul); ___backward_invoke_subgraph_0_0_post_graph = getitem_4 = getitem_5 = getitem_6 = mul = None
getitem_1: "f32[8, 8]" = invoke_subgraph_3[0]
getitem_2: "f32[8, 8]" = invoke_subgraph_3[1]; invoke_subgraph_3 = None
return (getitem_1, getitem_2)
@ -1195,10 +1195,10 @@ class GraphModule(torch.nn.Module):
l_y_ = L_y_
invoke_subgraph_0 = self.invoke_subgraph_0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (l_x_,)); invoke_subgraph_0 = l_x_ = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', l_x_); invoke_subgraph_0 = l_x_ = None
getitem: "f32[8, 8]" = invoke_subgraph[0]; invoke_subgraph = None
invoke_subgraph_1 = self.invoke_subgraph_1
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_1', (l_y_,)); invoke_subgraph_1 = l_y_ = None
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_1', l_y_); invoke_subgraph_1 = l_y_ = None
getitem_1: "f32[16, 16]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
return (getitem, getitem_1)
@ -1259,14 +1259,14 @@ class GraphModule(torch.nn.Module):
l_x_ = L_x_
invoke_subgraph_0 = self.invoke_subgraph_0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (s77, l_x_)); invoke_subgraph_0 = l_x_ = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', s77, l_x_); invoke_subgraph_0 = l_x_ = None
a: "f32[s77, 8]" = invoke_subgraph[0]; invoke_subgraph = None
floordiv: "Sym((s77//2))" = s77 // 2
b: "f32[(s77//2), 8]" = torch.narrow(a, 0, 0, floordiv); a = floordiv = None
invoke_subgraph_1 = self.invoke_subgraph_1
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_1', (s77, b)); invoke_subgraph_1 = s77 = b = None
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_1', s77, b); invoke_subgraph_1 = s77 = b = None
getitem_3: "f32[(s77//2), 8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
return (getitem_3,)
@ -1326,10 +1326,10 @@ class GraphModule(torch.nn.Module):
l_x_ = L_x_
invoke_subgraph_0 = self.invoke_subgraph_0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (l_x_,)); invoke_subgraph_0 = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', l_x_); invoke_subgraph_0 = None
getitem: "f32[8, 8]" = invoke_subgraph[0]; invoke_subgraph = None
invoke_subgraph_1 = self.invoke_subgraph_0
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_0', (l_x_,)); invoke_subgraph_1 = l_x_ = None
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_0', l_x_); invoke_subgraph_1 = l_x_ = None
getitem_1: "f32[8, 8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
add: "f32[8, 8]" = getitem + getitem_1; getitem = getitem_1 = None
@ -1435,11 +1435,11 @@ class GraphModule(torch.nn.Module):
y: "f32[5]" = l_y_.sin(); l_y_ = None
invoke_subgraph_0 = self.invoke_subgraph_0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (x, y)); invoke_subgraph_0 = x = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', x, y); invoke_subgraph_0 = x = None
z: "f32[5]" = invoke_subgraph[0]; invoke_subgraph = None
invoke_subgraph_1 = self.invoke_subgraph_1
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_1', (z, y)); invoke_subgraph_1 = z = y = None
invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_1', z, y); invoke_subgraph_1 = z = y = None
getitem_1: "f32[5]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None
return (getitem_1,)
@ -1534,14 +1534,14 @@ class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s77)", primals_2: "f32[s77, 16]"):
___forward_invoke_subgraph_0_1_post_graph = self.___forward_invoke_subgraph_0_1_post_graph
invoke_subgraph_8 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_1_post_graph, '___forward_invoke_subgraph_0_1_post_graph', (primals_1, primals_2)); ___forward_invoke_subgraph_0_1_post_graph = primals_2 = None
invoke_subgraph_8 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_1_post_graph, '___forward_invoke_subgraph_0_1_post_graph', primals_1, primals_2); ___forward_invoke_subgraph_0_1_post_graph = primals_2 = None
getitem_17: "Sym(s77)" = invoke_subgraph_8[2]
getitem_16: "f32[s77, 16]" = invoke_subgraph_8[1]
getitem: "f32[s77, 16]" = invoke_subgraph_8[0]; invoke_subgraph_8 = None
___forward_invoke_subgraph_0_1_post_graph_1 = self.___forward_invoke_subgraph_0_1_post_graph
invoke_subgraph_10 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_1_post_graph_1, '___forward_invoke_subgraph_0_1_post_graph', (primals_1, getitem)); ___forward_invoke_subgraph_0_1_post_graph_1 = getitem = None
invoke_subgraph_10 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_1_post_graph_1, '___forward_invoke_subgraph_0_1_post_graph', primals_1, getitem); ___forward_invoke_subgraph_0_1_post_graph_1 = getitem = None
getitem_19: "Sym(s77)" = invoke_subgraph_10[2]
getitem_18: "f32[s77, 16]" = invoke_subgraph_10[1]
getitem_1: "f32[s77, 16]" = invoke_subgraph_10[0]; invoke_subgraph_10 = None
@ -1550,14 +1550,14 @@ class GraphModule(torch.nn.Module):
___forward_invoke_subgraph_0_1_post_graph_2 = self.___forward_invoke_subgraph_0_1_post_graph
invoke_subgraph_12 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_1_post_graph_2, '___forward_invoke_subgraph_0_1_post_graph', (primals_1, sin)); ___forward_invoke_subgraph_0_1_post_graph_2 = sin = None
invoke_subgraph_12 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_1_post_graph_2, '___forward_invoke_subgraph_0_1_post_graph', primals_1, sin); ___forward_invoke_subgraph_0_1_post_graph_2 = sin = None
getitem_21: "Sym(s77)" = invoke_subgraph_12[2]
getitem_20: "f32[s77, 16]" = invoke_subgraph_12[1]
getitem_2: "f32[s77, 16]" = invoke_subgraph_12[0]; invoke_subgraph_12 = None
___forward_invoke_subgraph_0_0_post_graph = self.___forward_invoke_subgraph_0_0_post_graph
invoke_subgraph_14 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_0_post_graph, '___forward_invoke_subgraph_0_0_post_graph', (primals_1, getitem_2)); ___forward_invoke_subgraph_0_0_post_graph = None
invoke_subgraph_14 = torch.ops.higher_order.invoke_subgraph(___forward_invoke_subgraph_0_0_post_graph, '___forward_invoke_subgraph_0_0_post_graph', primals_1, getitem_2); ___forward_invoke_subgraph_0_0_post_graph = None
getitem_23: "Sym(s77)" = invoke_subgraph_14[2]
getitem_22: "f32[s77, 16]" = invoke_subgraph_14[1]
getitem_3: "f32[s77, 16]" = invoke_subgraph_14[0]; invoke_subgraph_14 = None
@ -1589,26 +1589,26 @@ class GraphModule(torch.nn.Module):
___backward_invoke_subgraph_0_0_post_graph = self.___backward_invoke_subgraph_0_0_post_graph
invoke_subgraph_15 = torch.ops.higher_order.invoke_subgraph(___backward_invoke_subgraph_0_0_post_graph, '___backward_invoke_subgraph_0_0_post_graph', (getitem_23, getitem_22, expand)); ___backward_invoke_subgraph_0_0_post_graph = getitem_23 = getitem_22 = None
invoke_subgraph_15 = torch.ops.higher_order.invoke_subgraph(___backward_invoke_subgraph_0_0_post_graph, '___backward_invoke_subgraph_0_0_post_graph', getitem_23, getitem_22, expand); ___backward_invoke_subgraph_0_0_post_graph = getitem_23 = getitem_22 = None
getitem_5: "f32[s77, 16]" = invoke_subgraph_15[1]; invoke_subgraph_15 = None
add_16: "f32[s77, 16]" = torch.ops.aten.add.Tensor(expand, getitem_5); expand = getitem_5 = None
___backward_invoke_subgraph_0_1_post_graph_2 = self.___backward_invoke_subgraph_0_1_post_graph
invoke_subgraph_13 = torch.ops.higher_order.invoke_subgraph(___backward_invoke_subgraph_0_1_post_graph_2, '___backward_invoke_subgraph_0_1_post_graph', (getitem_21, getitem_20, add_16)); ___backward_invoke_subgraph_0_1_post_graph_2 = getitem_21 = getitem_20 = add_16 = None
invoke_subgraph_13 = torch.ops.higher_order.invoke_subgraph(___backward_invoke_subgraph_0_1_post_graph_2, '___backward_invoke_subgraph_0_1_post_graph', getitem_21, getitem_20, add_16); ___backward_invoke_subgraph_0_1_post_graph_2 = getitem_21 = getitem_20 = add_16 = None
getitem_8: "f32[s77, 16]" = invoke_subgraph_13[1]; invoke_subgraph_13 = None
mul_10: "f32[s77, 16]" = torch.ops.aten.mul.Tensor(getitem_8, cos); getitem_8 = cos = None
___backward_invoke_subgraph_0_1_post_graph_1 = self.___backward_invoke_subgraph_0_1_post_graph
invoke_subgraph_11 = torch.ops.higher_order.invoke_subgraph(___backward_invoke_subgraph_0_1_post_graph_1, '___backward_invoke_subgraph_0_1_post_graph', (getitem_19, getitem_18, mul_10)); ___backward_invoke_subgraph_0_1_post_graph_1 = getitem_19 = getitem_18 = mul_10 = None
invoke_subgraph_11 = torch.ops.higher_order.invoke_subgraph(___backward_invoke_subgraph_0_1_post_graph_1, '___backward_invoke_subgraph_0_1_post_graph', getitem_19, getitem_18, mul_10); ___backward_invoke_subgraph_0_1_post_graph_1 = getitem_19 = getitem_18 = mul_10 = None
getitem_11: "f32[s77, 16]" = invoke_subgraph_11[1]; invoke_subgraph_11 = None
___backward_invoke_subgraph_0_1_post_graph = self.___backward_invoke_subgraph_0_1_post_graph
invoke_subgraph_9 = torch.ops.higher_order.invoke_subgraph(___backward_invoke_subgraph_0_1_post_graph, '___backward_invoke_subgraph_0_1_post_graph', (getitem_17, getitem_16, getitem_11)); ___backward_invoke_subgraph_0_1_post_graph = getitem_17 = getitem_16 = getitem_11 = None
invoke_subgraph_9 = torch.ops.higher_order.invoke_subgraph(___backward_invoke_subgraph_0_1_post_graph, '___backward_invoke_subgraph_0_1_post_graph', getitem_17, getitem_16, getitem_11); ___backward_invoke_subgraph_0_1_post_graph = getitem_17 = getitem_16 = getitem_11 = None
getitem_14: "f32[s77, 16]" = invoke_subgraph_9[1]; invoke_subgraph_9 = None
return (None, getitem_14)
@ -1679,11 +1679,11 @@ class TestInvokeSubgraphExport(TestCase):
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[8]", y: "f32[8]"):
repeated_subgraph0 = self.repeated_subgraph0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'invoke_subgraph_0', (x, y)); repeated_subgraph0 = x = None
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'invoke_subgraph_0', x, y); repeated_subgraph0 = x = None
getitem: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
repeated_subgraph0_1 = self.repeated_subgraph0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'invoke_subgraph_0', (getitem, y)); repeated_subgraph0_1 = getitem = y = None
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'invoke_subgraph_0', getitem, y); repeated_subgraph0_1 = getitem = y = None
getitem_1: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
return (getitem_1,)

View File

@ -2216,16 +2216,16 @@ class FakeTensorDispatchCache(TestCase):
FakeTensorMode.cache_clear()
self.assertHitsMisses(0, 0)
ref = invoke_subgraph(fn, "subgraph", (x, y))
ref = invoke_subgraph(fn, "subgraph", x, y)
self.assertHitsMisses(0, 2)
self.assertBypasses("function argument", 1)
res = invoke_subgraph(fn, "subgraph", (x, y))
res = invoke_subgraph(fn, "subgraph", x, y)
# The hits are from the ops inside fn
self.assertHitsMisses(2, 2)
self.assertBypasses("function argument", 2)
res = invoke_subgraph(fn, "subgraph", (x, y))
res = invoke_subgraph(fn, "subgraph", x, y)
# The hits are from the ops inside fn
self.assertHitsMisses(4, 2)
self.assertBypasses("function argument", 3)
@ -2246,14 +2246,14 @@ class FakeTensorDispatchCache(TestCase):
FakeTensorMode.cache_clear()
self.assertHitsMisses(0, 0)
ref = invoke_subgraph(mod, "subgraph", (x, y))
ref = invoke_subgraph(mod, "subgraph", x, y)
self.assertHitsMisses(0, 3)
res = invoke_subgraph(mod, "subgraph", (x, y))
res = invoke_subgraph(mod, "subgraph", x, y)
# The hits are from re-running the subgraph
self.assertHitsMisses(1, 3)
res = invoke_subgraph(mod, "subgraph", (x, y))
res = invoke_subgraph(mod, "subgraph", x, y)
# The hits are from re-running the subgraph
self.assertHitsMisses(2, 3)
@ -2312,14 +2312,14 @@ class FakeTensorDispatchCache(TestCase):
FakeTensorMode.cache_clear()
self.assertHitsMisses(0, 0)
ref = invoke_subgraph(mod, "subgraph", (x, y))
ref = invoke_subgraph(mod, "subgraph", x, y)
self.assertHitsMisses(0, 3)
res = invoke_subgraph(mod, "subgraph", (x, y))
res = invoke_subgraph(mod, "subgraph", x, y)
# The hits are from the ops inside fn and not the subgraph
self.assertHitsMisses(1, 3)
res = invoke_subgraph(mod, "subgraph", (x, y))
res = invoke_subgraph(mod, "subgraph", x, y)
# The hits are from the ops inside fn and not the subgraph
self.assertHitsMisses(2, 3)

View File

@ -112,7 +112,7 @@ def _replace_region_with_subgraph(
flattened_args_kwargs = _flatten_args_kwargs((node.args, node.kwargs))
sub_args.append(flattened_args_kwargs[arg_ind])
invoke_args = (get_subgraph_node, subgraph_name, tuple(sub_args))
invoke_args = (get_subgraph_node, subgraph_name, *sub_args)
fake_inputs = [node.meta["example_value"] for node in sub_args]
if has_potential_input_alias_or_mutation(sub_gm, fake_inputs):
@ -125,7 +125,10 @@ def _replace_region_with_subgraph(
from torch._inductor.pattern_matcher import stable_topological_sort
invoke_subgraph_node = graph.create_node(
"call_function", torch.ops.higher_order.invoke_subgraph, invoke_args, {}
"call_function",
torch.ops.higher_order.invoke_subgraph,
invoke_args, # type: ignore[arg-type]
{},
)
for ind, external_user_ind in enumerate(inds_with_external_users):
node = region[external_user_ind]

View File

@ -3221,7 +3221,7 @@ class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
p_args = (
p_args[0],
body_name,
p_args[1:],
*p_args[1:],
)
return _call_function_and_unflatten_output(
tx,

View File

@ -707,7 +707,7 @@ def run_joint_graph_passes_on_hops(
args=(
new_fw_mod_attr,
new_fw_mod_attr_name,
fw_node.args[2],
*fw_node.args[2:],
),
)
propagate_meta_info(new_fw_hop_gm, new_fw_node, fw_node)
@ -744,7 +744,7 @@ def run_joint_graph_passes_on_hops(
num_primals = new_hop_graphs[identifier].old_num_fw_inputs
assert num_primals is not None
tangents = list(bw_node.args[2][num_primals:])
tangents = list(bw_node.args[2 + num_primals :])
operands = sym_nodes + saved_tensor_nodes + tangents
# Insert the new_bw_hop_gm into the joint_gm
@ -758,7 +758,7 @@ def run_joint_graph_passes_on_hops(
args=(
new_bw_mod_attr,
new_bw_mod_attr_name,
tuple(operands),
*operands,
),
)
propagate_meta_info(new_bw_hop_gm, new_bw_node, bw_node)

View File

@ -65,23 +65,17 @@ class InvokeSubgraphHOP(HigherOrderOperator):
self,
subgraph: Union[GraphModule, FunctionalizeCtxWrapper],
identifier: Optional[str],
operands: Union[
list[Union[torch.Tensor, int, torch.SymInt]],
tuple[Union[torch.Tensor, int, torch.SymInt]],
],
*operands,
):
assert identifier is None or isinstance(
identifier, str
), "identifier must be a None or a string"
assert isinstance(
operands, (list, tuple)
), f"invoke_subgraph operands must be a list or tuple of tensors/ints/SymInts {operands}"
assert all(
isinstance(o, (torch.Tensor, int, torch.SymInt)) for o in operands
), f"invoke_subgraph operands must be a list of tensors/ints/SymInts {operands}"
return super().__call__(subgraph, identifier, operands)
return super().__call__(subgraph, identifier, *operands)
invoke_subgraph = InvokeSubgraphHOP()
@ -261,7 +255,7 @@ def create_fw_bw_graph(subgraph, operands, grad_outputs=None):
return fw_graph, bw_graph, output_metadata
def get_output_metadata(subgraph, operands):
def get_output_metadata(subgraph, *operands):
with suspend_functionalization(), disable_functional_mode():
with disable_proxy_modes_tracing():
# args are functional tensors, generate some example tensors
@ -384,7 +378,7 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function):
out = invoke_subgraph(
subgraph,
f"___forward_{identifier}",
operands,
*operands,
)
# Check that None is at expected indexes.
@ -472,13 +466,13 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function):
)
grads = invoke_subgraph(
bw_graph, f"___backward_{identifier}_{suffix}", primals_and_tangents
bw_graph, f"___backward_{identifier}_{suffix}", *primals_and_tangents
)[: -output_metadata.num_fw_outs]
return None, None, None, *grads
@invoke_subgraph.py_autograd_impl
def _(subgraph, identifier, operands):
def _(subgraph, identifier, *operands):
# Check if we have already traced the subgraph.
invoke_subgraph_cache = get_invoke_subgraph_cache()
if invoke_subgraph_cache:
@ -487,7 +481,7 @@ def _(subgraph, identifier, operands):
):
return saved_autograd_fn(*operands)
output_metadata = get_output_metadata(subgraph, operands)
output_metadata = get_output_metadata(subgraph, *operands)
def autograd_fn_callable(*args):
return InvokeSubgraphAutogradOp.apply(
@ -502,7 +496,7 @@ def _(subgraph, identifier, operands):
@invoke_subgraph.py_impl(DispatchKey.CompositeExplicitAutograd)
def _(subgraph, identifier, operands):
def _(subgraph, identifier, *operands):
from torch.utils._python_dispatch import _get_current_dispatch_mode
mode = _get_current_dispatch_mode()
@ -511,20 +505,20 @@ def _(subgraph, identifier, operands):
@invoke_subgraph.py_functionalize_impl
def _(ctx, subgraph, identifier, operands):
def _(ctx, subgraph, identifier, *operands):
unwrapped_operands = ctx.unwrap_tensors(operands)
with ctx.redispatch_to_next():
# NB: There is an assumption that subgraph does not mutate inputs and
# there is no aliasing. Its Dynamo responsibility to prevent formation
# of invoke_subgraph ops if input aliasing/mutation is detected.
functionalized_subgraph = FunctionalizeCtxWrapper(ctx, subgraph)
out = invoke_subgraph(functionalized_subgraph, identifier, unwrapped_operands)
out = invoke_subgraph(functionalized_subgraph, identifier, *unwrapped_operands)
return ctx.wrap_tensors(out)
# Register the hop fake fn. This will be called in the fake_tensor _dispatch_impl.
@register_fake(invoke_subgraph)
def _(subgraph, identifier, operands):
def _(subgraph, identifier, *operands):
from torch._dynamo.utils import dynamo_timed
with dynamo_timed("invoke_subgraph_fake_tensor", log_pt2_compile_event=True):
@ -532,7 +526,7 @@ def _(subgraph, identifier, operands):
@invoke_subgraph.py_impl(ProxyTorchDispatchMode)
def _(proxy_mode: ProxyTorchDispatchMode, subgraph, identifier, operands):
def _(proxy_mode: ProxyTorchDispatchMode, subgraph, identifier, *operands):
# Check if we have already traced the subgraph.
graph = None
invoke_subgraph_cache = get_invoke_subgraph_cache()
@ -562,13 +556,13 @@ def _(proxy_mode: ProxyTorchDispatchMode, subgraph, identifier, operands):
if invoke_subgraph_cache:
invoke_subgraph_cache.add_proxy_dispatch_entry(identifier, graph)
node_args = (graph, identifier, operands)
node_args = (graph, identifier, *operands)
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) # type: ignore[union-attr]
out_proxy = proxy_mode.tracer.create_proxy(
"call_function", invoke_subgraph, proxy_args, {}
)
example_out = invoke_subgraph(graph, identifier, operands)
example_out = invoke_subgraph(graph, identifier, *operands)
return track_tensor_tree(
example_out, out_proxy, constant=None, tracer=proxy_mode.tracer
)

View File

@ -7445,9 +7445,9 @@ class InvokeSubgraph(ExternKernel):
V.graph.register_operation(self)
@classmethod
def create(cls, subgraph: Subgraph, operands): # type: ignore[no-untyped-def]
def create(cls, subgraph: Subgraph, *operands): # type: ignore[no-untyped-def]
# TODO(anijain2305) - Support sym expr as operands in future.
fx_operands = V.graph.current_node.args[-1]
fx_operands = V.graph.current_node.args[2:]
fake_operands = [x.meta["val"] for x in fx_operands] # type: ignore[union-attr]
# Realize the inputs. Also intermediates can have different strides than

View File

@ -6904,8 +6904,8 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs):
@register_lowering(torch.ops.higher_order.invoke_subgraph, type_promotion_kind=None)
def invoke_subgraph(subgraph_fn: ir.Subgraph, identifier: str, operands):
result = ir.InvokeSubgraph.create(subgraph_fn, operands)
def invoke_subgraph(subgraph_fn: ir.Subgraph, identifier: str, *operands):
result = ir.InvokeSubgraph.create(subgraph_fn, *operands)
return list(map(TensorBox.create, result))