From e50dc40d28ba409930023c77a031ec0dd20fd73b Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Fri, 17 Oct 2025 22:35:50 +0000 Subject: [PATCH] Revert "Update gm.print_readable to include Annotation (#165397)" This reverts commit 7a657700131f31577544e93587eb339618677e97. Reverted https://github.com/pytorch/pytorch/pull/165397 on behalf of https://github.com/malfet due to I don't know how/why, but it breaks windows tests, see https://hud.pytorch.org/hud/pytorch/pytorch/2e22b1a61ea20a54448edf34a5d22fbe8391d626/1?per_page=50&name_filter=win&mergeEphemeralLF=true ([comment](https://github.com/pytorch/pytorch/pull/165397#issuecomment-3417428128)) --- test/dynamo/test_higher_order_ops.py | 30 +++++++++++++++++ test/dynamo/test_subclasses.py | 1 + test/export/test_export.py | 2 -- test/functorch/test_control_flow.py | 5 +++ test/higher_order_ops/test_invoke_subgraph.py | 22 ++++++------- test/inductor/test_compiled_autograd.py | 1 + torch/fx/graph.py | 32 +++++++++---------- 7 files changed, 63 insertions(+), 30 deletions(-) diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 693c90a10b3a..8b71fe398263 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -3802,6 +3802,7 @@ class GraphModule(torch.nn.Module): dual: "f32[4, 3, 4, 3]" = _unpack_dual[1]; _unpack_dual = None primals_out_unflatten: "f32[4, 3, 4, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None + tangents_out_unflatten: "f32[4, 3, 4, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None @@ -3932,6 +3933,7 @@ class GraphModule(torch.nn.Module): tangent: "f32[4, 3, 3, 4]" = torch.zeros_like(primal) child_8: "f32[4, 3, 3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = child_8 = None + child_9: "f32[4, 3, 3, 4]" = torch._C._functorch._unwrap_for_grad(tangent, 2); tangent = None _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None @@ -4144,6 +4146,7 @@ class GraphModule(torch.nn.Module): primals_out: "f32[3, 4]" = diff_primals.sin() aux_1: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None + results: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primals_out, 1) _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None @@ -4378,6 +4381,7 @@ class GraphModule(torch.nn.Module): primals_out: "f32[]" = sin.sum(); sin = None aux: "f32[5]" = torch._C._functorch._unwrap_for_grad(child, 1); child = aux = None + results: "f32[]" = torch._C._functorch._unwrap_for_grad(primals_out, 1); primals_out = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None @@ -4567,6 +4571,7 @@ class GraphModule(torch.nn.Module): grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None + output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None @@ -4634,6 +4639,7 @@ class GraphModule(torch.nn.Module): grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None + output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None @@ -4690,6 +4696,7 @@ class GraphModule(torch.nn.Module): grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None + output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None @@ -4746,6 +4753,7 @@ class GraphModule(torch.nn.Module): grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None + output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None @@ -4800,7 +4808,9 @@ class GraphModule(torch.nn.Module): grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None + output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None + aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None @@ -4856,7 +4866,9 @@ class GraphModule(torch.nn.Module): grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None + output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None + aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None @@ -4930,7 +4942,9 @@ class GraphModule(torch.nn.Module): _unwrap_for_grad: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_2, 1); child_2 = None _unwrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_3, 1); child_3 = None + output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None + aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None @@ -4974,7 +4988,9 @@ class GraphModule(torch.nn.Module): _unwrap_for_grad: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_2, 1); child_2 = None _unwrap_for_grad_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(child_3, 1); child_3 = None + output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None + aux_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(aux, 1); aux = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None @@ -5034,6 +5050,7 @@ class GraphModule(torch.nn.Module): grad_input: "f32[]" = _autograd_grad[0]; _autograd_grad = None grad_input_1: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input, 2); grad_input = None + output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 2); output = output_1 = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None @@ -5043,6 +5060,7 @@ class GraphModule(torch.nn.Module): grad_input_2: "f32[]" = _autograd_grad_1[0]; _autograd_grad_1 = None grad_input_3: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input_2, 1); grad_input_2 = None + output_2: "f32[]" = torch._C._functorch._unwrap_for_grad(grad_input_1, 1); grad_input_1 = output_2 = None _grad_decrement_nesting_1 = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting_1 = None @@ -5148,6 +5166,7 @@ class GraphModule(torch.nn.Module): grad_input: "f32[3, 3, 3]" = _autograd_grad[0]; _autograd_grad = None grad_input_1: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(grad_input, 1); grad_input = None + output_1: "f32[]" = torch._C._functorch._unwrap_for_grad(output, 1); output = output_1 = None _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None @@ -5226,6 +5245,7 @@ class GraphModule(torch.nn.Module): dual: "f32[4, 3]" = _unpack_dual[1]; _unpack_dual = None primals_out_unflatten: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None + tangents_out_unflatten: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None @@ -5307,6 +5327,7 @@ class GraphModule(torch.nn.Module): dual: "f32[3, 4]" = _unpack_dual[1]; _unpack_dual = None primals_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None + tangents_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None @@ -5390,6 +5411,7 @@ class GraphModule(torch.nn.Module): dual: "f32[3, 4]" = _unpack_dual[1]; _unpack_dual = None primals_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = primals_out_unflatten = None + tangents_out_unflatten: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None @@ -5480,6 +5502,7 @@ class GraphModule(torch.nn.Module): child_4: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = child_4 = None child_5: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(primal_1, 2); primal_1 = child_5 = None + child_6: "f32[3, 4]" = torch._C._functorch._unwrap_for_grad(tangent, 2); tangent = None child_7: "f32[4, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None @@ -5549,6 +5572,7 @@ class GraphModule(torch.nn.Module): dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None + tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None @@ -5602,6 +5626,7 @@ class GraphModule(torch.nn.Module): dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None + tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None @@ -5663,6 +5688,7 @@ class GraphModule(torch.nn.Module): dual: "f32[3, 3]" = _unpack_dual[1]; _unpack_dual = None primals_out_unflatten: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None + tangents_out_unflatten: "f32[3, 3]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None @@ -5716,6 +5742,7 @@ class GraphModule(torch.nn.Module): dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None + tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None @@ -5783,6 +5810,7 @@ class GraphModule(torch.nn.Module): dual: "f32[]" = _unpack_dual[1]; _unpack_dual = None primals_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(primal, 1); primal = None + tangents_out_unflatten: "f32[]" = torch._C._functorch._unwrap_for_grad(dual, 1); dual = None _exit_dual_level = torch._C._exit_dual_level(0); _exit_dual_level = None @@ -5859,6 +5887,7 @@ class GraphModule(torch.nn.Module): dual: "f32[3, 3, 3]" = _unpack_dual[1]; _unpack_dual = None primals_out_unflatten: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(primal, 2); primal = None + tangents_out_unflatten: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(dual, 2); dual = None _set_fwd_grad_enabled_2 = torch._C._set_fwd_grad_enabled(True); _set_fwd_grad_enabled_2 = None @@ -5873,6 +5902,7 @@ class GraphModule(torch.nn.Module): _unwrap_for_grad_2: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(primal_1, 1); primal_1 = None _unwrap_for_grad_3: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(primal_2, 1); primal_2 = None + _unwrap_for_grad_4: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(dual_1, 1); dual_1 = None _unwrap_for_grad_5: "f32[3, 3, 3]" = torch._C._functorch._unwrap_for_grad(dual_2, 1); dual_2 = None diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 39a0dc628bae..c590abe63788 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -3166,6 +3166,7 @@ class GraphModule(torch.nn.Module): ): slice_1: "f64[s64, s55]" = torch.ops.aten.slice.Tensor(tangents_1, 1, 0, primals_10) slice_2: "f64[s64, s55]" = torch.ops.aten.slice.Tensor(tangents_1, 1, primals_10, add_2); tangents_1 = add_2 = None + add_4: "f64[s64, s55]" = torch.ops.aten.add.Tensor(slice_1, slice_2); slice_1 = slice_2 = None return ( None, # None diff --git a/test/export/test_export.py b/test/export/test_export.py index 2842723ea25b..23a7ad9bff1e 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -16061,7 +16061,6 @@ class GraphModule(torch.nn.Module): add: "f32[2, 4]" = torch.ops.aten.add.Tensor(relu, arg1_1); relu = arg1_1 = None return (add,) """, - ignore_empty_lines=True, ) ep = export(M(), (x, y), strict=strict).run_decompositions({}) @@ -16094,7 +16093,6 @@ class GraphModule(torch.nn.Module): add: "f32[2, 4]" = torch.ops.aten.add.Tensor(relu, arg1_1); relu = arg1_1 = None return (add,) """, - ignore_empty_lines=True, ) @testing.expectedFailureStrict # test_hop doesn't have a dynamo implementation diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index cac6ae1ba36a..e47aaa9e9e2b 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -8104,6 +8104,7 @@ class GraphModule(torch.nn.Module): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) _guards_fn = self._guards_fn(x); _guards_fn = None + sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x, 0) while_loop_cond_graph_0 = self.while_loop_cond_graph_0 @@ -8403,6 +8404,7 @@ class GraphModule(torch.nn.Module): x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) _guards_fn = self._guards_fn(x); _guards_fn = None + sym_size_int_1: "Sym(s6)" = torch.ops.aten.sym_size.int(x, 0) sin: "f32[s6, 3]" = torch.ops.aten.sin.default(x); x = None @@ -8689,8 +8691,10 @@ class GraphModule(torch.nn.Module): t_4: "f32[3, 3]" = torch.ops.aten.t.default(t_3); t_3 = None mul_4: "f32[3, 3]" = torch.ops.aten.mul.Tensor(arg1_1, select) mul_5: "f32[3, 3]" = torch.ops.aten.mul.Tensor(arg1_1, select); arg1_1 = select = None + add_7: "f32[3, 3]" = torch.ops.aten.add.Tensor(mm, mul_5); mm = mul_5 = None add_8: "f32[3, 3]" = torch.ops.aten.add.Tensor(add_7, mul_4); add_7 = mul_4 = None + add_9: "i64[]" = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None add_10: "f32[3]" = torch.ops.aten.add.Tensor(view, arg2_1); view = arg2_1 = None add_11: "f32[3, 3]" = torch.ops.aten.add.Tensor(t_4, arg3_1); t_4 = arg3_1 = None @@ -8905,6 +8909,7 @@ class GraphModule(torch.nn.Module): x, y, z, = fx_pytree.tree_flatten_spec(([x, y, z], {}), self._in_spec) _guards_fn = self._guards_fn(x, y, z); _guards_fn = None + sym_size_int_4: "Sym(s17)" = torch.ops.aten.sym_size.int(y, 0); y = None sym_size_int_5: "Sym(s68)" = torch.ops.aten.sym_size.int(z, 0) diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py index 700751942ba1..ffbefe5cd9b4 100644 --- a/test/higher_order_ops/test_invoke_subgraph.py +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -17,7 +17,6 @@ from functorch.compile import aot_function, nop from torch._dynamo.testing import ( AotEagerAndRecordGraphs, EagerAndRecordGraphs, - empty_line_normalizer, InductorAndRecordGraphs, normalize_gm, ) @@ -352,8 +351,10 @@ class GraphModule(torch.nn.Module): getitem_14: "f32[8]" = invoke_subgraph_6[2] getitem_13: "f32[8]" = invoke_subgraph_6[1] getitem_1: "f32[8]" = invoke_subgraph_6[0]; invoke_subgraph_6 = None + add: "f32[8]" = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None return (add, getitem_12, getitem_11, getitem_10, getitem_15, getitem_14, getitem_13) + class partitioned_fw_subgraph_0_0(torch.nn.Module): def forward(self, primals_0: "f32[8]", primals_1: "f32[8]", primals_2: "f32[8]"): mul: "f32[8]" = torch.ops.aten.mul.Tensor(primals_0, primals_1) @@ -362,7 +363,6 @@ class GraphModule(torch.nn.Module): mul_2: "f32[8]" = torch.ops.aten.mul.Tensor(mul_1, primals_2); mul_1 = None return (mul_2, primals_0, primals_1, primals_2) """, - ignore_empty_lines=True, ) self.assertExpectedInline( normalize_gm(backend.bw_graphs[0].print_readable(print_output=False)), @@ -377,6 +377,7 @@ class GraphModule(torch.nn.Module): invoke_subgraph_5 = torch.ops.higher_order.invoke_subgraph(partitioned_bw_subgraph_0_0, 'partitioned_bw_subgraph_0_0', getitem_10, getitem_11, getitem_12, tangents_1); partitioned_bw_subgraph_0_0 = getitem_10 = getitem_11 = getitem_12 = tangents_1 = None getitem_6: "f32[8]" = invoke_subgraph_5[0] getitem_7: "f32[8]" = invoke_subgraph_5[1]; invoke_subgraph_5 = None + add_1: "f32[8]" = torch.ops.aten.add.Tensor(getitem_2, getitem_6); getitem_2 = getitem_6 = None add_2: "f32[8]" = torch.ops.aten.add.Tensor(getitem_3, getitem_7); getitem_3 = getitem_7 = None return (add_1, add_2, None) @@ -392,7 +393,6 @@ class GraphModule(torch.nn.Module): mul_7: "f32[8]" = torch.ops.aten.mul.Tensor(mul_5, primals_1); mul_5 = primals_1 = None return (mul_7, mul_6, None) """, - ignore_empty_lines=True, ) def test_buffer_mutation_works_under_no_grad(self): @@ -681,7 +681,6 @@ class GraphModule(torch.nn.Module): sin: "f32[8]" = torch.ops.aten.sin.default(primals_0) return (sin, primals_0) """, - ignore_empty_lines=True, ) @inductor_config.patch("fx_graph_cache", False) @@ -723,7 +722,6 @@ class (torch.nn.Module): mul_1: "f32[8]" = torch.ops.aten.mul.Tensor(mul, 2.0); mul = None return (mul_1,) """, - ignore_empty_lines=True, ) def test_dedupe(self): @@ -772,6 +770,7 @@ class GraphModule(torch.nn.Module): subgraph_0 = self.subgraph_0 invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = None a: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None + subgraph_1 = self.subgraph_0 invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_0', a, l_y_); subgraph_1 = a = l_y_ = None getitem_1: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None @@ -807,7 +806,6 @@ class GraphModule(torch.nn.Module): mul: "f32[8]" = torch.ops.aten.mul.Tensor(primals_0, primals_1) return (mul, primals_0, primals_1) """, - ignore_empty_lines=True, ) def test_dce(self): @@ -891,6 +889,7 @@ class GraphModule(torch.nn.Module): subgraph_0 = self.subgraph_0 invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = None a: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None + subgraph_1 = self.subgraph_1 invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', a, l_y_); subgraph_1 = a = l_y_ = None getitem_1: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None @@ -1536,6 +1535,7 @@ class GraphModule(torch.nn.Module): def forward(self, tangents_0: "f32[8, 8]", tangents_1: "f32[8, 8]"): mul_2: "f32[8, 8]" = torch.ops.aten.mul.Tensor(tangents_1, 3) mul_3: "f32[8, 8]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None + add: "f32[8, 8]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None return (add,) """, @@ -2145,6 +2145,7 @@ class GraphModule(torch.nn.Module): subgraph_0 = self.subgraph_0 invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', x, y); subgraph_0 = x = None z: "f32[5]" = invoke_subgraph[0]; invoke_subgraph = None + subgraph_1 = self.subgraph_1 invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', z, y); subgraph_1 = z = y = None getitem_1: "f32[5]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None @@ -2282,7 +2283,6 @@ class GraphModule(torch.nn.Module): cos: "f32[s77, 16]" = torch.ops.aten.cos.default(primals_1) return (cos, primals_1, primals_0) """, - ignore_empty_lines=True, ) self.assertExpectedInline( normalize_gm(backend.bw_graphs[0].print_readable(print_output=False)), @@ -2294,6 +2294,7 @@ class GraphModule(torch.nn.Module): partitioned_bw_subgraph_0_0 = self.partitioned_bw_subgraph_0_0 invoke_subgraph_15 = torch.ops.higher_order.invoke_subgraph(partitioned_bw_subgraph_0_0, 'partitioned_bw_subgraph_0_0', getitem_23, getitem_22, expand); partitioned_bw_subgraph_0_0 = getitem_23 = getitem_22 = None getitem_5: "f32[s77, 16]" = invoke_subgraph_15[1]; invoke_subgraph_15 = None + add_16: "f32[s77, 16]" = torch.ops.aten.add.Tensor(expand, getitem_5); expand = getitem_5 = None partitioned_bw_subgraph_0_3 = self.partitioned_bw_subgraph_0_1 @@ -2325,7 +2326,6 @@ class GraphModule(torch.nn.Module): mul_10: "f32[s77, 16]" = torch.ops.aten.mul.Tensor(tangents_0, neg); tangents_0 = neg = None return (None, mul_10) """, - ignore_empty_lines=True, ) def test_div(self): @@ -2535,19 +2535,19 @@ class TestInvokeSubgraphExport(TestCase): self.assertEqual(len(list(ep.graph_module.named_modules())), 2) self.assertExpectedInline( - empty_line_normalizer( - normalize_gm(ep.graph_module.print_readable(print_output=False)) - ), + normalize_gm(ep.graph_module.print_readable(print_output=False)), """\ class GraphModule(torch.nn.Module): def forward(self, x: "f32[8]", y: "f32[8]"): repeated_subgraph0 = self.repeated_subgraph0 invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', x, y); repeated_subgraph0 = x = None getitem: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None + repeated_subgraph0_1 = self.repeated_subgraph0 invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', getitem, y); repeated_subgraph0_1 = getitem = y = None getitem_1: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None return (getitem_1,) + class repeated_subgraph0(torch.nn.Module): def forward(self, arg0_1: "f32[8]", arg1_1: "f32[8]"): mul: "f32[8]" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index fee2b289db90..2612af01f6ff 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -3621,6 +3621,7 @@ class CompiledAutograd0(torch.nn.Module): aot0_mul_2 = torch.ops.aten.mul.Tensor(aot0_tangents_1, aot0_primals_1); aot0_tangents_1 = aot0_primals_1 = None aot0_mul_3 = torch.ops.aten.mul.Tensor(aot0_tangents_2, aot0_primals_2); aot0_tangents_2 = aot0_primals_2 = None + aot0_add_2 = torch.ops.aten.add.Tensor(aot0_mul_2, aot0_mul_2); aot0_mul_2 = None aot0_add_3 = torch.ops.aten.add.Tensor(aot0_mul_3, aot0_mul_3); aot0_mul_3 = None diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 7577b6bc6148..940737e7e3a6 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -606,31 +606,29 @@ class CodeGen: else: body.append("\n") - prev_summary_str = None + prev_stacktrace = None def append_stacktrace_summary(node: Node): """ Append a summary of the stacktrace to the generated code. This is useful for debugging. """ - nonlocal prev_summary_str + nonlocal prev_stacktrace if node.op not in {"placeholder", "output"}: - annotation_str = "" - annotation = node.meta.get("custom", {}) - if annotation: - annotation_str = f" Annotation: {annotation}" - - stack_trace_str = "No stacktrace found for following nodes" - if stack_trace := node.stack_trace: - if parsed_stack_trace := _parse_stack_trace(stack_trace): - stack_trace_str = parsed_stack_trace.get_summary_str() - - summary_str = f"\n{dim(f'#{annotation_str} {stack_trace_str}')}\n" - - if summary_str != prev_summary_str: - prev_summary_str = summary_str - body.append(summary_str) + stack_trace = node.stack_trace + if stack_trace: + if stack_trace != prev_stacktrace: + prev_stacktrace = stack_trace + if parsed_stack_trace := _parse_stack_trace(stack_trace): + summary_str = parsed_stack_trace.get_summary_str() + else: + summary_str = "" + body.append(f"\n {dim(f'# {summary_str}')}\n") + elif prev_stacktrace != "": + prev_stacktrace = "" + no_stacktrace_msg = "# No stacktrace found for following nodes" + body.append(f"\n{dim(no_stacktrace_msg)}\n") def stringify_shape(shape: Iterable) -> str: return f"[{', '.join([str(x) for x in shape])}]"