Move control flow export tests to new tracer (#163259)

Differential Revision: [D82732614](https://our.internmc.facebook.com/intern/diff/D82732614)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163259
Approved by: https://github.com/avikchaudhuri
ghstack dependencies: #163136, #163137, #163258
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2025-09-25 20:59:26 -07:00
committed by PyTorch MergeBot
parent cc0332563e
commit f6537d9616

View File

@ -5345,7 +5345,10 @@ class TestControlFlowTraced(TestCase):
def _check_export(self, fn, args, *, strict=False, dynamic_shapes=None):
eg_out = fn(*args)
ep = torch.export.export(fn, args, strict=strict, dynamic_shapes=dynamic_shapes)
with torch._export.config.patch(use_new_tracer_experimental=True):
ep = torch.export.export(
fn, args, strict=strict, dynamic_shapes=dynamic_shapes
)
ep_out = ep.module()(*args)
self.assertEqual(eg_out, ep_out)
return ep
@ -8402,14 +8405,14 @@ class GraphModule(torch.nn.Module):
"""\
class GraphModule(torch.nn.Module):
def forward(self, x):
x: "f32[s77, 3]";
x: "f32[s6, 3]";
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
_guards_fn = self._guards_fn(x); _guards_fn = None
sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x, 0)
sym_size_int_1: "Sym(s6)" = torch.ops.aten.sym_size.int(x, 0)
sin: "f32[s77, 3]" = torch.ops.aten.sin.default(x); x = None
sin: "f32[s6, 3]" = torch.ops.aten.sin.default(x); x = None
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
while_loop_body_graph_0 = self.while_loop_body_graph_0
@ -8421,35 +8424,35 @@ class GraphModule(torch.nn.Module):
getitem_9: "Sym(u13)" = while_loop[3]
getitem_10: "Sym(u14)" = while_loop[4]
getitem_5: "f32[s77, 3]" = while_loop[5]; while_loop = None
getitem_5: "f32[s6, 3]" = while_loop[5]; while_loop = None
add: "Sym(u12 + 1)" = getitem_8 + 1
add_1: "Sym(u13 + 1)" = getitem_9 + 1
add_2: "Sym(u14 + 1)" = getitem_10 + 1
add_3: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_8); getitem_8 = None
add_4: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_9); getitem_9 = None
add_5: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_10); getitem_10 = None
add_3: "f32[s6, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_8); getitem_8 = None
add_4: "f32[s6, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_9); getitem_9 = None
add_5: "f32[s6, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_10); getitem_10 = None
return pytree.tree_unflatten((getitem_6, getitem_7, add, add_1, add_2, add_3, add_4, add_5, getitem_5), self._out_spec)
class while_loop_cond_graph_0(torch.nn.Module):
def forward(self, arg0_1: "Sym(u20)", arg1_1: "Sym(u21)", arg2_1: "Sym(u22)", arg3_1: "Sym(u23)", arg4_1: "Sym(u24)", arg5_1: "f32[s77, 3]"):
mul: "Sym(u22*u23)" = arg2_1 * arg3_1; arg2_1 = arg3_1 = None
mul_1: "Sym(u22*u23*u24)" = mul * arg4_1; mul = arg4_1 = None
mul_2: "Sym(u20*u21)" = arg0_1 * arg1_1; arg0_1 = arg1_1 = None
lt: "Sym(u22*u23*u24 < u20*u21)" = mul_1 < mul_2; mul_1 = mul_2 = None
def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", arg3_1: "Sym(u18)", arg4_1: "Sym(u19)", arg5_1: "f32[s6, 3]"):
mul: "Sym(u17*u18)" = arg2_1 * arg3_1; arg2_1 = arg3_1 = None
mul_1: "Sym(u17*u18*u19)" = mul * arg4_1; mul = arg4_1 = None
mul_2: "Sym(u15*u16)" = arg0_1 * arg1_1; arg0_1 = arg1_1 = None
lt: "Sym(u17*u18*u19 < u15*u16)" = mul_1 < mul_2; mul_1 = mul_2 = None
return lt
class while_loop_body_graph_0(torch.nn.Module):
def forward(self, arg0_1: "Sym(u20)", arg1_1: "Sym(u21)", arg2_1: "Sym(u22)", arg3_1: "Sym(u23)", arg4_1: "Sym(u24)", arg5_1: "f32[s77, 3]"):
add: "Sym(u20 + 1)" = arg0_1 + 1; arg0_1 = None
add_1: "Sym(u21 + 1)" = arg1_1 + 1; arg1_1 = None
def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", arg3_1: "Sym(u18)", arg4_1: "Sym(u19)", arg5_1: "f32[s6, 3]"):
add: "Sym(u15 + 1)" = arg0_1 + 1; arg0_1 = None
add_1: "Sym(u16 + 1)" = arg1_1 + 1; arg1_1 = None
add_2: "Sym(u22 + 1)" = arg2_1 + 1; arg2_1 = None
add_3: "Sym(u23 + 1)" = arg3_1 + 1; arg3_1 = None
add_4: "Sym(u24 + 1)" = arg4_1 + 1; arg4_1 = None
add_2: "Sym(u17 + 1)" = arg2_1 + 1; arg2_1 = None
add_3: "Sym(u18 + 1)" = arg3_1 + 1; arg3_1 = None
add_4: "Sym(u19 + 1)" = arg4_1 + 1; arg4_1 = None
add_5: "f32[s77, 3]" = torch.ops.aten.add.Tensor(arg5_1, 1); arg5_1 = None
add_5: "f32[s6, 3]" = torch.ops.aten.add.Tensor(arg5_1, 1); arg5_1 = None
return (add, add_1, add_2, add_3, add_4, add_5)
""", # noqa: B950
)