mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Move control flow export tests to new tracer (#163259)
Differential Revision: [D82732614](https://our.internmc.facebook.com/intern/diff/D82732614) Pull Request resolved: https://github.com/pytorch/pytorch/pull/163259 Approved by: https://github.com/avikchaudhuri ghstack dependencies: #163136, #163137, #163258
This commit is contained in:
committed by
PyTorch MergeBot
parent
cc0332563e
commit
f6537d9616
@ -5345,7 +5345,10 @@ class TestControlFlowTraced(TestCase):
|
||||
|
||||
def _check_export(self, fn, args, *, strict=False, dynamic_shapes=None):
|
||||
eg_out = fn(*args)
|
||||
ep = torch.export.export(fn, args, strict=strict, dynamic_shapes=dynamic_shapes)
|
||||
with torch._export.config.patch(use_new_tracer_experimental=True):
|
||||
ep = torch.export.export(
|
||||
fn, args, strict=strict, dynamic_shapes=dynamic_shapes
|
||||
)
|
||||
ep_out = ep.module()(*args)
|
||||
self.assertEqual(eg_out, ep_out)
|
||||
return ep
|
||||
@ -8402,14 +8405,14 @@ class GraphModule(torch.nn.Module):
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
x: "f32[s77, 3]";
|
||||
x: "f32[s6, 3]";
|
||||
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
_guards_fn = self._guards_fn(x); _guards_fn = None
|
||||
|
||||
sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x, 0)
|
||||
sym_size_int_1: "Sym(s6)" = torch.ops.aten.sym_size.int(x, 0)
|
||||
|
||||
sin: "f32[s77, 3]" = torch.ops.aten.sin.default(x); x = None
|
||||
sin: "f32[s6, 3]" = torch.ops.aten.sin.default(x); x = None
|
||||
|
||||
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
||||
while_loop_body_graph_0 = self.while_loop_body_graph_0
|
||||
@ -8421,35 +8424,35 @@ class GraphModule(torch.nn.Module):
|
||||
getitem_9: "Sym(u13)" = while_loop[3]
|
||||
getitem_10: "Sym(u14)" = while_loop[4]
|
||||
|
||||
getitem_5: "f32[s77, 3]" = while_loop[5]; while_loop = None
|
||||
getitem_5: "f32[s6, 3]" = while_loop[5]; while_loop = None
|
||||
|
||||
add: "Sym(u12 + 1)" = getitem_8 + 1
|
||||
add_1: "Sym(u13 + 1)" = getitem_9 + 1
|
||||
add_2: "Sym(u14 + 1)" = getitem_10 + 1
|
||||
|
||||
add_3: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_8); getitem_8 = None
|
||||
add_4: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_9); getitem_9 = None
|
||||
add_5: "f32[s77, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_10); getitem_10 = None
|
||||
add_3: "f32[s6, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_8); getitem_8 = None
|
||||
add_4: "f32[s6, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_9); getitem_9 = None
|
||||
add_5: "f32[s6, 3]" = torch.ops.aten.add.Tensor(getitem_5, getitem_10); getitem_10 = None
|
||||
return pytree.tree_unflatten((getitem_6, getitem_7, add, add_1, add_2, add_3, add_4, add_5, getitem_5), self._out_spec)
|
||||
|
||||
class while_loop_cond_graph_0(torch.nn.Module):
|
||||
def forward(self, arg0_1: "Sym(u20)", arg1_1: "Sym(u21)", arg2_1: "Sym(u22)", arg3_1: "Sym(u23)", arg4_1: "Sym(u24)", arg5_1: "f32[s77, 3]"):
|
||||
mul: "Sym(u22*u23)" = arg2_1 * arg3_1; arg2_1 = arg3_1 = None
|
||||
mul_1: "Sym(u22*u23*u24)" = mul * arg4_1; mul = arg4_1 = None
|
||||
mul_2: "Sym(u20*u21)" = arg0_1 * arg1_1; arg0_1 = arg1_1 = None
|
||||
lt: "Sym(u22*u23*u24 < u20*u21)" = mul_1 < mul_2; mul_1 = mul_2 = None
|
||||
def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", arg3_1: "Sym(u18)", arg4_1: "Sym(u19)", arg5_1: "f32[s6, 3]"):
|
||||
mul: "Sym(u17*u18)" = arg2_1 * arg3_1; arg2_1 = arg3_1 = None
|
||||
mul_1: "Sym(u17*u18*u19)" = mul * arg4_1; mul = arg4_1 = None
|
||||
mul_2: "Sym(u15*u16)" = arg0_1 * arg1_1; arg0_1 = arg1_1 = None
|
||||
lt: "Sym(u17*u18*u19 < u15*u16)" = mul_1 < mul_2; mul_1 = mul_2 = None
|
||||
return lt
|
||||
|
||||
class while_loop_body_graph_0(torch.nn.Module):
|
||||
def forward(self, arg0_1: "Sym(u20)", arg1_1: "Sym(u21)", arg2_1: "Sym(u22)", arg3_1: "Sym(u23)", arg4_1: "Sym(u24)", arg5_1: "f32[s77, 3]"):
|
||||
add: "Sym(u20 + 1)" = arg0_1 + 1; arg0_1 = None
|
||||
add_1: "Sym(u21 + 1)" = arg1_1 + 1; arg1_1 = None
|
||||
def forward(self, arg0_1: "Sym(u15)", arg1_1: "Sym(u16)", arg2_1: "Sym(u17)", arg3_1: "Sym(u18)", arg4_1: "Sym(u19)", arg5_1: "f32[s6, 3]"):
|
||||
add: "Sym(u15 + 1)" = arg0_1 + 1; arg0_1 = None
|
||||
add_1: "Sym(u16 + 1)" = arg1_1 + 1; arg1_1 = None
|
||||
|
||||
add_2: "Sym(u22 + 1)" = arg2_1 + 1; arg2_1 = None
|
||||
add_3: "Sym(u23 + 1)" = arg3_1 + 1; arg3_1 = None
|
||||
add_4: "Sym(u24 + 1)" = arg4_1 + 1; arg4_1 = None
|
||||
add_2: "Sym(u17 + 1)" = arg2_1 + 1; arg2_1 = None
|
||||
add_3: "Sym(u18 + 1)" = arg3_1 + 1; arg3_1 = None
|
||||
add_4: "Sym(u19 + 1)" = arg4_1 + 1; arg4_1 = None
|
||||
|
||||
add_5: "f32[s77, 3]" = torch.ops.aten.add.Tensor(arg5_1, 1); arg5_1 = None
|
||||
add_5: "f32[s6, 3]" = torch.ops.aten.add.Tensor(arg5_1, 1); arg5_1 = None
|
||||
return (add, add_1, add_2, add_3, add_4, add_5)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
Reference in New Issue
Block a user